diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index c1914fe..fc8df35 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -565,50 +565,7 @@ impl AppServer { ReceivedMessage(len, endpoint) => { let msg_result = match self.under_load { DoSOperation::UnderLoad { last_under_load: _ } => { - //TODO: Lookup peer through addresses (hash) - let index = self - .peers - .iter() - .enumerate() - .find(|(_num, p)| { - if let Some(ep) = p.endpoint() { - ep.addresses() == endpoint.addresses() - } else { - false - } - }) - .ok_or(anyhow::anyhow!("Received message from unknown endpoint"))? - .0; - - let socket_addr = endpoint - .addresses() - .first() - .ok_or(anyhow::anyhow!("No socket address for endpoint"))?; - - let mut len = 0; - let mut ip_addr_port = [0u8; 18]; - - match socket_addr.ip() { - std::net::IpAddr::V4(ipv4) => { - ip_addr_port[0..4].copy_from_slice(&ipv4.octets()); - len += 4; - } - std::net::IpAddr::V6(ipv6) => { - ip_addr_port[0..16].copy_from_slice(&ipv6.octets()); - len += 16; - } - }; - - ip_addr_port[len..len + 2] - .copy_from_slice(&socket_addr.port().to_be_bytes()); - len += 2; - - self.crypt.handle_msg_under_load( - &rx[..len], - &mut *tx, - PeerPtr(index), - &ip_addr_port[..len], - ) + self.handle_msg_under_load(&endpoint, &rx[..len], &mut *tx) } DoSOperation::Normal => self.crypt.handle_msg(&rx[..len], &mut *tx), }; @@ -647,6 +604,53 @@ impl AppServer { } } + fn handle_msg_under_load( + &mut self, + endpoint: &Endpoint, + rx: &[u8], + tx: &mut [u8], + ) -> Result { + //TODO: Lookup peer through addresses (hash) + let index = self + .peers + .iter() + .enumerate() + .find(|(_num, p)| { + if let Some(ep) = p.endpoint() { + ep.addresses() == endpoint.addresses() + } else { + false + } + }) + .ok_or(anyhow::anyhow!("Received message from unknown endpoint"))? + .0; + + let socket_addr = endpoint + .addresses() + .first() + .ok_or(anyhow::anyhow!("No socket address for endpoint"))?; + + let mut len = 0; + let mut ip_addr_port = [0u8; 18]; + + match socket_addr.ip() { + std::net::IpAddr::V4(ipv4) => { + ip_addr_port[0..4].copy_from_slice(&ipv4.octets()); + len += 4; + } + std::net::IpAddr::V6(ipv6) => { + ip_addr_port[0..16].copy_from_slice(&ipv6.octets()); + len += 16; + } + }; + + ip_addr_port[len..len + 2].copy_from_slice(&socket_addr.port().to_be_bytes()); + len += 2; + + self.crypt + .handle_msg_under_load(&rx[..len], &mut *tx, PeerPtr(index), &ip_addr_port[..len]) + } + pub fn output_key( &self, peer: AppPeerPtr, diff --git a/rosenpass/src/protocol.rs b/rosenpass/src/protocol.rs index 85333d9..3f74080 100644 --- a/rosenpass/src/protocol.rs +++ b/rosenpass/src/protocol.rs @@ -2029,11 +2029,10 @@ impl CryptoServer { &cr.all_bytes()[4..], )?; - peer.get_mut(self).handshake.as_mut().unwrap().cookie_tau = - CookieSecret::Some { - value: cookie_value, - last_updated: Timebase::default(), - }; + peer.get_mut(self).handshake.as_mut().unwrap().cookie_tau = CookieSecret::Some { + value: cookie_value, + last_updated: Timebase::default(), + }; Ok(peer) } else { bail!( @@ -2197,7 +2196,12 @@ mod test { .into_value()[..16] .to_vec(); assert_eq!( - a.peers[0].handshake.as_ref().unwrap().cookie_tau.get(PEER_COOKIE_TAU_EXP), + a.peers[0] + .handshake + .as_ref() + .unwrap() + .cookie_tau + .get(PEER_COOKIE_TAU_EXP), Some(&expected_cookie_tau[..]) );