feat(API): AddListenSocket endpoint

This commit is contained in:
Karolin Varner
2024-08-11 13:56:07 +02:00
parent 1d2fa7d038
commit 24eebe29a1
14 changed files with 416 additions and 25 deletions

View File

@@ -23,7 +23,7 @@ name = "api-integration-tests"
required-features = ["experiment_api", "internal_testing"]
[[test]]
name = "api-integration-tests-supply-keypair"
name = "api-integration-tests-api-setup"
required-features = ["experiment_api", "internal_testing"]
[[bench]]

View File

@@ -2,9 +2,13 @@ use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd};
use anyhow::Context;
use rosenpass_to::{ops::copy_slice, To};
use rosenpass_util::{fd::FdIo, functional::run, io::ReadExt, mem::DiscardResultExt};
use rosenpass_util::{
fd::FdIo, functional::run, io::ReadExt, mem::DiscardResultExt, result::OkExt,
};
use crate::{app_server::AppServer, protocol::BuildCryptoServer};
use crate::{
api::add_listen_socket_response_status, app_server::AppServer, protocol::BuildCryptoServer,
};
use super::{supply_keypair_response_status, Server as ApiServer};
@@ -171,4 +175,54 @@ where
Ok(())
}
fn add_listen_socket(
&mut self,
_req: &super::boilerplate::AddListenSocketRequest,
req_fds: &mut VecDeque<OwnedFd>,
res: &mut super::boilerplate::AddListenSocketResponse,
) -> anyhow::Result<()> {
// Retrieve file descriptor
let sock_res = run(|| -> anyhow::Result<mio::net::UdpSocket> {
let sock = req_fds
.pop_front()
.context("Invalid request socket missing.")?;
// TODO: We need to have this outside linux
#[cfg(target_os = "linux")]
rosenpass_util::fd::GetSocketProtocol::demand_udp_socket(&sock)?;
let sock = std::net::UdpSocket::from(sock);
sock.set_nonblocking(true)?;
mio::net::UdpSocket::from_std(sock).ok()
});
let mut sock = match sock_res {
Ok(sock) => sock,
Err(e) => {
log::debug!("Error processing AddListenSocket API request: {e:?}");
res.payload.status = add_listen_socket_response_status::INVALID_REQUEST;
return Ok(());
}
};
// Register socket
let reg_result = run(|| -> anyhow::Result<()> {
let srv = self.app_server_mut();
srv.mio_poll.registry().register(
&mut sock,
srv.mio_token_dispenser.dispense(),
mio::Interest::READABLE,
)?;
srv.sockets.push(sock);
Ok(())
});
if let Err(internal_error) = reg_result {
log::warn!("Internal error processing AddListenSocket API request: {internal_error:?}");
res.payload.status = add_listen_socket_response_status::INTERNAL_ERROR;
return Ok(());
};
res.payload.status = add_listen_socket_response_status::OK;
Ok(())
}
}

View File

@@ -143,6 +143,44 @@ pub trait ByteSliceRefExt: ByteSlice {
) -> anyhow::Result<Ref<Self, SupplyKeypairResponse>> {
self.zk_parse_suffix()
}
fn add_listen_socket_request(self) -> anyhow::Result<Ref<Self, super::AddListenSocketRequest>> {
self.zk_parse()
}
fn add_listen_socket_request_from_prefix(
self,
) -> anyhow::Result<Ref<Self, super::AddListenSocketRequest>> {
self.zk_parse_prefix()
}
fn add_listen_socket_request_from_suffix(
self,
) -> anyhow::Result<Ref<Self, super::AddListenSocketRequest>> {
self.zk_parse_suffix()
}
fn add_listen_socket_response_maker(self) -> RefMaker<Self, super::AddListenSocketResponse> {
self.zk_ref_maker()
}
fn add_listen_socket_response(
self,
) -> anyhow::Result<Ref<Self, super::AddListenSocketResponse>> {
self.zk_parse()
}
fn add_listen_socket_response_from_prefix(
self,
) -> anyhow::Result<Ref<Self, super::AddListenSocketResponse>> {
self.zk_parse_prefix()
}
fn add_listen_socket_response_from_suffix(
self,
) -> anyhow::Result<Ref<Self, super::AddListenSocketResponse>> {
self.zk_parse_suffix()
}
}
impl<B: ByteSlice> ByteSliceRefExt for B {}

View File

@@ -21,6 +21,13 @@ const SUPPLY_KEYPAIR_REQUEST: RawMsgType =
const SUPPLY_KEYPAIR_RESPONSE: RawMsgType =
RawMsgType::from_le_bytes(hex!("f2dc 49bd e261 5f10 40b7 3c16 ec61 edb9"));
// hash domain hash of: Rosenpass IPC API -> Rosenpass Protocol Server -> Add Listen Socket Request
const ADD_LISTEN_SOCKET_REQUEST: RawMsgType =
RawMsgType::from_le_bytes(hex!("3f21 434f 87cc a08c 02c4 61e4 0816 c7da"));
// hash domain hash of: Rosenpass IPC API -> Rosenpass Protocol Server -> Add Listen Socket Response
const ADD_LISTEN_SOCKET_RESPONSE: RawMsgType =
RawMsgType::from_le_bytes(hex!("45d5 0f0d 93f0 6105 98f2 9469 5dfd 5f36"));
pub trait MessageAttributes {
fn message_size(&self) -> usize;
}
@@ -29,12 +36,14 @@ pub trait MessageAttributes {
pub enum RequestMsgType {
Ping,
SupplyKeypair,
AddListenSocket,
}
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
pub enum ResponseMsgType {
Ping,
SupplyKeypair,
AddListenSocket,
}
impl MessageAttributes for RequestMsgType {
@@ -42,6 +51,7 @@ impl MessageAttributes for RequestMsgType {
match self {
Self::Ping => std::mem::size_of::<super::PingRequest>(),
Self::SupplyKeypair => std::mem::size_of::<super::SupplyKeypairRequest>(),
Self::AddListenSocket => std::mem::size_of::<super::AddListenSocketRequest>(),
}
}
}
@@ -51,6 +61,7 @@ impl MessageAttributes for ResponseMsgType {
match self {
Self::Ping => std::mem::size_of::<super::PingResponse>(),
Self::SupplyKeypair => std::mem::size_of::<super::SupplyKeypairResponse>(),
Self::AddListenSocket => std::mem::size_of::<super::AddListenSocketResponse>(),
}
}
}
@@ -63,6 +74,7 @@ impl TryFrom<RawMsgType> for RequestMsgType {
Ok(match value {
self::PING_REQUEST => E::Ping,
self::SUPPLY_KEYPAIR_REQUEST => E::SupplyKeypair,
self::ADD_LISTEN_SOCKET_REQUEST => E::AddListenSocket,
_ => return Err(InvalidApiMessageType(value)),
})
}
@@ -74,6 +86,7 @@ impl From<RequestMsgType> for RawMsgType {
match val {
E::Ping => self::PING_REQUEST,
E::SupplyKeypair => self::SUPPLY_KEYPAIR_REQUEST,
E::AddListenSocket => self::ADD_LISTEN_SOCKET_REQUEST,
}
}
}
@@ -86,6 +99,7 @@ impl TryFrom<RawMsgType> for ResponseMsgType {
Ok(match value {
self::PING_RESPONSE => E::Ping,
self::SUPPLY_KEYPAIR_RESPONSE => E::SupplyKeypair,
self::ADD_LISTEN_SOCKET_RESPONSE => E::AddListenSocket,
_ => return Err(InvalidApiMessageType(value)),
})
}
@@ -97,6 +111,7 @@ impl From<ResponseMsgType> for RawMsgType {
match val {
E::Ping => self::PING_RESPONSE,
E::SupplyKeypair => self::SUPPLY_KEYPAIR_RESPONSE,
E::AddListenSocket => self::ADD_LISTEN_SOCKET_RESPONSE,
}
}
}

View File

@@ -181,3 +181,87 @@ impl Message for SupplyKeypairResponse {
self.msg_type = Self::MESSAGE_TYPE.into();
}
}
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
pub struct AddListenSocketRequestPayload {}
pub type AddListenSocketRequest = RequestEnvelope<AddListenSocketRequestPayload>;
impl Default for AddListenSocketRequest {
fn default() -> Self {
Self::new()
}
}
impl AddListenSocketRequest {
pub fn new() -> Self {
Self::from_payload(AddListenSocketRequestPayload {})
}
}
impl Message for AddListenSocketRequest {
type Payload = AddListenSocketRequestPayload;
type MessageClass = RequestMsgType;
const MESSAGE_TYPE: Self::MessageClass = RequestMsgType::AddListenSocket;
fn from_payload(payload: Self::Payload) -> Self {
Self {
msg_type: Self::MESSAGE_TYPE.into(),
payload,
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
}
fn init(&mut self) {
self.msg_type = Self::MESSAGE_TYPE.into();
}
}
pub mod add_listen_socket_response_status {
pub const OK: u128 = 0;
pub const INVALID_REQUEST: u128 = 1;
pub const INTERNAL_ERROR: u128 = 2;
}
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
pub struct AddListenSocketResponsePayload {
pub status: u128,
}
pub type AddListenSocketResponse = ResponseEnvelope<AddListenSocketResponsePayload>;
impl AddListenSocketResponse {
pub fn new(status: u128) -> Self {
Self::from_payload(AddListenSocketResponsePayload { status })
}
}
impl Message for AddListenSocketResponse {
type Payload = AddListenSocketResponsePayload;
type MessageClass = ResponseMsgType;
const MESSAGE_TYPE: Self::MessageClass = ResponseMsgType::AddListenSocket;
fn from_payload(payload: Self::Payload) -> Self {
Self {
msg_type: Self::MESSAGE_TYPE.into(),
payload,
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
}
fn init(&mut self) {
self.msg_type = Self::MESSAGE_TYPE.into();
}
}

View File

@@ -26,6 +26,7 @@ impl<B: ByteSlice> RequestRef<B> {
match self {
Self::Ping(_) => RequestMsgType::Ping,
Self::SupplyKeypair(_) => RequestMsgType::SupplyKeypair,
Self::AddListenSocket(_) => RequestMsgType::AddListenSocket,
}
}
}
@@ -36,6 +37,18 @@ impl<B> From<Ref<B, PingRequest>> for RequestRef<B> {
}
}
impl<B> From<Ref<B, super::SupplyKeypairRequest>> for RequestRef<B> {
fn from(v: Ref<B, super::SupplyKeypairRequest>) -> Self {
Self::SupplyKeypair(v)
}
}
impl<B> From<Ref<B, super::AddListenSocketRequest>> for RequestRef<B> {
fn from(v: Ref<B, super::AddListenSocketRequest>) -> Self {
Self::AddListenSocket(v)
}
}
impl<B: ByteSlice> RequestRefMaker<B> {
fn new(buf: B) -> anyhow::Result<Self> {
let msg_type = buf.deref().request_msg_type_from_prefix()?;
@@ -52,6 +65,9 @@ impl<B: ByteSlice> RequestRefMaker<B> {
RequestMsgType::SupplyKeypair => {
RequestRef::SupplyKeypair(self.buf.supply_keypair_request()?)
}
RequestMsgType::AddListenSocket => {
RequestRef::AddListenSocket(self.buf.add_listen_socket_request()?)
}
})
}
@@ -87,6 +103,7 @@ impl<B: ByteSlice> RequestRefMaker<B> {
pub enum RequestRef<B> {
Ping(Ref<B, PingRequest>),
SupplyKeypair(Ref<B, super::SupplyKeypairRequest>),
AddListenSocket(Ref<B, super::AddListenSocketRequest>),
}
impl<B> RequestRef<B>
@@ -97,6 +114,7 @@ where
match self {
Self::Ping(r) => r.bytes(),
Self::SupplyKeypair(r) => r.bytes(),
Self::AddListenSocket(r) => r.bytes(),
}
}
}
@@ -109,6 +127,7 @@ where
match self {
Self::Ping(r) => r.bytes_mut(),
Self::SupplyKeypair(r) => r.bytes_mut(),
Self::AddListenSocket(r) => r.bytes_mut(),
}
}
}

View File

@@ -50,15 +50,28 @@ impl ResponseMsg for super::SupplyKeypairResponse {
type RequestMsg = super::SupplyKeypairRequest;
}
impl RequestMsg for super::AddListenSocketRequest {
type ResponseMsg = super::AddListenSocketResponse;
}
impl ResponseMsg for super::AddListenSocketResponse {
type RequestMsg = super::AddListenSocketRequest;
}
pub type PingPair<B1, B2> = (Ref<B1, PingRequest>, Ref<B2, PingResponse>);
pub type SupplyKeypairPair<B1, B2> = (
Ref<B1, super::SupplyKeypairRequest>,
Ref<B2, super::SupplyKeypairResponse>,
);
pub type AddListenSocketPair<B1, B2> = (
Ref<B1, super::AddListenSocketRequest>,
Ref<B2, super::AddListenSocketResponse>,
);
pub enum RequestResponsePair<B1, B2> {
Ping(PingPair<B1, B2>),
SupplyKeypair(SupplyKeypairPair<B1, B2>),
AddListenSocket(AddListenSocketPair<B1, B2>),
}
impl<B1, B2> From<PingPair<B1, B2>> for RequestResponsePair<B1, B2> {
@@ -73,6 +86,12 @@ impl<B1, B2> From<SupplyKeypairPair<B1, B2>> for RequestResponsePair<B1, B2> {
}
}
impl<B1, B2> From<AddListenSocketPair<B1, B2>> for RequestResponsePair<B1, B2> {
fn from(v: AddListenSocketPair<B1, B2>) -> Self {
RequestResponsePair::AddListenSocket(v)
}
}
impl<B1, B2> RequestResponsePair<B1, B2>
where
B1: ByteSlice,
@@ -90,6 +109,11 @@ where
let res = ResponseRef::SupplyKeypair(res.emancipate());
(req, res)
}
Self::AddListenSocket((req, res)) => {
let req = RequestRef::AddListenSocket(req.emancipate());
let res = ResponseRef::AddListenSocket(res.emancipate());
(req, res)
}
}
}
@@ -119,6 +143,11 @@ where
let res = ResponseRef::SupplyKeypair(res.emancipate_mut());
(req, res)
}
Self::AddListenSocket((req, res)) => {
let req = RequestRef::AddListenSocket(req.emancipate_mut());
let res = ResponseRef::AddListenSocket(res.emancipate_mut());
(req, res)
}
}
}

View File

@@ -27,6 +27,7 @@ impl<B: ByteSlice> ResponseRef<B> {
match self {
Self::Ping(_) => ResponseMsgType::Ping,
Self::SupplyKeypair(_) => ResponseMsgType::SupplyKeypair,
Self::AddListenSocket(_) => ResponseMsgType::AddListenSocket,
}
}
}
@@ -43,6 +44,12 @@ impl<B> From<Ref<B, super::SupplyKeypairResponse>> for ResponseRef<B> {
}
}
impl<B> From<Ref<B, super::AddListenSocketResponse>> for ResponseRef<B> {
fn from(v: Ref<B, super::AddListenSocketResponse>) -> Self {
Self::AddListenSocket(v)
}
}
impl<B: ByteSlice> ResponseRefMaker<B> {
fn new(buf: B) -> anyhow::Result<Self> {
let msg_type = buf.deref().response_msg_type_from_prefix()?;
@@ -59,6 +66,9 @@ impl<B: ByteSlice> ResponseRefMaker<B> {
ResponseMsgType::SupplyKeypair => {
ResponseRef::SupplyKeypair(self.buf.supply_keypair_response()?)
}
ResponseMsgType::AddListenSocket => {
ResponseRef::AddListenSocket(self.buf.add_listen_socket_response()?)
}
})
}
@@ -94,6 +104,7 @@ impl<B: ByteSlice> ResponseRefMaker<B> {
pub enum ResponseRef<B> {
Ping(Ref<B, PingResponse>),
SupplyKeypair(Ref<B, super::SupplyKeypairResponse>),
AddListenSocket(Ref<B, super::AddListenSocketResponse>),
}
impl<B> ResponseRef<B>
@@ -104,6 +115,7 @@ where
match self {
Self::Ping(r) => r.bytes(),
Self::SupplyKeypair(r) => r.bytes(),
Self::AddListenSocket(r) => r.bytes(),
}
}
}
@@ -116,6 +128,7 @@ where
match self {
Self::Ping(r) => r.bytes_mut(),
Self::SupplyKeypair(r) => r.bytes_mut(),
Self::AddListenSocket(r) => r.bytes_mut(),
}
}
}

View File

@@ -17,6 +17,13 @@ pub trait Server {
res: &mut super::SupplyKeypairResponse,
) -> anyhow::Result<()>;
fn add_listen_socket(
&mut self,
req: &super::AddListenSocketRequest,
req_fds: &mut VecDeque<OwnedFd>,
res: &mut super::AddListenSocketResponse,
) -> anyhow::Result<()>;
fn dispatch<ReqBuf, ResBuf>(
&mut self,
p: &mut RequestResponsePair<ReqBuf, ResBuf>,
@@ -31,6 +38,9 @@ pub trait Server {
RequestResponsePair::SupplyKeypair((req, res)) => {
self.supply_keypair(req, req_fds, res)
}
RequestResponsePair::AddListenSocket((req, res)) => {
self.add_listen_socket(req, req_fds, res)
}
}
}
@@ -57,6 +67,11 @@ pub trait Server {
res.init();
RequestResponsePair::SupplyKeypair((req, res))
}
RequestRef::AddListenSocket(req) => {
let mut res = res.add_listen_socket_response_from_prefix()?;
res.init();
RequestResponsePair::AddListenSocket((req, res))
}
};
self.dispatch(&mut pair, req_fds)?;

View File

@@ -78,6 +78,8 @@ fn main() -> Result<()> {
Tree::Leaf("Ping Response".to_owned()),
Tree::Leaf("Supply Keypair Request".to_owned()),
Tree::Leaf("Supply Keypair Response".to_owned()),
Tree::Leaf("Add Listen Socket Request".to_owned()),
Tree::Leaf("Add Listen Socket Response".to_owned()),
],
)],
);

View File

@@ -1,6 +1,5 @@
use std::{
io::{BufRead, BufReader},
net::ToSocketAddrs,
os::unix::net::UnixStream,
process::Stdio,
thread::sleep,
@@ -8,19 +7,21 @@ use std::{
};
use anyhow::{bail, Context};
use rosenpass::api::{self, supply_keypair_response_status};
use rosenpass::api::{self, add_listen_socket_response_status, supply_keypair_response_status};
use rosenpass_util::{
file::LoadValueB64,
length_prefix_encoding::{decoder::LengthPrefixDecoder, encoder::LengthPrefixEncoder},
mio::WriteWithFileDescriptors,
zerocopy::ZerocopySliceExt,
};
use rosenpass_util::{mio::WriteWithFileDescriptors, zerocopy::ZerocopySliceExt};
use rustix::fd::AsFd;
use tempfile::TempDir;
use zerocopy::AsBytes;
use rosenpass::protocol::SymKey;
#[test]
fn api_integration_test() -> anyhow::Result<()> {
fn api_integration_api_setup() -> anyhow::Result<()> {
rosenpass_secret_memory::policy::secret_policy_use_only_malloc_secrets();
let dir = TempDir::with_prefix("rosenpass-api-integration-test")?;
@@ -33,17 +34,20 @@ fn api_integration_test() -> anyhow::Result<()> {
}}
}
let peer_a_endpoint = "[::1]:61424";
let peer_a_endpoint = "[::1]:0";
let peer_a_osk = tempfile!("a.osk");
let peer_b_osk = tempfile!("b.osk");
let peer_a_listen = std::net::UdpSocket::bind(peer_a_endpoint)?;
let peer_a_endpoint = format!("{}", peer_a_listen.local_addr()?);
use rosenpass::config;
let peer_a_keypair = config::Keypair::new(tempfile!("a.pk"), tempfile!("a.sk"));
let peer_a = config::Rosenpass {
config_file_path: tempfile!("a.config"),
keypair: Some(peer_a_keypair.clone()),
listen: peer_a_endpoint.to_socket_addrs()?.collect(), // TODO: This could collide by accident
keypair: None,
listen: vec![], // TODO: This could collide by accident
verbosity: config::Verbosity::Verbose,
api: api::config::ApiConfig {
listen_path: vec![tempfile!("a.sock")],
@@ -62,7 +66,7 @@ fn api_integration_test() -> anyhow::Result<()> {
let peer_b_keypair = config::Keypair::new(tempfile!("b.pk"), tempfile!("b.sk"));
let peer_b = config::Rosenpass {
config_file_path: tempfile!("b.config"),
keypair: None,
keypair: Some(peer_b_keypair.clone()),
listen: vec![],
verbosity: config::Verbosity::Verbose,
api: api::config::ApiConfig {
@@ -118,7 +122,7 @@ fn api_integration_test() -> anyhow::Result<()> {
let mut out_b = BufReader::new(proc_b.stdout.context("")?).lines();
// Now connect to the peers
let api_path = peer_b.api.listen_path[0].as_path();
let api_path = peer_a.api.listen_path[0].as_path();
// Wait for the socket to be created
let attempt = 0;
@@ -132,11 +136,34 @@ fn api_integration_test() -> anyhow::Result<()> {
let api = UnixStream::connect(api_path)?;
// Send AddListenSocket request
{
let fd = peer_a_listen.as_fd();
let mut fds = vec![&fd].into();
let mut api = WriteWithFileDescriptors::<UnixStream, _, _, _>::new(&api, &mut fds);
LengthPrefixEncoder::from_message(api::AddListenSocketRequest::new().as_bytes())
.write_all_to_stdio(&mut api)?;
assert!(fds.is_empty(), "Failed to write all file descriptors");
std::mem::forget(peer_a_listen);
}
// Read response
{
let mut decoder = LengthPrefixDecoder::new([0u8; api::MAX_RESPONSE_LEN]);
let res = decoder.read_all_from_stdio(&api)?;
let res = res.zk_parse::<api::AddListenSocketResponse>()?;
assert_eq!(
*res,
api::AddListenSocketResponse::new(add_listen_socket_response_status::OK)
);
}
// Send SupplyKeypairRequest
{
use rustix::fs::{open, Mode, OFlags};
let sk = open(peer_b_keypair.secret_key, OFlags::RDONLY, Mode::empty())?;
let pk = open(peer_b_keypair.public_key, OFlags::RDONLY, Mode::empty())?;
let sk = open(peer_a_keypair.secret_key, OFlags::RDONLY, Mode::empty())?;
let pk = open(peer_a_keypair.public_key, OFlags::RDONLY, Mode::empty())?;
let mut fds = vec![&sk, &pk].into();
let mut api = WriteWithFileDescriptors::<UnixStream, _, _, _>::new(&api, &mut fds);
@@ -147,7 +174,6 @@ fn api_integration_test() -> anyhow::Result<()> {
// Read response
{
//sleep(Duration::from_secs(10));
let mut decoder = LengthPrefixDecoder::new([0u8; api::MAX_RESPONSE_LEN]);
let res = decoder.read_all_from_stdio(api)?;
let res = res.zk_parse::<api::SupplyKeypairResponse>()?;

View File

@@ -7,6 +7,20 @@ macro_rules! repeat {
};
}
#[macro_export]
macro_rules! return_unless {
($cond:expr) => {
if !($cond) {
return;
}
};
($cond:expr, $val:expr) => {
if !($cond) {
return $val;
}
};
}
#[macro_export]
macro_rules! return_if {
($cond:expr) => {

View File

@@ -1,12 +1,10 @@
use anyhow::bail;
use rustix::{
fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd},
io::fcntl_dupfd_cloexec,
};
#[cfg(target_os = "linux")]
use rustix::io::DupFlags;
use crate::mem::Forgetting;
use crate::{mem::Forgetting, result::OkExt};
/// Prepare a file descriptor for use in Rust code.
///
@@ -51,7 +49,7 @@ pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
#[cfg(target_os = "linux")]
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::dup3;
use rustix::io::{dup3, DupFlags};
dup3(fd, new, DupFlags::CLOEXEC)
}
@@ -111,6 +109,85 @@ impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
}
}
pub trait StatExt {
fn is_socket(&self) -> bool;
}
impl StatExt for rustix::fs::Stat {
fn is_socket(&self) -> bool {
use rustix::fs::FileType;
let ft = FileType::from_raw_mode(self.st_mode);
matches!(ft, FileType::Socket)
}
}
pub trait TryStatExt {
type Error;
fn is_socket(&self) -> Result<bool, Self::Error>;
}
impl<T> TryStatExt for T
where
T: AsFd,
{
type Error = rustix::io::Errno;
fn is_socket(&self) -> Result<bool, Self::Error> {
rustix::fs::fstat(self)?.is_socket().ok()
}
}
pub trait GetSocketType {
type Error;
fn socket_type(&self) -> Result<rustix::net::SocketType, Self::Error>;
fn is_datagram_socket(&self) -> Result<bool, Self::Error> {
use rustix::net::SocketType;
matches!(self.socket_type()?, SocketType::DGRAM).ok()
}
}
impl<T> GetSocketType for T
where
T: AsFd,
{
type Error = rustix::io::Errno;
fn socket_type(&self) -> Result<rustix::net::SocketType, Self::Error> {
rustix::net::sockopt::get_socket_type(self)
}
}
#[cfg(target_os = "linux")]
pub trait GetSocketProtocol {
fn socket_protocol(&self) -> Result<Option<rustix::net::Protocol>, rustix::io::Errno>;
fn is_udp_socket(&self) -> Result<bool, rustix::io::Errno> {
self.socket_protocol()?
.map(|p| p == rustix::net::ipproto::UDP)
.unwrap_or(false)
.ok()
}
fn demand_udp_socket(&self) -> anyhow::Result<()> {
match self.socket_protocol() {
Ok(Some(rustix::net::ipproto::UDP)) => Ok(()),
Ok(Some(other_proto)) => {
bail!("Not a udp socket, instead socket protocol is: {other_proto:?}")
}
Ok(None) => bail!("getsockopt() returned empty value"),
Err(errno) => Err(errno.into_stdio_err())?,
}
}
}
#[cfg(target_os = "linux")]
impl<T> GetSocketProtocol for T
where
T: AsFd,
{
fn socket_protocol(&self) -> Result<Option<rustix::net::Protocol>, rustix::io::Errno> {
rustix::net::sockopt::get_socket_protocol(self)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,7 +1,7 @@
use mio::net::{UnixListener, UnixStream};
use rustix::fd::RawFd;
use rustix::fd::{OwnedFd, RawFd};
use crate::fd::claim_fd;
use crate::{fd::claim_fd, result::OkExt};
pub mod interest {
use mio::Interest;
@@ -25,15 +25,20 @@ impl UnixListenerExt for UnixListener {
}
pub trait UnixStreamExt: Sized {
fn from_fd(fd: OwnedFd) -> anyhow::Result<Self>;
fn claim_fd(fd: RawFd) -> anyhow::Result<Self>;
}
impl UnixStreamExt for UnixStream {
fn claim_fd(fd: RawFd) -> anyhow::Result<Self> {
fn from_fd(fd: OwnedFd) -> anyhow::Result<Self> {
use std::os::unix::net::UnixStream as StdUnixStream;
let sock = StdUnixStream::from(claim_fd(fd)?);
let sock = StdUnixStream::from(fd);
sock.set_nonblocking(true)?;
Ok(UnixStream::from_std(sock))
UnixStream::from_std(sock).ok()
}
fn claim_fd(fd: RawFd) -> anyhow::Result<Self> {
Self::from_fd(claim_fd(fd)?)
}
}