use crate::{ field::Field, queue::{Message, MsgType, Queue}, }; use chrono::prelude::*; use isolang::Language; use std::{ collections::HashMap, sync::mpsc::{channel, Receiver}, thread::spawn, time::Duration, }; use uuid::Uuid; const EXPIRE_IN: Duration = Duration::from_secs(60 * 60); const RESPONDS_TO: [MsgType; 3] = [MsgType::SessionGet, MsgType::SessionValidate, MsgType::Time]; const DEFAULT_LANG: Language = Language::Eng; struct SessionData { expire_on: DateTime, language: Language, } impl SessionData { fn new(lang: Option) -> Self { let session_lang = match lang { Some(data) => data.clone(), None => DEFAULT_LANG, }; Self { expire_on: Utc::now() + EXPIRE_IN, language: session_lang, } } fn extend(&mut self) { self.expire_on = Utc::now() + EXPIRE_IN; } fn is_expired(&self, now: &DateTime) -> bool { now > &self.expire_on } } #[cfg(test)] mod sessiondatas { use super::*; #[test] fn create_session_data() { let expire = Utc::now() + EXPIRE_IN; let data = SessionData::new(None); assert!( data.expire_on > expire, "{:?} should be greater than {:?}", data.expire_on, expire ); } #[test] fn extend_usage_time() { let mut data = SessionData::new(None); let expire = Utc::now() + EXPIRE_IN; data.extend(); assert!( data.expire_on > expire, "{:?} should be greater than {:?}", data.expire_on, expire ); } #[test] fn is_expired() { let data = SessionData::new(None); let expire = Utc::now() + EXPIRE_IN; assert!(data.is_expired(&expire), "should be expired"); } #[test] fn is_not_expired() { let expire = Utc::now() + EXPIRE_IN; let data = SessionData::new(None); assert!(!data.is_expired(&expire), "should be not expired"); } #[test] fn english_is_the_default_language() { let data = SessionData::new(None); assert_eq!(data.language, DEFAULT_LANG); } #[test] fn assign_language() { let langs = [Language::Jpn, Language::Deu]; for lang in langs.into_iter() { let data = SessionData::new(Some(lang.clone())); assert_eq!(data.language, lang); } } } pub struct Session { data: HashMap, queue: Queue, rx: Receiver, } impl Session { fn new(queue: Queue, rx: Receiver) -> Self { Self { data: HashMap::new(), queue: queue, rx: rx, } } pub fn start(queue: Queue) { let (tx, rx) = channel(); let mut session = Session::new(queue, rx); session.queue.add(tx, RESPONDS_TO.to_vec()); spawn(move || { session.listen(); }); } fn listen(&mut self) { loop { let msg = self.rx.recv().unwrap(); match msg.get_msg_type() { MsgType::SessionGet => self.get(msg), MsgType::SessionValidate => self.validate(msg), MsgType::Time => self.expire(msg), _ => unreachable!("received unknown message"), }; } } fn validate(&mut self, msg: Message) { match msg.get_data("sess_id") { Some(sid) => match sid { Field::Uuid(sess_id) => match self.data.get_mut(&sess_id) { Some(sess_data) => { sess_data.extend(); let reply = msg.reply_with_data(MsgType::SessionValidated); self.queue.send(reply).unwrap(); } None => self.new_session(msg), }, _ => self.new_session(msg), }, None => self.new_session(msg), } } fn new_session(&mut self, msg: Message) { let mut id = Uuid::new_v4(); while self.data.contains_key(&id) { id = Uuid::new_v4(); } let req_lang = match msg.get_data("language") { Some(data) => Some(data.to_language().unwrap().clone()), None => None, }; self.data.insert(id.clone(), SessionData::new(req_lang)); let mut reply = msg.reply_with_data(MsgType::SessionValidated); reply.add_data("sess_id", id); self.queue.send(reply).unwrap(); } fn expire(&mut self, msg: Message) { let now = msg.get_data("time").unwrap().to_datetime().unwrap(); let mut expired: Vec = Vec::new(); for (id, data) in self.data.iter() { if data.is_expired(&now) { expired.push(id.clone()); } } for id in expired.iter() { self.data.remove(id); } } fn get(&self, msg: Message) { let sess_id = msg.get_data("sess_id").unwrap().to_uuid().unwrap(); let sess_data = self.data.get(&sess_id).unwrap(); let mut reply = msg.reply(MsgType::Session); reply.add_data("language", sess_data.language.clone()); self.queue.send(reply); } } #[cfg(test)] pub mod sessions { use super::*; use crate::queue::{Message, MsgType}; use std::{sync::mpsc::channel, time::Duration}; static TIMEOUT: Duration = Duration::from_millis(500); pub fn create_validated_reply(msg: Message) -> Message { let mut reply = msg.reply(MsgType::SessionValidated); reply.add_data("sess_id", Uuid::new_v4()); reply } fn setup_session() -> (Queue, Receiver) { let queue = Queue::new(); let (tx, rx) = channel(); let listen_for = [MsgType::Session, MsgType::SessionValidated].to_vec(); queue.add(tx, listen_for); Session::start(queue.clone()); (queue, rx) } fn create_session(queue: &Queue, rx: &Receiver, lang: Option) -> Uuid { let mut msg = Message::new(MsgType::SessionValidate); match lang { Some(data) => msg.add_data("language", data.clone()), None => {} } queue.send(msg.clone()).unwrap(); let holder = rx.recv_timeout(TIMEOUT).unwrap(); holder.get_data("sess_id").unwrap().to_uuid().unwrap() } #[test] fn get_new_session() { let id = Uuid::new_v4(); let (queue, rx) = setup_session(); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data(id, id); queue.send(msg.clone()).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); match result.get_msg_type() { MsgType::SessionValidated => {} _ => unreachable!( "received {:?}, should have been a session", result.get_msg_type() ), } assert_eq!(result.get_id(), msg.get_id()); assert_eq!(result.get_data(id).unwrap().to_uuid().unwrap(), id); } #[test] fn session_id_is_unique() { let (queue, rx) = setup_session(); let msg = Message::new(MsgType::SessionValidate); let mut ids: Vec = Vec::new(); for _ in 0..10 { queue.send(msg.clone()).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let id = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert!(!ids.contains(&id), "{} is a duplicate id", id); ids.push(id); } } #[test] fn existing_id_is_returned() { let add_data = Uuid::new_v4(); let (queue, rx) = setup_session(); let id = create_session(&queue, &rx, None); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id.clone()); msg.add_data(add_data, add_data); queue.send(msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert_eq!(output, id); assert_eq!( result.get_data(add_data).unwrap().to_uuid().unwrap(), add_data ); } #[test] fn issue_new_if_validated_doe_not_exist() { let id = Uuid::new_v4(); let (queue, rx) = setup_session(); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id.clone()); queue.send(msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert_ne!(output, id); } #[test] fn new_for_bad_uuid() { let id = "bad uuid"; let (queue, rx) = setup_session(); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id); queue.send(msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_string(); assert_ne!(output, id); } #[test] fn timer_does_nothing_to_unexpired() { let expire = Utc::now() + EXPIRE_IN; let (queue, rx) = setup_session(); let id = create_session(&queue, &rx, None); let mut time_msg = Message::new(MsgType::Time); time_msg.add_data("time", expire); queue.send(time_msg).unwrap(); let mut validate_msg = Message::new(MsgType::SessionValidate); validate_msg.add_data("sess_id", id.clone()); queue.send(validate_msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_data("sess_id").unwrap().to_uuid().unwrap(), id); } #[test] fn timer_removes_expired() { let (queue, rx) = setup_session(); let id = create_session(&queue, &rx, None); let expire = Utc::now() + EXPIRE_IN; let mut time_msg = Message::new(MsgType::Time); time_msg.add_data("time", expire); queue.send(time_msg).unwrap(); let mut validate_msg = Message::new(MsgType::SessionValidate); validate_msg.add_data("sess_id", id.clone()); queue.send(validate_msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_ne!(result.get_data("sess_id").unwrap().to_uuid().unwrap(), id); } #[test] fn validate_extends_session() { let (queue, rx) = setup_session(); let id = create_session(&queue, &rx, None); let mut validate_msg = Message::new(MsgType::SessionValidate); validate_msg.add_data("sess_id", id.clone()); let expire = Utc::now() + EXPIRE_IN; let mut time_msg = Message::new(MsgType::Time); time_msg.add_data("time", expire); queue.send(validate_msg.clone()).unwrap(); queue.send(time_msg).unwrap(); queue.send(validate_msg).unwrap(); rx.recv_timeout(TIMEOUT).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_data("sess_id").unwrap().to_uuid().unwrap(), id); } #[test] fn get_session_information() { let (queue, rx) = setup_session(); let id = create_session(&queue, &rx, None); let mut msg = Message::new(MsgType::SessionGet); msg.add_data("sess_id", id.clone()); queue.send(msg.clone()); let reply = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(reply.get_id(), msg.get_id()); assert_eq!(reply.get_msg_type(), &MsgType::Session); assert_eq!( reply.get_data("language").unwrap().to_language().unwrap(), DEFAULT_LANG ); } #[test] fn get_requested_langaages() { let langs = [Language::Jpn, Language::Deu]; let (queue, rx) = setup_session(); for lang in langs.into_iter() { let id = create_session(&queue, &rx, Some(lang.clone())); let mut msg = Message::new(MsgType::SessionGet); msg.add_data("sess_id", id.clone()); queue.send(msg.clone()); let reply = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(reply.get_id(), msg.get_id()); assert_eq!(reply.get_msg_type(), &MsgType::Session); assert_eq!( reply.get_data("language").unwrap().to_language().unwrap(), lang ); } } }