From 77760d71df101ed71066f74626cdee506bcdd8c6 Mon Sep 17 00:00:00 2001 From: Karolin Varner Date: Sun, 18 Aug 2024 21:19:44 +0200 Subject: [PATCH] feat(API): Use mio::Token based polling Avoid polling every single IO source to collect events, poll those specific IO sources mio tells us about. --- rosenpass/Cargo.toml | 1 + rosenpass/src/api/api_handler.rs | 13 +- rosenpass/src/api/mio/connection.rs | 24 +- rosenpass/src/api/mio/manager.rs | 105 +++++-- rosenpass/src/app_server.rs | 262 ++++++++++++++---- .../tests/api-integration-tests-api-setup.rs | 1 + rosenpass/tests/integration_test.rs | 9 +- util/src/io.rs | 50 ++++ util/src/lib.rs | 1 + util/src/option.rs | 7 + wireguard-broker/src/brokers/mio_client.rs | 12 +- wireguard-broker/src/brokers/native_unix.rs | 14 +- wireguard-broker/src/lib.rs | 2 + 13 files changed, 403 insertions(+), 98 deletions(-) create mode 100644 util/src/option.rs diff --git a/rosenpass/Cargo.toml b/rosenpass/Cargo.toml index 0ab070b..70c9577 100644 --- a/rosenpass/Cargo.toml +++ b/rosenpass/Cargo.toml @@ -74,6 +74,7 @@ tempfile = { workspace = true } rustix = {workspace = true} [features] +default = ["experiment_api"] experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"] experiment_libcrux = ["rosenpass-ciphers/experiment_libcrux"] experiment_api = ["hex-literal", "uds", "command-fds", "rosenpass-util/experiment_file_descriptor_passing", "rosenpass-wireguard-broker/experiment_api"] diff --git a/rosenpass/src/api/api_handler.rs b/rosenpass/src/api/api_handler.rs index 24f036b..b7e041c 100644 --- a/rosenpass/src/api/api_handler.rs +++ b/rosenpass/src/api/api_handler.rs @@ -203,7 +203,7 @@ where mio::net::UdpSocket::from_std(sock).ok() }); - let mut sock = match sock_res { + let sock = match sock_res { Ok(sock) => sock, Err(e) => { log::debug!("Error processing AddListenSocket API request: {e:?}"); @@ -213,16 +213,7 @@ where }; // Register socket - let reg_result = run(|| -> anyhow::Result<()> { - let srv = self.app_server_mut(); - srv.mio_poll.registry().register( - &mut sock, - srv.mio_token_dispenser.dispense(), - mio::Interest::READABLE, - )?; - srv.sockets.push(sock); - Ok(()) - }); + let reg_result = self.app_server_mut().register_listen_socket(sock); if let Err(internal_error) = reg_result { log::warn!("Internal error processing AddListenSocket API request: {internal_error:?}"); diff --git a/rosenpass/src/api/mio/connection.rs b/rosenpass/src/api/mio/connection.rs index 52de31b..a3c7a4e 100644 --- a/rosenpass/src/api/mio/connection.rs +++ b/rosenpass/src/api/mio/connection.rs @@ -55,6 +55,7 @@ struct MioConnectionBuffers { #[derive(Debug)] pub struct MioConnection { io: UnixStream, + mio_token: mio::Token, invalid_read: bool, buffers: Option, api_handler: ApiHandler, @@ -62,11 +63,11 @@ pub struct MioConnection { impl MioConnection { pub fn new(app_server: &mut AppServer, mut io: UnixStream) -> std::io::Result { - app_server.mio_poll.registry().register( - &mut io, - app_server.mio_token_dispenser.dispense(), - MIO_RW, - )?; + let mio_token = app_server.mio_token_dispenser.dispense(); + app_server + .mio_poll + .registry() + .register(&mut io, mio_token, MIO_RW)?; let invalid_read = false; let read_buffer = LengthPrefixDecoder::new(SecretBuffer::new()); @@ -80,6 +81,7 @@ impl MioConnection { let api_state = ApiHandler::new(); Ok(Self { io, + mio_token, invalid_read, buffers, api_handler: api_state, @@ -99,6 +101,10 @@ impl MioConnection { app_server.mio_poll.registry().deregister(&mut self.io)?; Ok(()) } + + pub fn mio_token(&self) -> mio::Token { + self.mio_token + } } pub trait MioConnectionContext { @@ -250,6 +256,14 @@ pub trait MioConnectionContext { }; } } + + fn mio_token(&self) -> mio::Token { + self.mio_connection().mio_token() + } + + fn should_close(&self) -> bool { + self.mio_connection().shoud_close() + } } trait MioConnectionContextPrivate: MioConnectionContext { diff --git a/rosenpass/src/api/mio/manager.rs b/rosenpass/src/api/mio/manager.rs index bbf4265..cd6d276 100644 --- a/rosenpass/src/api/mio/manager.rs +++ b/rosenpass/src/api/mio/manager.rs @@ -1,20 +1,25 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - io, -}; +use std::{borrow::BorrowMut, io}; use mio::net::{UnixListener, UnixStream}; -use rosenpass_util::{io::nonblocking_handle_io_errors, mio::interest::RW as MIO_RW}; +use rosenpass_util::{ + functional::ApplyExt, io::nonblocking_handle_io_errors, mio::interest::RW as MIO_RW, +}; -use crate::app_server::AppServer; +use crate::app_server::{AppServer, AppServerIoSource}; use super::{MioConnection, MioConnectionContext}; #[derive(Default, Debug)] pub struct MioManager { listeners: Vec, - connections: Vec, + connections: Vec>, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum MioManagerIoSource { + Listener(usize), + Connection(usize), } impl MioManager { @@ -42,18 +47,49 @@ pub trait MioManagerContext { fn add_listener(&mut self, mut listener: UnixListener) -> io::Result<()> { let srv = self.app_server_mut(); - srv.mio_poll.registry().register( - &mut listener, - srv.mio_token_dispenser.dispense(), - MIO_RW, - )?; + let mio_token = srv.mio_token_dispenser.dispense(); + srv.mio_poll + .registry() + .register(&mut listener, mio_token, MIO_RW)?; + let io_source = self + .mio_manager() + .listeners + .len() + .apply(MioManagerIoSource::Listener) + .apply(AppServerIoSource::MioManager); self.mio_manager_mut().listeners.push(listener); + self.app_server_mut() + .register_io_source(mio_token, io_source); + Ok(()) } fn add_connection(&mut self, connection: UnixStream) -> io::Result<()> { let connection = MioConnection::new(self.app_server_mut(), connection)?; - self.mio_manager_mut().connections.push(connection); + let mio_token = connection.mio_token(); + let conns: &mut Vec> = + self.mio_manager_mut().connections.borrow_mut(); + let idx = conns + .iter_mut() + .enumerate() + .find(|(_, slot)| slot.is_some()) + .map(|(idx, _)| idx) + .unwrap_or(conns.len()); + conns.insert(idx, Some(connection)); + let io_source = idx + .apply(MioManagerIoSource::Listener) + .apply(AppServerIoSource::MioManager); + self.app_server_mut() + .register_io_source(mio_token, io_source); + Ok(()) + } + + fn poll_particular(&mut self, io_source: MioManagerIoSource) -> anyhow::Result<()> { + use MioManagerIoSource as S; + match io_source { + S::Listener(idx) => self.accept_from(idx)?, + S::Connection(idx) => self.poll_particular_connection(idx)?, + }; Ok(()) } @@ -87,27 +123,38 @@ pub trait MioManagerContext { } fn poll_connections(&mut self) -> anyhow::Result<()> { - let mut idx = 0; - while idx < self.mio_manager().connections.len() { - if self.mio_manager().connections[idx].shoud_close() { - let conn = self.mio_manager_mut().connections.remove(idx); - if let Err(e) = conn.close(self.app_server_mut()) { - log::warn!("Error while closing API connection {e:?}"); - }; - continue; - } - - MioConnectionFocus::new(self, idx).poll()?; - - idx += 1; + for idx in 0..self.mio_manager().connections.len() { + self.poll_particular_connection(idx)?; } Ok(()) } + + fn poll_particular_connection(&mut self, idx: usize) -> anyhow::Result<()> { + if self.mio_manager().connections[idx].is_none() { + return Ok(()); + } + + let mut conn = MioConnectionFocus::new(self, idx); + conn.poll()?; + + if conn.should_close() { + let conn = self.mio_manager_mut().connections[idx].take().unwrap(); + let mio_token = conn.mio_token(); + if let Err(e) = conn.close(self.app_server_mut()) { + log::warn!("Error while closing API connection {e:?}"); + }; + self.app_server_mut().unregister_io_source(mio_token); + } + + Ok(()) + } } impl MioConnectionContext for MioConnectionFocus<'_, T> { fn mio_connection(&self) -> &MioConnection { - self.ctx.mio_manager().connections[self.conn_idx].borrow() + self.ctx.mio_manager().connections[self.conn_idx] + .as_ref() + .unwrap() } fn app_server(&self) -> &AppServer { @@ -115,7 +162,9 @@ impl MioConnectionContext for MioConnectionFocus< } fn mio_connection_mut(&mut self) -> &mut MioConnection { - self.ctx.mio_manager_mut().connections[self.conn_idx].borrow_mut() + self.ctx.mio_manager_mut().connections[self.conn_idx] + .as_mut() + .unwrap() } fn app_server_mut(&mut self) -> &mut AppServer { diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index 77526d8..c999d46 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -10,6 +10,12 @@ use rosenpass_secret_memory::Public; use rosenpass_secret_memory::Secret; use rosenpass_util::build::ConstructionSite; use rosenpass_util::file::StoreValueB64; +use rosenpass_util::functional::run; +use rosenpass_util::functional::ApplyExt; +use rosenpass_util::io::IoResultKindHintExt; +use rosenpass_util::io::SubstituteForIoErrorKindExt; +use rosenpass_util::option::SomeExt; +use rosenpass_util::result::OkExt; use rosenpass_wireguard_broker::WireguardBrokerMio; use rosenpass_wireguard_broker::{WireguardBrokerCfg, WG_KEY_LEN}; use zerocopy::AsBytes; @@ -17,7 +23,9 @@ use zerocopy::AsBytes; use std::cell::Cell; use std::collections::HashMap; +use std::collections::VecDeque; use std::fmt::Debug; +use std::io; use std::io::stdout; use std::io::ErrorKind; use std::io::Write; @@ -143,6 +151,17 @@ pub struct AppServerTest { pub termination_handler: Option>, } +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum AppServerIoSource { + Socket(usize), + #[cfg(feature = "experiment_api")] + PskBroker(Public), + #[cfg(feature = "experiment_api")] + MioManager(crate::api::mio::MioManagerIoSource), +} + +const EVENT_CAPACITY: usize = 20; + /// Holds the state of the application, namely the external IO /// /// Responsible for file IO, network IO @@ -152,6 +171,9 @@ pub struct AppServer { pub crypto_site: ConstructionSite, pub sockets: Vec, pub events: mio::Events, + pub short_poll_queue: VecDeque, + pub performed_long_poll: bool, + pub io_source_index: HashMap, pub mio_poll: mio::Poll, pub mio_token_dispenser: MioTokenDispenser, pub brokers: BrokerStore, @@ -521,7 +543,7 @@ impl AppServer { ) -> anyhow::Result { // setup mio let mio_poll = mio::Poll::new()?; - let events = mio::Events::with_capacity(20); + let events = mio::Events::with_capacity(EVENT_CAPACITY); let mut mio_token_dispenser = MioTokenDispenser::default(); // bind each SocketAddr to a socket @@ -596,12 +618,14 @@ impl AppServer { } // register all sockets to mio - for socket in sockets.iter_mut() { - mio_poll.registry().register( - socket, - mio_token_dispenser.dispense(), - Interest::READABLE, - )?; + let mut io_source_index = HashMap::new(); + for (idx, socket) in sockets.iter_mut().enumerate() { + let mio_token = mio_token_dispenser.dispense(); + mio_poll + .registry() + .register(socket, mio_token, Interest::READABLE)?; + let prev = io_source_index.insert(mio_token, AppServerIoSource::Socket(idx)); + assert!(prev.is_none()); } let crypto_site = match keypair { @@ -615,6 +639,9 @@ impl AppServer { verbosity, sockets, events, + short_poll_queue: Default::default(), + performed_long_poll: false, + io_source_index, mio_poll, mio_token_dispenser, brokers: BrokerStore::default(), @@ -646,41 +673,57 @@ impl AppServer { matches!(self.verbosity, Verbosity::Verbose) } + pub fn register_listen_socket(&mut self, mut sock: mio::net::UdpSocket) -> anyhow::Result<()> { + let mio_token = self.mio_token_dispenser.dispense(); + self.mio_poll + .registry() + .register(&mut sock, mio_token, mio::Interest::READABLE)?; + let io_source = self.sockets.len().apply(AppServerIoSource::Socket); + self.sockets.push(sock); + self.register_io_source(mio_token, io_source); + Ok(()) + } + + pub fn register_io_source(&mut self, token: mio::Token, io_source: AppServerIoSource) { + let prev = self.io_source_index.insert(token, io_source); + assert!(prev.is_none()); + } + + pub fn unregister_io_source(&mut self, token: mio::Token) { + let value = self.io_source_index.remove(&token); + assert!(value.is_some(), "Removed IO source that does not exist"); + } + pub fn register_broker( &mut self, broker: Box>, ) -> Result { let ptr = Public::from_slice((self.brokers.store.len() as u64).as_bytes()); - if self.brokers.store.insert(ptr, broker).is_some() { bail!("Broker already registered"); } + + let mio_token = self.mio_token_dispenser.dispense(); + let io_source = ptr.apply(AppServerIoSource::PskBroker); //Register broker self.brokers .store .get_mut(&ptr) .ok_or(anyhow::format_err!("Broker wasn't added to registry"))? - .register( - self.mio_poll.registry(), - self.mio_token_dispenser.dispense(), - )?; + .register(self.mio_poll.registry(), mio_token)?; + self.register_io_source(mio_token, io_source); Ok(BrokerStorePtr(ptr)) } pub fn unregister_broker(&mut self, ptr: BrokerStorePtr) -> Result<()> { - //Unregister broker - self.brokers - .store - .get_mut(&ptr.0) - .ok_or_else(|| anyhow::anyhow!("Broker not found"))? - .unregister(self.mio_poll.registry())?; - - //Remove broker from store - self.brokers + let mut broker = self + .brokers .store .remove(&ptr.0) - .ok_or_else(|| anyhow::anyhow!("Broker not found"))?; + .context("Broker not found")?; + self.unregister_io_source(broker.mio_token().unwrap()); + broker.unregister(self.mio_poll.registry())?; Ok(()) } @@ -998,22 +1041,33 @@ impl AppServer { // readiness event seems to be good enough™ for now. // only poll if we drained all sockets before - if self.all_sockets_drained { - //Non blocked polling - self.mio_poll - .poll(&mut self.events, Some(Duration::from_secs(0)))?; - - if self.events.iter().peekable().peek().is_none() { - // if there are no events, then add to blocking poll count - self.blocking_polls_count += 1; - //Execute blocking poll - self.mio_poll.poll(&mut self.events, Some(timeout))?; - } else { - self.non_blocking_polls_count += 1; + run(|| -> anyhow::Result<()> { + if !self.all_sockets_drained || !self.short_poll_queue.is_empty() { + self.unpolled_count += 1; + return Ok(()); } - } else { - self.unpolled_count += 1; - } + + self.perform_mio_poll_and_register_events(Duration::from_secs(0))?; // Non-blocking poll + if !self.short_poll_queue.is_empty() { + // Got some events in non-blocking mode + self.non_blocking_polls_count += 1; + return Ok(()); + } + + if !self.performed_long_poll { + // pass – go perform a full long poll before we enter blocking poll mode + // to make sure our experimental short poll feature did not miss any events + // due to being buggy. + return Ok(()); + } + + // Perform and register blocking poll + self.blocking_polls_count += 1; + self.perform_mio_poll_and_register_events(timeout)?; + self.performed_long_poll = false; + + Ok(()) + })?; if let Some(AppServerTest { enable_dos_permanently: true, @@ -1048,26 +1102,58 @@ impl AppServer { } } + // Focused polling – i.e. actually using mio::Token – is experimental for now. + // The reason for this is that we need to figure out how to integrate load detection + // and focused polling for one. Mio event-based polling also does not play nice with + // the current function signature and its reentrant design which is focused around receiving UDP socket packages + // for processing by the crypto protocol server. + // Besides that, there are also some parts of the code which intentionally block + // despite available data. This is the correct behavior; e.g. api::mio::Connection blocks + // further reads from its unix socket until the write buffer is flushed. In other words + // the connection handler makes sure that there is a buffer to put the response in while + // before reading further request. + // The potential problem with this behavior is that we end up ignoring instructions from + // epoll() to read from the particular sockets, so epoll will return information about that + // particular – blocked – file descriptor every call. We have only so many event slots and + // in theory, the event array could fill up entirely with intentionally blocked sockets. + // We need to figure out how to deal with this situation. + // Mio uses uses epoll in level-triggered mode, so we could handle taint-tracking for ignored + // sockets ourselves. The facilities are available in epoll and Mio, but we need to figure out how mio uses those + // facilities and how we can integrate them here. + // This will involve rewriting a lot of IO code and we should probably have integration + // tests before we approach that. + // + // This hybrid approach is not without merit though; the short poll implementation covers + // all our IO sources, so under contention, rosenpass should generally not hit the long + // poll mode below. We keep short polling and calling epoll() in non-blocking mode (timeout + // of zero) until we run out of IO events processed. Then, just before we would perform a + // blocking poll, we go through all available IO sources to see if we missed anything. + { + while let Some(ev) = self.short_poll_queue.pop_front() { + if let Some(v) = self.try_recv_from_mio_token(buf, ev.token())? { + return Ok(Some(v)); + } + } + } + // drain all sockets let mut would_block_count = 0; - for (sock_no, socket) in self.sockets.iter_mut().enumerate() { - match socket.recv_from(buf) { - Ok((n, addr)) => { + for sock_no in 0..self.sockets.len() { + match self + .try_recv_from_listen_socket(buf, sock_no) + .io_err_kind_hint() + { + Ok(None) => continue, + Ok(Some(v)) => { // at least one socket was not drained... self.all_sockets_drained = false; - return Ok(Some(( - n, - Endpoint::SocketBoundAddress(SocketBoundEndpoint::new( - SocketPtr(sock_no), - addr, - )), - ))); + return Ok(Some(v)); } - Err(e) if e.kind() == ErrorKind::WouldBlock => { + Err((_, ErrorKind::WouldBlock)) => { would_block_count += 1; } // TODO if one socket continuously returns an error, then we never poll, thus we never wait for a timeout, thus we have a spin-lock - Err(e) => return Err(e.into()), + Err((e, _)) => return Err(e)?, } } @@ -1087,9 +1173,87 @@ impl AppServer { MioManagerFocus(self).poll()?; } + self.performed_long_poll = true; + Ok(None) } + fn perform_mio_poll_and_register_events(&mut self, timeout: Duration) -> io::Result<()> { + self.mio_poll.poll(&mut self.events, Some(timeout))?; + // Fill the short poll buffer with the acquired events + self.events + .iter() + .cloned() + .for_each(|v| self.short_poll_queue.push_back(v)); + Ok(()) + } + + fn try_recv_from_mio_token( + &mut self, + buf: &mut [u8], + token: mio::Token, + ) -> anyhow::Result> { + let io_source = match self.io_source_index.get(&token) { + Some(io_source) => *io_source, + None => { + log::warn!("No IO source assiociated with mio token ({token:?}). Polling using mio tokens directly is an experimental feature and IO handler should recover when all available io sources are polled. This is a developer error. Please report it."); + return Ok(None); + } + }; + + self.try_recv_from_io_source(buf, io_source) + } + + fn try_recv_from_io_source( + &mut self, + buf: &mut [u8], + io_source: AppServerIoSource, + ) -> anyhow::Result> { + use crate::api::mio::MioManagerContext; + + match io_source { + AppServerIoSource::Socket(idx) => self + .try_recv_from_listen_socket(buf, idx) + .substitute_for_ioerr_wouldblock(None)? + .ok(), + + #[cfg(feature = "experiment_api")] + AppServerIoSource::PskBroker(key) => self + .brokers + .store + .get_mut(&key) + .with_context(|| format!("No PSK broker under key {key:?}"))? + .process_poll() + .map(|_| None), + + #[cfg(feature = "experiment_api")] + AppServerIoSource::MioManager(mmio_src) => MioManagerFocus(self) + .poll_particular(mmio_src) + .map(|_| None), + } + } + + fn try_recv_from_listen_socket( + &mut self, + buf: &mut [u8], + idx: usize, + ) -> io::Result> { + use std::io::ErrorKind as K; + let (n, addr) = loop { + match self.sockets[idx].recv_from(buf).io_err_kind_hint() { + Ok(v) => break v, + Err((_, K::Interrupted)) => continue, + Err((e, _)) => return Err(e)?, + } + }; + SocketPtr(idx) + .apply(|sp| SocketBoundEndpoint::new(sp, addr)) + .apply(Endpoint::SocketBoundAddress) + .apply(|ep| (n, ep)) + .some() + .ok() + } + #[cfg(feature = "experiment_api")] pub fn add_api_connection(&mut self, connection: mio::net::UnixStream) -> std::io::Result<()> { use crate::api::mio::MioManagerContext; diff --git a/rosenpass/tests/api-integration-tests-api-setup.rs b/rosenpass/tests/api-integration-tests-api-setup.rs index b45f130..72eebbe 100644 --- a/rosenpass/tests/api-integration-tests-api-setup.rs +++ b/rosenpass/tests/api-integration-tests-api-setup.rs @@ -141,6 +141,7 @@ fn api_integration_api_setup() -> anyhow::Result<()> { peer_b.config_file_path.to_str().context("")?, ]) .stdin(Stdio::null()) + .stderr(Stdio::null()) .stdout(Stdio::piped()) .spawn()?; diff --git a/rosenpass/tests/integration_test.rs b/rosenpass/tests/integration_test.rs index 989adb0..623a1ba 100644 --- a/rosenpass/tests/integration_test.rs +++ b/rosenpass/tests/integration_test.rs @@ -293,6 +293,7 @@ struct MockBrokerInner { #[derive(Debug, Default)] struct MockBroker { inner: Arc>, + mio_token: Option, } impl WireguardBrokerMio for MockBroker { @@ -301,8 +302,9 @@ impl WireguardBrokerMio for MockBroker { fn register( &mut self, _registry: &mio::Registry, - _token: mio::Token, + token: mio::Token, ) -> Result<(), Self::MioError> { + self.mio_token = Some(token); Ok(()) } @@ -311,8 +313,13 @@ impl WireguardBrokerMio for MockBroker { } fn unregister(&mut self, _registry: &mio::Registry) -> Result<(), Self::MioError> { + self.mio_token = None; Ok(()) } + + fn mio_token(&self) -> Option { + self.mio_token + } } impl rosenpass_wireguard_broker::WireGuardBroker for MockBroker { diff --git a/util/src/io.rs b/util/src/io.rs index 31ab95f..d5ee812 100644 --- a/util/src/io.rs +++ b/util/src/io.rs @@ -52,6 +52,56 @@ impl TryIoResultKindHintExt for Result { } } +pub trait SubstituteForIoErrorKindExt: Sized { + type Error; + fn substitute_for_ioerr_kind_with T>( + self, + kind: io::ErrorKind, + f: F, + ) -> Result; + fn substitute_for_ioerr_kind(self, kind: io::ErrorKind, v: T) -> Result { + self.substitute_for_ioerr_kind_with(kind, || v) + } + + fn substitute_for_ioerr_interrupted_with T>( + self, + f: F, + ) -> Result { + self.substitute_for_ioerr_kind_with(io::ErrorKind::Interrupted, f) + } + + fn substitute_for_ioerr_interrupted(self, v: T) -> Result { + self.substitute_for_ioerr_interrupted_with(|| v) + } + + fn substitute_for_ioerr_wouldblock_with T>( + self, + f: F, + ) -> Result { + self.substitute_for_ioerr_kind_with(io::ErrorKind::WouldBlock, f) + } + + fn substitute_for_ioerr_wouldblock(self, v: T) -> Result { + self.substitute_for_ioerr_wouldblock_with(|| v) + } +} + +impl SubstituteForIoErrorKindExt for Result { + type Error = E; + + fn substitute_for_ioerr_kind_with T>( + self, + kind: io::ErrorKind, + f: F, + ) -> Result { + match self.try_io_err_kind_hint() { + Ok(v) => Ok(v), + Err((_, Some(k))) if k == kind => Ok(f()), + Err((e, _)) => Err(e), + } + } +} + /// Automatically handles `std::io::ErrorKind::Interrupted`. /// /// - If there is no error (i.e. on `Ok(r)`), the function will return `Ok(Some(r))` diff --git a/util/src/lib.rs b/util/src/lib.rs index 7eb7c6d..bd6889f 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -10,6 +10,7 @@ pub mod io; pub mod length_prefix_encoding; pub mod mem; pub mod mio; +pub mod option; pub mod ord; pub mod result; pub mod time; diff --git a/util/src/option.rs b/util/src/option.rs new file mode 100644 index 0000000..6b743df --- /dev/null +++ b/util/src/option.rs @@ -0,0 +1,7 @@ +pub trait SomeExt: Sized { + fn some(self) -> Option { + Some(self) + } +} + +impl SomeExt for T {} diff --git a/wireguard-broker/src/brokers/mio_client.rs b/wireguard-broker/src/brokers/mio_client.rs index 8af7fcf..8cbbf7f 100644 --- a/wireguard-broker/src/brokers/mio_client.rs +++ b/wireguard-broker/src/brokers/mio_client.rs @@ -16,6 +16,7 @@ use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio}; #[derive(Debug)] pub struct MioBrokerClient { inner: BrokerClient, + mio_token: Option, } #[derive(Debug)] @@ -59,7 +60,10 @@ impl MioBrokerClient { write_buffer, }; let inner = BrokerClient::new(io); - Self { inner } + Self { + inner, + mio_token: None, + } } fn poll(&mut self) -> anyhow::Result<()> { @@ -104,6 +108,7 @@ impl WireguardBrokerMio for MioBrokerClient { registry: &mio::Registry, token: mio::Token, ) -> Result<(), Self::MioError> { + self.mio_token = Some(token); registry.register( &mut self.inner.io_mut().socket, token, @@ -118,9 +123,14 @@ impl WireguardBrokerMio for MioBrokerClient { } fn unregister(&mut self, registry: &mio::Registry) -> Result<(), Self::MioError> { + self.mio_token = None; registry.deregister(&mut self.inner.io_mut().socket)?; Ok(()) } + + fn mio_token(&self) -> Option { + self.mio_token + } } impl BrokerClientIo for MioBrokerClientIo { diff --git a/wireguard-broker/src/brokers/native_unix.rs b/wireguard-broker/src/brokers/native_unix.rs index b56413f..1dad5df 100644 --- a/wireguard-broker/src/brokers/native_unix.rs +++ b/wireguard-broker/src/brokers/native_unix.rs @@ -16,7 +16,9 @@ const MAX_B64_KEY_SIZE: usize = WG_KEY_LEN * 5 / 3; const MAX_B64_PEER_ID_SIZE: usize = WG_PEER_LEN * 5 / 3; #[derive(Debug)] -pub struct NativeUnixBroker {} +pub struct NativeUnixBroker { + mio_token: Option, +} impl Default for NativeUnixBroker { fn default() -> Self { @@ -26,7 +28,7 @@ impl Default for NativeUnixBroker { impl NativeUnixBroker { pub fn new() -> Self { - Self {} + Self { mio_token: None } } } @@ -88,8 +90,9 @@ impl WireguardBrokerMio for NativeUnixBroker { fn register( &mut self, _registry: &mio::Registry, - _token: mio::Token, + token: mio::Token, ) -> Result<(), Self::MioError> { + self.mio_token = Some(token); Ok(()) } @@ -98,8 +101,13 @@ impl WireguardBrokerMio for NativeUnixBroker { } fn unregister(&mut self, _registry: &mio::Registry) -> Result<(), Self::MioError> { + self.mio_token = None; Ok(()) } + + fn mio_token(&self) -> Option { + self.mio_token + } } #[derive(Debug, Builder)] diff --git a/wireguard-broker/src/lib.rs b/wireguard-broker/src/lib.rs index 3a206ba..ecb47bd 100644 --- a/wireguard-broker/src/lib.rs +++ b/wireguard-broker/src/lib.rs @@ -28,6 +28,8 @@ pub trait WireguardBrokerMio: WireGuardBroker { registry: &mio::Registry, token: mio::Token, ) -> Result<(), Self::MioError>; + fn mio_token(&self) -> Option; + /// Run after a mio::poll operation fn process_poll(&mut self) -> Result<(), Self::MioError>;