use crate::{client::Request, field::Field}; use std::{ collections::HashMap, sync::{mpsc::Sender, Arc, RwLock}, }; use uuid::Uuid; #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum MsgType { ClientRequest, NoOp, SessionValidate, Session, } #[derive(Clone)] pub struct Message { id: Uuid, class: MsgType, data: HashMap, } impl Message { pub fn new(msg_type: MsgType) -> Self { Self { id: Uuid::new_v4(), class: msg_type, data: HashMap::new(), } } pub fn reply(&self, data: MsgType) -> Message { Self { id: self.id.clone(), class: data, data: HashMap::new(), } } pub fn get_class(&self) -> &MsgType { &self.class } pub fn add_data(&mut self, name: S, data: F) where S: Into, F: Into, { self.data.insert(name.into(), data.into()); } pub fn get_data(&self) -> &HashMap { &self.data } pub fn get_id(&self) -> Uuid { self.id.clone() } } impl From for Message { fn from(_value: Request) -> Self { let msg = Message::new(MsgType::ClientRequest); msg.reply(MsgType::ClientRequest) } } #[cfg(test)] mod messages { use super::*; #[test] fn new_message() { let msg = Message::new(MsgType::NoOp); match msg.class { MsgType::NoOp => (), _ => unreachable!("new defaults to noop"), } assert!(msg.data.is_empty()); } #[test] fn message_ids_are_random() { let mut ids: Vec = Vec::new(); for _ in 0..10 { let msg = Message::new(MsgType::NoOp); let id = msg.id.clone(); assert!(!ids.contains(&id), "{} is a duplicate", id); ids.push(id); } } #[test] fn create_reply() { let id = Uuid::new_v4(); let mut msg = Message::new(MsgType::NoOp); msg.id = id.clone(); msg.add_data("test", "test"); let data = MsgType::ClientRequest; let result = msg.reply(data); assert_eq!(result.id, id); match result.class { MsgType::ClientRequest => {} _ => unreachable!("should have been a registration request"), } assert!(result.data.is_empty()); } #[test] fn get_message_type() { let msg = Message::new(MsgType::NoOp); match msg.get_class() { MsgType::NoOp => {} _ => unreachable!("should have bneen noopn"), } } #[test] fn add_data() { let mut msg = Message::new(MsgType::NoOp); let one = "one"; let two = "two".to_string(); msg.add_data(one, one); msg.add_data(two.clone(), two.clone()); let result = msg.get_data(); assert_eq!(result.get(one).unwrap().to_string(), one); assert_eq!(result.get(&two).unwrap().to_string(), two); } #[test] fn get_message_id() { let msg = Message::new(MsgType::Session); assert_eq!(msg.get_id(), msg.id); } } #[derive(Clone)] pub struct Queue { store: Arc>>>>, } impl Queue { pub fn new() -> Self { Self { store: Arc::new(RwLock::new(HashMap::new())), } } pub fn add(&self, tx: Sender, msg_types: Vec) { let mut store = self.store.write().unwrap(); for msg_type in msg_types.into_iter() { if !store.contains_key(&msg_type) { store.insert(msg_type.clone(), Vec::new()); } let senders = store.get_mut(&msg_type).unwrap(); senders.push(tx.clone()); } } pub fn send(&self, msg: Message) { let store = self.store.read().unwrap(); match store.get(&msg.get_class()) { Some(senders) => { for sender in senders.into_iter() { sender.send(msg.clone()).unwrap(); } } None => {} } } } #[cfg(test)] mod queues { use super::*; use std::{ sync::mpsc::{channel, RecvTimeoutError}, time::Duration, }; static TIMEOUT: Duration = Duration::from_millis(500); #[test] fn create_queue() { let queue = Queue::new(); let (tx1, rx1) = channel(); let (tx2, rx2) = channel(); queue.add(tx1, [MsgType::NoOp].to_vec()); queue.add(tx2, [MsgType::NoOp].to_vec()); queue.send(Message::new(MsgType::NoOp)); rx1.recv().unwrap(); rx2.recv().unwrap(); } #[test] fn messages_are_routed() { let queue = Queue::new(); let (tx1, rx1) = channel(); let (tx2, rx2) = channel(); queue.add(tx1, [MsgType::SessionValidate].to_vec()); queue.add(tx2, [MsgType::Session].to_vec()); queue.send(Message::new(MsgType::SessionValidate)); let result = rx1.recv().unwrap(); match result.get_class() { MsgType::SessionValidate => {} _ => unreachable!( "received {:?}, should have been session vvalidate", result.get_class() ), } match rx2.recv_timeout(TIMEOUT) { Ok(_) => unreachable!("should not have received anything"), Err(err) => match err { RecvTimeoutError::Timeout => {} _ => unreachable!("{:?}", err), }, } queue.send(Message::new(MsgType::Session)); let result = rx2.recv().unwrap(); match result.get_class() { MsgType::Session => {} _ => unreachable!( "received {:?}, should have been session vvalidate", result.get_class() ), } match rx1.recv_timeout(TIMEOUT) { Ok(_) => unreachable!("should not have received anything"), Err(err) => match err { RecvTimeoutError::Timeout => {} _ => unreachable!("{:?}", err), }, } } #[test] fn assign_sender_multiple_message_types() { let queue = Queue::new(); let (tx, rx) = channel(); queue.add(tx, [MsgType::Session, MsgType::SessionValidate].to_vec()); queue.send(Message::new(MsgType::SessionValidate)); let msg = rx.recv().unwrap(); assert_eq!(msg.get_class(), &MsgType::SessionValidate); queue.send(Message::new(MsgType::Session)); let msg = rx.recv().unwrap(); assert_eq!(msg.get_class(), &MsgType::Session); } #[test] fn unassigned_message_should_not_panic() { let queue = Queue::new(); queue.send(Message::new(MsgType::Session)); } }