mirror of
https://github.com/rosenpass/rosenpass.git
synced 2026-02-27 22:13:12 -08:00
chore(API): Infrastructure to use endpoints with fd. passing
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -1828,6 +1828,7 @@ dependencies = [
|
|||||||
"test_bin",
|
"test_bin",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"toml",
|
"toml",
|
||||||
|
"uds",
|
||||||
"zerocopy",
|
"zerocopy",
|
||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
@@ -1923,6 +1924,7 @@ dependencies = [
|
|||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"typenum",
|
"typenum",
|
||||||
|
"uds",
|
||||||
"zerocopy",
|
"zerocopy",
|
||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
@@ -2396,6 +2398,15 @@ version = "1.17.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "uds"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "git+https://github.com/rosenpass/uds#b47934fe52422e559f7278938875f9105f91c5a2"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"mio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.12"
|
version = "1.0.12"
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ hex-literal = { version = "0.4.1" }
|
|||||||
hex = { version = "0.4.3" }
|
hex = { version = "0.4.3" }
|
||||||
heck = { version = "0.5.0" }
|
heck = { version = "0.5.0" }
|
||||||
libc = { version = "0.2" }
|
libc = { version = "0.2" }
|
||||||
|
uds = { git = "https://github.com/rosenpass/uds" }
|
||||||
|
|
||||||
#Dev dependencies
|
#Dev dependencies
|
||||||
serial_test = "3.1.1"
|
serial_test = "3.1.1"
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ hex = { workspace = true, optional = true }
|
|||||||
heck = { workspace = true, optional = true }
|
heck = { workspace = true, optional = true }
|
||||||
command-fds = { workspace = true, optional = true }
|
command-fds = { workspace = true, optional = true }
|
||||||
rustix = { workspace = true }
|
rustix = { workspace = true }
|
||||||
|
uds = { workspace = true, optional = true, features = ["mio_1xx"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@@ -71,6 +72,6 @@ tempfile = { workspace = true }
|
|||||||
experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"]
|
experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"]
|
||||||
experiment_broker_api = ["rosenpass-wireguard-broker/experiment_broker_api", "command-fds"]
|
experiment_broker_api = ["rosenpass-wireguard-broker/experiment_broker_api", "command-fds"]
|
||||||
experiment_libcrux = ["rosenpass-ciphers/experiment_libcrux"]
|
experiment_libcrux = ["rosenpass-ciphers/experiment_libcrux"]
|
||||||
experiment_api = ["hex-literal"]
|
experiment_api = ["hex-literal", "uds", "rosenpass-util/experiment_file_descriptor_passing"]
|
||||||
internal_testing = []
|
internal_testing = []
|
||||||
internal_bin_gen_ipc_msg_types = ["hex", "heck"]
|
internal_bin_gen_ipc_msg_types = ["hex", "heck"]
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd};
|
||||||
|
|
||||||
use rosenpass_to::{ops::copy_slice, To};
|
use rosenpass_to::{ops::copy_slice, To};
|
||||||
|
|
||||||
use crate::app_server::AppServer;
|
use crate::app_server::AppServer;
|
||||||
@@ -30,6 +32,7 @@ where
|
|||||||
fn ping(
|
fn ping(
|
||||||
&mut self,
|
&mut self,
|
||||||
req: &super::PingRequest,
|
req: &super::PingRequest,
|
||||||
|
_req_fds: &mut VecDeque<OwnedFd>,
|
||||||
res: &mut super::PingResponse,
|
res: &mut super::PingResponse,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let (req, res) = (&req.payload, &mut res.payload);
|
let (req, res) = (&req.payload, &mut res.payload);
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ use super::{Message, RawMsgType, RequestMsgType, ResponseMsgType};
|
|||||||
/// Size required to fit any message in binary form
|
/// Size required to fit any message in binary form
|
||||||
pub const MAX_REQUEST_LEN: usize = 2500; // TODO fix this
|
pub const MAX_REQUEST_LEN: usize = 2500; // TODO fix this
|
||||||
pub const MAX_RESPONSE_LEN: usize = 2500; // TODO fix this
|
pub const MAX_RESPONSE_LEN: usize = 2500; // TODO fix this
|
||||||
|
pub const MAX_REQUEST_FDS: usize = 2;
|
||||||
|
|
||||||
#[repr(packed)]
|
#[repr(packed)]
|
||||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||||
|
|||||||
@@ -1,24 +1,35 @@
|
|||||||
|
use super::{ByteSliceRefExt, Message, PingRequest, PingResponse, RequestRef, RequestResponsePair};
|
||||||
|
use std::{collections::VecDeque, os::fd::OwnedFd};
|
||||||
use zerocopy::{ByteSlice, ByteSliceMut};
|
use zerocopy::{ByteSlice, ByteSliceMut};
|
||||||
|
|
||||||
use super::{ByteSliceRefExt, Message, PingRequest, PingResponse, RequestRef, RequestResponsePair};
|
|
||||||
|
|
||||||
pub trait Server {
|
pub trait Server {
|
||||||
fn ping(&mut self, req: &PingRequest, res: &mut PingResponse) -> anyhow::Result<()>;
|
fn ping(
|
||||||
|
&mut self,
|
||||||
|
req: &PingRequest,
|
||||||
|
req_fds: &mut VecDeque<OwnedFd>,
|
||||||
|
res: &mut PingResponse,
|
||||||
|
) -> anyhow::Result<()>;
|
||||||
|
|
||||||
fn dispatch<ReqBuf, ResBuf>(
|
fn dispatch<ReqBuf, ResBuf>(
|
||||||
&mut self,
|
&mut self,
|
||||||
p: &mut RequestResponsePair<ReqBuf, ResBuf>,
|
p: &mut RequestResponsePair<ReqBuf, ResBuf>,
|
||||||
|
req_fds: &mut VecDeque<OwnedFd>,
|
||||||
) -> anyhow::Result<()>
|
) -> anyhow::Result<()>
|
||||||
where
|
where
|
||||||
ReqBuf: ByteSlice,
|
ReqBuf: ByteSlice,
|
||||||
ResBuf: ByteSliceMut,
|
ResBuf: ByteSliceMut,
|
||||||
{
|
{
|
||||||
match p {
|
match p {
|
||||||
RequestResponsePair::Ping((req, res)) => self.ping(req, res),
|
RequestResponsePair::Ping((req, res)) => self.ping(req, req_fds, res),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_message<ReqBuf, ResBuf>(&mut self, req: ReqBuf, res: ResBuf) -> anyhow::Result<usize>
|
fn handle_message<ReqBuf, ResBuf>(
|
||||||
|
&mut self,
|
||||||
|
req: ReqBuf,
|
||||||
|
req_fds: &mut VecDeque<OwnedFd>,
|
||||||
|
res: ResBuf,
|
||||||
|
) -> anyhow::Result<usize>
|
||||||
where
|
where
|
||||||
ReqBuf: ByteSlice,
|
ReqBuf: ByteSlice,
|
||||||
ResBuf: ByteSliceMut,
|
ResBuf: ByteSliceMut,
|
||||||
@@ -32,7 +43,7 @@ pub trait Server {
|
|||||||
RequestResponsePair::Ping((req, res))
|
RequestResponsePair::Ping((req, res))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.dispatch(&mut pair)?;
|
self.dispatch(&mut pair, req_fds)?;
|
||||||
|
|
||||||
let res_len = pair.response().bytes().len();
|
let res_len = pair.response().bytes().len();
|
||||||
Ok(res_len)
|
Ok(res_len)
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
use std::borrow::{Borrow, BorrowMut};
|
use std::borrow::{Borrow, BorrowMut};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::os::fd::OwnedFd;
|
||||||
|
|
||||||
use mio::net::UnixStream;
|
use mio::net::UnixStream;
|
||||||
use rosenpass_secret_memory::Secret;
|
use rosenpass_secret_memory::Secret;
|
||||||
|
use rosenpass_util::mio::ReadWithFileDescriptors;
|
||||||
use rosenpass_util::{
|
use rosenpass_util::{
|
||||||
io::{IoResultKindHintExt, TryIoResultKindHintExt},
|
io::{IoResultKindHintExt, TryIoResultKindHintExt},
|
||||||
length_prefix_encoding::{
|
length_prefix_encoding::{
|
||||||
@@ -12,6 +15,7 @@ use rosenpass_util::{
|
|||||||
};
|
};
|
||||||
use zeroize::Zeroize;
|
use zeroize::Zeroize;
|
||||||
|
|
||||||
|
use crate::api::MAX_REQUEST_FDS;
|
||||||
use crate::{api::Server, app_server::AppServer};
|
use crate::{api::Server, app_server::AppServer};
|
||||||
|
|
||||||
use super::super::{ApiHandler, ApiHandlerContext};
|
use super::super::{ApiHandler, ApiHandlerContext};
|
||||||
@@ -39,11 +43,13 @@ impl<const N: usize> BorrowMut<[u8]> for SecretBuffer<N> {
|
|||||||
// TODO: Unfortunately, zerocopy is quite particular about alignment, hence the 4096
|
// TODO: Unfortunately, zerocopy is quite particular about alignment, hence the 4096
|
||||||
type ReadBuffer = LengthPrefixDecoder<SecretBuffer<4096>>;
|
type ReadBuffer = LengthPrefixDecoder<SecretBuffer<4096>>;
|
||||||
type WriteBuffer = LengthPrefixEncoder<SecretBuffer<4096>>;
|
type WriteBuffer = LengthPrefixEncoder<SecretBuffer<4096>>;
|
||||||
|
type ReadFdBuffer = VecDeque<OwnedFd>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct MioConnectionBuffers {
|
struct MioConnectionBuffers {
|
||||||
read_buffer: ReadBuffer,
|
read_buffer: ReadBuffer,
|
||||||
write_buffer: WriteBuffer,
|
write_buffer: WriteBuffer,
|
||||||
|
read_fd_buffer: ReadFdBuffer,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -65,9 +71,11 @@ impl MioConnection {
|
|||||||
let invalid_read = false;
|
let invalid_read = false;
|
||||||
let read_buffer = LengthPrefixDecoder::new(SecretBuffer::new());
|
let read_buffer = LengthPrefixDecoder::new(SecretBuffer::new());
|
||||||
let write_buffer = LengthPrefixEncoder::from_buffer(SecretBuffer::new());
|
let write_buffer = LengthPrefixEncoder::from_buffer(SecretBuffer::new());
|
||||||
|
let read_fd_buffer = VecDeque::new();
|
||||||
let buffers = Some(MioConnectionBuffers {
|
let buffers = Some(MioConnectionBuffers {
|
||||||
read_buffer,
|
read_buffer,
|
||||||
write_buffer,
|
write_buffer,
|
||||||
|
read_fd_buffer,
|
||||||
});
|
});
|
||||||
let api_state = ApiHandler::new();
|
let api_state = ApiHandler::new();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@@ -106,20 +114,22 @@ pub trait MioConnectionContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn handle_incoming_message(&mut self) -> anyhow::Result<Option<()>> {
|
fn handle_incoming_message(&mut self) -> anyhow::Result<Option<()>> {
|
||||||
self.with_buffers_stolen(|this, read_buf, write_buf| {
|
self.with_buffers_stolen(|this, bufs| {
|
||||||
// Acquire request & response. Caller is responsible to make sure
|
// Acquire request & response. Caller is responsible to make sure
|
||||||
// that read buffer holds a message and that write buffer is cleared.
|
// that read buffer holds a message and that write buffer is cleared.
|
||||||
// Hence the unwraps and assertions
|
// Hence the unwraps and assertions
|
||||||
assert!(write_buf.exhausted());
|
assert!(bufs.write_buffer.exhausted());
|
||||||
let req = read_buf.message().unwrap().unwrap();
|
let req = bufs.read_buffer.message().unwrap().unwrap();
|
||||||
let res = write_buf.buffer_bytes_mut();
|
let req_fds = &mut bufs.read_fd_buffer;
|
||||||
|
let res = bufs.write_buffer.buffer_bytes_mut();
|
||||||
|
|
||||||
// Call API handler
|
// Call API handler
|
||||||
// Transitive trait implementations: MioConnectionContext -> ApiHandlerContext -> as ApiServer
|
// Transitive trait implementations: MioConnectionContext -> ApiHandlerContext -> as ApiServer
|
||||||
let response_len = this.handle_message(req, res)?;
|
let response_len = this.handle_message(req, req_fds, res)?;
|
||||||
|
|
||||||
write_buf.restart_write_with_new_message(response_len)?;
|
bufs.write_buffer
|
||||||
read_buf.zeroize(); // clear for new message to read
|
.restart_write_with_new_message(response_len)?;
|
||||||
|
bufs.read_buffer.zeroize(); // clear for new message to read
|
||||||
|
|
||||||
Ok(Some(()))
|
Ok(Some(()))
|
||||||
})
|
})
|
||||||
@@ -130,36 +140,37 @@ pub trait MioConnectionContext {
|
|||||||
return Ok(Some(()));
|
return Ok(Some(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
self.with_buffers_stolen(|this, _read_buf, write_buf| {
|
use lpe_encoder::WriteToIoReturn as Ret;
|
||||||
use lpe_encoder::WriteToIoReturn as Ret;
|
use std::io::ErrorKind as K;
|
||||||
use std::io::ErrorKind as K;
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match write_buf
|
let conn = self.mio_connection_mut();
|
||||||
.write_to_stdio(&this.mio_connection_mut().io)
|
let bufs = conn.buffers.as_mut().unwrap();
|
||||||
.io_err_kind_hint()
|
|
||||||
{
|
|
||||||
// Done
|
|
||||||
Ok(Ret { done: true, .. }) => {
|
|
||||||
write_buf.zeroize(); // clear for new message to write
|
|
||||||
break Ok(Some(()));
|
|
||||||
},
|
|
||||||
|
|
||||||
// Would block
|
let sock = &conn.io;
|
||||||
Ok(Ret {
|
let write_buf = &mut bufs.write_buffer;
|
||||||
bytes_written: 0, ..
|
|
||||||
}) => break Ok(None),
|
|
||||||
Err((_e, K::WouldBlock)) => break Ok(None),
|
|
||||||
|
|
||||||
// Just continue
|
match write_buf.write_to_stdio(sock).io_err_kind_hint() {
|
||||||
Ok(_) => continue, /* Ret { bytes_written > 0, done = false } acc. to previous cases*/
|
// Done
|
||||||
Err((_e, K::Interrupted)) => continue,
|
Ok(Ret { done: true, .. }) => {
|
||||||
|
write_buf.zeroize(); // clear for new message to write
|
||||||
// Other errors
|
break Ok(Some(()));
|
||||||
Err((e, _ek)) => Err(e)?,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Would block
|
||||||
|
Ok(Ret {
|
||||||
|
bytes_written: 0, ..
|
||||||
|
}) => break Ok(None),
|
||||||
|
Err((_e, K::WouldBlock)) => break Ok(None),
|
||||||
|
|
||||||
|
// Just continue
|
||||||
|
Ok(_) => continue, /* Ret { bytes_written > 0, done = false } acc. to previous cases*/
|
||||||
|
Err((_e, K::Interrupted)) => continue,
|
||||||
|
|
||||||
|
// Other errors
|
||||||
|
Err((e, _ek)) => Err(e)?,
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn recv(&mut self) -> anyhow::Result<Option<()>> {
|
fn recv(&mut self) -> anyhow::Result<Option<()>> {
|
||||||
@@ -167,49 +178,68 @@ pub trait MioConnectionContext {
|
|||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.with_buffers_stolen(|this, read_buf, _write_buf| {
|
use lpe_decoder::{ReadFromIoError as E, ReadFromIoReturn as Ret};
|
||||||
use lpe_decoder::{ReadFromIoError as E, ReadFromIoReturn as Ret};
|
use std::io::ErrorKind as K;
|
||||||
use std::io::ErrorKind as K;
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match read_buf
|
let conn = self.mio_connection_mut();
|
||||||
.read_from_stdio(&this.mio_connection_mut().io)
|
let bufs = conn.buffers.as_mut().unwrap();
|
||||||
.try_io_err_kind_hint()
|
|
||||||
{
|
|
||||||
// We actually received a proper message
|
|
||||||
// (Impl below match to appease borrow checker)
|
|
||||||
Ok(Ret {
|
|
||||||
message: Some(_msg),
|
|
||||||
..
|
|
||||||
}) => break Ok(Some(())),
|
|
||||||
|
|
||||||
// Message does not fit in buffer
|
let read_buf = &mut bufs.read_buffer;
|
||||||
Err((e @ E::MessageTooLargeError { .. }, _)) => {
|
let read_fd_buf = &mut bufs.read_fd_buffer;
|
||||||
log::warn!("Received message on API that was too big to fit in our buffers; \
|
|
||||||
|
let sock = &conn.io;
|
||||||
|
let fd_passing_sock = ReadWithFileDescriptors::<MAX_REQUEST_FDS, UnixStream, _, _>::new(
|
||||||
|
sock,
|
||||||
|
read_fd_buf,
|
||||||
|
);
|
||||||
|
|
||||||
|
match read_buf
|
||||||
|
.read_from_stdio(fd_passing_sock)
|
||||||
|
.try_io_err_kind_hint()
|
||||||
|
{
|
||||||
|
// We actually received a proper message
|
||||||
|
// (Impl below match to appease borrow checker)
|
||||||
|
Ok(Ret {
|
||||||
|
message: Some(_msg),
|
||||||
|
..
|
||||||
|
}) => break Ok(Some(())),
|
||||||
|
|
||||||
|
// Message does not fit in buffer
|
||||||
|
Err((e @ E::MessageTooLargeError { .. }, _)) => {
|
||||||
|
log::warn!("Received message on API that was too big to fit in our buffers; \
|
||||||
looks like the client is broken. Stopping to process messages of the client.\n\
|
looks like the client is broken. Stopping to process messages of the client.\n\
|
||||||
Error: {e:?}");
|
Error: {e:?}");
|
||||||
// TODO: We should properly close down the socket in this case, but to do that,
|
// TODO: We should properly close down the socket in this case, but to do that,
|
||||||
// we need to have the facilities in the Rosenpass IO handling system to close
|
// we need to have the facilities in the Rosenpass IO handling system to close
|
||||||
// open connections.
|
// open connections.
|
||||||
// Just leaving the API connections dangling for now.
|
// Just leaving the API connections dangling for now.
|
||||||
// This should be fixed for non-experimental use of the API.
|
// This should be fixed for non-experimental use of the API.
|
||||||
this.mio_connection_mut().invalid_read = true;
|
conn.invalid_read = true;
|
||||||
break Ok(None);
|
break Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Would block
|
// Would block
|
||||||
Ok(Ret { bytes_read: 0, .. }) => break Ok(None),
|
Ok(Ret { bytes_read: 0, .. }) => break Ok(None),
|
||||||
Err((_, Some(K::WouldBlock))) => break Ok(None),
|
Err((_, Some(K::WouldBlock))) => break Ok(None),
|
||||||
|
|
||||||
// Just keep going
|
// Just keep going
|
||||||
Ok(Ret { bytes_read: _, .. }) => continue,
|
Ok(Ret { bytes_read: _, .. }) => continue,
|
||||||
Err((_, Some(K::Interrupted))) => continue,
|
Err((_, Some(K::Interrupted))) => continue,
|
||||||
|
|
||||||
// Other IO Error (just pass on to the caller)
|
// Other IO Error (just pass on to the caller)
|
||||||
Err((E::IoError(e), _)) => Err(e)?,
|
Err((E::IoError(e), _)) => {
|
||||||
};
|
log::warn!(
|
||||||
}
|
"IO error while trying to read message from API socket. \
|
||||||
})
|
The connection is broken. Stopping to process messages of the client.\n\
|
||||||
|
Error: {e:?}"
|
||||||
|
);
|
||||||
|
// TODO: Same as above
|
||||||
|
conn.invalid_read = true;
|
||||||
|
break Err(e.into());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,32 +254,23 @@ trait MioConnectionContextPrivate: MioConnectionContext {
|
|||||||
let _ = opt.insert(buffers);
|
let _ = opt.insert(buffers);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn with_buffers_stolen<R, F: FnOnce(&mut Self, &mut ReadBuffer, &mut WriteBuffer) -> R>(
|
fn with_buffers_stolen<R, F: FnOnce(&mut Self, &mut MioConnectionBuffers) -> R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
f: F,
|
f: F,
|
||||||
) -> R {
|
) -> R {
|
||||||
let mut bufs = self.steal_buffers();
|
let mut bufs = self.steal_buffers();
|
||||||
let res = f(self, &mut bufs.read_buffer, &mut bufs.write_buffer);
|
let res = f(self, &mut bufs);
|
||||||
self.return_buffers(bufs);
|
self.return_buffers(bufs);
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
|
|
||||||
fn both_buffers_mut(&mut self) -> (&mut ReadBuffer, &mut WriteBuffer) {
|
|
||||||
let bufs = self.mio_connection_mut().buffers.as_mut().unwrap();
|
|
||||||
let MioConnectionBuffers {
|
|
||||||
ref mut read_buffer,
|
|
||||||
ref mut write_buffer,
|
|
||||||
} = bufs;
|
|
||||||
(read_buffer, write_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn read_buf_mut(&mut self) -> &mut ReadBuffer {
|
|
||||||
self.both_buffers_mut().0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn write_buf_mut(&mut self) -> &mut WriteBuffer {
|
fn write_buf_mut(&mut self) -> &mut WriteBuffer {
|
||||||
self.both_buffers_mut().1
|
self.mio_connection_mut()
|
||||||
|
.buffers
|
||||||
|
.as_mut()
|
||||||
|
.unwrap()
|
||||||
|
.write_buffer
|
||||||
|
.borrow_mut()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,3 +22,8 @@ zerocopy = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
mio = { workspace = true }
|
mio = { workspace = true }
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
|
uds = { workspace = true, optional = true, features = ["mio_1xx"] }
|
||||||
|
|
||||||
|
|
||||||
|
[features]
|
||||||
|
experiment_file_descriptor_passing = ["uds"]
|
||||||
|
|||||||
50
util/src/controlflow.rs
Normal file
50
util/src/controlflow.rs
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
#[macro_export]
|
||||||
|
macro_rules! repeat {
|
||||||
|
($times:expr, $body:expr) => {
|
||||||
|
for _ in 0..($times) {
|
||||||
|
$body
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! return_if {
|
||||||
|
($cond:expr) => {
|
||||||
|
if $cond {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
($cond:expr, $val:expr) => {
|
||||||
|
if $cond {
|
||||||
|
return $val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! break_if {
|
||||||
|
($cond:expr) => {
|
||||||
|
if $cond {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
($cond:expr, $val:expr) => {
|
||||||
|
if $cond {
|
||||||
|
break $val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! continue_if {
|
||||||
|
($cond:expr) => {
|
||||||
|
if $cond {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
($cond:expr, $val:expr) => {
|
||||||
|
if $cond {
|
||||||
|
break $val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -25,6 +25,18 @@ pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
|||||||
Ok(new)
|
Ok(new)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prepare a file descriptor for use in Rust code.
|
||||||
|
///
|
||||||
|
/// Checks if the file descriptor is valid.
|
||||||
|
///
|
||||||
|
/// Unlike [claim_fd], this will reuse the same file descriptor identifier instead of masking it.
|
||||||
|
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||||
|
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||||
|
let tmp = clone_fd_cloexec(&new)?;
|
||||||
|
clone_fd_to_cloexec(tmp, &mut new)?;
|
||||||
|
Ok(new)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
|
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
|
||||||
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
|
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
|
||||||
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
|
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
|
||||||
@@ -58,6 +70,47 @@ pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
|
|||||||
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
|
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert low level errors into std::io::Error
|
||||||
|
pub trait IntoStdioErr {
|
||||||
|
type Target;
|
||||||
|
fn into_stdio_err(self) -> Self::Target;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoStdioErr for rustix::io::Errno {
|
||||||
|
type Target = std::io::Error;
|
||||||
|
|
||||||
|
fn into_stdio_err(self) -> Self::Target {
|
||||||
|
std::io::Error::from_raw_os_error(self.raw_os_error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> IntoStdioErr for rustix::io::Result<T> {
|
||||||
|
type Target = std::io::Result<T>;
|
||||||
|
|
||||||
|
fn into_stdio_err(self) -> Self::Target {
|
||||||
|
self.map_err(IntoStdioErr::into_stdio_err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read and write directly from a file descriptor
|
||||||
|
pub struct FdIo<Fd: AsFd>(pub Fd);
|
||||||
|
|
||||||
|
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
|
||||||
|
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||||
|
rustix::io::read(&self.0, buf).into_stdio_err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||||
|
rustix::io::write(&self.0, buf).into_stdio_err()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> std::io::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
12
util/src/mio/mod.rs
Normal file
12
util/src/mio/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#[allow(clippy::module_inception)]
|
||||||
|
mod mio;
|
||||||
|
pub use mio::*;
|
||||||
|
|
||||||
|
#[cfg(feature = "experiment_file_descriptor_passing")]
|
||||||
|
mod uds_recv_fd;
|
||||||
|
#[cfg(feature = "experiment_file_descriptor_passing")]
|
||||||
|
mod uds_send_fd;
|
||||||
|
#[cfg(feature = "experiment_file_descriptor_passing")]
|
||||||
|
pub use uds_recv_fd::*;
|
||||||
|
#[cfg(feature = "experiment_file_descriptor_passing")]
|
||||||
|
pub use uds_send_fd::*;
|
||||||
123
util/src/mio/uds_recv_fd.rs
Normal file
123
util/src/mio/uds_recv_fd.rs
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
use std::{
|
||||||
|
borrow::{Borrow, BorrowMut},
|
||||||
|
collections::VecDeque,
|
||||||
|
io::Read,
|
||||||
|
marker::PhantomData,
|
||||||
|
os::fd::OwnedFd,
|
||||||
|
};
|
||||||
|
use uds::UnixStreamExt as FdPassingExt;
|
||||||
|
|
||||||
|
use crate::fd::{claim_fd_inplace, IntoStdioErr};
|
||||||
|
|
||||||
|
pub struct ReadWithFileDescriptors<const MAX_FDS: usize, Sock, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
BorrowSock: Borrow<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<OwnedFd>>,
|
||||||
|
{
|
||||||
|
socket: BorrowSock,
|
||||||
|
fds: BorrowFds,
|
||||||
|
_sock_dummy: PhantomData<Sock>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const MAX_FDS: usize, Sock, BorrowSock, BorrowFds>
|
||||||
|
ReadWithFileDescriptors<MAX_FDS, Sock, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
BorrowSock: Borrow<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<OwnedFd>>,
|
||||||
|
{
|
||||||
|
pub fn new(socket: BorrowSock, fds: BorrowFds) -> Self {
|
||||||
|
let _sock_dummy = PhantomData;
|
||||||
|
Self {
|
||||||
|
socket,
|
||||||
|
fds,
|
||||||
|
_sock_dummy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_parts(self) -> (BorrowSock, BorrowFds) {
|
||||||
|
let Self { socket, fds, .. } = self;
|
||||||
|
(socket, fds)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn socket(&self) -> &Sock {
|
||||||
|
self.socket.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fds(&self) -> &VecDeque<OwnedFd> {
|
||||||
|
self.fds.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fds_mut(&mut self) -> &mut VecDeque<OwnedFd> {
|
||||||
|
self.fds.borrow_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const MAX_FDS: usize, Sock, BorrowSock, BorrowFds>
|
||||||
|
ReadWithFileDescriptors<MAX_FDS, Sock, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
BorrowSock: BorrowMut<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<OwnedFd>>,
|
||||||
|
{
|
||||||
|
pub fn socket_mut(&mut self) -> &mut Sock {
|
||||||
|
self.socket.borrow_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const MAX_FDS: usize, Sock, BorrowSock, BorrowFds> Read
|
||||||
|
for ReadWithFileDescriptors<MAX_FDS, Sock, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
BorrowSock: Borrow<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<OwnedFd>>,
|
||||||
|
{
|
||||||
|
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||||
|
// Calculate space for additional file descriptors
|
||||||
|
let have_fds_before_read = self.fds().len();
|
||||||
|
let free_fd_slots = MAX_FDS.saturating_sub(have_fds_before_read);
|
||||||
|
|
||||||
|
// Allocate a buffer for file descriptors
|
||||||
|
let mut fd_buf = [0; MAX_FDS];
|
||||||
|
let fd_buf = &mut fd_buf[..free_fd_slots];
|
||||||
|
|
||||||
|
// Read from the unix socket
|
||||||
|
let (bytes_read, fds_read) = self.socket.borrow().recv_fds(buf, fd_buf)?;
|
||||||
|
let fd_buf = &fd_buf[..fds_read];
|
||||||
|
|
||||||
|
// Process the file descriptors
|
||||||
|
let mut fd_iter = fd_buf.iter();
|
||||||
|
|
||||||
|
// Try claiming all the file descriptors
|
||||||
|
let mut claim_fd_result = Ok(bytes_read);
|
||||||
|
self.fds_mut().reserve(fd_buf.len());
|
||||||
|
for fd in fd_iter.by_ref() {
|
||||||
|
match claim_fd_inplace(*fd) {
|
||||||
|
Ok(owned) => self.fds_mut().push_back(owned),
|
||||||
|
Err(e) => {
|
||||||
|
// Abort on error and pass to error handler
|
||||||
|
// Note that claim_fd_inplace is responsible for closing this particular
|
||||||
|
// file descriptor if claiming it fails
|
||||||
|
claim_fd_result = Err(e.into_stdio_err());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return if we where able to claim all file descriptors
|
||||||
|
if claim_fd_result.is_ok() {
|
||||||
|
return claim_fd_result;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An error occurred while claiming fds
|
||||||
|
self.fds_mut().truncate(have_fds_before_read); // Close fds successfully claimed
|
||||||
|
|
||||||
|
// Close the remaining fds
|
||||||
|
for fd in fd_iter {
|
||||||
|
unsafe { rustix::io::close(*fd) };
|
||||||
|
}
|
||||||
|
|
||||||
|
claim_fd_result
|
||||||
|
}
|
||||||
|
}
|
||||||
121
util/src/mio/uds_send_fd.rs
Normal file
121
util/src/mio/uds_send_fd.rs
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
use rustix::fd::{AsFd, AsRawFd};
|
||||||
|
use std::{
|
||||||
|
borrow::{Borrow, BorrowMut},
|
||||||
|
cmp::min,
|
||||||
|
collections::VecDeque,
|
||||||
|
io::Write,
|
||||||
|
marker::PhantomData,
|
||||||
|
};
|
||||||
|
use uds::UnixStreamExt as FdPassingExt;
|
||||||
|
|
||||||
|
use crate::{repeat, return_if};
|
||||||
|
|
||||||
|
pub struct WriteWithFileDescriptors<Sock, Fd, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
Fd: AsFd,
|
||||||
|
BorrowSock: Borrow<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<Fd>>,
|
||||||
|
{
|
||||||
|
socket: BorrowSock,
|
||||||
|
fds: BorrowFds,
|
||||||
|
_sock_dummy: PhantomData<Sock>,
|
||||||
|
_fd_dummy: PhantomData<Fd>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Sock, Fd, BorrowSock, BorrowFds> WriteWithFileDescriptors<Sock, Fd, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
Fd: AsFd,
|
||||||
|
BorrowSock: Borrow<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<Fd>>,
|
||||||
|
{
|
||||||
|
pub fn new(socket: BorrowSock, fds: BorrowFds) -> Self {
|
||||||
|
let _sock_dummy = PhantomData;
|
||||||
|
let _fd_dummy = PhantomData;
|
||||||
|
Self {
|
||||||
|
socket,
|
||||||
|
fds,
|
||||||
|
_sock_dummy,
|
||||||
|
_fd_dummy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_parts(self) -> (BorrowSock, BorrowFds) {
|
||||||
|
let Self { socket, fds, .. } = self;
|
||||||
|
(socket, fds)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn socket(&self) -> &Sock {
|
||||||
|
self.socket.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fds(&self) -> &VecDeque<Fd> {
|
||||||
|
self.fds.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fds_mut(&mut self) -> &mut VecDeque<Fd> {
|
||||||
|
self.fds.borrow_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Sock, Fd, BorrowSock, BorrowFds> WriteWithFileDescriptors<Sock, Fd, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
Fd: AsFd,
|
||||||
|
BorrowSock: BorrowMut<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<Fd>>,
|
||||||
|
{
|
||||||
|
pub fn socket_mut(&mut self) -> &mut Sock {
|
||||||
|
self.socket.borrow_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Sock, Fd, BorrowSock, BorrowFds> Write
|
||||||
|
for WriteWithFileDescriptors<Sock, Fd, BorrowSock, BorrowFds>
|
||||||
|
where
|
||||||
|
Sock: FdPassingExt,
|
||||||
|
Fd: AsFd,
|
||||||
|
BorrowSock: Borrow<Sock>,
|
||||||
|
BorrowFds: BorrowMut<VecDeque<Fd>>,
|
||||||
|
{
|
||||||
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||||
|
// At least one byte of real data should be sent when sending ancillary data. -- unix(7)
|
||||||
|
return_if!(buf.is_empty(), Ok(0));
|
||||||
|
|
||||||
|
// The kernel constant SCM_MAX_FD defines a limit on the number of file descriptors
|
||||||
|
// in the array. Attempting to send an array larger than this limit causes
|
||||||
|
// sendmsg(2) to fail with the error EINVAL. SCM_MAX_FD has the value 253 (or 255
|
||||||
|
// before Linux 2.6.38).
|
||||||
|
// -- unix(7)
|
||||||
|
const SCM_MAX_FD: usize = 253;
|
||||||
|
let buf = match self.fds().len() <= SCM_MAX_FD {
|
||||||
|
false => &buf[..1], // Force caller to immediately call write() again to send its data
|
||||||
|
true => buf,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Allocate the buffer for the file descriptor array
|
||||||
|
let fd_no = min(SCM_MAX_FD, self.fds().len());
|
||||||
|
let mut fd_buf = [0; SCM_MAX_FD]; // My kingdom for alloca(3)
|
||||||
|
let fd_buf = &mut fd_buf[..fd_no];
|
||||||
|
|
||||||
|
// Fill the file descriptor array
|
||||||
|
for (raw, fancy) in fd_buf.iter_mut().zip(self.fds().iter()) {
|
||||||
|
*raw = fancy.as_fd().as_raw_fd();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send data and file descriptors
|
||||||
|
let bytes_written = self.socket().send_fds(buf, fd_buf)?;
|
||||||
|
|
||||||
|
// Drop the file descriptors from the Deque
|
||||||
|
repeat!(fd_no, {
|
||||||
|
self.fds_mut().pop_front();
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(bytes_written)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> std::io::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user