Checkpoint

This commit is contained in:
Prabhpreet Dua
2024-02-04 11:39:34 +05:30
parent efd0ce51cb
commit 3498ab2d7b
2 changed files with 155 additions and 28 deletions

View File

@@ -34,9 +34,10 @@ use rosenpass_util::b64::{b64_writer, fmt_b64};
const IPV4_ANY_ADDR: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0);
const IPV6_ANY_ADDR: Ipv6Addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
// Using values from Linux Kernel implementation
// TODO: Customize values for rosenpass
const MAX_QUEUED_INCOMING_HANDSHAKES_THRESHOLD: usize = 4096;
const NORMAL_OPERATION_THRESHOLD: usize = 5;
const UNDER_LOAD_THRESHOLD: usize = 10;
const RESET_DURATION: Duration = Duration::from_secs(1);
const LAST_UNDER_LOAD_WINDOW: Duration = Duration::from_secs(1);
fn ipv4_any_binding() -> SocketAddr {
@@ -76,7 +77,7 @@ pub struct WireguardOut {
#[derive(Debug)]
pub enum DoSOperation {
UnderLoad { last_under_load: Instant },
Normal,
Normal{blocked_polls: usize},
}
/// Holds the state of the application, namely the external IO
@@ -199,15 +200,17 @@ impl std::fmt::Display for SocketBoundEndpoint {
}
impl SocketBoundEndpoint {
const SOCKET_SIZE: usize = usize::BITS as usize/8;
const SOCKET_SIZE: usize = usize::BITS as usize / 8;
const IPV6_SIZE: usize = 16;
const PORT_SIZE: usize = 2;
const SCOPE_ID_SIZE: usize = 4;
const BUFFER_SIZE: usize = SocketBoundEndpoint::SOCKET_SIZE + SocketBoundEndpoint::IPV6_SIZE + SocketBoundEndpoint::PORT_SIZE + SocketBoundEndpoint::SCOPE_ID_SIZE;
pub fn to_bytes(&self) -> (usize,[u8; SocketBoundEndpoint::BUFFER_SIZE]) {
let mut buf = [0u8;SocketBoundEndpoint::BUFFER_SIZE];
const BUFFER_SIZE: usize = SocketBoundEndpoint::SOCKET_SIZE
+ SocketBoundEndpoint::IPV6_SIZE
+ SocketBoundEndpoint::PORT_SIZE
+ SocketBoundEndpoint::SCOPE_ID_SIZE;
pub fn to_bytes(&self) -> (usize, [u8; SocketBoundEndpoint::BUFFER_SIZE]) {
let mut buf = [0u8; SocketBoundEndpoint::BUFFER_SIZE];
let addr = match self.addr {
SocketAddr::V4(addr) => {
//Map IPv4-mapped to IPv6 addresses
@@ -217,15 +220,17 @@ impl SocketBoundEndpoint {
SocketAddr::V6(addr) => addr,
};
let mut len: usize = 0;
buf[len..len+SocketBoundEndpoint::SOCKET_SIZE].copy_from_slice(&self.socket.0.to_be_bytes());
buf[len..len + SocketBoundEndpoint::SOCKET_SIZE]
.copy_from_slice(&self.socket.0.to_be_bytes());
len += SocketBoundEndpoint::SOCKET_SIZE;
buf[len..len+SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets());
buf[len..len + SocketBoundEndpoint::IPV6_SIZE].copy_from_slice(&addr.ip().octets());
len += SocketBoundEndpoint::IPV6_SIZE;
buf[len..len+SocketBoundEndpoint::PORT_SIZE].copy_from_slice(&addr.port().to_be_bytes());
buf[len..len + SocketBoundEndpoint::PORT_SIZE].copy_from_slice(&addr.port().to_be_bytes());
len += SocketBoundEndpoint::PORT_SIZE;
buf[len..len+SocketBoundEndpoint::SCOPE_ID_SIZE].copy_from_slice(&addr.scope_id().to_be_bytes());
buf[len..len + SocketBoundEndpoint::SCOPE_ID_SIZE]
.copy_from_slice(&addr.scope_id().to_be_bytes());
len += SocketBoundEndpoint::SCOPE_ID_SIZE;
(len,buf)
(len, buf)
}
}
@@ -398,7 +403,7 @@ impl AppServer {
) -> anyhow::Result<Self> {
// setup mio
let mio_poll = mio::Poll::new()?;
let events = mio::Events::with_capacity(8);
let events = mio::Events::with_capacity(20);
// bind each SocketAddr to a socket
let maybe_sockets: Result<Vec<_>, _> =
@@ -488,7 +493,7 @@ impl AppServer {
events,
mio_poll,
all_sockets_drained: false,
under_load: DoSOperation::Normal,
under_load: DoSOperation::Normal{ blocked_polls: 0},
})
}
@@ -605,9 +610,13 @@ impl AppServer {
ReceivedMessage(len, endpoint) => {
let msg_result = match self.under_load {
DoSOperation::UnderLoad { last_under_load: _ } => {
println!("Processing msg under load");
self.handle_msg_under_load(&endpoint, &rx[..len], &mut *tx)
}
DoSOperation::Normal => self.crypt.handle_msg(&rx[..len], &mut *tx),
DoSOperation::Normal { blocked_polls: _} => {
println!("Processing msg normally");
self.crypt.handle_msg(&rx[..len], &mut *tx)
}
};
match msg_result {
Err(ref e) => {
@@ -652,7 +661,7 @@ impl AppServer {
) -> Result<crate::protocol::HandleMsgResult> {
match endpoint {
Endpoint::SocketBoundAddress(socket) => {
let (hi_len,host_identification )= socket.to_bytes();
let (hi_len, host_identification) = socket.to_bytes();
self.crypt
.handle_msg_under_load(&rx, &mut *tx, &host_identification[0..hi_len])
}
@@ -784,20 +793,36 @@ impl AppServer {
// only poll if we drained all sockets before
if self.all_sockets_drained {
//Non blocked polling
self.mio_poll.poll(&mut self.events, Some(Duration::from_secs(0)))?;
if self.events.iter().peekable().peek().is_none() {
// if there are no events, then we can just return
match self.under_load {
DoSOperation::Normal { blocked_polls } => {
self.under_load = DoSOperation::Normal {
blocked_polls: blocked_polls + 1,
}
}
_ => {}
}
self.mio_poll.poll(&mut self.events, Some(timeout))?;
}
}
let queue_length = self.events.iter().peekable().count();
if queue_length > MAX_QUEUED_INCOMING_HANDSHAKES_THRESHOLD {
match self.under_load {
DoSOperation::Normal { blocked_polls } => {
if blocked_polls > NORMAL_OPERATION_THRESHOLD {
self.under_load = DoSOperation::UnderLoad {
last_under_load: Instant::now(),
}
}
}
if let DoSOperation::UnderLoad { last_under_load } = self.under_load {
if last_under_load.elapsed() > LAST_UNDER_LOAD_WINDOW {
self.under_load = DoSOperation::Normal;
DoSOperation::UnderLoad { last_under_load } => {
if last_under_load.elapsed() > RESET_DURATION {
self.under_load = DoSOperation::Normal { blocked_polls: 0 };
}
}
}

View File

@@ -39,7 +39,7 @@ fn find_udp_socket() -> u16 {
// check that we can exchange keys
#[test]
fn check_exchange() {
fn check_exchange_under_normal() {
let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange");
fs::create_dir_all(&tmpdir).unwrap();
@@ -117,3 +117,105 @@ fn check_exchange() {
// cleanup
fs::remove_dir_all(&tmpdir).unwrap();
}
// check that we can exchange keys
#[test]
fn check_exchange_under_dos() {
let tmpdir = PathBuf::from(env!("CARGO_TARGET_TMPDIR")).join("exchange-dos");
fs::create_dir_all(&tmpdir).unwrap();
let secret_key_paths = [tmpdir.join("secret-key-0"), tmpdir.join("secret-key-1")];
let public_key_paths = [tmpdir.join("public-key-0"), tmpdir.join("public-key-1")];
let shared_key_paths = [tmpdir.join("shared-key-0"), tmpdir.join("shared-key-1")];
// generate key pairs
for (secret_key_path, pub_key_path) in secret_key_paths.iter().zip(public_key_paths.iter()) {
let output = test_bin::get_test_bin(BIN)
.args(["gen-keys", "--secret-key"])
.arg(secret_key_path)
.arg("--public-key")
.arg(pub_key_path)
.output()
.expect("Failed to start {BIN}");
assert_eq!(String::from_utf8_lossy(&output.stdout), "");
assert!(secret_key_path.is_file());
assert!(pub_key_path.is_file());
}
// start first process, the server
let port = find_udp_socket();
let listen_addr = format!("localhost:{port}");
let mut server = test_bin::get_test_bin(BIN)
.args(["exchange", "secret-key"])
.arg(&secret_key_paths[0])
.arg("public-key")
.arg(&public_key_paths[0])
.args(["listen", &listen_addr, "verbose", "peer", "public-key"])
.arg(&public_key_paths[1])
.arg("outfile")
.arg(&shared_key_paths[0])
//.stdout(Stdio::null())
//.stderr(Stdio::null())
.spawn()
.expect("Failed to start {BIN}");
std::thread::sleep(Duration::from_millis(500));
//DoS Sender
//Create a UDP socket
let socket = UdpSocket::bind("127.0.0.1:0").expect("couldn't bind to address");
//Spawn a thread to send DoS packets
let server_addr = listen_addr.clone();
//Create thread safe atomic bool to stop the DoS attack
let stop_dos = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let stop_dos_handle = stop_dos.clone();
let dos_attack = std::thread::spawn(move || {
while stop_dos.load(std::sync::atomic::Ordering::Relaxed) == false {
let buf = [0; 10];
socket
.send_to(&buf, &server_addr)
.expect("couldn't send data");
}
});
// start second process, the client
let mut client = test_bin::get_test_bin(BIN)
.args(["exchange", "secret-key"])
.arg(&secret_key_paths[1])
.arg("public-key")
.arg(&public_key_paths[1])
.args(["verbose", "peer", "public-key"])
.arg(&public_key_paths[0])
.args(["endpoint", &listen_addr])
.arg("outfile")
.arg(&shared_key_paths[1])
//.stdout(Stdio::null())
//.stderr(Stdio::null())
.spawn()
.expect("Failed to start {BIN}");
// give them some time to do the key exchange
std::thread::sleep(Duration::from_secs(2));
// time's up, kill the childs
server.kill().unwrap();
client.kill().unwrap();
stop_dos_handle.store(true, std::sync::atomic::Ordering::Relaxed);
dos_attack.join().unwrap();
// read the shared keys they created
let shared_keys: Vec<_> = shared_key_paths
.iter()
.map(|p| fs::read_to_string(p).unwrap())
.collect();
// check that they created two equal keys
assert_eq!(shared_keys.len(), 2);
assert_eq!(shared_keys[0], shared_keys[1]);
// cleanup
fs::remove_dir_all(&tmpdir).unwrap();
}