From 5df672398bf135710e47db946d118b9e9bfcf12c Mon Sep 17 00:00:00 2001 From: william Date: Mon, 3 Apr 2023 23:30:00 -0400 Subject: [PATCH] Messages --- client/src/main.rs | 17 ++- messages/Cargo.toml | 4 - messages/src/client_registration.rs | 43 ++++--- messages/src/lib.rs | 18 +-- messages/src/serialization.rs | 86 +++++++++++++ server/src/main.rs | 61 ++++++++- server/src/{ => net}/epoll.rs | 46 ++++--- server/src/net/mod.rs | 3 + server/src/net/tcp_server.rs | 192 ++++++++++++++++++++++++++++ server/src/tcp_server.rs | 103 --------------- 10 files changed, 401 insertions(+), 172 deletions(-) create mode 100644 messages/src/serialization.rs rename server/src/{ => net}/epoll.rs (68%) create mode 100644 server/src/net/mod.rs create mode 100644 server/src/net/tcp_server.rs delete mode 100644 server/src/tcp_server.rs diff --git a/client/src/main.rs b/client/src/main.rs index af27585..815b812 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,9 +1,11 @@ use std::io; use std::io::Write; use std::net::{SocketAddr, TcpStream}; +use std::thread::sleep; +use std::time::Duration; -use messages::{any_as_u8_slice, Serializable}; use messages::client_registration::ClientRegistration; +use messages::serialization::SerializeMessage; fn main() -> io::Result<()> { let addr = SocketAddr::from(([127, 0, 0, 1], 4433)); @@ -15,8 +17,17 @@ fn main() -> io::Result<()> { name: "My new client :)".to_string(), }; - let mut buf = [0u8; 1024]; - registration.serialize(&mut buf); + let buf = registration.serialize(); + + stream.write(&buf)?; + stream.flush().unwrap(); + + sleep(Duration::from_secs(2)); + + stream.write(&buf)?; + stream.flush().unwrap(); + + sleep(Duration::from_secs(2)); stream.write(&buf)?; diff --git a/messages/Cargo.toml b/messages/Cargo.toml index daa857d..85d014f 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -2,7 +2,3 @@ name = "messages" version = "0.1.0" edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] diff --git a/messages/src/client_registration.rs b/messages/src/client_registration.rs index 4e9e2ee..afaf933 100644 --- a/messages/src/client_registration.rs +++ b/messages/src/client_registration.rs @@ -1,6 +1,6 @@ -use crate::{DeserializationError, Serializable}; +use crate::serialization::{deserialize_str_at_end, DeserializeMessage, MessageType, serialize_message, SerializeMessage}; -const CR_BUFFER_MIN_CAPACITY: usize = 3; +const CR_STACK_SIZE: usize = 2; // Contains the version using semantic versioning // The patch version is omitted, because it should not affect the communication between the server and the client @@ -11,27 +11,28 @@ pub struct ClientRegistration { pub name: String, } -impl Serializable for ClientRegistration { - fn serialize(&self, buf: &mut [u8]) { - let mut vec_buf: Vec = Vec::with_capacity(buf.len()); - vec_buf.insert(0, self.major_version); - vec_buf.insert(1, self.minor_version); - - let name_bytes = self.name.as_bytes(); - vec_buf.extend_from_slice(name_bytes); - - buf[..vec_buf.len()].copy_from_slice(&vec_buf); +impl SerializeMessage for ClientRegistration { + fn get_message_type() -> MessageType { + MessageType::NewConnection } - fn deserialize(buf: &[u8]) -> Result { - if buf.len() < CR_BUFFER_MIN_CAPACITY { - return Err(DeserializationError::MissingData); - } + fn serialize(&self) -> Vec { + let mut buf: Vec = Vec::with_capacity(CR_STACK_SIZE); - Ok(ClientRegistration { - major_version: buf[0], - minor_version: buf[1], - name: String::from_utf8_lossy(&buf[2..]).into_owned(), - }) + buf.insert(0, self.major_version); + buf.insert(1, self.minor_version); + buf.extend_from_slice(self.name.as_bytes()); + + return serialize_message::(&buf); + } +} + +impl DeserializeMessage for ClientRegistration { + fn deserialize(buf: &[u8]) -> Self { + ClientRegistration { + major_version: buf[0], + minor_version: buf[1], + name: deserialize_str_at_end(&buf[2..]), + } } } diff --git a/messages/src/lib.rs b/messages/src/lib.rs index c0873f4..6783f9a 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -1,18 +1,2 @@ pub mod client_registration; - -pub trait Serializable where Self: Sized { - fn serialize(&self, buf: &mut [u8]); - fn deserialize(buf: &[u8]) -> Result; -} - -// From: https://stackoverflow.com/questions/28127165/how-to-convert-struct-to-u8 -pub unsafe fn any_as_u8_slice(p: &T) -> &[u8] { - ::core::slice::from_raw_parts( - (p as *const T) as *const u8, - ::core::mem::size_of::(), - ) -} - -pub enum DeserializationError { - MissingData -} +pub mod serialization; diff --git a/messages/src/serialization.rs b/messages/src/serialization.rs new file mode 100644 index 0000000..29f7272 --- /dev/null +++ b/messages/src/serialization.rs @@ -0,0 +1,86 @@ +const MESSAGE_METADATA_LENGTH: usize = 5; + +pub trait SerializeMessage { + fn get_message_type() -> MessageType; + fn serialize(&self) -> Vec; +} + +pub trait DeserializeMessage where Self: Sized { + fn deserialize(buf: &[u8]) -> Self; +} + +#[repr(u8)] +#[derive(Debug)] +pub enum MessageType { + NewConnection = 1 +} + +impl MessageType { + fn from(type_id: u8) -> Option { + match type_id { + 1 => Some(MessageType::NewConnection), + _ => None + } + } +} + +#[derive(Debug)] +pub enum DeserializationError { + InvalidDataLength(usize, usize), + UnexpectedMessageType(u8), +} + +pub fn serialize_message(data: &[u8]) -> Vec { + let mut buf = Vec::with_capacity(MESSAGE_METADATA_LENGTH); + + serialize_message_metadata::(&mut buf, data.len()); + buf.extend_from_slice(data); + + buf +} + +fn serialize_message_metadata(buf: &mut Vec, data_length: usize) { + let message_type = T::get_message_type() as u8; + + buf.insert(0, message_type); + buf.insert(1, (data_length >> 24) as u8); + buf.insert(2, (data_length >> 16) as u8); + buf.insert(3, (data_length >> 8) as u8); + buf.insert(4, data_length as u8); +} + +pub fn read_message_data(data: &[u8]) -> Result<(MessageType, &[u8]), DeserializationError> { + let message_type_id = data[0]; + let message_type = MessageType::from(message_type_id); + if message_type.is_none() { + return Err(DeserializationError::UnexpectedMessageType(message_type_id)); + } + + let data_length = deserialize_u32(&data[1..=MESSAGE_METADATA_LENGTH]) as usize; + if data.len() < MESSAGE_METADATA_LENGTH + data_length { + return Err(DeserializationError::InvalidDataLength(data_length, data.len())); + } + + let data = &data[MESSAGE_METADATA_LENGTH..data_length]; + + Ok((message_type.unwrap(), data)) +} + +pub fn deserialize_u32(data: &[u8]) -> u32 { + let mut val = 0u32; + for i in 0..4 { + let offset = (3 - i) * 8; + val += (data[i] as u32) << offset; + } + val +} + +/// Deserialize a string at the end of a buffer, removing empty bytes at the end. +pub fn deserialize_str_at_end(buf: &[u8]) -> String { + let mut str_end = buf.len() - 1; + while buf[str_end] == 0 { + str_end -= 1; + } + + String::from_utf8_lossy(&buf[..=str_end]).into_owned() +} diff --git a/server/src/main.rs b/server/src/main.rs index 670d717..264531d 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,17 +1,68 @@ +use std::collections::VecDeque; use std::io; -use std::net::{SocketAddr, TcpStream}; +use std::net::SocketAddr; -use crate::tcp_server::TcpServer; +use messages::client_registration::ClientRegistration; +use messages::serialization::{DeserializeMessage, read_message_data, SerializeMessage}; -mod epoll; -mod tcp_server; +use crate::net::tcp_server::{NextIntent, TcpClient, TcpServer}; + +mod net; mod client; fn main() -> io::Result<()> { let addr = SocketAddr::from(([127, 0, 0, 1], 4433)); - let mut server = TcpServer::new(addr)?; + let mut server: TcpServer = TcpServer::new(addr)?; server.listen()?; Ok(()) } + +static mut AUTO_ID: u8 = 0; + +struct KvmClient { + id: u8, + write_buffer: VecDeque>, +} + +impl TcpClient for KvmClient { + fn new() -> Self { + let id = unsafe { + AUTO_ID += 1; + AUTO_ID + }; + + println!("New client ({id})"); + + KvmClient { + id, + write_buffer: VecDeque::new(), + } + } + + fn read(&self, buf: &[u8]) -> NextIntent { + let (message_type, data) = read_message_data(buf).expect("Received invalid message"); + + println!("Received {:?} message with {} bytes", message_type, data.len()); + + NextIntent::Read + } + + fn get_next_write(&mut self) -> (Vec, NextIntent) { + let data = self.write_buffer.pop_front().unwrap(); + (data, NextIntent::Read) + } + + fn close(&self) { + println!("Client {} disconnected", self.id); + } +} + +impl KvmClient { + pub fn write(&mut self, data: &T) + where T: SerializeMessage { + let buf = data.serialize(); + self.write_buffer.push_back(buf); + } +} diff --git a/server/src/epoll.rs b/server/src/net/epoll.rs similarity index 68% rename from server/src/epoll.rs rename to server/src/net/epoll.rs index ff22f86..c7ff9b2 100644 --- a/server/src/epoll.rs +++ b/server/src/net/epoll.rs @@ -1,19 +1,24 @@ use std::io; use std::os::fd::RawFd; -use libc::epoll_event; +use libc::{epoll_event}; + +const EVENTS_CAPACITY: usize = 1024; +const WAIT_MAX_EVENTS: i32 = 1024; +const WAIT_TIMEOUT: i32 = 1000; pub struct Epoll { fd: RawFd, pub events: Vec, } -const READ_FLAGS: i32 = libc::EPOLLONESHOT | libc::EPOLLIN; -const WRITE_FLAGS: i32 = libc::EPOLLONESHOT | libc::EPOLLOUT; - -const EVENTS_CAPACITY: usize = 1024; -const WAIT_MAX_EVENTS: i32 = 1024; -const WAIT_TIMEOUT: i32 = 1000; +#[derive(Copy, Clone)] +#[repr(i32)] +pub enum EpollEvent { + Read = libc::EPOLLIN, + Write = libc::EPOLLOUT, + Disconnect = libc::EPOLLRDHUP, +} impl Epoll { pub fn create() -> io::Result { @@ -26,12 +31,12 @@ impl Epoll { } } - pub fn add_read_interest(&self, fd: RawFd, key: u16) -> io::Result<()> { - add_interest(self.fd, fd, listener_read_event(key)) + pub fn add_interest(&self, fd: RawFd, key: u16, events: &[EpollEvent]) -> io::Result<()> { + add_interest(self.fd, fd, create_oneshot_epoll_event(key, events)) } - pub fn modify_read_interest(&self, fd: RawFd, key: u16) -> io::Result<()> { - modify_interest(self.fd, fd, listener_read_event(key)) + pub fn modify_interest(&self, fd: RawFd, key: u16, events: &[EpollEvent]) -> io::Result<()> { + modify_interest(self.fd, fd, create_oneshot_epoll_event(key, events)) } pub fn wait(&mut self) -> io::Result<()> { @@ -47,12 +52,9 @@ impl Epoll { } } -pub fn is_read_event(event: u32) -> bool { - event as i32 & libc::EPOLLIN == libc::EPOLLIN -} - -pub fn is_write_event(event: u32) -> bool { - event as i32 & libc::EPOLLOUT == libc::EPOLLOUT +pub fn match_epoll_event(event: u32, expected_event: EpollEvent) -> bool { + let expected_event = expected_event as i32; + event as i32 & expected_event == expected_event } macro_rules! syscall { @@ -97,9 +99,15 @@ fn modify_interest(epoll_fd: RawFd, fd: RawFd, mut event: epoll_event) -> io::Re Ok(()) } -fn listener_read_event(key: u16) -> epoll_event { +fn create_oneshot_epoll_event(key: u16, events: &[EpollEvent]) -> epoll_event { epoll_event { - events: READ_FLAGS as u32, + events: get_oneshot_events_flag(events), u64: key as u64, } } + +fn get_oneshot_events_flag(events: &[EpollEvent]) -> u32 { + let mut flag: i32 = libc::EPOLLONESHOT; + events.into_iter().for_each(|e| flag = flag | *e as i32); + flag as u32 +} diff --git a/server/src/net/mod.rs b/server/src/net/mod.rs new file mode 100644 index 0000000..07adc27 --- /dev/null +++ b/server/src/net/mod.rs @@ -0,0 +1,3 @@ +pub mod tcp_server; + +mod epoll; diff --git a/server/src/net/tcp_server.rs b/server/src/net/tcp_server.rs new file mode 100644 index 0000000..c39c262 --- /dev/null +++ b/server/src/net/tcp_server.rs @@ -0,0 +1,192 @@ +use std::collections::HashMap; +use std::io; +use std::io::{Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::os::fd::{AsRawFd, RawFd}; + +use crate::net::epoll::{Epoll, EpollEvent, match_epoll_event}; + +// Based on: https://www.zupzup.org/epoll-with-rust/index.html + +const KEY_NEW_CONNECTION: u16 = 0; + +pub trait TcpClient { + fn new() -> Self; + fn read(&self, buf: &[u8]) -> NextIntent; + fn get_next_write(&mut self) -> (Vec, NextIntent); + fn close(&self); +} + +pub enum NextIntent { + Read, + Write, +} + +pub struct TcpServer + where T: TcpClient { + addr: SocketAddr, + listener: TcpListener, + listener_fd: RawFd, + epoll: Epoll, + key: u16, + request_contexts: HashMap>, +} + +struct TcpContext + where T: TcpClient { + key: u16, + stream: TcpStream, + client: T, +} + +impl TcpServer + where T: TcpClient { + pub fn new(addr: SocketAddr) -> io::Result { + let listener = TcpListener::bind(addr)?; + listener.set_nonblocking(true)?; + + let listener_fd = listener.as_raw_fd(); + + let epoll = Epoll::create()?; + epoll.add_interest(listener_fd, KEY_NEW_CONNECTION, &[EpollEvent::Read])?; + + Ok(TcpServer { + addr, + listener, + listener_fd, + epoll, + key: 0, + request_contexts: HashMap::new(), + }) + } + + pub fn listen(&mut self) -> io::Result<()> { + println!("Listening on {}", self.addr); + + loop { + self.epoll.wait().expect("Failed to wait for epoll event"); + + let events = &self.epoll.events.iter() + .map(|event| (event.events, event.u64)) + .collect::>(); + + let mut to_remove = Vec::new(); + + for (events, u64) in events { + match *u64 as u16 { + KEY_NEW_CONNECTION => self.accept_connection()?, + key => { + if let Some(context) = self.request_contexts.get_mut(&key) { + if handle_event(&self.epoll, context, *events) { + context.client.close(); + to_remove.push(context.key); + } + } + } + } + } + + for key_to_remove in to_remove { + self.request_contexts.remove(&key_to_remove); + } + } + } + + fn accept_connection(&mut self) -> io::Result<()> { + match self.listener.accept() { + Ok((stream, addr)) => { + stream.set_nonblocking(true)?; + println!("Accepted connection from {addr}"); + + let client = T::new(); + let key = self.get_next_key(); + let fd = stream.as_raw_fd(); + + let context = TcpContext { key, stream, client }; + + self.epoll.add_interest(fd, key, &[EpollEvent::Read, EpollEvent::Disconnect])?; + self.request_contexts.insert(key, context); + } + Err(e) => eprintln!("Couldn't accept: {e}") + }; + + self.epoll.modify_interest(self.listener_fd, KEY_NEW_CONNECTION, &[EpollEvent::Read]) + } + + fn get_next_key(&mut self) -> u16 { + self.key += 1; + self.key + } +} + +fn handle_event(epoll: &Epoll, context: &mut TcpContext, event: u32) -> bool + where T: TcpClient { + match event { + v if match_epoll_event(v, EpollEvent::Read) => { + println!("Read"); + return handle_read_event(epoll, context); + } + v if match_epoll_event(v, EpollEvent::Write) => { + println!("Write"); + return handle_write_event(epoll, context); + } + v if match_epoll_event(v, EpollEvent::Disconnect) => { + println!("Disconnect"); + return true; + } + v => println!("Unexpected event: {v}"), + }; + + false +} + +fn handle_read_event(epoll: &Epoll, context: &mut TcpContext) -> bool + where T: TcpClient { + let mut data: Vec = Vec::new(); + let mut buf = [0u8; 2048]; + + let read_length = context.stream.read(&mut buf).expect("Failed to read stream"); + if read_length == 0 { + return true; + } + + let next_interest = context.client.read(&buf); + set_interest(epoll, context, &next_interest); + + false +} + +fn handle_write_event(epoll: &Epoll, context: &mut TcpContext) -> bool + where T: TcpClient { + let (data, next_interest) = context.client.get_next_write(); + + let data = trim_end(&data); + context.stream.write(data).expect("Failed to write to stream"); + + set_interest(epoll, context, &next_interest); + false +} + + +fn set_interest(epoll: &Epoll, context: &TcpContext, next_intent: &NextIntent) + where T: TcpClient { + let event = match next_intent { + NextIntent::Read => EpollEvent::Read, + NextIntent::Write => EpollEvent::Write, + }; + + epoll.modify_interest( + context.stream.as_raw_fd(), + context.key, + &[event, EpollEvent::Disconnect]) + .unwrap(); +} + +fn trim_end(buf: &[u8]) -> &[u8] { + let mut end = buf.len() - 1; + while buf[end] == 0 { + end -= 1; + } + + &buf[..=end] +} diff --git a/server/src/tcp_server.rs b/server/src/tcp_server.rs deleted file mode 100644 index aadda79..0000000 --- a/server/src/tcp_server.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::collections::HashMap; -use std::io; -use std::io::Read; -use std::net::{SocketAddr, TcpListener}; -use std::os::fd::{AsRawFd, RawFd}; -use messages::client_registration::ClientRegistration; -use messages::Serializable; -use crate::client::Client; - -use crate::epoll::{Epoll, is_read_event, is_write_event}; - -// Based on: https://www.zupzup.org/epoll-with-rust/index.html - -const KEY_NEW_CONNECTION: u16 = u16::MAX; - -pub struct TcpServer { - addr: SocketAddr, - listener: TcpListener, - listener_fd: RawFd, - epoll: Epoll, - request_contexts: HashMap, -} - -impl TcpServer { - pub fn new(addr: SocketAddr) -> io::Result { - let listener = TcpListener::bind(addr)?; - listener.set_nonblocking(true)?; - - let listener_fd = listener.as_raw_fd(); - - let epoll = Epoll::create()?; - epoll.add_read_interest(listener_fd, KEY_NEW_CONNECTION)?; - - Ok(TcpServer { - addr, - listener, - listener_fd, - epoll, - request_contexts: HashMap::new(), - }) - } - - pub fn listen(&mut self) -> io::Result<()> { - println!("Listening on {}", self.addr); - - loop { - self.epoll.wait().expect("Failed to wait for epoll event"); - - let events = &self.epoll.events.iter() - .map(|event| (event.events, event.u64)) - .collect::>(); - - for (events, u64) in events { - match *u64 as u16 { - KEY_NEW_CONNECTION => self.accept_connection()?, - key => self.handle_event(*events, key as u8) - } - } - } - } - - fn accept_connection(&mut self) -> io::Result<()> { - match self.listener.accept() { - Ok((stream, addr)) => { - stream.set_nonblocking(true)?; - - let fd = stream.as_raw_fd(); - let client = Client::new(stream); - println!("New client: {addr} ({})", client.id); - - self.epoll.add_read_interest(fd, client.id as u16)?; - self.request_contexts.insert(client.id, client); - } - Err(e) => eprintln!("Couldn't accept: {e}") - }; - - self.epoll.modify_read_interest(self.listener_fd, KEY_NEW_CONNECTION) - } - - fn handle_event(&mut self, events: u32, key: u8) { - let mut to_delete = None; - if let Some(client) = self.request_contexts.get_mut(&key) { - match events { - v if is_read_event(v) => { - let mut buf = [0u8; 1024]; - let read_length = client.stream.read(&mut buf).expect("Failed to read stream"); - let registration = ClientRegistration::deserialize(&buf[..read_length]); - - println!("Test"); - } - v if is_write_event(v) => { - println!("Write Event"); - // context.write_cb(key, epoll_fd)?; - to_delete = Some(key); - } - v => println!("Unexpected event: {v}"), - } - } - if let Some(key) = to_delete { - self.request_contexts.remove(&key); - } - } -}