use crate::{client::Request, field::Field}; use std::{ collections::HashMap, sync::{mpsc::Sender, Arc, RwLock}, }; use uuid::Uuid; #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum MsgType { ClientRequest, Document, DocumentRequest, SessionValidate, SessionValidated, Time, } #[derive(Clone)] pub struct Message { id: Uuid, msg_type: MsgType, data: HashMap, } impl Message { pub fn new(msg_type: MsgType) -> Self { Self { id: Uuid::new_v4(), msg_type: msg_type, data: HashMap::new(), } } pub fn reply(&self, data: MsgType) -> Message { Self { id: self.id.clone(), msg_type: data, data: HashMap::new(), } } pub fn reply_with_data(&self, msg_type: MsgType) -> Message { Self { id: self.id.clone(), msg_type: msg_type, data: self.data.clone(), } } pub fn get_msg_type(&self) -> &MsgType { &self.msg_type } pub fn add_data(&mut self, name: S, data: F) where S: Into, F: Into, { self.data.insert(name.into(), data.into()); } pub fn get_data(&self, name: S) -> Option<&Field> where S: Into, { let field_name = name.into(); self.data.get(&field_name) } pub fn get_id(&self) -> Uuid { self.id.clone() } } impl From for Message { fn from(value: Request) -> Self { let mut msg = Message::new(MsgType::ClientRequest); match value.session { Some(id) => msg.add_data("sess_id", id), None => {} } msg } } #[cfg(test)] mod messages { use super::*; use crate::client::requests::{get_root_document, get_root_document_eith_session}; #[test] fn new_message() { let msg = Message::new(MsgType::SessionValidate); match msg.msg_type { MsgType::SessionValidate => (), _ => unreachable!("new defaults to noop"), } assert!(msg.data.is_empty()); } #[test] fn message_ids_are_random() { let mut ids: Vec = Vec::new(); for _ in 0..10 { let msg = Message::new(MsgType::SessionValidate); let id = msg.id.clone(); assert!(!ids.contains(&id), "{} is a duplicate", id); ids.push(id); } } #[test] fn create_reply() { let id = Uuid::new_v4(); let mut msg = Message::new(MsgType::SessionValidate); msg.id = id.clone(); msg.add_data("test", "test"); let data = MsgType::ClientRequest; let result = msg.reply(data); assert_eq!(result.id, id); match result.msg_type { MsgType::ClientRequest => {} _ => unreachable!("should have been a registration request"), } assert!(result.data.is_empty()); } #[test] fn get_message_type() { let msg = Message::new(MsgType::SessionValidate); match msg.get_msg_type() { MsgType::SessionValidate => {} _ => unreachable!("should have bneen noopn"), } } #[test] fn add_data() { let mut msg = Message::new(MsgType::SessionValidate); let one = "one"; let two = "two".to_string(); msg.add_data(one, one); msg.add_data(two.clone(), two.clone()); assert_eq!(msg.get_data(one).unwrap().to_string(), one); assert_eq!(msg.get_data(&two).unwrap().to_string(), two); } #[test] fn get_data_into_string() { let id = Uuid::new_v4(); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data(id, id); assert_eq!(msg.get_data(id).unwrap().to_uuid().unwrap(), id); } #[test] fn copy_data_with_reply() { let id = Uuid::new_v4(); let reply_type = MsgType::SessionValidated; let mut msg = Message::new(MsgType::SessionValidate); msg.add_data(id, id); let reply = msg.reply_with_data(reply_type.clone()); assert_eq!(reply.id, msg.id); match reply.get_msg_type() { MsgType::SessionValidated => {} _ => unreachable!( "Got {:?} should have been {:?}", msg.get_msg_type(), reply_type ), } assert_eq!(reply.data.len(), msg.data.len()); let output = reply.get_data(&id.to_string()).unwrap().to_uuid().unwrap(); let expected = msg.get_data(&id.to_string()).unwrap().to_uuid().unwrap(); assert_eq!(output, expected); } #[test] fn get_message_id() { let msg = Message::new(MsgType::SessionValidated); assert_eq!(msg.get_id(), msg.id); } #[test] fn from_request_no_session() { let req = get_root_document(); let msg: Message = req.into(); assert!( msg.get_data("sess_id").is_none(), "should not have a session id" ) } #[test] fn from_request_with_session() { let id = Uuid::new_v4(); let req = get_root_document_eith_session(id.clone()); let msg: Message = req.into(); match msg.get_data("sess_id") { Some(result) => assert_eq!(result.to_uuid().unwrap(), id), None => unreachable!("should return an id"), } } } #[derive(Clone)] pub struct Queue { store: Arc>>>>, } impl Queue { pub fn new() -> Self { Self { store: Arc::new(RwLock::new(HashMap::new())), } } pub fn add(&self, tx: Sender, msg_types: Vec) { let mut store = self.store.write().unwrap(); for msg_type in msg_types.into_iter() { if !store.contains_key(&msg_type) { store.insert(msg_type.clone(), Vec::new()); } let senders = store.get_mut(&msg_type).unwrap(); senders.push(tx.clone()); } } pub fn send(&self, msg: Message) -> Result<(), String> { let store = self.store.read().unwrap(); match store.get(&msg.get_msg_type()) { Some(senders) => { for sender in senders.into_iter() { sender.send(msg.clone()).unwrap(); } Ok(()) } None => Err(format!("no listeners for {:?}", msg.get_msg_type())), } } } #[cfg(test)] mod queues { use super::*; use std::{ sync::mpsc::{channel, RecvTimeoutError}, time::Duration, }; static TIMEOUT: Duration = Duration::from_millis(500); #[test] fn create_queue() { let queue = Queue::new(); let (tx1, rx1) = channel(); let (tx2, rx2) = channel(); queue.add(tx1, [MsgType::SessionValidate].to_vec()); queue.add(tx2, [MsgType::SessionValidate].to_vec()); queue.send(Message::new(MsgType::SessionValidate)).unwrap(); rx1.recv().unwrap(); rx2.recv().unwrap(); } #[test] fn messages_are_routed() { let queue = Queue::new(); let (tx1, rx1) = channel(); let (tx2, rx2) = channel(); queue.add(tx1, [MsgType::SessionValidate].to_vec()); queue.add(tx2, [MsgType::SessionValidated].to_vec()); queue.send(Message::new(MsgType::SessionValidate)).unwrap(); let result = rx1.recv().unwrap(); match result.get_msg_type() { MsgType::SessionValidate => {} _ => unreachable!( "received {:?}, should have been session vvalidate", result.get_msg_type() ), } match rx2.recv_timeout(TIMEOUT) { Ok(_) => unreachable!("should not have received anything"), Err(err) => match err { RecvTimeoutError::Timeout => {} _ => unreachable!("{:?}", err), }, } queue.send(Message::new(MsgType::SessionValidated)).unwrap(); let result = rx2.recv().unwrap(); match result.get_msg_type() { MsgType::SessionValidated => {} _ => unreachable!( "received {:?}, should have been session vvalidate", result.get_msg_type() ), } match rx1.recv_timeout(TIMEOUT) { Ok(_) => unreachable!("should not have received anything"), Err(err) => match err { RecvTimeoutError::Timeout => {} _ => unreachable!("{:?}", err), }, } } #[test] fn assign_sender_multiple_message_types() { let queue = Queue::new(); let (tx, rx) = channel(); queue.add( tx, [MsgType::SessionValidated, MsgType::SessionValidate].to_vec(), ); queue.send(Message::new(MsgType::SessionValidate)).unwrap(); let msg = rx.recv().unwrap(); assert_eq!(msg.get_msg_type(), &MsgType::SessionValidate); queue.send(Message::new(MsgType::SessionValidated)).unwrap(); let msg = rx.recv().unwrap(); assert_eq!(msg.get_msg_type(), &MsgType::SessionValidated); } #[test] fn unassigned_message_should_return_error() { let queue = Queue::new(); match queue.send(Message::new(MsgType::SessionValidated)) { Ok(_) => unreachable!("should return error"), Err(_) => {} } } }