diff --git a/rosenpass/Cargo.toml b/rosenpass/Cargo.toml index 298cfe3..1e1d489 100644 --- a/rosenpass/Cargo.toml +++ b/rosenpass/Cargo.toml @@ -23,7 +23,7 @@ name = "api-integration-tests" required-features = ["experiment_api", "internal_testing"] [[test]] -name = "api-integration-tests-supply-keypair" +name = "api-integration-tests-api-setup" required-features = ["experiment_api", "internal_testing"] [[bench]] diff --git a/rosenpass/src/api/api_handler.rs b/rosenpass/src/api/api_handler.rs index 984ccba..d3088b2 100644 --- a/rosenpass/src/api/api_handler.rs +++ b/rosenpass/src/api/api_handler.rs @@ -2,9 +2,13 @@ use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd}; use anyhow::Context; use rosenpass_to::{ops::copy_slice, To}; -use rosenpass_util::{fd::FdIo, functional::run, io::ReadExt, mem::DiscardResultExt}; +use rosenpass_util::{ + fd::FdIo, functional::run, io::ReadExt, mem::DiscardResultExt, result::OkExt, +}; -use crate::{app_server::AppServer, protocol::BuildCryptoServer}; +use crate::{ + api::add_listen_socket_response_status, app_server::AppServer, protocol::BuildCryptoServer, +}; use super::{supply_keypair_response_status, Server as ApiServer}; @@ -171,4 +175,54 @@ where Ok(()) } + + fn add_listen_socket( + &mut self, + _req: &super::boilerplate::AddListenSocketRequest, + req_fds: &mut VecDeque, + res: &mut super::boilerplate::AddListenSocketResponse, + ) -> anyhow::Result<()> { + // Retrieve file descriptor + let sock_res = run(|| -> anyhow::Result { + let sock = req_fds + .pop_front() + .context("Invalid request – socket missing.")?; + // TODO: We need to have this outside linux + #[cfg(target_os = "linux")] + rosenpass_util::fd::GetSocketProtocol::demand_udp_socket(&sock)?; + let sock = std::net::UdpSocket::from(sock); + sock.set_nonblocking(true)?; + mio::net::UdpSocket::from_std(sock).ok() + }); + + let mut sock = match sock_res { + Ok(sock) => sock, + Err(e) => { + log::debug!("Error processing AddListenSocket API request: {e:?}"); + res.payload.status = add_listen_socket_response_status::INVALID_REQUEST; + return Ok(()); + } + }; + + // Register socket + let reg_result = run(|| -> anyhow::Result<()> { + let srv = self.app_server_mut(); + srv.mio_poll.registry().register( + &mut sock, + srv.mio_token_dispenser.dispense(), + mio::Interest::READABLE, + )?; + srv.sockets.push(sock); + Ok(()) + }); + + if let Err(internal_error) = reg_result { + log::warn!("Internal error processing AddListenSocket API request: {internal_error:?}"); + res.payload.status = add_listen_socket_response_status::INTERNAL_ERROR; + return Ok(()); + }; + + res.payload.status = add_listen_socket_response_status::OK; + Ok(()) + } } diff --git a/rosenpass/src/api/boilerplate/byte_slice_ext.rs b/rosenpass/src/api/boilerplate/byte_slice_ext.rs index 93af764..de2ff31 100644 --- a/rosenpass/src/api/boilerplate/byte_slice_ext.rs +++ b/rosenpass/src/api/boilerplate/byte_slice_ext.rs @@ -143,6 +143,44 @@ pub trait ByteSliceRefExt: ByteSlice { ) -> anyhow::Result> { self.zk_parse_suffix() } + + fn add_listen_socket_request(self) -> anyhow::Result> { + self.zk_parse() + } + + fn add_listen_socket_request_from_prefix( + self, + ) -> anyhow::Result> { + self.zk_parse_prefix() + } + + fn add_listen_socket_request_from_suffix( + self, + ) -> anyhow::Result> { + self.zk_parse_suffix() + } + + fn add_listen_socket_response_maker(self) -> RefMaker { + self.zk_ref_maker() + } + + fn add_listen_socket_response( + self, + ) -> anyhow::Result> { + self.zk_parse() + } + + fn add_listen_socket_response_from_prefix( + self, + ) -> anyhow::Result> { + self.zk_parse_prefix() + } + + fn add_listen_socket_response_from_suffix( + self, + ) -> anyhow::Result> { + self.zk_parse_suffix() + } } impl ByteSliceRefExt for B {} diff --git a/rosenpass/src/api/boilerplate/message_type.rs b/rosenpass/src/api/boilerplate/message_type.rs index a867344..e76a54d 100644 --- a/rosenpass/src/api/boilerplate/message_type.rs +++ b/rosenpass/src/api/boilerplate/message_type.rs @@ -21,6 +21,13 @@ const SUPPLY_KEYPAIR_REQUEST: RawMsgType = const SUPPLY_KEYPAIR_RESPONSE: RawMsgType = RawMsgType::from_le_bytes(hex!("f2dc 49bd e261 5f10 40b7 3c16 ec61 edb9")); +// hash domain hash of: Rosenpass IPC API -> Rosenpass Protocol Server -> Add Listen Socket Request +const ADD_LISTEN_SOCKET_REQUEST: RawMsgType = + RawMsgType::from_le_bytes(hex!("3f21 434f 87cc a08c 02c4 61e4 0816 c7da")); +// hash domain hash of: Rosenpass IPC API -> Rosenpass Protocol Server -> Add Listen Socket Response +const ADD_LISTEN_SOCKET_RESPONSE: RawMsgType = + RawMsgType::from_le_bytes(hex!("45d5 0f0d 93f0 6105 98f2 9469 5dfd 5f36")); + pub trait MessageAttributes { fn message_size(&self) -> usize; } @@ -29,12 +36,14 @@ pub trait MessageAttributes { pub enum RequestMsgType { Ping, SupplyKeypair, + AddListenSocket, } #[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)] pub enum ResponseMsgType { Ping, SupplyKeypair, + AddListenSocket, } impl MessageAttributes for RequestMsgType { @@ -42,6 +51,7 @@ impl MessageAttributes for RequestMsgType { match self { Self::Ping => std::mem::size_of::(), Self::SupplyKeypair => std::mem::size_of::(), + Self::AddListenSocket => std::mem::size_of::(), } } } @@ -51,6 +61,7 @@ impl MessageAttributes for ResponseMsgType { match self { Self::Ping => std::mem::size_of::(), Self::SupplyKeypair => std::mem::size_of::(), + Self::AddListenSocket => std::mem::size_of::(), } } } @@ -63,6 +74,7 @@ impl TryFrom for RequestMsgType { Ok(match value { self::PING_REQUEST => E::Ping, self::SUPPLY_KEYPAIR_REQUEST => E::SupplyKeypair, + self::ADD_LISTEN_SOCKET_REQUEST => E::AddListenSocket, _ => return Err(InvalidApiMessageType(value)), }) } @@ -74,6 +86,7 @@ impl From for RawMsgType { match val { E::Ping => self::PING_REQUEST, E::SupplyKeypair => self::SUPPLY_KEYPAIR_REQUEST, + E::AddListenSocket => self::ADD_LISTEN_SOCKET_REQUEST, } } } @@ -86,6 +99,7 @@ impl TryFrom for ResponseMsgType { Ok(match value { self::PING_RESPONSE => E::Ping, self::SUPPLY_KEYPAIR_RESPONSE => E::SupplyKeypair, + self::ADD_LISTEN_SOCKET_RESPONSE => E::AddListenSocket, _ => return Err(InvalidApiMessageType(value)), }) } @@ -97,6 +111,7 @@ impl From for RawMsgType { match val { E::Ping => self::PING_RESPONSE, E::SupplyKeypair => self::SUPPLY_KEYPAIR_RESPONSE, + E::AddListenSocket => self::ADD_LISTEN_SOCKET_RESPONSE, } } } diff --git a/rosenpass/src/api/boilerplate/payload.rs b/rosenpass/src/api/boilerplate/payload.rs index cbf6332..a55fb8d 100644 --- a/rosenpass/src/api/boilerplate/payload.rs +++ b/rosenpass/src/api/boilerplate/payload.rs @@ -181,3 +181,87 @@ impl Message for SupplyKeypairResponse { self.msg_type = Self::MESSAGE_TYPE.into(); } } + +#[repr(packed)] +#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)] +pub struct AddListenSocketRequestPayload {} + +pub type AddListenSocketRequest = RequestEnvelope; + +impl Default for AddListenSocketRequest { + fn default() -> Self { + Self::new() + } +} + +impl AddListenSocketRequest { + pub fn new() -> Self { + Self::from_payload(AddListenSocketRequestPayload {}) + } +} + +impl Message for AddListenSocketRequest { + type Payload = AddListenSocketRequestPayload; + type MessageClass = RequestMsgType; + const MESSAGE_TYPE: Self::MessageClass = RequestMsgType::AddListenSocket; + + fn from_payload(payload: Self::Payload) -> Self { + Self { + msg_type: Self::MESSAGE_TYPE.into(), + payload, + } + } + + fn setup(buf: B) -> anyhow::Result> { + let mut r: Ref = buf.zk_zeroized()?; + r.init(); + Ok(r) + } + + fn init(&mut self) { + self.msg_type = Self::MESSAGE_TYPE.into(); + } +} + +pub mod add_listen_socket_response_status { + pub const OK: u128 = 0; + pub const INVALID_REQUEST: u128 = 1; + pub const INTERNAL_ERROR: u128 = 2; +} + +#[repr(packed)] +#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)] +pub struct AddListenSocketResponsePayload { + pub status: u128, +} + +pub type AddListenSocketResponse = ResponseEnvelope; + +impl AddListenSocketResponse { + pub fn new(status: u128) -> Self { + Self::from_payload(AddListenSocketResponsePayload { status }) + } +} + +impl Message for AddListenSocketResponse { + type Payload = AddListenSocketResponsePayload; + type MessageClass = ResponseMsgType; + const MESSAGE_TYPE: Self::MessageClass = ResponseMsgType::AddListenSocket; + + fn from_payload(payload: Self::Payload) -> Self { + Self { + msg_type: Self::MESSAGE_TYPE.into(), + payload, + } + } + + fn setup(buf: B) -> anyhow::Result> { + let mut r: Ref = buf.zk_zeroized()?; + r.init(); + Ok(r) + } + + fn init(&mut self) { + self.msg_type = Self::MESSAGE_TYPE.into(); + } +} diff --git a/rosenpass/src/api/boilerplate/request_ref.rs b/rosenpass/src/api/boilerplate/request_ref.rs index 198342e..3161e8c 100644 --- a/rosenpass/src/api/boilerplate/request_ref.rs +++ b/rosenpass/src/api/boilerplate/request_ref.rs @@ -26,6 +26,7 @@ impl RequestRef { match self { Self::Ping(_) => RequestMsgType::Ping, Self::SupplyKeypair(_) => RequestMsgType::SupplyKeypair, + Self::AddListenSocket(_) => RequestMsgType::AddListenSocket, } } } @@ -36,6 +37,18 @@ impl From> for RequestRef { } } +impl From> for RequestRef { + fn from(v: Ref) -> Self { + Self::SupplyKeypair(v) + } +} + +impl From> for RequestRef { + fn from(v: Ref) -> Self { + Self::AddListenSocket(v) + } +} + impl RequestRefMaker { fn new(buf: B) -> anyhow::Result { let msg_type = buf.deref().request_msg_type_from_prefix()?; @@ -52,6 +65,9 @@ impl RequestRefMaker { RequestMsgType::SupplyKeypair => { RequestRef::SupplyKeypair(self.buf.supply_keypair_request()?) } + RequestMsgType::AddListenSocket => { + RequestRef::AddListenSocket(self.buf.add_listen_socket_request()?) + } }) } @@ -87,6 +103,7 @@ impl RequestRefMaker { pub enum RequestRef { Ping(Ref), SupplyKeypair(Ref), + AddListenSocket(Ref), } impl RequestRef @@ -97,6 +114,7 @@ where match self { Self::Ping(r) => r.bytes(), Self::SupplyKeypair(r) => r.bytes(), + Self::AddListenSocket(r) => r.bytes(), } } } @@ -109,6 +127,7 @@ where match self { Self::Ping(r) => r.bytes_mut(), Self::SupplyKeypair(r) => r.bytes_mut(), + Self::AddListenSocket(r) => r.bytes_mut(), } } } diff --git a/rosenpass/src/api/boilerplate/request_response.rs b/rosenpass/src/api/boilerplate/request_response.rs index 69e0d34..148c51d 100644 --- a/rosenpass/src/api/boilerplate/request_response.rs +++ b/rosenpass/src/api/boilerplate/request_response.rs @@ -50,15 +50,28 @@ impl ResponseMsg for super::SupplyKeypairResponse { type RequestMsg = super::SupplyKeypairRequest; } +impl RequestMsg for super::AddListenSocketRequest { + type ResponseMsg = super::AddListenSocketResponse; +} + +impl ResponseMsg for super::AddListenSocketResponse { + type RequestMsg = super::AddListenSocketRequest; +} + pub type PingPair = (Ref, Ref); pub type SupplyKeypairPair = ( Ref, Ref, ); +pub type AddListenSocketPair = ( + Ref, + Ref, +); pub enum RequestResponsePair { Ping(PingPair), SupplyKeypair(SupplyKeypairPair), + AddListenSocket(AddListenSocketPair), } impl From> for RequestResponsePair { @@ -73,6 +86,12 @@ impl From> for RequestResponsePair { } } +impl From> for RequestResponsePair { + fn from(v: AddListenSocketPair) -> Self { + RequestResponsePair::AddListenSocket(v) + } +} + impl RequestResponsePair where B1: ByteSlice, @@ -90,6 +109,11 @@ where let res = ResponseRef::SupplyKeypair(res.emancipate()); (req, res) } + Self::AddListenSocket((req, res)) => { + let req = RequestRef::AddListenSocket(req.emancipate()); + let res = ResponseRef::AddListenSocket(res.emancipate()); + (req, res) + } } } @@ -119,6 +143,11 @@ where let res = ResponseRef::SupplyKeypair(res.emancipate_mut()); (req, res) } + Self::AddListenSocket((req, res)) => { + let req = RequestRef::AddListenSocket(req.emancipate_mut()); + let res = ResponseRef::AddListenSocket(res.emancipate_mut()); + (req, res) + } } } diff --git a/rosenpass/src/api/boilerplate/response_ref.rs b/rosenpass/src/api/boilerplate/response_ref.rs index d028685..32b6307 100644 --- a/rosenpass/src/api/boilerplate/response_ref.rs +++ b/rosenpass/src/api/boilerplate/response_ref.rs @@ -27,6 +27,7 @@ impl ResponseRef { match self { Self::Ping(_) => ResponseMsgType::Ping, Self::SupplyKeypair(_) => ResponseMsgType::SupplyKeypair, + Self::AddListenSocket(_) => ResponseMsgType::AddListenSocket, } } } @@ -43,6 +44,12 @@ impl From> for ResponseRef { } } +impl From> for ResponseRef { + fn from(v: Ref) -> Self { + Self::AddListenSocket(v) + } +} + impl ResponseRefMaker { fn new(buf: B) -> anyhow::Result { let msg_type = buf.deref().response_msg_type_from_prefix()?; @@ -59,6 +66,9 @@ impl ResponseRefMaker { ResponseMsgType::SupplyKeypair => { ResponseRef::SupplyKeypair(self.buf.supply_keypair_response()?) } + ResponseMsgType::AddListenSocket => { + ResponseRef::AddListenSocket(self.buf.add_listen_socket_response()?) + } }) } @@ -94,6 +104,7 @@ impl ResponseRefMaker { pub enum ResponseRef { Ping(Ref), SupplyKeypair(Ref), + AddListenSocket(Ref), } impl ResponseRef @@ -104,6 +115,7 @@ where match self { Self::Ping(r) => r.bytes(), Self::SupplyKeypair(r) => r.bytes(), + Self::AddListenSocket(r) => r.bytes(), } } } @@ -116,6 +128,7 @@ where match self { Self::Ping(r) => r.bytes_mut(), Self::SupplyKeypair(r) => r.bytes_mut(), + Self::AddListenSocket(r) => r.bytes_mut(), } } } diff --git a/rosenpass/src/api/boilerplate/server.rs b/rosenpass/src/api/boilerplate/server.rs index 92aefea..cb033c4 100644 --- a/rosenpass/src/api/boilerplate/server.rs +++ b/rosenpass/src/api/boilerplate/server.rs @@ -17,6 +17,13 @@ pub trait Server { res: &mut super::SupplyKeypairResponse, ) -> anyhow::Result<()>; + fn add_listen_socket( + &mut self, + req: &super::AddListenSocketRequest, + req_fds: &mut VecDeque, + res: &mut super::AddListenSocketResponse, + ) -> anyhow::Result<()>; + fn dispatch( &mut self, p: &mut RequestResponsePair, @@ -31,6 +38,9 @@ pub trait Server { RequestResponsePair::SupplyKeypair((req, res)) => { self.supply_keypair(req, req_fds, res) } + RequestResponsePair::AddListenSocket((req, res)) => { + self.add_listen_socket(req, req_fds, res) + } } } @@ -57,6 +67,11 @@ pub trait Server { res.init(); RequestResponsePair::SupplyKeypair((req, res)) } + RequestRef::AddListenSocket(req) => { + let mut res = res.add_listen_socket_response_from_prefix()?; + res.init(); + RequestResponsePair::AddListenSocket((req, res)) + } }; self.dispatch(&mut pair, req_fds)?; diff --git a/rosenpass/src/bin/gen-ipc-msg-types.rs b/rosenpass/src/bin/gen-ipc-msg-types.rs index 9d36801..22eef47 100644 --- a/rosenpass/src/bin/gen-ipc-msg-types.rs +++ b/rosenpass/src/bin/gen-ipc-msg-types.rs @@ -78,6 +78,8 @@ fn main() -> Result<()> { Tree::Leaf("Ping Response".to_owned()), Tree::Leaf("Supply Keypair Request".to_owned()), Tree::Leaf("Supply Keypair Response".to_owned()), + Tree::Leaf("Add Listen Socket Request".to_owned()), + Tree::Leaf("Add Listen Socket Response".to_owned()), ], )], ); diff --git a/rosenpass/tests/api-integration-tests-supply-keypair.rs b/rosenpass/tests/api-integration-tests-api-setup.rs similarity index 79% rename from rosenpass/tests/api-integration-tests-supply-keypair.rs rename to rosenpass/tests/api-integration-tests-api-setup.rs index da25391..97e7cfb 100644 --- a/rosenpass/tests/api-integration-tests-supply-keypair.rs +++ b/rosenpass/tests/api-integration-tests-api-setup.rs @@ -1,6 +1,5 @@ use std::{ io::{BufRead, BufReader}, - net::ToSocketAddrs, os::unix::net::UnixStream, process::Stdio, thread::sleep, @@ -8,19 +7,21 @@ use std::{ }; use anyhow::{bail, Context}; -use rosenpass::api::{self, supply_keypair_response_status}; +use rosenpass::api::{self, add_listen_socket_response_status, supply_keypair_response_status}; use rosenpass_util::{ file::LoadValueB64, length_prefix_encoding::{decoder::LengthPrefixDecoder, encoder::LengthPrefixEncoder}, + mio::WriteWithFileDescriptors, + zerocopy::ZerocopySliceExt, }; -use rosenpass_util::{mio::WriteWithFileDescriptors, zerocopy::ZerocopySliceExt}; +use rustix::fd::AsFd; use tempfile::TempDir; use zerocopy::AsBytes; use rosenpass::protocol::SymKey; #[test] -fn api_integration_test() -> anyhow::Result<()> { +fn api_integration_api_setup() -> anyhow::Result<()> { rosenpass_secret_memory::policy::secret_policy_use_only_malloc_secrets(); let dir = TempDir::with_prefix("rosenpass-api-integration-test")?; @@ -33,17 +34,20 @@ fn api_integration_test() -> anyhow::Result<()> { }} } - let peer_a_endpoint = "[::1]:61424"; + let peer_a_endpoint = "[::1]:0"; let peer_a_osk = tempfile!("a.osk"); let peer_b_osk = tempfile!("b.osk"); + let peer_a_listen = std::net::UdpSocket::bind(peer_a_endpoint)?; + let peer_a_endpoint = format!("{}", peer_a_listen.local_addr()?); + use rosenpass::config; let peer_a_keypair = config::Keypair::new(tempfile!("a.pk"), tempfile!("a.sk")); let peer_a = config::Rosenpass { config_file_path: tempfile!("a.config"), - keypair: Some(peer_a_keypair.clone()), - listen: peer_a_endpoint.to_socket_addrs()?.collect(), // TODO: This could collide by accident + keypair: None, + listen: vec![], // TODO: This could collide by accident verbosity: config::Verbosity::Verbose, api: api::config::ApiConfig { listen_path: vec![tempfile!("a.sock")], @@ -62,7 +66,7 @@ fn api_integration_test() -> anyhow::Result<()> { let peer_b_keypair = config::Keypair::new(tempfile!("b.pk"), tempfile!("b.sk")); let peer_b = config::Rosenpass { config_file_path: tempfile!("b.config"), - keypair: None, + keypair: Some(peer_b_keypair.clone()), listen: vec![], verbosity: config::Verbosity::Verbose, api: api::config::ApiConfig { @@ -118,7 +122,7 @@ fn api_integration_test() -> anyhow::Result<()> { let mut out_b = BufReader::new(proc_b.stdout.context("")?).lines(); // Now connect to the peers - let api_path = peer_b.api.listen_path[0].as_path(); + let api_path = peer_a.api.listen_path[0].as_path(); // Wait for the socket to be created let attempt = 0; @@ -132,11 +136,34 @@ fn api_integration_test() -> anyhow::Result<()> { let api = UnixStream::connect(api_path)?; + // Send AddListenSocket request + { + let fd = peer_a_listen.as_fd(); + + let mut fds = vec![&fd].into(); + let mut api = WriteWithFileDescriptors::::new(&api, &mut fds); + LengthPrefixEncoder::from_message(api::AddListenSocketRequest::new().as_bytes()) + .write_all_to_stdio(&mut api)?; + assert!(fds.is_empty(), "Failed to write all file descriptors"); + std::mem::forget(peer_a_listen); + } + + // Read response + { + let mut decoder = LengthPrefixDecoder::new([0u8; api::MAX_RESPONSE_LEN]); + let res = decoder.read_all_from_stdio(&api)?; + let res = res.zk_parse::()?; + assert_eq!( + *res, + api::AddListenSocketResponse::new(add_listen_socket_response_status::OK) + ); + } + // Send SupplyKeypairRequest { use rustix::fs::{open, Mode, OFlags}; - let sk = open(peer_b_keypair.secret_key, OFlags::RDONLY, Mode::empty())?; - let pk = open(peer_b_keypair.public_key, OFlags::RDONLY, Mode::empty())?; + let sk = open(peer_a_keypair.secret_key, OFlags::RDONLY, Mode::empty())?; + let pk = open(peer_a_keypair.public_key, OFlags::RDONLY, Mode::empty())?; let mut fds = vec![&sk, &pk].into(); let mut api = WriteWithFileDescriptors::::new(&api, &mut fds); @@ -147,7 +174,6 @@ fn api_integration_test() -> anyhow::Result<()> { // Read response { - //sleep(Duration::from_secs(10)); let mut decoder = LengthPrefixDecoder::new([0u8; api::MAX_RESPONSE_LEN]); let res = decoder.read_all_from_stdio(api)?; let res = res.zk_parse::()?; diff --git a/util/src/controlflow.rs b/util/src/controlflow.rs index 73626b2..211a032 100644 --- a/util/src/controlflow.rs +++ b/util/src/controlflow.rs @@ -7,6 +7,20 @@ macro_rules! repeat { }; } +#[macro_export] +macro_rules! return_unless { + ($cond:expr) => { + if !($cond) { + return; + } + }; + ($cond:expr, $val:expr) => { + if !($cond) { + return $val; + } + }; +} + #[macro_export] macro_rules! return_if { ($cond:expr) => { diff --git a/util/src/fd.rs b/util/src/fd.rs index 9a44564..6dd113d 100644 --- a/util/src/fd.rs +++ b/util/src/fd.rs @@ -1,12 +1,10 @@ +use anyhow::bail; use rustix::{ fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}, io::fcntl_dupfd_cloexec, }; -#[cfg(target_os = "linux")] -use rustix::io::DupFlags; - -use crate::mem::Forgetting; +use crate::{mem::Forgetting, result::OkExt}; /// Prepare a file descriptor for use in Rust code. /// @@ -51,7 +49,7 @@ pub fn clone_fd_cloexec(fd: Fd) -> rustix::io::Result { #[cfg(target_os = "linux")] pub fn clone_fd_to_cloexec(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> { - use rustix::io::dup3; + use rustix::io::{dup3, DupFlags}; dup3(fd, new, DupFlags::CLOEXEC) } @@ -111,6 +109,85 @@ impl std::io::Write for FdIo { } } +pub trait StatExt { + fn is_socket(&self) -> bool; +} + +impl StatExt for rustix::fs::Stat { + fn is_socket(&self) -> bool { + use rustix::fs::FileType; + let ft = FileType::from_raw_mode(self.st_mode); + matches!(ft, FileType::Socket) + } +} + +pub trait TryStatExt { + type Error; + fn is_socket(&self) -> Result; +} + +impl TryStatExt for T +where + T: AsFd, +{ + type Error = rustix::io::Errno; + + fn is_socket(&self) -> Result { + rustix::fs::fstat(self)?.is_socket().ok() + } +} + +pub trait GetSocketType { + type Error; + fn socket_type(&self) -> Result; + fn is_datagram_socket(&self) -> Result { + use rustix::net::SocketType; + matches!(self.socket_type()?, SocketType::DGRAM).ok() + } +} + +impl GetSocketType for T +where + T: AsFd, +{ + type Error = rustix::io::Errno; + + fn socket_type(&self) -> Result { + rustix::net::sockopt::get_socket_type(self) + } +} + +#[cfg(target_os = "linux")] +pub trait GetSocketProtocol { + fn socket_protocol(&self) -> Result, rustix::io::Errno>; + fn is_udp_socket(&self) -> Result { + self.socket_protocol()? + .map(|p| p == rustix::net::ipproto::UDP) + .unwrap_or(false) + .ok() + } + fn demand_udp_socket(&self) -> anyhow::Result<()> { + match self.socket_protocol() { + Ok(Some(rustix::net::ipproto::UDP)) => Ok(()), + Ok(Some(other_proto)) => { + bail!("Not a udp socket, instead socket protocol is: {other_proto:?}") + } + Ok(None) => bail!("getsockopt() returned empty value"), + Err(errno) => Err(errno.into_stdio_err())?, + } + } +} + +#[cfg(target_os = "linux")] +impl GetSocketProtocol for T +where + T: AsFd, +{ + fn socket_protocol(&self) -> Result, rustix::io::Errno> { + rustix::net::sockopt::get_socket_protocol(self) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/util/src/mio/mio.rs b/util/src/mio/mio.rs index 10757ee..f70de92 100644 --- a/util/src/mio/mio.rs +++ b/util/src/mio/mio.rs @@ -1,7 +1,7 @@ use mio::net::{UnixListener, UnixStream}; -use rustix::fd::RawFd; +use rustix::fd::{OwnedFd, RawFd}; -use crate::fd::claim_fd; +use crate::{fd::claim_fd, result::OkExt}; pub mod interest { use mio::Interest; @@ -25,15 +25,20 @@ impl UnixListenerExt for UnixListener { } pub trait UnixStreamExt: Sized { + fn from_fd(fd: OwnedFd) -> anyhow::Result; fn claim_fd(fd: RawFd) -> anyhow::Result; } impl UnixStreamExt for UnixStream { - fn claim_fd(fd: RawFd) -> anyhow::Result { + fn from_fd(fd: OwnedFd) -> anyhow::Result { use std::os::unix::net::UnixStream as StdUnixStream; - let sock = StdUnixStream::from(claim_fd(fd)?); + let sock = StdUnixStream::from(fd); sock.set_nonblocking(true)?; - Ok(UnixStream::from_std(sock)) + UnixStream::from_std(sock).ok() + } + + fn claim_fd(fd: RawFd) -> anyhow::Result { + Self::from_fd(claim_fd(fd)?) } }