use crate::{mtterror::MTTError, message::{Action, DocRegistry, Message, RegMsg, Register}, name::NameType}; use uuid::Uuid; use std::{collections::HashMap, sync::{mpsc::{Sender, channel}, Arc, RwLock}}; #[derive(Clone, Debug, Eq, Hash)] pub enum Include { All, Just(T), } impl PartialEq for Include { fn eq(&self, other: &Self) -> bool { match self { Include::All => true, Include::Just(data) => match other { Include::All => true, Include::Just(other_data) => data == other_data, }, } } } #[cfg(test)] mod includes { use super::*; #[test] fn does_all_equal_evberything() { let a: Include = Include::All; let b: Include = Include::Just(5); let c: Include = Include::Just(7); assert!(a == a, "all should equal all"); assert!(a == b, "all should equal just"); assert!(b == a, "just should equal all"); assert!(b == b, "same just should equal"); assert!(b != c, "different justs do not equal"); } } #[derive(Clone, Debug)] pub struct Path { pub msg_id: Include, pub doc: Include, pub action: Include, } impl Path { pub fn new(id: Include, doc: Include, action: Include) -> Self { Self { msg_id: id, doc: doc, action: action, } } } struct Router { doc_registry: Sender, senders: HashMap>, } impl Router { fn new(tx: Sender) -> Self { Self { doc_registry: tx, senders: HashMap::new(), } } fn add_sender(&mut self, sender: Sender) -> Uuid { let mut id = Uuid::new_v4(); while self.senders.contains_key(&id) { id = Uuid::new_v4(); } self.senders.insert(id.clone(), sender); id } fn remove_sender(&mut self, id: &Uuid) { let action = Register::new(Uuid::nil(), RegMsg::RemoveSender(id.clone())); self.doc_registry .send(Message::new(NameType::None, action)) .unwrap(); self.senders.remove(id); } fn forward(&self, id: &Uuid, msg: Message) { if id == &Uuid::nil() { return; } self.senders.get(id).unwrap().send(msg).unwrap(); } fn send(&self, msg: Message) { self.doc_registry.send(msg).unwrap(); } } #[derive(Clone)] pub struct Queue { router: Arc>, } impl Queue { pub fn new() -> Self { let (tx, rx) = channel(); let output = Self { router: Arc::new(RwLock::new(Router::new(tx))), }; DocRegistry::start(output.clone(), rx); output } pub fn add_sender(&mut self, sender: Sender) -> Uuid { let mut router = self.router.write().unwrap(); router.add_sender(sender) } pub fn remove_sender(&mut self, id: &Uuid) { let mut router = self.router.write().unwrap(); router.remove_sender(id); } pub fn forward(&self, id: &Uuid, msg: Message) { let router = self.router.read().unwrap(); router.forward(id, msg); } pub fn send(&self, msg: Message) -> Result<(), MTTError> { let router = self.router.read().unwrap(); router.send(msg.clone()); Ok(()) } } #[cfg(test)] mod routers { use crate::{message::{MsgAction, Query}, name::Name, support_tests::TIMEOUT}; use std::collections::HashSet; use super::*; #[test] fn can_pass_message() { let (tx, rx) = channel(); let router = Router::new(tx); let msg = Message::new(Name::english("task"), Query::new()); router.send(msg.clone()); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_message_id(), msg.get_message_id()); } #[test] fn can_forward_message() { let (tx, _) = channel(); let mut router = Router::new(tx); let (sender, receiver) = channel(); let id = router.add_sender(sender); let msg = Message::new(Name::english("wiki"), Query::new()); router.forward(&id, msg.clone()); let result = receiver.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_message_id(), msg.get_message_id()); } #[test] fn sender_ids_are_unique() { let (tx, _) = channel(); let mut router = Router::new(tx); let count = 10; let mut holder: HashSet = HashSet::new(); for _ in 0..count { let (tx, _) = channel(); holder.insert(router.add_sender(tx)); } assert_eq!(holder.len(), count, "had duplicate keys"); } #[test] fn can_remove_sender() { let (tx, rx) = channel(); let mut router = Router::new(tx); let (data, _) = channel(); let id = router.add_sender(data); assert_eq!(router.senders.len(), 1, "should have only one sender"); router.remove_sender(&id); assert_eq!(router.senders.len(), 0, "should have no senders."); let result = rx.recv_timeout(TIMEOUT).unwrap(); let action = result.get_action(); match action { MsgAction::Register(reg_msg) => { let reg_action = reg_msg.get_msg(); match reg_action { RegMsg::RemoveSender(result) => assert_eq!(result, &id), _ => unreachable!("got {:?}, should have been remove sender", reg_action), } } _ => unreachable!("got {:?}, should have been registry message", action), } } #[test] fn ignores_bad_id_removals() { let (tx, rx) = channel(); let mut router = Router::new(tx); router.remove_sender(&Uuid::new_v4()); assert_eq!(router.senders.len(), 0, "should have no senders."); rx.recv_timeout(TIMEOUT).unwrap(); } }