use crate::{ data_director::{DocRegistry, RegMsg, Register}, message::Message, mtterror::MTTError, name::NameType, }; use std::{ collections::HashMap, sync::{ mpsc::{channel, Sender}, Arc, RwLock, }, }; use uuid::Uuid; struct Router { doc_registry: Sender, senders: HashMap>, } impl Router { fn new(tx: Sender) -> Self { Self { doc_registry: tx, senders: HashMap::new(), } } fn add_sender(&mut self, sender: Sender) -> Uuid { let mut id = Uuid::new_v4(); while self.senders.contains_key(&id) { id = Uuid::new_v4(); } self.senders.insert(id.clone(), sender); id } fn remove_sender(&mut self, id: &Uuid) { let action = Register::new(Uuid::nil(), RegMsg::RemoveSender(id.clone())); self.doc_registry .send(Message::new(NameType::None, action)) .unwrap(); self.senders.remove(id); } fn forward(&self, id: &Uuid, msg: Message) { if id == &Uuid::nil() { return; } self.senders.get(id).unwrap().send(msg).unwrap(); } fn send(&self, msg: Message) { self.doc_registry.send(msg).unwrap(); } } #[derive(Clone)] pub struct Queue { router: Arc>, } impl Queue { pub fn new() -> Self { let (tx, rx) = channel(); let output = Self { router: Arc::new(RwLock::new(Router::new(tx))), }; DocRegistry::start(output.clone(), rx); output } pub fn add_sender(&mut self, sender: Sender) -> Uuid { let mut router = self.router.write().unwrap(); router.add_sender(sender) } pub fn remove_sender(&mut self, id: &Uuid) { let mut router = self.router.write().unwrap(); router.remove_sender(id); } pub fn forward(&self, id: &Uuid, msg: Message) { let router = self.router.read().unwrap(); router.forward(id, msg); } pub fn send(&self, msg: Message) -> Result<(), MTTError> { let router = self.router.read().unwrap(); router.send(msg.clone()); Ok(()) } } #[cfg(test)] mod routers { use super::*; use crate::{ message::{MsgAction, Query}, name::Name, support_tests::TIMEOUT, }; use std::collections::HashSet; #[test] fn can_pass_message() { let (tx, rx) = channel(); let router = Router::new(tx); let msg = Message::new(Name::english("task"), Query::new()); router.send(msg.clone()); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_message_id(), msg.get_message_id()); } #[test] fn can_forward_message() { let (tx, _) = channel(); let mut router = Router::new(tx); let (sender, receiver) = channel(); let id = router.add_sender(sender); let msg = Message::new(Name::english("wiki"), Query::new()); router.forward(&id, msg.clone()); let result = receiver.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_message_id(), msg.get_message_id()); } #[test] fn sender_ids_are_unique() { let (tx, _) = channel(); let mut router = Router::new(tx); let count = 10; let mut holder: HashSet = HashSet::new(); for _ in 0..count { let (tx, _) = channel(); holder.insert(router.add_sender(tx)); } assert_eq!(holder.len(), count, "had duplicate keys"); } #[test] fn can_remove_sender() { let (tx, rx) = channel(); let mut router = Router::new(tx); let (data, _) = channel(); let id = router.add_sender(data); assert_eq!(router.senders.len(), 1, "should have only one sender"); router.remove_sender(&id); assert_eq!(router.senders.len(), 0, "should have no senders."); let result = rx.recv_timeout(TIMEOUT).unwrap(); let action = result.get_action(); match action { MsgAction::Register(reg_msg) => { let reg_action = reg_msg.get_msg(); match reg_action { RegMsg::RemoveSender(result) => assert_eq!(result, &id), _ => unreachable!("got {:?}, should have been remove sender", reg_action), } } _ => unreachable!("got {:?}, should have been registry message", action), } } #[test] fn ignores_bad_id_removals() { let (tx, rx) = channel(); let mut router = Router::new(tx); router.remove_sender(&Uuid::new_v4()); assert_eq!(router.senders.len(), 0, "should have no senders."); rx.recv_timeout(TIMEOUT).unwrap(); } }