From 3498ab2d7b31f51f74c15c8e74eb27c7516059d4 Mon Sep 17 00:00:00 2001 From: Prabhpreet Dua <615318+prabhpreet@users.noreply.github.com> Date: Sun, 4 Feb 2024 11:39:34 +0530 Subject: [PATCH] Checkpoint --- rosenpass/src/app_server.rs | 79 +++++++++++++-------- rosenpass/tests/integration_test.rs | 104 +++++++++++++++++++++++++++- 2 files changed, 155 insertions(+), 28 deletions(-) diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index c3d5edd..d0876e7 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -34,9 +34,10 @@ use rosenpass_util::b64::{b64_writer, fmt_b64}; const IPV4_ANY_ADDR: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0); const IPV6_ANY_ADDR: Ipv6Addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0); -// Using values from Linux Kernel implementation -// TODO: Customize values for rosenpass -const MAX_QUEUED_INCOMING_HANDSHAKES_THRESHOLD: usize = 4096; +const NORMAL_OPERATION_THRESHOLD: usize = 5; +const UNDER_LOAD_THRESHOLD: usize = 10; +const RESET_DURATION: Duration = Duration::from_secs(1); + const LAST_UNDER_LOAD_WINDOW: Duration = Duration::from_secs(1); fn ipv4_any_binding() -> SocketAddr { @@ -76,7 +77,7 @@ pub struct WireguardOut { #[derive(Debug)] pub enum DoSOperation { UnderLoad { last_under_load: Instant }, - Normal, + Normal{blocked_polls: usize}, } /// Holds the state of the application, namely the external IO @@ -199,15 +200,17 @@ impl std::fmt::Display for SocketBoundEndpoint { } impl SocketBoundEndpoint { - - const SOCKET_SIZE: usize = usize::BITS as usize/8; + const SOCKET_SIZE: usize = usize::BITS as usize / 8; const IPV6_SIZE: usize = 16; const PORT_SIZE: usize = 2; const SCOPE_ID_SIZE: usize = 4; - const BUFFER_SIZE: usize = SocketBoundEndpoint::SOCKET_SIZE + SocketBoundEndpoint::IPV6_SIZE + SocketBoundEndpoint::PORT_SIZE + SocketBoundEndpoint::SCOPE_ID_SIZE; - pub fn to_bytes(&self) -> (usize,[u8; SocketBoundEndpoint::BUFFER_SIZE]) { - let mut buf = [0u8;SocketBoundEndpoint::BUFFER_SIZE]; + const BUFFER_SIZE: usize = SocketBoundEndpoint::SOCKET_SIZE + + SocketBoundEndpoint::IPV6_SIZE + + SocketBoundEndpoint::PORT_SIZE + + SocketBoundEndpoint::SCOPE_ID_SIZE; + pub fn to_bytes(&self) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) { + let mut buf = [0u8; SocketBoundEndpoint::BUFFER_SIZE]; let addr = match self.addr { SocketAddr::V4(addr) => { //Map IPv4-mapped to IPv6 addresses @@ -217,15 +220,17 @@ impl SocketBoundEndpoint { SocketAddr::V6(addr) => addr, }; let mut len: usize = 0; - buf[len..len+SocketBoundEndpoint::SOCKET_SIZE].copy_from_slice(&self.socket.0.to_be_bytes()); + buf[len..len + SocketBoundEndpoint::SOCKET_SIZE] + .copy_from_slice(&self.socket.0.to_be_bytes()); len += SocketBoundEndpoint::SOCKET_SIZE; - buf[len..len+SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets()); + buf[len..len + SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets()); len += SocketBoundEndpoint::IPV6_SIZE; - buf[len..len+SocketBoundEndpoint::PORT_SIZE].copy_from_slice(&addr.port().to_be_bytes()); + buf[len..len + SocketBoundEndpoint::PORT_SIZE].copy_from_slice(&addr.port().to_be_bytes()); len += SocketBoundEndpoint::PORT_SIZE; - buf[len..len+SocketBoundEndpoint::SCOPE_ID_SIZE].copy_from_slice(&addr.scope_id().to_be_bytes()); + buf[len..len + SocketBoundEndpoint::SCOPE_ID_SIZE] + .copy_from_slice(&addr.scope_id().to_be_bytes()); len += SocketBoundEndpoint::SCOPE_ID_SIZE; - (len,buf) + (len, buf) } } @@ -398,7 +403,7 @@ impl AppServer { ) -> anyhow::Result { // setup mio let mio_poll = mio::Poll::new()?; - let events = mio::Events::with_capacity(8); + let events = mio::Events::with_capacity(20); // bind each SocketAddr to a socket let maybe_sockets: Result, _> = @@ -488,7 +493,7 @@ impl AppServer { events, mio_poll, all_sockets_drained: false, - under_load: DoSOperation::Normal, + under_load: DoSOperation::Normal{ blocked_polls: 0}, }) } @@ -605,9 +610,13 @@ impl AppServer { ReceivedMessage(len, endpoint) => { let msg_result = match self.under_load { DoSOperation::UnderLoad { last_under_load: _ } => { + println!("Processing msg under load"); self.handle_msg_under_load(&endpoint, &rx[..len], &mut *tx) } - DoSOperation::Normal => self.crypt.handle_msg(&rx[..len], &mut *tx), + DoSOperation::Normal { blocked_polls: _} => { + println!("Processing msg normally"); + self.crypt.handle_msg(&rx[..len], &mut *tx) + } }; match msg_result { Err(ref e) => { @@ -652,7 +661,7 @@ impl AppServer { ) -> Result { match endpoint { Endpoint::SocketBoundAddress(socket) => { - let (hi_len,host_identification )= socket.to_bytes(); + let (hi_len, host_identification) = socket.to_bytes(); self.crypt .handle_msg_under_load(&rx, &mut *tx, &host_identification[0..hi_len]) } @@ -784,20 +793,36 @@ impl AppServer { // only poll if we drained all sockets before if self.all_sockets_drained { - self.mio_poll.poll(&mut self.events, Some(timeout))?; + //Non blocked polling + self.mio_poll.poll(&mut self.events, Some(Duration::from_secs(0)))?; - let queue_length = self.events.iter().peekable().count(); - - if queue_length > MAX_QUEUED_INCOMING_HANDSHAKES_THRESHOLD { - self.under_load = DoSOperation::UnderLoad { - last_under_load: Instant::now(), + if self.events.iter().peekable().peek().is_none() { + // if there are no events, then we can just return + match self.under_load { + DoSOperation::Normal { blocked_polls } => { + self.under_load = DoSOperation::Normal { + blocked_polls: blocked_polls + 1, + } + } + _ => {} } + + self.mio_poll.poll(&mut self.events, Some(timeout))?; } } - if let DoSOperation::UnderLoad { last_under_load } = self.under_load { - if last_under_load.elapsed() > LAST_UNDER_LOAD_WINDOW { - self.under_load = DoSOperation::Normal; + match self.under_load { + DoSOperation::Normal { blocked_polls } => { + if blocked_polls > NORMAL_OPERATION_THRESHOLD { + self.under_load = DoSOperation::UnderLoad { + last_under_load: Instant::now(), + } + } + } + DoSOperation::UnderLoad { last_under_load } => { + if last_under_load.elapsed() > RESET_DURATION { + self.under_load = DoSOperation::Normal { blocked_polls: 0 }; + } } } diff --git a/rosenpass/tests/integration_test.rs b/rosenpass/tests/integration_test.rs index e786df3..703594b 100644 --- a/rosenpass/tests/integration_test.rs +++ b/rosenpass/tests/integration_test.rs @@ -39,7 +39,7 @@ fn find_udp_socket() -> u16 { // check that we can exchange keys #[test] -fn check_exchange() { +fn check_exchange_under_normal() { let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange"); fs::create_dir_all(&tmpdir).unwrap(); @@ -117,3 +117,105 @@ fn check_exchange() { // cleanup fs::remove_dir_all(&tmpdir).unwrap(); } + +// check that we can exchange keys +#[test] +fn check_exchange_under_dos() { + let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange-dos"); + fs::create_dir_all(&tmpdir).unwrap(); + + let secret_key_paths = [tmpdir.join("secret-key-0"), tmpdir.join("secret-key-1")]; + let public_key_paths = [tmpdir.join("public-key-0"), tmpdir.join("public-key-1")]; + let shared_key_paths = [tmpdir.join("shared-key-0"), tmpdir.join("shared-key-1")]; + + // generate key pairs + for (secret_key_path, pub_key_path) in secret_key_paths.iter().zip(public_key_paths.iter()) { + let output = test_bin::get_test_bin(BIN) + .args(["gen-keys", "--secret-key"]) + .arg(secret_key_path) + .arg("--public-key") + .arg(pub_key_path) + .output() + .expect("Failed to start {BIN}"); + + assert_eq!(String::from_utf8_lossy(&output.stdout), ""); + assert!(secret_key_path.is_file()); + assert!(pub_key_path.is_file()); + } + + // start first process, the server + let port = find_udp_socket(); + let listen_addr = format!("localhost:{port}"); + let mut server = test_bin::get_test_bin(BIN) + .args(["exchange", "secret-key"]) + .arg(&secret_key_paths[0]) + .arg("public-key") + .arg(&public_key_paths[0]) + .args(["listen", &listen_addr, "verbose", "peer", "public-key"]) + .arg(&public_key_paths[1]) + .arg("outfile") + .arg(&shared_key_paths[0]) + //.stdout(Stdio::null()) + //.stderr(Stdio::null()) + .spawn() + .expect("Failed to start {BIN}"); + + std::thread::sleep(Duration::from_millis(500)); + + //DoS Sender + //Create a UDP socket + let socket = UdpSocket::bind("127.0.0.1:0").expect("couldn't bind to address"); + //Spawn a thread to send DoS packets + let server_addr = listen_addr.clone(); + + //Create thread safe atomic bool to stop the DoS attack + let stop_dos = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let stop_dos_handle = stop_dos.clone(); + + let dos_attack = std::thread::spawn(move || { + while stop_dos.load(std::sync::atomic::Ordering::Relaxed) == false { + let buf = [0; 10]; + socket + .send_to(&buf, &server_addr) + .expect("couldn't send data"); + } + }); + + // start second process, the client + let mut client = test_bin::get_test_bin(BIN) + .args(["exchange", "secret-key"]) + .arg(&secret_key_paths[1]) + .arg("public-key") + .arg(&public_key_paths[1]) + .args(["verbose", "peer", "public-key"]) + .arg(&public_key_paths[0]) + .args(["endpoint", &listen_addr]) + .arg("outfile") + .arg(&shared_key_paths[1]) + //.stdout(Stdio::null()) + //.stderr(Stdio::null()) + .spawn() + .expect("Failed to start {BIN}"); + + // give them some time to do the key exchange + std::thread::sleep(Duration::from_secs(2)); + + // time's up, kill the childs + server.kill().unwrap(); + client.kill().unwrap(); + stop_dos_handle.store(true, std::sync::atomic::Ordering::Relaxed); + dos_attack.join().unwrap(); + + // read the shared keys they created + let shared_keys: Vec<_> = shared_key_paths + .iter() + .map(|p| fs::read_to_string(p).unwrap()) + .collect(); + + // check that they created two equal keys + assert_eq!(shared_keys.len(), 2); + assert_eq!(shared_keys[0], shared_keys[1]); + + // cleanup + fs::remove_dir_all(&tmpdir).unwrap(); +}