use crate::{ data_director::{DocRegistry, RegMsg, Register}, message::Message, mtterror::MTTError, name::NameType, }; use std::{ collections::HashMap, sync::{ mpsc::{channel, Sender}, Arc, RwLock, }, }; use uuid::Uuid; struct Router { doc_registry: Sender, senders: HashMap>, } impl Router { fn new(tx: Sender) -> Self { Self { doc_registry: tx, senders: HashMap::new(), } } fn add_sender(&mut self, sender: Sender) -> Uuid { let mut id = Uuid::new_v4(); while self.senders.contains_key(&id) { id = Uuid::new_v4(); } self.senders.insert(id.clone(), sender); id } fn remove_sender(&mut self, id: &Uuid) { let action = Register::new(Uuid::nil(), RegMsg::RemoveSender(id.clone())); self.doc_registry .send(Message::new(NameType::None, action)) .unwrap(); self.senders.remove(id); } fn forward(&self, id: &Uuid, msg: Message) { if id == &Uuid::nil() { return; } match self.senders.get(id) { Some(sender) => sender.send(msg).unwrap(), None => {} } } fn send(&self, msg: Message) { self.doc_registry.send(msg).unwrap(); } } #[derive(Clone)] pub struct Queue { router: Arc>, } impl Queue { pub fn new() -> Self { let (tx, rx) = channel(); let output = Self { router: Arc::new(RwLock::new(Router::new(tx))), }; DocRegistry::start(output.clone(), rx); output } pub fn add_sender(&mut self, sender: Sender) -> Uuid { let mut router = self.router.write().unwrap(); router.add_sender(sender) } pub fn remove_sender(&mut self, id: &Uuid) { let mut router = self.router.write().unwrap(); router.remove_sender(id); } pub fn forward(&self, id: &Uuid, msg: Message) { let router = self.router.read().unwrap(); router.forward(id, msg); } pub fn send(&self, msg: Message) -> Result<(), MTTError> { let router = self.router.read().unwrap(); router.send(msg.clone()); Ok(()) } } #[cfg(test)] mod routers { use super::*; use crate::{ message::{MsgAction, Query}, name::Name, support_tests::TIMEOUT, }; use std::{ collections::HashSet, sync::mpsc::{Receiver, RecvTimeoutError}, }; struct Setup { test_mod: Router, rx: Receiver, } impl Setup { fn new() -> Self { let (tx, rx) = channel(); Self { test_mod: Router::new(tx), rx: rx, } } fn get_router(&self) -> &Router { &self.test_mod } fn get_router_mut(&mut self) -> &mut Router { &mut self.test_mod } fn recv(&self) -> Result { self.rx.recv_timeout(TIMEOUT) } } #[test] fn can_pass_message() { let setup = Setup::new(); let router = setup.get_router(); let msg = Message::new(Name::english("task"), Query::new()); router.send(msg.clone()); let result = setup.recv().unwrap(); assert_eq!(result.get_message_id(), msg.get_message_id()); } #[test] fn can_forward_message() { let mut setup = Setup::new(); let router = setup.get_router_mut(); let mut receivers: HashMap> = HashMap::new(); for _ in 0..10 { let (tx, rx) = channel(); let id = router.add_sender(tx); receivers.insert(id, rx); } for (id, recv) in receivers.iter() { let msg = Message::new(Name::english(id.to_string().as_str()), Query::new()); router.forward(id, msg.clone()); let result = recv.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_message_id(), msg.get_message_id()); } } #[test] fn sender_ids_are_unique() { let mut setup = Setup::new(); let router = setup.get_router_mut(); let count = 10; let mut holder: HashSet = HashSet::new(); for _ in 0..count { let (tx, _) = channel(); holder.insert(router.add_sender(tx)); } assert_eq!(holder.len(), count, "had duplicate keys"); } #[test] fn can_remove_sender() { let mut setup = Setup::new(); let router = setup.get_router_mut(); let mut receivers: HashMap> = HashMap::new(); for _ in 0..10 { let (tx, rx) = channel(); let id = router.add_sender(tx); receivers.insert(id, rx); } let removed = receivers.keys().last().unwrap().clone(); router.remove_sender(&removed); let router = setup.get_router(); let removed_recv = receivers.remove(&removed).unwrap(); router.forward(&removed, Message::new(NameType::None, Query::new())); match removed_recv.recv_timeout(TIMEOUT) { Err(err) => match err { RecvTimeoutError::Disconnected => {} _ => unreachable!("got {:?}, should have been disconnected", err), }, _ => unreachable!("should have returned an error"), } let announce = setup.recv().unwrap(); let action = announce.get_action(); match action { MsgAction::Register(data) => { let output = data.get_msg(); match output { RegMsg::RemoveSender(id) => assert_eq!(id, &removed), _ => unreachable!("got {:?} should have been sender removal", output), } } _ => unreachable!("got {:?}, should have been register", action), } for (id, recv) in receivers.iter() { let msg = Message::new(Name::english(id.to_string().as_str()), Query::new()); router.forward(id, msg.clone()); let result = recv.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_message_id(), msg.get_message_id()); } } #[test] fn ignores_bad_id_removals() { let mut setup = Setup::new(); let router = setup.get_router_mut(); let removed = Uuid::new_v4(); router.remove_sender(&removed); assert_eq!(router.senders.len(), 0, "should have no senders."); let announce = setup.recv().unwrap(); let action = announce.get_action(); match action { MsgAction::Register(data) => { let output = data.get_msg(); match output { RegMsg::RemoveSender(id) => assert_eq!(id, &removed), _ => unreachable!("got {:?} should have been sender removal", output), } } _ => unreachable!("got {:?}, should have been register", action), } } } #[cfg(test)] mod queues { use super::*; use crate::{ data_director::{Include, Path}, message::MsgAction, name::Name, support_tests::TIMEOUT, }; use std::sync::mpsc::{Receiver, RecvTimeoutError}; struct Setup { test_mod: Queue, rx: Receiver, rx_id: Uuid, } impl Setup { fn new() -> Self { let mut queue = Queue::new(); let (tx, rx) = channel(); let id = queue.add_sender(tx); Self { test_mod: queue, rx: rx, rx_id: id, } } fn send_reg_msg(&self, msg: RegMsg) { let reg_msg = Register::new(self.rx_id.clone(), msg); self.test_mod .send(Message::new(NameType::None, reg_msg)) .unwrap(); } fn recv(&self) -> Result { self.rx.recv_timeout(TIMEOUT) } } #[test] fn can_add_names_registry() { let setup = Setup::new(); let name = Name::english(Uuid::new_v4().to_string().as_str()); let reg = RegMsg::AddDocName(vec![name.clone()]); setup.send_reg_msg(reg); let result = setup.recv().unwrap(); let action = result.get_action(); match action { MsgAction::Register(data) => { let regmsg = data.get_msg(); match data.get_msg() { RegMsg::DocumentNameID(_) => {} _ => unreachable!("got {:?} should have been document id", regmsg), } } _ => unreachable!("got {:?} should have been register", action), } } #[test] fn returns_error_when_document_name_not_found() { let setup = Setup::new(); let name = Name::english(Uuid::new_v4().to_string().as_str()); let path = Path::new( Include::All, Include::Just(name.clone().into()), Include::All, ); let reg = RegMsg::AddRoute(path); setup.send_reg_msg(reg); let result = setup.recv().unwrap(); let action = result.get_action(); match action { MsgAction::Register(data) => { let regmsg = data.get_msg(); match data.get_msg() { RegMsg::Error(err) => match err { MTTError::NameNotFound(failed_name) => assert_eq!(failed_name, &name), _ => unreachable!("got {:?} should have been missing name", err), }, _ => unreachable!("got {:?} should have been error", regmsg), } } _ => unreachable!("got {:?} should have been register", action), } } }