Checkpoint

This commit is contained in:
Prabhpreet Dua
2024-02-04 11:39:34 +05:30
parent efd0ce51cb
commit 3498ab2d7b
2 changed files with 155 additions and 28 deletions

View File

@@ -34,9 +34,10 @@ use rosenpass_util::b64::{b64_writer, fmt_b64};
const IPV4_ANY_ADDR: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0); const IPV4_ANY_ADDR: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0);
const IPV6_ANY_ADDR: Ipv6Addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0); const IPV6_ANY_ADDR: Ipv6Addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
// Using values from Linux Kernel implementation const NORMAL_OPERATION_THRESHOLD: usize = 5;
// TODO: Customize values for rosenpass const UNDER_LOAD_THRESHOLD: usize = 10;
const MAX_QUEUED_INCOMING_HANDSHAKES_THRESHOLD: usize = 4096; const RESET_DURATION: Duration = Duration::from_secs(1);
const LAST_UNDER_LOAD_WINDOW: Duration = Duration::from_secs(1); const LAST_UNDER_LOAD_WINDOW: Duration = Duration::from_secs(1);
fn ipv4_any_binding() -> SocketAddr { fn ipv4_any_binding() -> SocketAddr {
@@ -76,7 +77,7 @@ pub struct WireguardOut {
#[derive(Debug)] #[derive(Debug)]
pub enum DoSOperation { pub enum DoSOperation {
UnderLoad { last_under_load: Instant }, UnderLoad { last_under_load: Instant },
Normal, Normal{blocked_polls: usize},
} }
/// Holds the state of the application, namely the external IO /// Holds the state of the application, namely the external IO
@@ -199,15 +200,17 @@ impl std::fmt::Display for SocketBoundEndpoint {
} }
impl SocketBoundEndpoint { impl SocketBoundEndpoint {
const SOCKET_SIZE: usize = usize::BITS as usize / 8;
const SOCKET_SIZE: usize = usize::BITS as usize/8;
const IPV6_SIZE: usize = 16; const IPV6_SIZE: usize = 16;
const PORT_SIZE: usize = 2; const PORT_SIZE: usize = 2;
const SCOPE_ID_SIZE: usize = 4; const SCOPE_ID_SIZE: usize = 4;
const BUFFER_SIZE: usize = SocketBoundEndpoint::SOCKET_SIZE + SocketBoundEndpoint::IPV6_SIZE + SocketBoundEndpoint::PORT_SIZE + SocketBoundEndpoint::SCOPE_ID_SIZE; const BUFFER_SIZE: usize = SocketBoundEndpoint::SOCKET_SIZE
pub fn to_bytes(&self) -> (usize,[u8; SocketBoundEndpoint::BUFFER_SIZE]) { + SocketBoundEndpoint::IPV6_SIZE
let mut buf = [0u8;SocketBoundEndpoint::BUFFER_SIZE]; + SocketBoundEndpoint::PORT_SIZE
+ SocketBoundEndpoint::SCOPE_ID_SIZE;
pub fn to_bytes(&self) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) {
let mut buf = [0u8; SocketBoundEndpoint::BUFFER_SIZE];
let addr = match self.addr { let addr = match self.addr {
SocketAddr::V4(addr) => { SocketAddr::V4(addr) => {
//Map IPv4-mapped to IPv6 addresses //Map IPv4-mapped to IPv6 addresses
@@ -217,15 +220,17 @@ impl SocketBoundEndpoint {
SocketAddr::V6(addr) => addr, SocketAddr::V6(addr) => addr,
}; };
let mut len: usize = 0; let mut len: usize = 0;
buf[len..len+SocketBoundEndpoint::SOCKET_SIZE].copy_from_slice(&self.socket.0.to_be_bytes()); buf[len..len + SocketBoundEndpoint::SOCKET_SIZE]
.copy_from_slice(&self.socket.0.to_be_bytes());
len += SocketBoundEndpoint::SOCKET_SIZE; len += SocketBoundEndpoint::SOCKET_SIZE;
buf[len..len+SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets()); buf[len..len + SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets());
len += SocketBoundEndpoint::IPV6_SIZE; len += SocketBoundEndpoint::IPV6_SIZE;
buf[len..len+SocketBoundEndpoint::PORT_SIZE].copy_from_slice(&addr.port().to_be_bytes()); buf[len..len + SocketBoundEndpoint::PORT_SIZE].copy_from_slice(&addr.port().to_be_bytes());
len += SocketBoundEndpoint::PORT_SIZE; len += SocketBoundEndpoint::PORT_SIZE;
buf[len..len+SocketBoundEndpoint::SCOPE_ID_SIZE].copy_from_slice(&addr.scope_id().to_be_bytes()); buf[len..len + SocketBoundEndpoint::SCOPE_ID_SIZE]
.copy_from_slice(&addr.scope_id().to_be_bytes());
len += SocketBoundEndpoint::SCOPE_ID_SIZE; len += SocketBoundEndpoint::SCOPE_ID_SIZE;
(len,buf) (len, buf)
} }
} }
@@ -398,7 +403,7 @@ impl AppServer {
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
// setup mio // setup mio
let mio_poll = mio::Poll::new()?; let mio_poll = mio::Poll::new()?;
let events = mio::Events::with_capacity(8); let events = mio::Events::with_capacity(20);
// bind each SocketAddr to a socket // bind each SocketAddr to a socket
let maybe_sockets: Result<Vec<_>, _> = let maybe_sockets: Result<Vec<_>, _> =
@@ -488,7 +493,7 @@ impl AppServer {
events, events,
mio_poll, mio_poll,
all_sockets_drained: false, all_sockets_drained: false,
under_load: DoSOperation::Normal, under_load: DoSOperation::Normal{ blocked_polls: 0},
}) })
} }
@@ -605,9 +610,13 @@ impl AppServer {
ReceivedMessage(len, endpoint) => { ReceivedMessage(len, endpoint) => {
let msg_result = match self.under_load { let msg_result = match self.under_load {
DoSOperation::UnderLoad { last_under_load: _ } => { DoSOperation::UnderLoad { last_under_load: _ } => {
println!("Processing msg under load");
self.handle_msg_under_load(&endpoint, &rx[..len], &mut *tx) self.handle_msg_under_load(&endpoint, &rx[..len], &mut *tx)
} }
DoSOperation::Normal => self.crypt.handle_msg(&rx[..len], &mut *tx), DoSOperation::Normal { blocked_polls: _} => {
println!("Processing msg normally");
self.crypt.handle_msg(&rx[..len], &mut *tx)
}
}; };
match msg_result { match msg_result {
Err(ref e) => { Err(ref e) => {
@@ -652,7 +661,7 @@ impl AppServer {
) -> Result<crate::protocol::HandleMsgResult> { ) -> Result<crate::protocol::HandleMsgResult> {
match endpoint { match endpoint {
Endpoint::SocketBoundAddress(socket) => { Endpoint::SocketBoundAddress(socket) => {
let (hi_len,host_identification )= socket.to_bytes(); let (hi_len, host_identification) = socket.to_bytes();
self.crypt self.crypt
.handle_msg_under_load(&rx, &mut *tx, &host_identification[0..hi_len]) .handle_msg_under_load(&rx, &mut *tx, &host_identification[0..hi_len])
} }
@@ -784,20 +793,36 @@ impl AppServer {
// only poll if we drained all sockets before // only poll if we drained all sockets before
if self.all_sockets_drained { if self.all_sockets_drained {
self.mio_poll.poll(&mut self.events, Some(timeout))?; //Non blocked polling
self.mio_poll.poll(&mut self.events, Some(Duration::from_secs(0)))?;
let queue_length = self.events.iter().peekable().count(); if self.events.iter().peekable().peek().is_none() {
// if there are no events, then we can just return
if queue_length > MAX_QUEUED_INCOMING_HANDSHAKES_THRESHOLD { match self.under_load {
self.under_load = DoSOperation::UnderLoad { DoSOperation::Normal { blocked_polls } => {
last_under_load: Instant::now(), self.under_load = DoSOperation::Normal {
blocked_polls: blocked_polls + 1,
}
}
_ => {}
} }
self.mio_poll.poll(&mut self.events, Some(timeout))?;
} }
} }
if let DoSOperation::UnderLoad { last_under_load } = self.under_load { match self.under_load {
if last_under_load.elapsed() > LAST_UNDER_LOAD_WINDOW { DoSOperation::Normal { blocked_polls } => {
self.under_load = DoSOperation::Normal; if blocked_polls > NORMAL_OPERATION_THRESHOLD {
self.under_load = DoSOperation::UnderLoad {
last_under_load: Instant::now(),
}
}
}
DoSOperation::UnderLoad { last_under_load } => {
if last_under_load.elapsed() > RESET_DURATION {
self.under_load = DoSOperation::Normal { blocked_polls: 0 };
}
} }
} }

View File

@@ -39,7 +39,7 @@ fn find_udp_socket() -> u16 {
// check that we can exchange keys // check that we can exchange keys
#[test] #[test]
fn check_exchange() { fn check_exchange_under_normal() {
let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange"); let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange");
fs::create_dir_all(&tmpdir).unwrap(); fs::create_dir_all(&tmpdir).unwrap();
@@ -117,3 +117,105 @@ fn check_exchange() {
// cleanup // cleanup
fs::remove_dir_all(&tmpdir).unwrap(); fs::remove_dir_all(&tmpdir).unwrap();
} }
// check that we can exchange keys
#[test]
fn check_exchange_under_dos() {
let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange-dos");
fs::create_dir_all(&tmpdir).unwrap();
let secret_key_paths = [tmpdir.join("secret-key-0"), tmpdir.join("secret-key-1")];
let public_key_paths = [tmpdir.join("public-key-0"), tmpdir.join("public-key-1")];
let shared_key_paths = [tmpdir.join("shared-key-0"), tmpdir.join("shared-key-1")];
// generate key pairs
for (secret_key_path, pub_key_path) in secret_key_paths.iter().zip(public_key_paths.iter()) {
let output = test_bin::get_test_bin(BIN)
.args(["gen-keys", "--secret-key"])
.arg(secret_key_path)
.arg("--public-key")
.arg(pub_key_path)
.output()
.expect("Failed to start {BIN}");
assert_eq!(String::from_utf8_lossy(&output.stdout), "");
assert!(secret_key_path.is_file());
assert!(pub_key_path.is_file());
}
// start first process, the server
let port = find_udp_socket();
let listen_addr = format!("localhost:{port}");
let mut server = test_bin::get_test_bin(BIN)
.args(["exchange", "secret-key"])
.arg(&secret_key_paths[0])
.arg("public-key")
.arg(&public_key_paths[0])
.args(["listen", &listen_addr, "verbose", "peer", "public-key"])
.arg(&public_key_paths[1])
.arg("outfile")
.arg(&shared_key_paths[0])
//.stdout(Stdio::null())
//.stderr(Stdio::null())
.spawn()
.expect("Failed to start {BIN}");
std::thread::sleep(Duration::from_millis(500));
//DoS Sender
//Create a UDP socket
let socket = UdpSocket::bind("127.0.0.1:0").expect("couldn't bind to address");
//Spawn a thread to send DoS packets
let server_addr = listen_addr.clone();
//Create thread safe atomic bool to stop the DoS attack
let stop_dos = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let stop_dos_handle = stop_dos.clone();
let dos_attack = std::thread::spawn(move || {
while stop_dos.load(std::sync::atomic::Ordering::Relaxed) == false {
let buf = [0; 10];
socket
.send_to(&buf, &server_addr)
.expect("couldn't send data");
}
});
// start second process, the client
let mut client = test_bin::get_test_bin(BIN)
.args(["exchange", "secret-key"])
.arg(&secret_key_paths[1])
.arg("public-key")
.arg(&public_key_paths[1])
.args(["verbose", "peer", "public-key"])
.arg(&public_key_paths[0])
.args(["endpoint", &listen_addr])
.arg("outfile")
.arg(&shared_key_paths[1])
//.stdout(Stdio::null())
//.stderr(Stdio::null())
.spawn()
.expect("Failed to start {BIN}");
// give them some time to do the key exchange
std::thread::sleep(Duration::from_secs(2));
// time's up, kill the childs
server.kill().unwrap();
client.kill().unwrap();
stop_dos_handle.store(true, std::sync::atomic::Ordering::Relaxed);
dos_attack.join().unwrap();
// read the shared keys they created
let shared_keys: Vec<_> = shared_key_paths
.iter()
.map(|p| fs::read_to_string(p).unwrap())
.collect();
// check that they created two equal keys
assert_eq!(shared_keys.len(), 2);
assert_eq!(shared_keys[0], shared_keys[1]);
// cleanup
fs::remove_dir_all(&tmpdir).unwrap();
}