diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index b245c64..c1914fe 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -579,16 +579,35 @@ impl AppServer { }) .ok_or(anyhow::anyhow!("Received message from unknown endpoint"))? .0; + let socket_addr = endpoint .addresses() .first() - .copied() .ok_or(anyhow::anyhow!("No socket address for endpoint"))?; + + let mut len = 0; + let mut ip_addr_port = [0u8; 18]; + + match socket_addr.ip() { + std::net::IpAddr::V4(ipv4) => { + ip_addr_port[0..4].copy_from_slice(&ipv4.octets()); + len += 4; + } + std::net::IpAddr::V6(ipv6) => { + ip_addr_port[0..16].copy_from_slice(&ipv6.octets()); + len += 16; + } + }; + + ip_addr_port[len..len + 2] + .copy_from_slice(&socket_addr.port().to_be_bytes()); + len += 2; + self.crypt.handle_msg_under_load( &rx[..len], &mut *tx, PeerPtr(index), - socket_addr, + &ip_addr_port[..len], ) } DoSOperation::Normal => self.crypt.handle_msg(&rx[..len], &mut *tx), diff --git a/rosenpass/src/protocol.rs b/rosenpass/src/protocol.rs index 54b57f8..17299d6 100644 --- a/rosenpass/src/protocol.rs +++ b/rosenpass/src/protocol.rs @@ -82,7 +82,6 @@ use std::collections::hash_map::{ HashMap, }; use std::convert::Infallible; -use std::net::{IpAddr, SocketAddr}; // CONSTANTS & SETTINGS ////////////////////////// @@ -870,15 +869,16 @@ impl CryptoServer { rx_buf: &[u8], tx_buf: &mut [u8], peer: PeerPtr, - socket_addr: SocketAddr, + ip_addr_port: &[u8], ) -> Result { - //Check cookie value + /* let mut ip_addr_port = match socket_addr.ip() { IpAddr::V4(ipv4) => ipv4.octets().to_vec(), IpAddr::V6(ipv6) => ipv6.octets().to_vec(), }; ip_addr_port.extend_from_slice(&socket_addr.port().to_be_bytes()); + */ let (rx_bytes_til_cookie, rx_cookie, rx_mac, rx_sid) = match rx_buf[0].try_into() { Ok(MsgType::InitHello) => { @@ -2154,6 +2154,13 @@ mod test { let _ip_b: SocketAddrV4 = "127.0.0.1:8081".parse().unwrap(); let init_hello_len = a.initiate_handshake(PeerPtr(0), &mut *a_to_b_buf).unwrap(); + let socket_addr_a = std::net::SocketAddr::V4(ip_a); + let mut ip_addr_port_a = match socket_addr_a.ip() { + std::net::IpAddr::V4(ipv4) => ipv4.octets().to_vec(), + std::net::IpAddr::V6(ipv6) => ipv6.octets().to_vec(), + }; + + ip_addr_port_a.extend_from_slice(&socket_addr_a.port().to_be_bytes()); //B handles handshake under load, should send cookie reply message with invalid cookie let HandleMsgResult { resp, .. } = b @@ -2161,7 +2168,7 @@ mod test { &a_to_b_buf.as_slice()[..init_hello_len], &mut *b_to_a_buf, PeerPtr(0), - SocketAddr::V4(ip_a), + &ip_addr_port_a, ) .unwrap(); @@ -2207,7 +2214,7 @@ mod test { &a_to_b_buf.as_slice()[..retx_init_hello_len], &mut *b_to_a_buf, PeerPtr(0), - SocketAddr::V4(ip_a), + &ip_addr_port_a ) .unwrap(); @@ -2245,13 +2252,30 @@ mod test { let resp_msg_type: MsgType = b_to_a_buf.value[0].try_into().unwrap(); assert_eq!(resp_msg_type, MsgType::RespHello); + let socket_addr_b = std::net::SocketAddr::V4(ip_b); + let mut ip_addr_port_b = [0u8; 18]; + let mut ip_addr_port_b_len = 0; + match socket_addr_b.ip() { + std::net::IpAddr::V4(ipv4) => { + ip_addr_port_b[0..4].copy_from_slice(&ipv4.octets()); + ip_addr_port_b_len += 4; + } + std::net::IpAddr::V6(ipv6) => { + ip_addr_port_b[0..16].copy_from_slice(&ipv6.octets()); + ip_addr_port_b_len += 16; + } + }; + + ip_addr_port_b[ip_addr_port_b_len..ip_addr_port_b_len+2].copy_from_slice(&socket_addr_b.port().to_be_bytes()); + ip_addr_port_b_len += 2; + //A handles RespHello message under load, should not send cookie reply assert!(a .handle_msg_under_load( &b_to_a_buf[..resp_hello_len], &mut *a_to_b_buf, PeerPtr(0), - SocketAddr::V4(ip_b), + &ip_addr_port_b[..ip_addr_port_b_len] ) .is_err()); });