Add HostIdentification trait, add logging

This commit is contained in:
Prabhpreet Dua
2024-04-16 11:17:03 +05:30
parent e7de4848fb
commit 8420d953eb
3 changed files with 150 additions and 52 deletions

View File

@@ -1,6 +1,7 @@
use anyhow::bail; use anyhow::bail;
use anyhow::Result; use anyhow::Result;
use clap::builder;
use derive_builder::Builder; use derive_builder::Builder;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use mio::Interest; use mio::Interest;
@@ -17,6 +18,7 @@ use std::net::SocketAddr;
use std::net::SocketAddrV4; use std::net::SocketAddrV4;
use std::net::SocketAddrV6; use std::net::SocketAddrV6;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::path::Display;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Command; use std::process::Command;
use std::process::Stdio; use std::process::Stdio;
@@ -25,6 +27,7 @@ use std::thread;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use crate::protocol::HostIdentification;
use crate::{ use crate::{
config::Verbosity, config::Verbosity,
protocol::{CryptoServer, MsgBuf, PeerPtr, SPk, SSk, SymKey, Timing}, protocol::{CryptoServer, MsgBuf, PeerPtr, SPk, SSk, SymKey, Timing},
@@ -82,9 +85,11 @@ pub enum DoSOperation {
#[builder(pattern = "owned")] #[builder(pattern = "owned")]
pub struct AppServerTest { pub struct AppServerTest {
/// Enable DoS operation permanently /// Enable DoS operation permanently
#[builder(default = "false")]
pub enable_dos_permanently: bool, pub enable_dos_permanently: bool,
/// Terminate application signal /// Terminate application signal
pub terminate: Option<std::sync::mpsc::Receiver<()>>, #[builder(default = "None")]
pub termination_handler: Option<std::sync::mpsc::Receiver<()>>,
} }
/// Holds the state of the application, namely the external IO /// Holds the state of the application, namely the external IO
@@ -196,13 +201,24 @@ pub enum Endpoint {
Discovery(HostPathDiscoveryEndpoint), Discovery(HostPathDiscoveryEndpoint),
} }
impl std::fmt::Display for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Endpoint::SocketBoundAddress(host) => write!(f, "{}", host),
Endpoint::Discovery(host) => write!(f, "{}", host),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct SocketBoundEndpoint { pub struct SocketBoundEndpoint {
/// The socket the address can be reached under; this is generally /// The socket the address can be reached under; this is generally
/// determined when we actually receive an RespHello message /// determined when we actually receive an RespHello message
pub socket: SocketPtr, socket: SocketPtr,
/// Just the address /// Just the address
pub addr: SocketAddr, addr: SocketAddr,
/// identifier
bytes: (usize,[u8;SocketBoundEndpoint::BUFFER_SIZE])
} }
impl std::fmt::Display for SocketBoundEndpoint { impl std::fmt::Display for SocketBoundEndpoint {
@@ -221,19 +237,29 @@ impl SocketBoundEndpoint {
+ SocketBoundEndpoint::IPV6_SIZE + SocketBoundEndpoint::IPV6_SIZE
+ SocketBoundEndpoint::PORT_SIZE + SocketBoundEndpoint::PORT_SIZE
+ SocketBoundEndpoint::SCOPE_ID_SIZE; + SocketBoundEndpoint::SCOPE_ID_SIZE;
pub fn to_bytes(&self) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) {
pub fn new(socket: SocketPtr, addr: SocketAddr) -> Self {
let bytes = Self::to_bytes(&socket, &addr);
Self {
socket,
addr,
bytes,
}
}
fn to_bytes(socket: &SocketPtr, addr: &SocketAddr) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) {
let mut buf = [0u8; SocketBoundEndpoint::BUFFER_SIZE]; let mut buf = [0u8; SocketBoundEndpoint::BUFFER_SIZE];
let addr = match self.addr { let addr = match addr {
SocketAddr::V4(addr) => { SocketAddr::V4(addr) => {
//Map IPv4-mapped to IPv6 addresses //Map IPv4-mapped to IPv6 addresses
let ip = addr.ip().to_ipv6_mapped(); let ip = addr.ip().to_ipv6_mapped();
SocketAddrV6::new(ip, addr.port(), 0, 0) SocketAddrV6::new(ip, addr.port(), 0, 0)
} }
SocketAddr::V6(addr) => addr, SocketAddr::V6(addr) => addr.clone(),
}; };
let mut len: usize = 0; let mut len: usize = 0;
buf[len..len + SocketBoundEndpoint::SOCKET_SIZE] buf[len..len + SocketBoundEndpoint::SOCKET_SIZE]
.copy_from_slice(&self.socket.0.to_be_bytes()); .copy_from_slice(&socket.0.to_be_bytes());
len += SocketBoundEndpoint::SOCKET_SIZE; len += SocketBoundEndpoint::SOCKET_SIZE;
buf[len..len + SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets()); buf[len..len + SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets());
len += SocketBoundEndpoint::IPV6_SIZE; len += SocketBoundEndpoint::IPV6_SIZE;
@@ -246,6 +272,12 @@ impl SocketBoundEndpoint {
} }
} }
impl HostIdentification for SocketBoundEndpoint {
fn encode(&self) -> &[u8] {
&self.bytes.1[0..self.bytes.0]
}
}
impl Endpoint { impl Endpoint {
/// Start discovery from some addresses /// Start discovery from some addresses
pub fn discovery_from_addresses(addresses: Vec<SocketAddr>) -> Self { pub fn discovery_from_addresses(addresses: Vec<SocketAddr>) -> Self {
@@ -332,6 +364,12 @@ pub struct HostPathDiscoveryEndpoint {
addresses: Vec<SocketAddr>, addresses: Vec<SocketAddr>,
} }
impl std::fmt::Display for HostPathDiscoveryEndpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.addresses)
}
}
impl HostPathDiscoveryEndpoint { impl HostPathDiscoveryEndpoint {
pub fn from_addresses(addresses: Vec<SocketAddr>) -> Self { pub fn from_addresses(addresses: Vec<SocketAddr>) -> Self {
let scouting_state = Cell::new((0, 0)); let scouting_state = Cell::new((0, 0));
@@ -604,7 +642,7 @@ impl AppServer {
use KeyOutputReason::*; use KeyOutputReason::*;
if let Some(AppServerTest { if let Some(AppServerTest {
terminate: Some(terminate), termination_handler: Some(terminate),
.. ..
}) = &self.test_helpers }) = &self.test_helpers
{ {
@@ -647,7 +685,7 @@ impl AppServer {
Err(ref e) => { Err(ref e) => {
self.verbose().then(|| { self.verbose().then(|| {
info!( info!(
"error processing incoming message from {:?}: {:?} {}", "error processing incoming message from {}: {:?} {}",
endpoint, endpoint,
e, e,
e.backtrace() e.backtrace()
@@ -686,9 +724,8 @@ impl AppServer {
) -> Result<crate::protocol::HandleMsgResult> { ) -> Result<crate::protocol::HandleMsgResult> {
match endpoint { match endpoint {
Endpoint::SocketBoundAddress(socket) => { Endpoint::SocketBoundAddress(socket) => {
let (hi_len, host_identification) = socket.to_bytes();
self.crypt self.crypt
.handle_msg_under_load(&rx, &mut *tx, &host_identification[0..hi_len]) .handle_msg_under_load(&rx, &mut *tx, socket)
} }
Endpoint::Discovery(_) => { Endpoint::Discovery(_) => {
anyhow::bail!("Host-path discovery is not supported under load") anyhow::bail!("Host-path discovery is not supported under load")
@@ -876,10 +913,10 @@ impl AppServer {
self.all_sockets_drained = false; self.all_sockets_drained = false;
return Ok(Some(( return Ok(Some((
n, n,
Endpoint::SocketBoundAddress(SocketBoundEndpoint { Endpoint::SocketBoundAddress(SocketBoundEndpoint::new(
socket: SocketPtr(sock_no), SocketPtr(sock_no),
addr, addr,
}), )),
))); )));
} }
Err(e) if e.kind() == ErrorKind::WouldBlock => { Err(e) if e.kind() == ErrorKind::WouldBlock => {

View File

@@ -65,14 +65,18 @@
//! # } //! # }
//! ``` //! ```
use std::collections::hash_map::{
Entry::{Occupied, Vacant},
HashMap,
};
use std::convert::Infallible; use std::convert::Infallible;
use std::mem::size_of; use std::mem::size_of;
use std::{
collections::hash_map::{
Entry::{Occupied, Vacant},
HashMap,
},
fmt::Display,
};
use anyhow::{bail, ensure, Context, Result}; use anyhow::{bail, ensure, Context, Result};
use mio::net::SocketAddr;
use rand::Fill as Randomize; use rand::Fill as Randomize;
use memoffset::span_of; use memoffset::span_of;
@@ -902,6 +906,12 @@ pub struct HandleMsgResult {
pub resp: Option<usize>, pub resp: Option<usize>,
} }
/// Trait for host identification types
pub trait HostIdentification: Display {
// Byte slice representing the host identification encoding
fn encode(&self) -> &[u8];
}
impl CryptoServer { impl CryptoServer {
/// Process a message under load /// Process a message under load
/// This is one of the main entry point for the protocol. /// This is one of the main entry point for the protocol.
@@ -911,16 +921,18 @@ impl CryptoServer {
/// message for sender to process and verify for messages part of the handshake phase /// message for sender to process and verify for messages part of the handshake phase
/// (i.e. InitHello, InitConf messages only). Bails on messages sent by responder and /// (i.e. InitHello, InitConf messages only). Bails on messages sent by responder and
/// non-handshake messages. /// non-handshake messages.
pub fn handle_msg_under_load(
pub fn handle_msg_under_load<H: HostIdentification>(
&mut self, &mut self,
rx_buf: &[u8], rx_buf: &[u8],
tx_buf: &mut [u8], tx_buf: &mut [u8],
host_identification: &[u8], host_identification: &H,
) -> Result<HandleMsgResult> { ) -> Result<HandleMsgResult> {
let mut active_cookie_value: Option<[u8; COOKIE_SIZE]> = None; let mut active_cookie_value: Option<[u8; COOKIE_SIZE]> = None;
let mut rx_cookie = [0u8; COOKIE_SIZE]; let mut rx_cookie = [0u8; COOKIE_SIZE];
let mut rx_mac = [0u8; MAC_SIZE]; let mut rx_mac = [0u8; MAC_SIZE];
let mut rx_sid = [0u8; 4]; let mut rx_sid = [0u8; 4];
let msg_type : Result<MsgType,_> = rx_buf[0].try_into();
for cookie_secret in self.active_or_retired_cookie_secrets() { for cookie_secret in self.active_or_retired_cookie_secrets() {
if let Some(cookie_secret) = cookie_secret { if let Some(cookie_secret) = cookie_secret {
@@ -929,7 +941,7 @@ impl CryptoServer {
cookie_value.copy_from_slice( cookie_value.copy_from_slice(
&hash_domains::cookie_value()? &hash_domains::cookie_value()?
.mix(cookie_secret)? .mix(cookie_secret)?
.mix(&host_identification)? .mix(host_identification.encode())?
.into_value()[..16], .into_value()[..16],
); );
@@ -951,6 +963,7 @@ impl CryptoServer {
//If valid cookie is found, process message //If valid cookie is found, process message
if constant_time::memcmp(&rx_cookie, &expected) { if constant_time::memcmp(&rx_cookie, &expected) {
log::debug!("Rx {:?} from {} under load, valid cookie", msg_type, host_identification);
let result = self.handle_msg(rx_buf, tx_buf)?; let result = self.handle_msg(rx_buf, tx_buf)?;
return Ok(result); return Ok(result);
} }
@@ -964,6 +977,8 @@ impl CryptoServer {
bail!("No active cookie value found"); bail!("No active cookie value found");
} }
log::debug!("Rx {:?} from {} under load, tx cookie reply message", msg_type, host_identification);
let cookie_value = active_cookie_value.unwrap(); let cookie_value = active_cookie_value.unwrap();
let cookie_key = hash_domains::cookie_key()? let cookie_key = hash_domains::cookie_key()?
.mix(self.spkm.secret())? .mix(self.spkm.secret())?
@@ -1083,7 +1098,11 @@ impl CryptoServer {
ensure!(!rx_buf.is_empty(), "received empty message, ignoring it"); ensure!(!rx_buf.is_empty(), "received empty message, ignoring it");
let peer = match rx_buf[0].try_into() { let msg_type = rx_buf[0].try_into();
log::debug!("Rx {:?}, processing", msg_type);
let peer = match msg_type {
Ok(MsgType::InitHello) => { Ok(MsgType::InitHello) => {
let msg_in: Ref<&[u8], Envelope<InitHello>> = let msg_in: Ref<&[u8], Envelope<InitHello>> =
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?; Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
@@ -2134,6 +2153,26 @@ mod test {
use super::*; use super::*;
struct VecHostIdentifier(Vec<u8>);
impl HostIdentification for VecHostIdentifier {
fn encode(&self) -> &[u8] {
&self.0
}
}
impl Display for VecHostIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl From<Vec<u8>> for VecHostIdentifier {
fn from(v: Vec<u8>) -> Self {
VecHostIdentifier(v)
}
}
#[test] #[test]
/// Ensure that the protocol implementation can deal with truncated /// Ensure that the protocol implementation can deal with truncated
/// messages and with overlong messages. /// messages and with overlong messages.
@@ -2243,6 +2282,8 @@ mod test {
ip_addr_port_a.extend_from_slice(&socket_addr_a.port().to_be_bytes()); ip_addr_port_a.extend_from_slice(&socket_addr_a.port().to_be_bytes());
let ip_addr_port_a: VecHostIdentifier = ip_addr_port_a.into();
//B handles handshake under load, should send cookie reply message with invalid cookie //B handles handshake under load, should send cookie reply message with invalid cookie
let HandleMsgResult { resp, .. } = b let HandleMsgResult { resp, .. } = b
.handle_msg_under_load( .handle_msg_under_load(
@@ -2270,7 +2311,7 @@ mod test {
.secret(), .secret(),
) )
.unwrap() .unwrap()
.mix(&ip_addr_port_a) .mix(&ip_addr_port_a.encode())
.unwrap() .unwrap()
.into_value()[..16] .into_value()[..16]
.to_vec(); .to_vec();
@@ -2355,12 +2396,15 @@ mod test {
.copy_from_slice(&socket_addr_b.port().to_be_bytes()); .copy_from_slice(&socket_addr_b.port().to_be_bytes());
ip_addr_port_b_len += 2; ip_addr_port_b_len += 2;
let ip_addr_port_b: VecHostIdentifier =
ip_addr_port_b[..ip_addr_port_b_len].to_vec().into();
//A handles RespHello message under load, should not send cookie reply //A handles RespHello message under load, should not send cookie reply
assert!(a assert!(a
.handle_msg_under_load( .handle_msg_under_load(
&b_to_a_buf[..resp_hello_len], &b_to_a_buf[..resp_hello_len],
&mut *a_to_b_buf, &mut *a_to_b_buf,
&ip_addr_port_b[..ip_addr_port_b_len] &ip_addr_port_b
) )
.is_err()); .is_err());
}); });

View File

@@ -1,14 +1,9 @@
use std::{ use std::{fs::{self, write}, net::UdpSocket, path::PathBuf, process::Stdio, time::Duration};
fs, net::UdpSocket, os::unix::thread::JoinHandleExt, path::PathBuf, process::Stdio,
time::Duration,
};
use clap::Parser; use clap::Parser;
use rosenpass::{ use rosenpass::{app_server::AppServerTestBuilder, cli::CliArgs};
app_server::{AppServerTest, AppServerTestBuilder},
cli::CliArgs,
};
use serial_test::serial; use serial_test::serial;
use std::io::Write;
const BIN: &str = "rosenpass"; const BIN: &str = "rosenpass";
@@ -139,7 +134,15 @@ fn check_exchange_under_normal() {
#[test] #[test]
#[serial] #[serial]
fn check_exchange_under_dos() { fn check_exchange_under_dos() {
procspawn::init(); let mut log_builder = env_logger::Builder::from_default_env(); // sets log level filter from environment (or defaults)
log_builder.filter_level(log::LevelFilter::Debug);
log_builder.format_timestamp_nanos();
log_builder.format(|buf, record| {
let ts_format = buf.timestamp_nanos().to_string();
writeln!(buf, "\x1b[1m{:?}\x1b[0m {}: {}", std::thread::current().id(), &ts_format[18..], record.args())
});
let _ = log_builder.try_init();
//Generate binary with responder with feature integration_test //Generate binary with responder with feature integration_test
let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange-dos"); let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange-dos");
@@ -185,26 +188,21 @@ fn check_exchange_under_dos() {
.arg("outfile") .arg("outfile")
.arg(&shared_key_paths[0]); .arg(&shared_key_paths[0]);
let server_cmd: Vec<String> =
server_cmd
.get_args()
.into_iter()
.fold(vec![BIN.to_string()], |mut acc, x| {
if let Some(s) = x.to_str() {
acc.push(s.to_string());
}
acc
});
let (server_terminate, server_terminate_rx) = std::sync::mpsc::channel(); let (server_terminate, server_terminate_rx) = std::sync::mpsc::channel();
let mut server = std::thread::spawn(move || { let cli = CliArgs::try_parse_from(
let cli = CliArgs::try_parse_from(server_cmd.iter()).unwrap(); [server_cmd.get_program()]
.into_iter()
.chain(server_cmd.get_args()),
)
.unwrap();
std::thread::spawn(move || {
cli.command cli.command
.run(Some( .run(Some(
AppServerTestBuilder::default() AppServerTestBuilder::default()
.enable_dos_permanently(true) .enable_dos_permanently(true)
.terminate(Some(server_terminate_rx)) .termination_handler(Some(server_terminate_rx))
.build() .build()
.unwrap(), .unwrap(),
)) ))
@@ -214,7 +212,8 @@ fn check_exchange_under_dos() {
std::thread::sleep(Duration::from_millis(500)); std::thread::sleep(Duration::from_millis(500));
// start second process, the client // start second process, the client
let mut client = test_bin::get_test_bin(BIN) let mut client_cmd = std::process::Command::new(BIN);
client_cmd
.args(["exchange", "secret-key"]) .args(["exchange", "secret-key"])
.arg(&secret_key_paths[1]) .arg(&secret_key_paths[1])
.arg("public-key") .arg("public-key")
@@ -223,16 +222,34 @@ fn check_exchange_under_dos() {
.arg(&public_key_paths[0]) .arg(&public_key_paths[0])
.args(["endpoint", &listen_addr]) .args(["endpoint", &listen_addr])
.arg("outfile") .arg("outfile")
.arg(&shared_key_paths[1]) .arg(&shared_key_paths[1]);
.spawn()
.expect("Failed to start {BIN}"); let (client_terminate, client_terminate_rx) = std::sync::mpsc::channel();
let cli = CliArgs::try_parse_from(
[client_cmd.get_program()]
.into_iter()
.chain(client_cmd.get_args()),
)
.unwrap();
std::thread::spawn(move || {
cli.command
.run(Some(
AppServerTestBuilder::default()
.termination_handler(Some(client_terminate_rx))
.build()
.unwrap(),
))
.unwrap();
});
// give them some time to do the key exchange under load // give them some time to do the key exchange under load
std::thread::sleep(Duration::from_secs(10)); std::thread::sleep(Duration::from_secs(10));
// time's up, kill the childs // time's up, kill the childs
server_terminate.send(()).unwrap(); server_terminate.send(()).unwrap();
client.kill().unwrap(); client_terminate.send(()).unwrap();
// read the shared keys they created // read the shared keys they created
let shared_keys: Vec<_> = shared_key_paths let shared_keys: Vec<_> = shared_key_paths