diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index d33dc5a..5b0b998 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -1,6 +1,7 @@ use anyhow::bail; use anyhow::Result; +use clap::builder; use derive_builder::Builder; use log::{debug, error, info, warn}; use mio::Interest; @@ -17,6 +18,7 @@ use std::net::SocketAddr; use std::net::SocketAddrV4; use std::net::SocketAddrV6; use std::net::ToSocketAddrs; +use std::path::Display; use std::path::PathBuf; use std::process::Command; use std::process::Stdio; @@ -25,6 +27,7 @@ use std::thread; use std::time::Duration; use std::time::Instant; +use crate::protocol::HostIdentification; use crate::{ config::Verbosity, protocol::{CryptoServer, MsgBuf, PeerPtr, SPk, SSk, SymKey, Timing}, @@ -82,9 +85,11 @@ pub enum DoSOperation { #[builder(pattern = "owned")] pub struct AppServerTest { /// Enable DoS operation permanently + #[builder(default = "false")] pub enable_dos_permanently: bool, /// Terminate application signal - pub terminate: Option>, + #[builder(default = "None")] + pub termination_handler: Option>, } /// Holds the state of the application, namely the external IO @@ -196,13 +201,24 @@ pub enum Endpoint { Discovery(HostPathDiscoveryEndpoint), } +impl std::fmt::Display for Endpoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Endpoint::SocketBoundAddress(host) => write!(f, "{}", host), + Endpoint::Discovery(host) => write!(f, "{}", host), + } + } +} + #[derive(Debug)] pub struct SocketBoundEndpoint { /// The socket the address can be reached under; this is generally /// determined when we actually receive an RespHello message - pub socket: SocketPtr, + socket: SocketPtr, /// Just the address - pub addr: SocketAddr, + addr: SocketAddr, + /// identifier + bytes: (usize,[u8;SocketBoundEndpoint::BUFFER_SIZE]) } impl std::fmt::Display for SocketBoundEndpoint { @@ -221,19 +237,29 @@ impl SocketBoundEndpoint { + SocketBoundEndpoint::IPV6_SIZE + SocketBoundEndpoint::PORT_SIZE + SocketBoundEndpoint::SCOPE_ID_SIZE; - pub fn to_bytes(&self) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) { + + pub fn new(socket: SocketPtr, addr: SocketAddr) -> Self { + let bytes = Self::to_bytes(&socket, &addr); + Self { + socket, + addr, + bytes, + } + } + + fn to_bytes(socket: &SocketPtr, addr: &SocketAddr) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) { let mut buf = [0u8; SocketBoundEndpoint::BUFFER_SIZE]; - let addr = match self.addr { + let addr = match addr { SocketAddr::V4(addr) => { //Map IPv4-mapped to IPv6 addresses let ip = addr.ip().to_ipv6_mapped(); SocketAddrV6::new(ip, addr.port(), 0, 0) } - SocketAddr::V6(addr) => addr, + SocketAddr::V6(addr) => addr.clone(), }; let mut len: usize = 0; buf[len..len + SocketBoundEndpoint::SOCKET_SIZE] - .copy_from_slice(&self.socket.0.to_be_bytes()); + .copy_from_slice(&socket.0.to_be_bytes()); len += SocketBoundEndpoint::SOCKET_SIZE; buf[len..len + SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets()); len += SocketBoundEndpoint::IPV6_SIZE; @@ -246,6 +272,12 @@ impl SocketBoundEndpoint { } } +impl HostIdentification for SocketBoundEndpoint { + fn encode(&self) -> &[u8] { + &self.bytes.1[0..self.bytes.0] + } +} + impl Endpoint { /// Start discovery from some addresses pub fn discovery_from_addresses(addresses: Vec) -> Self { @@ -332,6 +364,12 @@ pub struct HostPathDiscoveryEndpoint { addresses: Vec, } +impl std::fmt::Display for HostPathDiscoveryEndpoint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.addresses) + } +} + impl HostPathDiscoveryEndpoint { pub fn from_addresses(addresses: Vec) -> Self { let scouting_state = Cell::new((0, 0)); @@ -604,7 +642,7 @@ impl AppServer { use KeyOutputReason::*; if let Some(AppServerTest { - terminate: Some(terminate), + termination_handler: Some(terminate), .. }) = &self.test_helpers { @@ -647,7 +685,7 @@ impl AppServer { Err(ref e) => { self.verbose().then(|| { info!( - "error processing incoming message from {:?}: {:?} {}", + "error processing incoming message from {}: {:?} {}", endpoint, e, e.backtrace() @@ -686,9 +724,8 @@ impl AppServer { ) -> Result { match endpoint { Endpoint::SocketBoundAddress(socket) => { - let (hi_len, host_identification) = socket.to_bytes(); self.crypt - .handle_msg_under_load(&rx, &mut *tx, &host_identification[0..hi_len]) + .handle_msg_under_load(&rx, &mut *tx, socket) } Endpoint::Discovery(_) => { anyhow::bail!("Host-path discovery is not supported under load") @@ -876,10 +913,10 @@ impl AppServer { self.all_sockets_drained = false; return Ok(Some(( n, - Endpoint::SocketBoundAddress(SocketBoundEndpoint { - socket: SocketPtr(sock_no), + Endpoint::SocketBoundAddress(SocketBoundEndpoint::new( + SocketPtr(sock_no), addr, - }), + )), ))); } Err(e) if e.kind() == ErrorKind::WouldBlock => { diff --git a/rosenpass/src/protocol.rs b/rosenpass/src/protocol.rs index 1f15092..2c93166 100644 --- a/rosenpass/src/protocol.rs +++ b/rosenpass/src/protocol.rs @@ -65,14 +65,18 @@ //! # } //! ``` -use std::collections::hash_map::{ - Entry::{Occupied, Vacant}, - HashMap, -}; use std::convert::Infallible; use std::mem::size_of; +use std::{ + collections::hash_map::{ + Entry::{Occupied, Vacant}, + HashMap, + }, + fmt::Display, +}; use anyhow::{bail, ensure, Context, Result}; +use mio::net::SocketAddr; use rand::Fill as Randomize; use memoffset::span_of; @@ -902,6 +906,12 @@ pub struct HandleMsgResult { pub resp: Option, } +/// Trait for host identification types +pub trait HostIdentification: Display { + // Byte slice representing the host identification encoding + fn encode(&self) -> &[u8]; +} + impl CryptoServer { /// Process a message under load /// This is one of the main entry point for the protocol. @@ -911,16 +921,18 @@ impl CryptoServer { /// message for sender to process and verify for messages part of the handshake phase /// (i.e. InitHello, InitConf messages only). Bails on messages sent by responder and /// non-handshake messages. - pub fn handle_msg_under_load( + + pub fn handle_msg_under_load( &mut self, rx_buf: &[u8], tx_buf: &mut [u8], - host_identification: &[u8], + host_identification: &H, ) -> Result { let mut active_cookie_value: Option<[u8; COOKIE_SIZE]> = None; let mut rx_cookie = [0u8; COOKIE_SIZE]; let mut rx_mac = [0u8; MAC_SIZE]; let mut rx_sid = [0u8; 4]; + let msg_type : Result = rx_buf[0].try_into(); for cookie_secret in self.active_or_retired_cookie_secrets() { if let Some(cookie_secret) = cookie_secret { @@ -929,7 +941,7 @@ impl CryptoServer { cookie_value.copy_from_slice( &hash_domains::cookie_value()? .mix(cookie_secret)? - .mix(&host_identification)? + .mix(host_identification.encode())? .into_value()[..16], ); @@ -951,6 +963,7 @@ impl CryptoServer { //If valid cookie is found, process message if constant_time::memcmp(&rx_cookie, &expected) { + log::debug!("Rx {:?} from {} under load, valid cookie", msg_type, host_identification); let result = self.handle_msg(rx_buf, tx_buf)?; return Ok(result); } @@ -964,6 +977,8 @@ impl CryptoServer { bail!("No active cookie value found"); } + log::debug!("Rx {:?} from {} under load, tx cookie reply message", msg_type, host_identification); + let cookie_value = active_cookie_value.unwrap(); let cookie_key = hash_domains::cookie_key()? .mix(self.spkm.secret())? @@ -1083,7 +1098,11 @@ impl CryptoServer { ensure!(!rx_buf.is_empty(), "received empty message, ignoring it"); - let peer = match rx_buf[0].try_into() { + let msg_type = rx_buf[0].try_into(); + + log::debug!("Rx {:?}, processing", msg_type); + + let peer = match msg_type { Ok(MsgType::InitHello) => { let msg_in: Ref<&[u8], Envelope> = Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?; @@ -2134,6 +2153,26 @@ mod test { use super::*; + struct VecHostIdentifier(Vec); + + impl HostIdentification for VecHostIdentifier { + fn encode(&self) -> &[u8] { + &self.0 + } + } + + impl Display for VecHostIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + + impl From> for VecHostIdentifier { + fn from(v: Vec) -> Self { + VecHostIdentifier(v) + } + } + #[test] /// Ensure that the protocol implementation can deal with truncated /// messages and with overlong messages. @@ -2243,6 +2282,8 @@ mod test { ip_addr_port_a.extend_from_slice(&socket_addr_a.port().to_be_bytes()); + let ip_addr_port_a: VecHostIdentifier = ip_addr_port_a.into(); + //B handles handshake under load, should send cookie reply message with invalid cookie let HandleMsgResult { resp, .. } = b .handle_msg_under_load( @@ -2270,7 +2311,7 @@ mod test { .secret(), ) .unwrap() - .mix(&ip_addr_port_a) + .mix(&ip_addr_port_a.encode()) .unwrap() .into_value()[..16] .to_vec(); @@ -2355,12 +2396,15 @@ mod test { .copy_from_slice(&socket_addr_b.port().to_be_bytes()); ip_addr_port_b_len += 2; + let ip_addr_port_b: VecHostIdentifier = + ip_addr_port_b[..ip_addr_port_b_len].to_vec().into(); + //A handles RespHello message under load, should not send cookie reply assert!(a .handle_msg_under_load( &b_to_a_buf[..resp_hello_len], &mut *a_to_b_buf, - &ip_addr_port_b[..ip_addr_port_b_len] + &ip_addr_port_b ) .is_err()); }); diff --git a/rosenpass/tests/integration_test.rs b/rosenpass/tests/integration_test.rs index 0a895a9..fd6893e 100644 --- a/rosenpass/tests/integration_test.rs +++ b/rosenpass/tests/integration_test.rs @@ -1,14 +1,9 @@ -use std::{ - fs, net::UdpSocket, os::unix::thread::JoinHandleExt, path::PathBuf, process::Stdio, - time::Duration, -}; +use std::{fs::{self, write}, net::UdpSocket, path::PathBuf, process::Stdio, time::Duration}; use clap::Parser; -use rosenpass::{ - app_server::{AppServerTest, AppServerTestBuilder}, - cli::CliArgs, -}; +use rosenpass::{app_server::AppServerTestBuilder, cli::CliArgs}; use serial_test::serial; +use std::io::Write; const BIN: &str = "rosenpass"; @@ -139,7 +134,15 @@ fn check_exchange_under_normal() { #[test] #[serial] fn check_exchange_under_dos() { - procspawn::init(); + let mut log_builder = env_logger::Builder::from_default_env(); // sets log level filter from environment (or defaults) + log_builder.filter_level(log::LevelFilter::Debug); + log_builder.format_timestamp_nanos(); + log_builder.format(|buf, record| { + let ts_format = buf.timestamp_nanos().to_string(); + writeln!(buf, "\x1b[1m{:?}\x1b[0m {}: {}", std::thread::current().id(), &ts_format[18..], record.args()) + }); + + let _ = log_builder.try_init(); //Generate binary with responder with feature integration_test let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange-dos"); @@ -185,26 +188,21 @@ fn check_exchange_under_dos() { .arg("outfile") .arg(&shared_key_paths[0]); - let server_cmd: Vec = - server_cmd - .get_args() - .into_iter() - .fold(vec![BIN.to_string()], |mut acc, x| { - if let Some(s) = x.to_str() { - acc.push(s.to_string()); - } - acc - }); - let (server_terminate, server_terminate_rx) = std::sync::mpsc::channel(); - let mut server = std::thread::spawn(move || { - let cli = CliArgs::try_parse_from(server_cmd.iter()).unwrap(); + let cli = CliArgs::try_parse_from( + [server_cmd.get_program()] + .into_iter() + .chain(server_cmd.get_args()), + ) + .unwrap(); + + std::thread::spawn(move || { cli.command .run(Some( AppServerTestBuilder::default() .enable_dos_permanently(true) - .terminate(Some(server_terminate_rx)) + .termination_handler(Some(server_terminate_rx)) .build() .unwrap(), )) @@ -214,7 +212,8 @@ fn check_exchange_under_dos() { std::thread::sleep(Duration::from_millis(500)); // start second process, the client - let mut client = test_bin::get_test_bin(BIN) + let mut client_cmd = std::process::Command::new(BIN); + client_cmd .args(["exchange", "secret-key"]) .arg(&secret_key_paths[1]) .arg("public-key") @@ -223,16 +222,34 @@ fn check_exchange_under_dos() { .arg(&public_key_paths[0]) .args(["endpoint", &listen_addr]) .arg("outfile") - .arg(&shared_key_paths[1]) - .spawn() - .expect("Failed to start {BIN}"); + .arg(&shared_key_paths[1]); + + let (client_terminate, client_terminate_rx) = std::sync::mpsc::channel(); + + let cli = CliArgs::try_parse_from( + [client_cmd.get_program()] + .into_iter() + .chain(client_cmd.get_args()), + ) + .unwrap(); + + std::thread::spawn(move || { + cli.command + .run(Some( + AppServerTestBuilder::default() + .termination_handler(Some(client_terminate_rx)) + .build() + .unwrap(), + )) + .unwrap(); + }); // give them some time to do the key exchange under load std::thread::sleep(Duration::from_secs(10)); // time's up, kill the childs server_terminate.send(()).unwrap(); - client.kill().unwrap(); + client_terminate.send(()).unwrap(); // read the shared keys they created let shared_keys: Vec<_> = shared_key_paths