diff --git a/rosenpass/src/lib.rs b/rosenpass/src/lib.rs index 46bc1ba..df02553 100644 --- a/rosenpass/src/lib.rs +++ b/rosenpass/src/lib.rs @@ -19,7 +19,7 @@ pub enum RosenpassError { Oqs, #[error("error from external library while calling OQS")] OqsExternalLib, - #[error("buffer size mismatch, required {required_size} but only found {actual_size}")] + #[error("buffer size mismatch, required {required_size} but found {actual_size}")] BufferSizeMismatch { required_size: usize, actual_size: usize, diff --git a/rosenpass/src/msgs.rs b/rosenpass/src/msgs.rs index d3b6ca9..9108f1b 100644 --- a/rosenpass/src/msgs.rs +++ b/rosenpass/src/msgs.rs @@ -143,7 +143,7 @@ macro_rules! data_lense( pub fn check_size(len: usize) -> Result<(), RosenpassError>{ let required_size = $( $len + )+ 0; let actual_size = len; - if required_size < actual_size { + if required_size != actual_size { Err(RosenpassError::BufferSizeMismatch { required_size, actual_size, @@ -199,23 +199,53 @@ macro_rules! data_lense( type __ContainerType; /// Create a lense to the byte slice - fn [< $type:snake >] $(< $($generic),* >)? (self) -> Result< $type, RosenpassError>; + fn [< $type:snake >] $(< $($generic : LenseView),* >)? (self) -> Result< $type, RosenpassError>; + + /// Create a lense to the byte slice, automatically truncating oversized buffers + fn [< $type:snake _ truncating >] $(< $($generic : LenseView),* >)? (self) -> Result< $type, RosenpassError>; } impl<'a> [< $type Ext >] for &'a [u8] { type __ContainerType = &'a [u8]; - fn [< $type:snake >] $(< $($generic),* >)? (self) -> Result< $type, RosenpassError> { + fn [< $type:snake >] $(< $($generic : LenseView),* >)? (self) -> Result< $type, RosenpassError> { + $type::::check_size(self.len())?; Ok($type ( self, $( $( ::core::marker::PhantomData::<$generic> ),+ )? )) } + + fn [< $type:snake _ truncating >] $(< $($generic : LenseView),* >)? (self) -> Result< $type, RosenpassError> { + let required_size = $( $len + )+ 0; + let actual_size = self.len(); + if actual_size < required_size { + return Err(RosenpassError::BufferSizeMismatch { + required_size, + actual_size, + }); + } + + [< $type Ext >]::[< $type:snake >](&self[..required_size]) + } } impl<'a> [< $type Ext >] for &'a mut [u8] { type __ContainerType = &'a mut [u8]; - - fn [< $type:snake >] $(< $($generic),* >)? (self) -> Result< $type, RosenpassError> { + fn [< $type:snake >] $(< $($generic : LenseView),* >)? (self) -> Result< $type, RosenpassError> { + $type::::check_size(self.len())?; Ok($type ( self, $( $( ::core::marker::PhantomData::<$generic> ),+ )? )) } + + fn [< $type:snake _ truncating >] $(< $($generic : LenseView),* >)? (self) -> Result< $type, RosenpassError> { + let required_size = $( $len + )+ 0; + let actual_size = self.len(); + if actual_size < required_size { + return Err(RosenpassError::BufferSizeMismatch { + required_size, + actual_size, + }); + } + + [< $type Ext >]::[< $type:snake >](&mut self[..required_size]) + } } }); ); diff --git a/rosenpass/src/protocol.rs b/rosenpass/src/protocol.rs index 092c8ea..a0d4c5b 100644 --- a/rosenpass/src/protocol.rs +++ b/rosenpass/src/protocol.rs @@ -736,7 +736,7 @@ impl CryptoServer { // TODO remove unnecessary copying between global tx_buf and per-peer buf // TODO move retransmission storage to io server pub fn initiate_handshake(&mut self, peer: PeerPtr, tx_buf: &mut [u8]) -> Result { - let mut msg = tx_buf.envelope::>()?; // Envelope::::default(); // TODO + let mut msg = tx_buf.envelope_truncating::>()?; // Envelope::::default(); // TODO self.handle_initiation(peer, msg.payload_mut().init_hello()?)?; let len = self.seal_and_commit_msg(peer, MsgType::InitHello, msg)?; peer.hs() @@ -793,7 +793,7 @@ impl CryptoServer { let msg_in = rx_buf.envelope::>()?; ensure!(msg_in.check_seal(self)?, seal_broken); - let mut msg_out = tx_buf.envelope::>()?; + let mut msg_out = tx_buf.envelope_truncating::>()?; let peer = self.handle_init_hello( msg_in.payload().init_hello()?, msg_out.payload_mut().resp_hello()?, @@ -805,7 +805,7 @@ impl CryptoServer { let msg_in = rx_buf.envelope::>()?; ensure!(msg_in.check_seal(self)?, seal_broken); - let mut msg_out = tx_buf.envelope::>()?; + let mut msg_out = tx_buf.envelope_truncating::>()?; let peer = self.handle_resp_hello( msg_in.payload().resp_hello()?, msg_out.payload_mut().init_conf()?, @@ -820,7 +820,7 @@ impl CryptoServer { let msg_in = rx_buf.envelope::>()?; ensure!(msg_in.check_seal(self)?, seal_broken); - let mut msg_out = tx_buf.envelope::>()?; + let mut msg_out = tx_buf.envelope_truncating::>()?; let peer = self.handle_init_conf( msg_in.payload().init_conf()?, msg_out.payload_mut().empty_data()?, @@ -1733,31 +1733,94 @@ impl CryptoServer { mod test { use super::*; - fn init_crypto_server() -> CryptoServer { - // always init libsodium before anything + #[test] + /// Ensure that the protocol implementation can deal with truncated + /// messages and with overlong messages. + /// + /// This test performs a complete handshake between two randomly generated + /// servers; instead of delivering the message correctly at first messages + /// of length zero through about 1.2 times the correct message size are delivered. + /// + /// Producing an error is expected on each of these messages. + /// + /// Finally the correct message is delivered and the same process + /// starts again in the other direction. + /// + /// Through all this, the handshake should still successfully terminate; + /// i.e. an exchanged key must be produced in both servers. + fn handles_incorrect_size_messages() { crate::sodium::sodium_init().unwrap(); - // initialize secret and public key for the crypto server - let (mut sk, mut pk) = (SSk::zero(), SPk::zero()); - - // Guranteed to have 8MiB of stack size stacker::grow(8 * 1024 * 1024, || { - StaticKEM::keygen(sk.secret_mut(), pk.secret_mut()).expect("unable to generate keys"); - }); + const OVERSIZED_MESSAGE: usize = ((MAX_MESSAGE_LEN as f32) * 1.2) as usize; + type MsgBufPlus = Public; - CryptoServer::new(sk, pk) + const PEER0: PeerPtr = PeerPtr(0); + + let (mut me, mut they) = make_server_pair().unwrap(); + let (mut msgbuf, mut resbuf) = (MsgBufPlus::zero(), MsgBufPlus::zero()); + + // Process the entire handshake + let mut msglen = Some(me.initiate_handshake(PEER0, &mut *resbuf).unwrap()); + loop { + if let Some(l) = msglen { + std::mem::swap(&mut me, &mut they); + std::mem::swap(&mut msgbuf, &mut resbuf); + msglen = test_incorrect_sizes_for_msg(&mut me, &*msgbuf, l, &mut *resbuf); + } else { + break; + } + } + + assert_eq!( + me.osk(PEER0).unwrap().secret(), + they.osk(PEER0).unwrap().secret() + ); + }); } - /// The determination of the message type relies on reading the first byte of the message. Only - /// after that the length of the message is checked against the specified message type. This - /// test ensures that nothing breaks in the case of an empty message. - #[test] - #[should_panic = "called `Result::unwrap()` on an `Err` value: received empty message, ignoring it"] - fn handle_empty_message() { - let mut crypt = init_crypto_server(); - let empty_rx_buf = [0u8; 0]; - let mut tx_buf = [0u8; 0]; + /// Used in handles_incorrect_size_messages() to first deliver many truncated + /// and overlong messages, finally the correct message is delivered and the response + /// returned. + fn test_incorrect_sizes_for_msg( + srv: &mut CryptoServer, + msgbuf: &[u8], + msglen: usize, + resbuf: &mut [u8], + ) -> Option { + resbuf.fill(0); - crypt.handle_msg(&empty_rx_buf, &mut tx_buf).unwrap(); + for l in 0..(((msglen as f32) * 1.2) as usize) { + if l == msglen { + continue; + } + + let res = srv.handle_msg(&msgbuf[..l], resbuf); + assert!(matches!(res, Err(_))); // handle_msg should raise an error + assert!(!resbuf.iter().find(|x| **x != 0).is_some()); // resbuf should not have been changed + } + + // Apply the proper handle_msg operation + srv.handle_msg(&msgbuf[..msglen], resbuf).unwrap().resp + } + + fn keygen() -> Result<(SSk, SPk)> { + // TODO: Copied from the benchmark; deduplicate + let (mut sk, mut pk) = (SSk::zero(), SPk::zero()); + StaticKEM::keygen(sk.secret_mut(), pk.secret_mut())?; + Ok((sk, pk)) + } + + fn make_server_pair() -> Result<(CryptoServer, CryptoServer)> { + // TODO: Copied from the benchmark; deduplicate + let psk = SymKey::random(); + let ((ska, pka), (skb, pkb)) = (keygen()?, keygen()?); + let (mut a, mut b) = ( + CryptoServer::new(ska, pka.clone()), + CryptoServer::new(skb, pkb.clone()), + ); + a.add_peer(Some(psk.clone()), pkb)?; + b.add_peer(Some(psk), pka)?; + Ok((a, b)) } }