Add HostIdentification trait, add logging

This commit is contained in:
Prabhpreet Dua
2024-04-16 11:17:03 +05:30
parent e7de4848fb
commit 8420d953eb
3 changed files with 150 additions and 52 deletions

View File

@@ -1,6 +1,7 @@
use anyhow::bail;
use anyhow::Result;
use clap::builder;
use derive_builder::Builder;
use log::{debug, error, info, warn};
use mio::Interest;
@@ -17,6 +18,7 @@ use std::net::SocketAddr;
use std::net::SocketAddrV4;
use std::net::SocketAddrV6;
use std::net::ToSocketAddrs;
use std::path::Display;
use std::path::PathBuf;
use std::process::Command;
use std::process::Stdio;
@@ -25,6 +27,7 @@ use std::thread;
use std::time::Duration;
use std::time::Instant;
use crate::protocol::HostIdentification;
use crate::{
config::Verbosity,
protocol::{CryptoServer, MsgBuf, PeerPtr, SPk, SSk, SymKey, Timing},
@@ -82,9 +85,11 @@ pub enum DoSOperation {
#[builder(pattern = "owned")]
pub struct AppServerTest {
/// Enable DoS operation permanently
#[builder(default = "false")]
pub enable_dos_permanently: bool,
/// Terminate application signal
pub terminate: Option<std::sync::mpsc::Receiver<()>>,
#[builder(default = "None")]
pub termination_handler: Option<std::sync::mpsc::Receiver<()>>,
}
/// Holds the state of the application, namely the external IO
@@ -196,13 +201,24 @@ pub enum Endpoint {
Discovery(HostPathDiscoveryEndpoint),
}
impl std::fmt::Display for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Endpoint::SocketBoundAddress(host) => write!(f, "{}", host),
Endpoint::Discovery(host) => write!(f, "{}", host),
}
}
}
#[derive(Debug)]
pub struct SocketBoundEndpoint {
/// The socket the address can be reached under; this is generally
/// determined when we actually receive an RespHello message
pub socket: SocketPtr,
socket: SocketPtr,
/// Just the address
pub addr: SocketAddr,
addr: SocketAddr,
/// identifier
bytes: (usize,[u8;SocketBoundEndpoint::BUFFER_SIZE])
}
impl std::fmt::Display for SocketBoundEndpoint {
@@ -221,19 +237,29 @@ impl SocketBoundEndpoint {
+ SocketBoundEndpoint::IPV6_SIZE
+ SocketBoundEndpoint::PORT_SIZE
+ SocketBoundEndpoint::SCOPE_ID_SIZE;
pub fn to_bytes(&self) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) {
pub fn new(socket: SocketPtr, addr: SocketAddr) -> Self {
let bytes = Self::to_bytes(&socket, &addr);
Self {
socket,
addr,
bytes,
}
}
fn to_bytes(socket: &SocketPtr, addr: &SocketAddr) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) {
let mut buf = [0u8; SocketBoundEndpoint::BUFFER_SIZE];
let addr = match self.addr {
let addr = match addr {
SocketAddr::V4(addr) => {
//Map IPv4-mapped to IPv6 addresses
let ip = addr.ip().to_ipv6_mapped();
SocketAddrV6::new(ip, addr.port(), 0, 0)
}
SocketAddr::V6(addr) => addr,
SocketAddr::V6(addr) => addr.clone(),
};
let mut len: usize = 0;
buf[len..len + SocketBoundEndpoint::SOCKET_SIZE]
.copy_from_slice(&self.socket.0.to_be_bytes());
.copy_from_slice(&socket.0.to_be_bytes());
len += SocketBoundEndpoint::SOCKET_SIZE;
buf[len..len + SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets());
len += SocketBoundEndpoint::IPV6_SIZE;
@@ -246,6 +272,12 @@ impl SocketBoundEndpoint {
}
}
impl HostIdentification for SocketBoundEndpoint {
fn encode(&self) -> &[u8] {
&self.bytes.1[0..self.bytes.0]
}
}
impl Endpoint {
/// Start discovery from some addresses
pub fn discovery_from_addresses(addresses: Vec<SocketAddr>) -> Self {
@@ -332,6 +364,12 @@ pub struct HostPathDiscoveryEndpoint {
addresses: Vec<SocketAddr>,
}
impl std::fmt::Display for HostPathDiscoveryEndpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.addresses)
}
}
impl HostPathDiscoveryEndpoint {
pub fn from_addresses(addresses: Vec<SocketAddr>) -> Self {
let scouting_state = Cell::new((0, 0));
@@ -604,7 +642,7 @@ impl AppServer {
use KeyOutputReason::*;
if let Some(AppServerTest {
terminate: Some(terminate),
termination_handler: Some(terminate),
..
}) = &self.test_helpers
{
@@ -647,7 +685,7 @@ impl AppServer {
Err(ref e) => {
self.verbose().then(|| {
info!(
"error processing incoming message from {:?}: {:?} {}",
"error processing incoming message from {}: {:?} {}",
endpoint,
e,
e.backtrace()
@@ -686,9 +724,8 @@ impl AppServer {
) -> Result<crate::protocol::HandleMsgResult> {
match endpoint {
Endpoint::SocketBoundAddress(socket) => {
let (hi_len, host_identification) = socket.to_bytes();
self.crypt
.handle_msg_under_load(&rx, &mut *tx, &host_identification[0..hi_len])
.handle_msg_under_load(&rx, &mut *tx, socket)
}
Endpoint::Discovery(_) => {
anyhow::bail!("Host-path discovery is not supported under load")
@@ -876,10 +913,10 @@ impl AppServer {
self.all_sockets_drained = false;
return Ok(Some((
n,
Endpoint::SocketBoundAddress(SocketBoundEndpoint {
socket: SocketPtr(sock_no),
Endpoint::SocketBoundAddress(SocketBoundEndpoint::new(
SocketPtr(sock_no),
addr,
}),
)),
)));
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {

View File

@@ -65,14 +65,18 @@
//! # }
//! ```
use std::collections::hash_map::{
Entry::{Occupied, Vacant},
HashMap,
};
use std::convert::Infallible;
use std::mem::size_of;
use std::{
collections::hash_map::{
Entry::{Occupied, Vacant},
HashMap,
},
fmt::Display,
};
use anyhow::{bail, ensure, Context, Result};
use mio::net::SocketAddr;
use rand::Fill as Randomize;
use memoffset::span_of;
@@ -902,6 +906,12 @@ pub struct HandleMsgResult {
pub resp: Option<usize>,
}
/// Trait for host identification types
pub trait HostIdentification: Display {
// Byte slice representing the host identification encoding
fn encode(&self) -> &[u8];
}
impl CryptoServer {
/// Process a message under load
/// This is one of the main entry point for the protocol.
@@ -911,16 +921,18 @@ impl CryptoServer {
/// message for sender to process and verify for messages part of the handshake phase
/// (i.e. InitHello, InitConf messages only). Bails on messages sent by responder and
/// non-handshake messages.
pub fn handle_msg_under_load(
pub fn handle_msg_under_load<H: HostIdentification>(
&mut self,
rx_buf: &[u8],
tx_buf: &mut [u8],
host_identification: &[u8],
host_identification: &H,
) -> Result<HandleMsgResult> {
let mut active_cookie_value: Option<[u8; COOKIE_SIZE]> = None;
let mut rx_cookie = [0u8; COOKIE_SIZE];
let mut rx_mac = [0u8; MAC_SIZE];
let mut rx_sid = [0u8; 4];
let msg_type : Result<MsgType,_> = rx_buf[0].try_into();
for cookie_secret in self.active_or_retired_cookie_secrets() {
if let Some(cookie_secret) = cookie_secret {
@@ -929,7 +941,7 @@ impl CryptoServer {
cookie_value.copy_from_slice(
&hash_domains::cookie_value()?
.mix(cookie_secret)?
.mix(&host_identification)?
.mix(host_identification.encode())?
.into_value()[..16],
);
@@ -951,6 +963,7 @@ impl CryptoServer {
//If valid cookie is found, process message
if constant_time::memcmp(&rx_cookie, &expected) {
log::debug!("Rx {:?} from {} under load, valid cookie", msg_type, host_identification);
let result = self.handle_msg(rx_buf, tx_buf)?;
return Ok(result);
}
@@ -964,6 +977,8 @@ impl CryptoServer {
bail!("No active cookie value found");
}
log::debug!("Rx {:?} from {} under load, tx cookie reply message", msg_type, host_identification);
let cookie_value = active_cookie_value.unwrap();
let cookie_key = hash_domains::cookie_key()?
.mix(self.spkm.secret())?
@@ -1083,7 +1098,11 @@ impl CryptoServer {
ensure!(!rx_buf.is_empty(), "received empty message, ignoring it");
let peer = match rx_buf[0].try_into() {
let msg_type = rx_buf[0].try_into();
log::debug!("Rx {:?}, processing", msg_type);
let peer = match msg_type {
Ok(MsgType::InitHello) => {
let msg_in: Ref<&[u8], Envelope<InitHello>> =
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
@@ -2134,6 +2153,26 @@ mod test {
use super::*;
struct VecHostIdentifier(Vec<u8>);
impl HostIdentification for VecHostIdentifier {
fn encode(&self) -> &[u8] {
&self.0
}
}
impl Display for VecHostIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl From<Vec<u8>> for VecHostIdentifier {
fn from(v: Vec<u8>) -> Self {
VecHostIdentifier(v)
}
}
#[test]
/// Ensure that the protocol implementation can deal with truncated
/// messages and with overlong messages.
@@ -2243,6 +2282,8 @@ mod test {
ip_addr_port_a.extend_from_slice(&socket_addr_a.port().to_be_bytes());
let ip_addr_port_a: VecHostIdentifier = ip_addr_port_a.into();
//B handles handshake under load, should send cookie reply message with invalid cookie
let HandleMsgResult { resp, .. } = b
.handle_msg_under_load(
@@ -2270,7 +2311,7 @@ mod test {
.secret(),
)
.unwrap()
.mix(&ip_addr_port_a)
.mix(&ip_addr_port_a.encode())
.unwrap()
.into_value()[..16]
.to_vec();
@@ -2355,12 +2396,15 @@ mod test {
.copy_from_slice(&socket_addr_b.port().to_be_bytes());
ip_addr_port_b_len += 2;
let ip_addr_port_b: VecHostIdentifier =
ip_addr_port_b[..ip_addr_port_b_len].to_vec().into();
//A handles RespHello message under load, should not send cookie reply
assert!(a
.handle_msg_under_load(
&b_to_a_buf[..resp_hello_len],
&mut *a_to_b_buf,
&ip_addr_port_b[..ip_addr_port_b_len]
&ip_addr_port_b
)
.is_err());
});

View File

@@ -1,14 +1,9 @@
use std::{
fs, net::UdpSocket, os::unix::thread::JoinHandleExt, path::PathBuf, process::Stdio,
time::Duration,
};
use std::{fs::{self, write}, net::UdpSocket, path::PathBuf, process::Stdio, time::Duration};
use clap::Parser;
use rosenpass::{
app_server::{AppServerTest, AppServerTestBuilder},
cli::CliArgs,
};
use rosenpass::{app_server::AppServerTestBuilder, cli::CliArgs};
use serial_test::serial;
use std::io::Write;
const BIN: &str = "rosenpass";
@@ -139,7 +134,15 @@ fn check_exchange_under_normal() {
#[test]
#[serial]
fn check_exchange_under_dos() {
procspawn::init();
let mut log_builder = env_logger::Builder::from_default_env(); // sets log level filter from environment (or defaults)
log_builder.filter_level(log::LevelFilter::Debug);
log_builder.format_timestamp_nanos();
log_builder.format(|buf, record| {
let ts_format = buf.timestamp_nanos().to_string();
writeln!(buf, "\x1b[1m{:?}\x1b[0m {}: {}", std::thread::current().id(), &ts_format[18..], record.args())
});
let _ = log_builder.try_init();
//Generate binary with responder with feature integration_test
let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange-dos");
@@ -185,26 +188,21 @@ fn check_exchange_under_dos() {
.arg("outfile")
.arg(&shared_key_paths[0]);
let server_cmd: Vec<String> =
server_cmd
.get_args()
.into_iter()
.fold(vec![BIN.to_string()], |mut acc, x| {
if let Some(s) = x.to_str() {
acc.push(s.to_string());
}
acc
});
let (server_terminate, server_terminate_rx) = std::sync::mpsc::channel();
let mut server = std::thread::spawn(move || {
let cli = CliArgs::try_parse_from(server_cmd.iter()).unwrap();
let cli = CliArgs::try_parse_from(
[server_cmd.get_program()]
.into_iter()
.chain(server_cmd.get_args()),
)
.unwrap();
std::thread::spawn(move || {
cli.command
.run(Some(
AppServerTestBuilder::default()
.enable_dos_permanently(true)
.terminate(Some(server_terminate_rx))
.termination_handler(Some(server_terminate_rx))
.build()
.unwrap(),
))
@@ -214,7 +212,8 @@ fn check_exchange_under_dos() {
std::thread::sleep(Duration::from_millis(500));
// start second process, the client
let mut client = test_bin::get_test_bin(BIN)
let mut client_cmd = std::process::Command::new(BIN);
client_cmd
.args(["exchange", "secret-key"])
.arg(&secret_key_paths[1])
.arg("public-key")
@@ -223,16 +222,34 @@ fn check_exchange_under_dos() {
.arg(&public_key_paths[0])
.args(["endpoint", &listen_addr])
.arg("outfile")
.arg(&shared_key_paths[1])
.spawn()
.expect("Failed to start {BIN}");
.arg(&shared_key_paths[1]);
let (client_terminate, client_terminate_rx) = std::sync::mpsc::channel();
let cli = CliArgs::try_parse_from(
[client_cmd.get_program()]
.into_iter()
.chain(client_cmd.get_args()),
)
.unwrap();
std::thread::spawn(move || {
cli.command
.run(Some(
AppServerTestBuilder::default()
.termination_handler(Some(client_terminate_rx))
.build()
.unwrap(),
))
.unwrap();
});
// give them some time to do the key exchange under load
std::thread::sleep(Duration::from_secs(10));
// time's up, kill the childs
server_terminate.send(()).unwrap();
client.kill().unwrap();
client_terminate.send(()).unwrap();
// read the shared keys they created
let shared_keys: Vec<_> = shared_key_paths