From 7a31b572273bf4202ea89986569ee8df073dccc9 Mon Sep 17 00:00:00 2001 From: Karolin Varner Date: Sat, 10 Aug 2024 17:50:04 +0200 Subject: [PATCH] chore(API): Infrastructure to use endpoints with fd. passing --- Cargo.lock | 11 ++ Cargo.toml | 1 + rosenpass/Cargo.toml | 3 +- rosenpass/src/api/api_handler.rs | 3 + rosenpass/src/api/boilerplate/payload.rs | 1 + rosenpass/src/api/boilerplate/server.rs | 23 ++- rosenpass/src/api/mio/connection.rs | 191 +++++++++++++---------- util/Cargo.toml | 5 + util/src/controlflow.rs | 50 ++++++ util/src/fd.rs | 53 +++++++ util/src/{ => mio}/mio.rs | 0 util/src/mio/mod.rs | 12 ++ util/src/mio/uds_recv_fd.rs | 123 +++++++++++++++ util/src/mio/uds_send_fd.rs | 121 ++++++++++++++ 14 files changed, 505 insertions(+), 92 deletions(-) create mode 100644 util/src/controlflow.rs rename util/src/{ => mio}/mio.rs (100%) create mode 100644 util/src/mio/mod.rs create mode 100644 util/src/mio/uds_recv_fd.rs create mode 100644 util/src/mio/uds_send_fd.rs diff --git a/Cargo.lock b/Cargo.lock index d3d16fc..52211b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1828,6 +1828,7 @@ dependencies = [ "test_bin", "thiserror", "toml", + "uds", "zerocopy", "zeroize", ] @@ -1923,6 +1924,7 @@ dependencies = [ "tempfile", "thiserror", "typenum", + "uds", "zerocopy", "zeroize", ] @@ -2396,6 +2398,15 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "uds" +version = "0.4.2" +source = "git+https://github.com/rosenpass/uds#b47934fe52422e559f7278938875f9105f91c5a2" +dependencies = [ + "libc", + "mio", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/Cargo.toml b/Cargo.toml index 31b5228..f419e04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ hex-literal = { version = "0.4.1" } hex = { version = "0.4.3" } heck = { version = "0.5.0" } libc = { version = "0.2" } +uds = { git = "https://github.com/rosenpass/uds" } #Dev dependencies serial_test = "3.1.1" diff --git a/rosenpass/Cargo.toml b/rosenpass/Cargo.toml index 3788434..2099d92 100644 --- a/rosenpass/Cargo.toml +++ b/rosenpass/Cargo.toml @@ -55,6 +55,7 @@ hex = { workspace = true, optional = true } heck = { workspace = true, optional = true } command-fds = { workspace = true, optional = true } rustix = { workspace = true } +uds = { workspace = true, optional = true, features = ["mio_1xx"] } [build-dependencies] anyhow = { workspace = true } @@ -71,6 +72,6 @@ tempfile = { workspace = true } experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"] experiment_broker_api = ["rosenpass-wireguard-broker/experiment_broker_api", "command-fds"] experiment_libcrux = ["rosenpass-ciphers/experiment_libcrux"] -experiment_api = ["hex-literal"] +experiment_api = ["hex-literal", "uds", "rosenpass-util/experiment_file_descriptor_passing"] internal_testing = [] internal_bin_gen_ipc_msg_types = ["hex", "heck"] diff --git a/rosenpass/src/api/api_handler.rs b/rosenpass/src/api/api_handler.rs index 827c285..fee379d 100644 --- a/rosenpass/src/api/api_handler.rs +++ b/rosenpass/src/api/api_handler.rs @@ -1,3 +1,5 @@ +use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd}; + use rosenpass_to::{ops::copy_slice, To}; use crate::app_server::AppServer; @@ -30,6 +32,7 @@ where fn ping( &mut self, req: &super::PingRequest, + _req_fds: &mut VecDeque, res: &mut super::PingResponse, ) -> anyhow::Result<()> { let (req, res) = (&req.payload, &mut res.payload); diff --git a/rosenpass/src/api/boilerplate/payload.rs b/rosenpass/src/api/boilerplate/payload.rs index 7537ee5..ea67939 100644 --- a/rosenpass/src/api/boilerplate/payload.rs +++ b/rosenpass/src/api/boilerplate/payload.rs @@ -6,6 +6,7 @@ use super::{Message, RawMsgType, RequestMsgType, ResponseMsgType}; /// Size required to fit any message in binary form pub const MAX_REQUEST_LEN: usize = 2500; // TODO fix this pub const MAX_RESPONSE_LEN: usize = 2500; // TODO fix this +pub const MAX_REQUEST_FDS: usize = 2; #[repr(packed)] #[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)] diff --git a/rosenpass/src/api/boilerplate/server.rs b/rosenpass/src/api/boilerplate/server.rs index 7506aaa..a89ea03 100644 --- a/rosenpass/src/api/boilerplate/server.rs +++ b/rosenpass/src/api/boilerplate/server.rs @@ -1,24 +1,35 @@ +use super::{ByteSliceRefExt, Message, PingRequest, PingResponse, RequestRef, RequestResponsePair}; +use std::{collections::VecDeque, os::fd::OwnedFd}; use zerocopy::{ByteSlice, ByteSliceMut}; -use super::{ByteSliceRefExt, Message, PingRequest, PingResponse, RequestRef, RequestResponsePair}; - pub trait Server { - fn ping(&mut self, req: &PingRequest, res: &mut PingResponse) -> anyhow::Result<()>; + fn ping( + &mut self, + req: &PingRequest, + req_fds: &mut VecDeque, + res: &mut PingResponse, + ) -> anyhow::Result<()>; fn dispatch( &mut self, p: &mut RequestResponsePair, + req_fds: &mut VecDeque, ) -> anyhow::Result<()> where ReqBuf: ByteSlice, ResBuf: ByteSliceMut, { match p { - RequestResponsePair::Ping((req, res)) => self.ping(req, res), + RequestResponsePair::Ping((req, res)) => self.ping(req, req_fds, res), } } - fn handle_message(&mut self, req: ReqBuf, res: ResBuf) -> anyhow::Result + fn handle_message( + &mut self, + req: ReqBuf, + req_fds: &mut VecDeque, + res: ResBuf, + ) -> anyhow::Result where ReqBuf: ByteSlice, ResBuf: ByteSliceMut, @@ -32,7 +43,7 @@ pub trait Server { RequestResponsePair::Ping((req, res)) } }; - self.dispatch(&mut pair)?; + self.dispatch(&mut pair, req_fds)?; let res_len = pair.response().bytes().len(); Ok(res_len) diff --git a/rosenpass/src/api/mio/connection.rs b/rosenpass/src/api/mio/connection.rs index f0c891c..1697d25 100644 --- a/rosenpass/src/api/mio/connection.rs +++ b/rosenpass/src/api/mio/connection.rs @@ -1,7 +1,10 @@ use std::borrow::{Borrow, BorrowMut}; +use std::collections::VecDeque; +use std::os::fd::OwnedFd; use mio::net::UnixStream; use rosenpass_secret_memory::Secret; +use rosenpass_util::mio::ReadWithFileDescriptors; use rosenpass_util::{ io::{IoResultKindHintExt, TryIoResultKindHintExt}, length_prefix_encoding::{ @@ -12,6 +15,7 @@ use rosenpass_util::{ }; use zeroize::Zeroize; +use crate::api::MAX_REQUEST_FDS; use crate::{api::Server, app_server::AppServer}; use super::super::{ApiHandler, ApiHandlerContext}; @@ -39,11 +43,13 @@ impl BorrowMut<[u8]> for SecretBuffer { // TODO: Unfortunately, zerocopy is quite particular about alignment, hence the 4096 type ReadBuffer = LengthPrefixDecoder>; type WriteBuffer = LengthPrefixEncoder>; +type ReadFdBuffer = VecDeque; #[derive(Debug)] struct MioConnectionBuffers { read_buffer: ReadBuffer, write_buffer: WriteBuffer, + read_fd_buffer: ReadFdBuffer, } #[derive(Debug)] @@ -65,9 +71,11 @@ impl MioConnection { let invalid_read = false; let read_buffer = LengthPrefixDecoder::new(SecretBuffer::new()); let write_buffer = LengthPrefixEncoder::from_buffer(SecretBuffer::new()); + let read_fd_buffer = VecDeque::new(); let buffers = Some(MioConnectionBuffers { read_buffer, write_buffer, + read_fd_buffer, }); let api_state = ApiHandler::new(); Ok(Self { @@ -106,20 +114,22 @@ pub trait MioConnectionContext { } fn handle_incoming_message(&mut self) -> anyhow::Result> { - self.with_buffers_stolen(|this, read_buf, write_buf| { + self.with_buffers_stolen(|this, bufs| { // Acquire request & response. Caller is responsible to make sure // that read buffer holds a message and that write buffer is cleared. // Hence the unwraps and assertions - assert!(write_buf.exhausted()); - let req = read_buf.message().unwrap().unwrap(); - let res = write_buf.buffer_bytes_mut(); + assert!(bufs.write_buffer.exhausted()); + let req = bufs.read_buffer.message().unwrap().unwrap(); + let req_fds = &mut bufs.read_fd_buffer; + let res = bufs.write_buffer.buffer_bytes_mut(); // Call API handler // Transitive trait implementations: MioConnectionContext -> ApiHandlerContext -> as ApiServer - let response_len = this.handle_message(req, res)?; + let response_len = this.handle_message(req, req_fds, res)?; - write_buf.restart_write_with_new_message(response_len)?; - read_buf.zeroize(); // clear for new message to read + bufs.write_buffer + .restart_write_with_new_message(response_len)?; + bufs.read_buffer.zeroize(); // clear for new message to read Ok(Some(())) }) @@ -130,36 +140,37 @@ pub trait MioConnectionContext { return Ok(Some(())); } - self.with_buffers_stolen(|this, _read_buf, write_buf| { - use lpe_encoder::WriteToIoReturn as Ret; - use std::io::ErrorKind as K; + use lpe_encoder::WriteToIoReturn as Ret; + use std::io::ErrorKind as K; - loop { - match write_buf - .write_to_stdio(&this.mio_connection_mut().io) - .io_err_kind_hint() - { - // Done - Ok(Ret { done: true, .. }) => { - write_buf.zeroize(); // clear for new message to write - break Ok(Some(())); - }, + loop { + let conn = self.mio_connection_mut(); + let bufs = conn.buffers.as_mut().unwrap(); - // Would block - Ok(Ret { - bytes_written: 0, .. - }) => break Ok(None), - Err((_e, K::WouldBlock)) => break Ok(None), + let sock = &conn.io; + let write_buf = &mut bufs.write_buffer; - // Just continue - Ok(_) => continue, /* Ret { bytes_written > 0, done = false } acc. to previous cases*/ - Err((_e, K::Interrupted)) => continue, - - // Other errors - Err((e, _ek)) => Err(e)?, + match write_buf.write_to_stdio(sock).io_err_kind_hint() { + // Done + Ok(Ret { done: true, .. }) => { + write_buf.zeroize(); // clear for new message to write + break Ok(Some(())); } + + // Would block + Ok(Ret { + bytes_written: 0, .. + }) => break Ok(None), + Err((_e, K::WouldBlock)) => break Ok(None), + + // Just continue + Ok(_) => continue, /* Ret { bytes_written > 0, done = false } acc. to previous cases*/ + Err((_e, K::Interrupted)) => continue, + + // Other errors + Err((e, _ek)) => Err(e)?, } - }) + } } fn recv(&mut self) -> anyhow::Result> { @@ -167,49 +178,68 @@ pub trait MioConnectionContext { return Ok(None); } - self.with_buffers_stolen(|this, read_buf, _write_buf| { - use lpe_decoder::{ReadFromIoError as E, ReadFromIoReturn as Ret}; - use std::io::ErrorKind as K; + use lpe_decoder::{ReadFromIoError as E, ReadFromIoReturn as Ret}; + use std::io::ErrorKind as K; - loop { - match read_buf - .read_from_stdio(&this.mio_connection_mut().io) - .try_io_err_kind_hint() - { - // We actually received a proper message - // (Impl below match to appease borrow checker) - Ok(Ret { - message: Some(_msg), - .. - }) => break Ok(Some(())), + loop { + let conn = self.mio_connection_mut(); + let bufs = conn.buffers.as_mut().unwrap(); - // Message does not fit in buffer - Err((e @ E::MessageTooLargeError { .. }, _)) => { - log::warn!("Received message on API that was too big to fit in our buffers; \ + let read_buf = &mut bufs.read_buffer; + let read_fd_buf = &mut bufs.read_fd_buffer; + + let sock = &conn.io; + let fd_passing_sock = ReadWithFileDescriptors::::new( + sock, + read_fd_buf, + ); + + match read_buf + .read_from_stdio(fd_passing_sock) + .try_io_err_kind_hint() + { + // We actually received a proper message + // (Impl below match to appease borrow checker) + Ok(Ret { + message: Some(_msg), + .. + }) => break Ok(Some(())), + + // Message does not fit in buffer + Err((e @ E::MessageTooLargeError { .. }, _)) => { + log::warn!("Received message on API that was too big to fit in our buffers; \ looks like the client is broken. Stopping to process messages of the client.\n\ Error: {e:?}"); - // TODO: We should properly close down the socket in this case, but to do that, - // we need to have the facilities in the Rosenpass IO handling system to close - // open connections. - // Just leaving the API connections dangling for now. - // This should be fixed for non-experimental use of the API. - this.mio_connection_mut().invalid_read = true; - break Ok(None); - } + // TODO: We should properly close down the socket in this case, but to do that, + // we need to have the facilities in the Rosenpass IO handling system to close + // open connections. + // Just leaving the API connections dangling for now. + // This should be fixed for non-experimental use of the API. + conn.invalid_read = true; + break Ok(None); + } - // Would block - Ok(Ret { bytes_read: 0, .. }) => break Ok(None), - Err((_, Some(K::WouldBlock))) => break Ok(None), + // Would block + Ok(Ret { bytes_read: 0, .. }) => break Ok(None), + Err((_, Some(K::WouldBlock))) => break Ok(None), - // Just keep going - Ok(Ret { bytes_read: _, .. }) => continue, - Err((_, Some(K::Interrupted))) => continue, + // Just keep going + Ok(Ret { bytes_read: _, .. }) => continue, + Err((_, Some(K::Interrupted))) => continue, - // Other IO Error (just pass on to the caller) - Err((E::IoError(e), _)) => Err(e)?, - }; - } - }) + // Other IO Error (just pass on to the caller) + Err((E::IoError(e), _)) => { + log::warn!( + "IO error while trying to read message from API socket. \ + The connection is broken. Stopping to process messages of the client.\n\ + Error: {e:?}" + ); + // TODO: Same as above + conn.invalid_read = true; + break Err(e.into()); + } + }; + } } } @@ -224,32 +254,23 @@ trait MioConnectionContextPrivate: MioConnectionContext { let _ = opt.insert(buffers); } - fn with_buffers_stolen R>( + fn with_buffers_stolen R>( &mut self, f: F, ) -> R { let mut bufs = self.steal_buffers(); - let res = f(self, &mut bufs.read_buffer, &mut bufs.write_buffer); + let res = f(self, &mut bufs); self.return_buffers(bufs); res } - fn both_buffers_mut(&mut self) -> (&mut ReadBuffer, &mut WriteBuffer) { - let bufs = self.mio_connection_mut().buffers.as_mut().unwrap(); - let MioConnectionBuffers { - ref mut read_buffer, - ref mut write_buffer, - } = bufs; - (read_buffer, write_buffer) - } - - #[allow(dead_code)] - fn read_buf_mut(&mut self) -> &mut ReadBuffer { - self.both_buffers_mut().0 - } - fn write_buf_mut(&mut self) -> &mut WriteBuffer { - self.both_buffers_mut().1 + self.mio_connection_mut() + .buffers + .as_mut() + .unwrap() + .write_buffer + .borrow_mut() } } diff --git a/util/Cargo.toml b/util/Cargo.toml index 4ad42a7..d6bdf00 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -22,3 +22,8 @@ zerocopy = { workspace = true } thiserror = { workspace = true } mio = { workspace = true } tempfile = { workspace = true } +uds = { workspace = true, optional = true, features = ["mio_1xx"] } + + +[features] +experiment_file_descriptor_passing = ["uds"] diff --git a/util/src/controlflow.rs b/util/src/controlflow.rs new file mode 100644 index 0000000..73626b2 --- /dev/null +++ b/util/src/controlflow.rs @@ -0,0 +1,50 @@ +#[macro_export] +macro_rules! repeat { + ($times:expr, $body:expr) => { + for _ in 0..($times) { + $body + } + }; +} + +#[macro_export] +macro_rules! return_if { + ($cond:expr) => { + if $cond { + return; + } + }; + ($cond:expr, $val:expr) => { + if $cond { + return $val; + } + }; +} + +#[macro_export] +macro_rules! break_if { + ($cond:expr) => { + if $cond { + break; + } + }; + ($cond:expr, $val:expr) => { + if $cond { + break $val; + } + }; +} + +#[macro_export] +macro_rules! continue_if { + ($cond:expr) => { + if $cond { + continue; + } + }; + ($cond:expr, $val:expr) => { + if $cond { + break $val; + } + }; +} diff --git a/util/src/fd.rs b/util/src/fd.rs index 3d7af75..9a44564 100644 --- a/util/src/fd.rs +++ b/util/src/fd.rs @@ -25,6 +25,18 @@ pub fn claim_fd(fd: RawFd) -> rustix::io::Result { Ok(new) } +/// Prepare a file descriptor for use in Rust code. +/// +/// Checks if the file descriptor is valid. +/// +/// Unlike [claim_fd], this will reuse the same file descriptor identifier instead of masking it. +pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result { + let mut new = unsafe { OwnedFd::from_raw_fd(fd) }; + let tmp = clone_fd_cloexec(&new)?; + clone_fd_to_cloexec(tmp, &mut new)?; + Ok(new) +} + pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> { // Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting, // it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd @@ -58,6 +70,47 @@ pub fn open_nullfd() -> rustix::io::Result { open("/dev/null", OFlags::CLOEXEC, Mode::empty()) } +/// Convert low level errors into std::io::Error +pub trait IntoStdioErr { + type Target; + fn into_stdio_err(self) -> Self::Target; +} + +impl IntoStdioErr for rustix::io::Errno { + type Target = std::io::Error; + + fn into_stdio_err(self) -> Self::Target { + std::io::Error::from_raw_os_error(self.raw_os_error()) + } +} + +impl IntoStdioErr for rustix::io::Result { + type Target = std::io::Result; + + fn into_stdio_err(self) -> Self::Target { + self.map_err(IntoStdioErr::into_stdio_err) + } +} + +/// Read and write directly from a file descriptor +pub struct FdIo(pub Fd); + +impl std::io::Read for FdIo { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + rustix::io::read(&self.0, buf).into_stdio_err() + } +} + +impl std::io::Write for FdIo { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + rustix::io::write(&self.0, buf).into_stdio_err() + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/util/src/mio.rs b/util/src/mio/mio.rs similarity index 100% rename from util/src/mio.rs rename to util/src/mio/mio.rs diff --git a/util/src/mio/mod.rs b/util/src/mio/mod.rs new file mode 100644 index 0000000..251e211 --- /dev/null +++ b/util/src/mio/mod.rs @@ -0,0 +1,12 @@ +#[allow(clippy::module_inception)] +mod mio; +pub use mio::*; + +#[cfg(feature = "experiment_file_descriptor_passing")] +mod uds_recv_fd; +#[cfg(feature = "experiment_file_descriptor_passing")] +mod uds_send_fd; +#[cfg(feature = "experiment_file_descriptor_passing")] +pub use uds_recv_fd::*; +#[cfg(feature = "experiment_file_descriptor_passing")] +pub use uds_send_fd::*; diff --git a/util/src/mio/uds_recv_fd.rs b/util/src/mio/uds_recv_fd.rs new file mode 100644 index 0000000..8e83b83 --- /dev/null +++ b/util/src/mio/uds_recv_fd.rs @@ -0,0 +1,123 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + collections::VecDeque, + io::Read, + marker::PhantomData, + os::fd::OwnedFd, +}; +use uds::UnixStreamExt as FdPassingExt; + +use crate::fd::{claim_fd_inplace, IntoStdioErr}; + +pub struct ReadWithFileDescriptors +where + Sock: FdPassingExt, + BorrowSock: Borrow, + BorrowFds: BorrowMut>, +{ + socket: BorrowSock, + fds: BorrowFds, + _sock_dummy: PhantomData, +} + +impl + ReadWithFileDescriptors +where + Sock: FdPassingExt, + BorrowSock: Borrow, + BorrowFds: BorrowMut>, +{ + pub fn new(socket: BorrowSock, fds: BorrowFds) -> Self { + let _sock_dummy = PhantomData; + Self { + socket, + fds, + _sock_dummy, + } + } + + pub fn into_parts(self) -> (BorrowSock, BorrowFds) { + let Self { socket, fds, .. } = self; + (socket, fds) + } + + pub fn socket(&self) -> &Sock { + self.socket.borrow() + } + + pub fn fds(&self) -> &VecDeque { + self.fds.borrow() + } + + pub fn fds_mut(&mut self) -> &mut VecDeque { + self.fds.borrow_mut() + } +} + +impl + ReadWithFileDescriptors +where + Sock: FdPassingExt, + BorrowSock: BorrowMut, + BorrowFds: BorrowMut>, +{ + pub fn socket_mut(&mut self) -> &mut Sock { + self.socket.borrow_mut() + } +} + +impl Read + for ReadWithFileDescriptors +where + Sock: FdPassingExt, + BorrowSock: Borrow, + BorrowFds: BorrowMut>, +{ + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + // Calculate space for additional file descriptors + let have_fds_before_read = self.fds().len(); + let free_fd_slots = MAX_FDS.saturating_sub(have_fds_before_read); + + // Allocate a buffer for file descriptors + let mut fd_buf = [0; MAX_FDS]; + let fd_buf = &mut fd_buf[..free_fd_slots]; + + // Read from the unix socket + let (bytes_read, fds_read) = self.socket.borrow().recv_fds(buf, fd_buf)?; + let fd_buf = &fd_buf[..fds_read]; + + // Process the file descriptors + let mut fd_iter = fd_buf.iter(); + + // Try claiming all the file descriptors + let mut claim_fd_result = Ok(bytes_read); + self.fds_mut().reserve(fd_buf.len()); + for fd in fd_iter.by_ref() { + match claim_fd_inplace(*fd) { + Ok(owned) => self.fds_mut().push_back(owned), + Err(e) => { + // Abort on error and pass to error handler + // Note that claim_fd_inplace is responsible for closing this particular + // file descriptor if claiming it fails + claim_fd_result = Err(e.into_stdio_err()); + break; + } + } + } + + // Return if we where able to claim all file descriptors + if claim_fd_result.is_ok() { + return claim_fd_result; + }; + + // An error occurred while claiming fds + self.fds_mut().truncate(have_fds_before_read); // Close fds successfully claimed + + // Close the remaining fds + for fd in fd_iter { + unsafe { rustix::io::close(*fd) }; + } + + claim_fd_result + } +} diff --git a/util/src/mio/uds_send_fd.rs b/util/src/mio/uds_send_fd.rs new file mode 100644 index 0000000..07d5d6f --- /dev/null +++ b/util/src/mio/uds_send_fd.rs @@ -0,0 +1,121 @@ +use rustix::fd::{AsFd, AsRawFd}; +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::min, + collections::VecDeque, + io::Write, + marker::PhantomData, +}; +use uds::UnixStreamExt as FdPassingExt; + +use crate::{repeat, return_if}; + +pub struct WriteWithFileDescriptors +where + Sock: FdPassingExt, + Fd: AsFd, + BorrowSock: Borrow, + BorrowFds: BorrowMut>, +{ + socket: BorrowSock, + fds: BorrowFds, + _sock_dummy: PhantomData, + _fd_dummy: PhantomData, +} + +impl WriteWithFileDescriptors +where + Sock: FdPassingExt, + Fd: AsFd, + BorrowSock: Borrow, + BorrowFds: BorrowMut>, +{ + pub fn new(socket: BorrowSock, fds: BorrowFds) -> Self { + let _sock_dummy = PhantomData; + let _fd_dummy = PhantomData; + Self { + socket, + fds, + _sock_dummy, + _fd_dummy, + } + } + + pub fn into_parts(self) -> (BorrowSock, BorrowFds) { + let Self { socket, fds, .. } = self; + (socket, fds) + } + + pub fn socket(&self) -> &Sock { + self.socket.borrow() + } + + pub fn fds(&self) -> &VecDeque { + self.fds.borrow() + } + + pub fn fds_mut(&mut self) -> &mut VecDeque { + self.fds.borrow_mut() + } +} + +impl WriteWithFileDescriptors +where + Sock: FdPassingExt, + Fd: AsFd, + BorrowSock: BorrowMut, + BorrowFds: BorrowMut>, +{ + pub fn socket_mut(&mut self) -> &mut Sock { + self.socket.borrow_mut() + } +} + +impl Write + for WriteWithFileDescriptors +where + Sock: FdPassingExt, + Fd: AsFd, + BorrowSock: Borrow, + BorrowFds: BorrowMut>, +{ + fn write(&mut self, buf: &[u8]) -> std::io::Result { + // At least one byte of real data should be sent when sending ancillary data. -- unix(7) + return_if!(buf.is_empty(), Ok(0)); + + // The kernel constant SCM_MAX_FD defines a limit on the number of file descriptors + // in the array. Attempting to send an array larger than this limit causes + // sendmsg(2) to fail with the error EINVAL. SCM_MAX_FD has the value 253 (or 255 + // before Linux 2.6.38). + // -- unix(7) + const SCM_MAX_FD: usize = 253; + let buf = match self.fds().len() <= SCM_MAX_FD { + false => &buf[..1], // Force caller to immediately call write() again to send its data + true => buf, + }; + + // Allocate the buffer for the file descriptor array + let fd_no = min(SCM_MAX_FD, self.fds().len()); + let mut fd_buf = [0; SCM_MAX_FD]; // My kingdom for alloca(3) + let fd_buf = &mut fd_buf[..fd_no]; + + // Fill the file descriptor array + for (raw, fancy) in fd_buf.iter_mut().zip(self.fds().iter()) { + *raw = fancy.as_fd().as_raw_fd(); + } + + // Send data and file descriptors + let bytes_written = self.socket().send_fds(buf, fd_buf)?; + + // Drop the file descriptors from the Deque + repeat!(fd_no, { + self.fds_mut().pop_front(); + }); + + Ok(bytes_written) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +}