use crate::{ field::Field, queue::{Message, MsgType, Queue}, }; use chrono::prelude::*; use std::{ collections::HashMap, sync::mpsc::{channel, Receiver}, thread::spawn, time::Duration, }; use uuid::Uuid; const EXPIRE_IN: Duration = Duration::from_secs(60 * 60); const RESPONDS_TO: [MsgType; 2] = [MsgType::SessionValidate, MsgType::Time]; struct SessionData { expire_on: DateTime, } impl SessionData { fn new() -> Self { Self { expire_on: Utc::now() + EXPIRE_IN, } } fn extend(&mut self) { self.expire_on = Utc::now() + EXPIRE_IN; } fn is_expired(&self, now: &DateTime) -> bool { now > &self.expire_on } } #[cfg(test)] mod sessiondatas { use super::*; #[test] fn create_session_data() { let expire = Utc::now() + EXPIRE_IN; let data = SessionData::new(); assert!( data.expire_on > expire, "{:?} should be greater than {:?}", data.expire_on, expire ); } #[test] fn extend_usage_time() { let mut data = SessionData::new(); let expire = Utc::now() + EXPIRE_IN; data.extend(); assert!( data.expire_on > expire, "{:?} should be greater than {:?}", data.expire_on, expire ); } #[test] fn is_expired() { let data = SessionData::new(); let expire = Utc::now() + EXPIRE_IN; assert!(data.is_expired(&expire), "should be expired"); } #[test] fn is_not_expired() { let expire = Utc::now() + EXPIRE_IN; let data = SessionData::new(); assert!(!data.is_expired(&expire), "should be not expired"); } } pub struct Session { data: HashMap, queue: Queue, rx: Receiver, } impl Session { fn new(queue: Queue, rx: Receiver) -> Self { Self { data: HashMap::new(), queue: queue, rx: rx, } } pub fn start(queue: Queue) { let (tx, rx) = channel(); let mut session = Session::new(queue, rx); session.queue.add(tx, RESPONDS_TO.to_vec()); spawn(move || { session.listen(); }); } fn listen(&mut self) { loop { let msg = self.rx.recv().unwrap(); match msg.get_msg_type() { MsgType::SessionValidate => self.validate(msg), MsgType::Time => self.expire(msg), _ => unreachable!("received unknown message"), }; } } fn validate(&mut self, msg: Message) { match msg.get_data("sess_id") { Some(sid) => match sid { Field::Uuid(sess_id) => match self.data.get_mut(&sess_id) { Some(sess_data) => { sess_data.extend(); let reply = msg.reply_with_data(MsgType::SessionValidated); self.queue.send(reply).unwrap(); } None => self.new_session(msg), }, _ => self.new_session(msg), }, None => self.new_session(msg), } } fn new_session(&mut self, msg: Message) { let mut id = Uuid::new_v4(); while self.data.contains_key(&id) { id = Uuid::new_v4(); } self.data.insert(id.clone(), SessionData::new()); let mut reply = msg.reply_with_data(MsgType::SessionValidated); reply.add_data("sess_id", id); self.queue.send(reply).unwrap(); } fn expire(&mut self, msg: Message) { let now = msg.get_data("time").unwrap().to_datetime().unwrap(); let mut expired: Vec = Vec::new(); for (id, data) in self.data.iter() { if data.is_expired(&now) { expired.push(id.clone()); } } for id in expired.iter() { self.data.remove(id); } } } #[cfg(test)] mod sessions { use super::*; use crate::queue::{Message, MsgType}; use std::{sync::mpsc::channel, time::Duration}; static TIMEOUT: Duration = Duration::from_millis(500); fn setup_session(listen_for: Vec) -> (Queue, Receiver) { let queue = Queue::new(); let (tx, rx) = channel(); queue.add(tx, listen_for); Session::start(queue.clone()); (queue, rx) } fn create_session(queue: &Queue, rx: &Receiver) -> Uuid { let msg = Message::new(MsgType::SessionValidate); queue.send(msg.clone()).unwrap(); let holder = rx.recv_timeout(TIMEOUT).unwrap(); holder.get_data("sess_id").unwrap().to_uuid().unwrap() } #[test] fn get_new_session() { let id = Uuid::new_v4(); let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data(id, id); queue.send(msg.clone()).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); match result.get_msg_type() { MsgType::SessionValidated => {} _ => unreachable!( "received {:?}, should have been a session", result.get_msg_type() ), } assert_eq!(result.get_id(), msg.get_id()); assert_eq!(result.get_data(id).unwrap().to_uuid().unwrap(), id); } #[test] fn session_id_is_unique() { let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let msg = Message::new(MsgType::SessionValidate); let mut ids: Vec = Vec::new(); for _ in 0..10 { queue.send(msg.clone()).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let id = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert!(!ids.contains(&id), "{} is a duplicate id", id); ids.push(id); } } #[test] fn existing_id_is_returned() { let add_data = Uuid::new_v4(); let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id.clone()); msg.add_data(add_data, add_data); queue.send(msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert_eq!(output, id); assert_eq!( result.get_data(add_data).unwrap().to_uuid().unwrap(), add_data ); } #[test] fn issue_new_if_validated_doe_not_exist() { let id = Uuid::new_v4(); let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id.clone()); queue.send(msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_uuid().unwrap(); assert_ne!(output, id); } #[test] fn new_for_bad_uuid() { let id = "bad uuid"; let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let mut msg = Message::new(MsgType::SessionValidate); msg.add_data("sess_id", id); queue.send(msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); let output = result.get_data("sess_id").unwrap().to_string(); assert_ne!(output, id); } #[test] fn timer_does_nothing_to_unexpired() { let expire = Utc::now() + EXPIRE_IN; let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let mut time_msg = Message::new(MsgType::Time); time_msg.add_data("time", expire); queue.send(time_msg).unwrap(); let mut validate_msg = Message::new(MsgType::SessionValidate); validate_msg.add_data("sess_id", id.clone()); queue.send(validate_msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_data("sess_id").unwrap().to_uuid().unwrap(), id); } #[test] fn timer_removes_expired() { let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let expire = Utc::now() + EXPIRE_IN; let mut time_msg = Message::new(MsgType::Time); time_msg.add_data("time", expire); queue.send(time_msg).unwrap(); let mut validate_msg = Message::new(MsgType::SessionValidate); validate_msg.add_data("sess_id", id.clone()); queue.send(validate_msg).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_ne!(result.get_data("sess_id").unwrap().to_uuid().unwrap(), id); } #[test] fn validate_extends_session() { let listen_for = [MsgType::SessionValidated]; let (queue, rx) = setup_session(listen_for.to_vec()); let id = create_session(&queue, &rx); let mut validate_msg = Message::new(MsgType::SessionValidate); validate_msg.add_data("sess_id", id.clone()); let expire = Utc::now() + EXPIRE_IN; let mut time_msg = Message::new(MsgType::Time); time_msg.add_data("time", expire); queue.send(validate_msg.clone()).unwrap(); queue.send(time_msg).unwrap(); queue.send(validate_msg).unwrap(); rx.recv_timeout(TIMEOUT).unwrap(); let result = rx.recv_timeout(TIMEOUT).unwrap(); assert_eq!(result.get_data("sess_id").unwrap().to_uuid().unwrap(), id); } }