diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index 90a4b11..ecdfdab 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -8,6 +8,7 @@ use mio::Interest; use mio::Token; use rosenpass_secret_memory::Public; use rosenpass_secret_memory::Secret; +use rosenpass_util::build::ConstructionSite; use rosenpass_util::file::StoreValueB64; use rosenpass_wireguard_broker::WireguardBrokerMio; use rosenpass_wireguard_broker::{WireguardBrokerCfg, WG_KEY_LEN}; @@ -31,6 +32,7 @@ use std::slice; use std::time::Duration; use std::time::Instant; +use crate::protocol::BuildCryptoServer; use crate::protocol::HostIdentification; use crate::{ config::Verbosity, @@ -147,7 +149,7 @@ pub struct AppServerTest { // TODO add user control via unix domain socket and stdin/stdout #[derive(Debug)] pub struct AppServer { - pub crypt: Option, + pub crypto_site: ConstructionSite, pub sockets: Vec, pub events: mio::Events, pub mio_poll: mio::Poll, @@ -606,7 +608,7 @@ impl AppServer { // TODO use mio::net::UnixStream together with std::os::unix::net::UnixStream for Linux Ok(Self { - crypt: Some(CryptoServer::new(sk, pk)), + crypto_site: ConstructionSite::from_product(CryptoServer::new(sk, pk)), peers: Vec::new(), verbosity, sockets, @@ -627,14 +629,14 @@ impl AppServer { } pub fn crypto_server(&self) -> anyhow::Result<&CryptoServer> { - self.crypt - .as_ref() + self.crypto_site + .product_ref() .context("Cryptography handler not initialized") } pub fn crypto_server_mut(&mut self) -> anyhow::Result<&mut CryptoServer> { - self.crypt - .as_mut() + self.crypto_site + .product_mut() .context("Cryptography handler not initialized") } @@ -688,8 +690,13 @@ impl AppServer { broker_peer: Option, hostname: Option, ) -> anyhow::Result { - let PeerPtr(pn) = self.crypto_server_mut()?.add_peer(psk, pk)?; + let PeerPtr(pn) = match &mut self.crypto_site { + ConstructionSite::Void => bail!("Crypto server construction site is void"), + ConstructionSite::Builder(builder) => builder.add_peer(psk, pk), + ConstructionSite::Product(srv) => srv.add_peer(psk, pk)?, + }; assert!(pn == self.peers.len()); + let initial_endpoint = hostname .map(Endpoint::discovery_from_hostname) .transpose()?; @@ -774,16 +781,31 @@ impl AppServer { } } - match self.poll(&mut *rx)? { - #[allow(clippy::redundant_closure_call)] - SendInitiation(peer) => tx_maybe_with!(peer, || self + enum CryptoSrv { + Avail, + Missing, + } + + let poll_result = self.poll(&mut *rx)?; + let have_crypto = match self.crypto_site.is_available() { + true => CryptoSrv::Avail, + false => CryptoSrv::Missing, + }; + + #[allow(clippy::redundant_closure_call)] + match (have_crypto, poll_result) { + (CryptoSrv::Missing, SendInitiation(_)) => {} + (CryptoSrv::Avail, SendInitiation(peer)) => tx_maybe_with!(peer, || self .crypto_server_mut()? .initiate_handshake(peer.lower(), &mut *tx))?, - #[allow(clippy::redundant_closure_call)] - SendRetransmission(peer) => tx_maybe_with!(peer, || self + + (CryptoSrv::Missing, SendRetransmission(_)) => {} + (CryptoSrv::Avail, SendRetransmission(peer)) => tx_maybe_with!(peer, || self .crypto_server_mut()? .retransmit_handshake(peer.lower(), &mut *tx))?, - DeleteKey(peer) => { + + (CryptoSrv::Missing, DeleteKey(_)) => {} + (CryptoSrv::Avail, DeleteKey(peer)) => { self.output_key(peer, Stale, &SymKey::random())?; // There was a loss of connection apparently; restart host discovery @@ -797,7 +819,8 @@ impl AppServer { ); } - ReceivedMessage(len, endpoint) => { + (CryptoSrv::Missing, ReceivedMessage(_, _)) => {} + (CryptoSrv::Avail, ReceivedMessage(len, endpoint)) => { let msg_result = match self.under_load { DoSOperation::UnderLoad => { self.handle_msg_under_load(&endpoint, &rx[..len], &mut *tx) @@ -910,17 +933,32 @@ impl AppServer { pub fn poll(&mut self, rx_buf: &mut [u8]) -> anyhow::Result { use crate::protocol::PollResult as C; use AppPollResult as A; - loop { - return Ok(match self.crypto_server_mut()?.poll()? { - C::DeleteKey(PeerPtr(no)) => A::DeleteKey(AppPeerPtr(no)), - C::SendInitiation(PeerPtr(no)) => A::SendInitiation(AppPeerPtr(no)), - C::SendRetransmission(PeerPtr(no)) => A::SendRetransmission(AppPeerPtr(no)), - C::Sleep(timeout) => match self.try_recv(rx_buf, timeout)? { - Some((len, addr)) => A::ReceivedMessage(len, addr), - None => continue, - }, - }); - } + let res = loop { + // Call CryptoServer's poll (if available) + let crypto_poll = self + .crypto_site + .product_mut() + .map(|crypto| crypto.poll()) + .transpose()?; + + // Map crypto server's poll result to our poll result + let io_poll_timeout = match crypto_poll { + Some(C::DeleteKey(PeerPtr(no))) => break A::DeleteKey(AppPeerPtr(no)), + Some(C::SendInitiation(PeerPtr(no))) => break A::SendInitiation(AppPeerPtr(no)), + Some(C::SendRetransmission(PeerPtr(no))) => { + break A::SendRetransmission(AppPeerPtr(no)) + } + Some(C::Sleep(timeout)) => timeout, // No event from crypto-server, do IO + None => crate::protocol::UNENDING, // Crypto server is uninitialized, do IO + }; + + // Perform IO (look for a message) + if let Some((len, addr)) = self.try_recv(rx_buf, io_poll_timeout)? { + break A::ReceivedMessage(len, addr); + } + }; + + Ok(res) } /// Tries to receive a new message diff --git a/rosenpass/src/protocol/build_crypto_server.rs b/rosenpass/src/protocol/build_crypto_server.rs new file mode 100644 index 0000000..224fbf0 --- /dev/null +++ b/rosenpass/src/protocol/build_crypto_server.rs @@ -0,0 +1,127 @@ +use rosenpass_util::{ + build::Build, + mem::{DiscardResultExt, SwapWithDefaultExt}, + result::ensure_or, +}; +use thiserror::Error; + +use super::{CryptoServer, PeerPtr, SPk, SSk, SymKey}; + +#[derive(Debug, Clone)] +pub struct Keypair { + pub sk: SSk, + pub pk: SPk, +} + +// TODO: We need a named tuple derive +impl Keypair { + pub fn new(sk: SSk, pk: SPk) -> Self { + Self { sk, pk } + } + + pub fn zero() -> Self { + Self::new(SSk::zero(), SPk::zero()) + } + + pub fn random() -> Self { + Self::new(SSk::random(), SPk::random()) + } + + pub fn from_parts(parts: (SSk, SPk)) -> Self { + Self::new(parts.0, parts.1) + } + + pub fn into_parts(self) -> (SSk, SPk) { + (self.sk, self.pk) + } +} + +#[derive(Error, Debug)] +#[error("PSK already set in BuildCryptoServer")] +pub struct PskAlreadySet; + +#[derive(Error, Debug)] +#[error("Keypair already set in BuildCryptoServer")] +pub struct KeypairAlreadySet; + +#[derive(Error, Debug)] +#[error("Can not construct CryptoServer: Missing keypair")] +pub struct MissingKeypair; + +#[derive(Debug, Default)] +pub struct BuildCryptoServer { + pub keypair: Option, + pub peers: Vec, +} + +impl Build for BuildCryptoServer { + type Error = anyhow::Error; + + fn build(self) -> Result { + let Some(Keypair { sk, pk }) = self.keypair else { + return Err(MissingKeypair)?; + }; + + let mut srv = CryptoServer::new(sk, pk); + + for (idx, PeerParams { psk, pk }) in self.peers.into_iter().enumerate() { + let PeerPtr(idx2) = srv.add_peer(psk, pk)?; + assert!(idx == idx2, "Peer id changed during CryptoServer construction from {idx} to {idx2}. This is a developer error.") + } + + Ok(srv) + } +} + +#[derive(Debug)] +pub struct PeerParams { + pub psk: Option, + pub pk: SPk, +} + +impl BuildCryptoServer { + pub fn new(keypair: Option, peers: Vec) -> Self { + Self { keypair, peers } + } + + pub fn empty() -> Self { + Self::new(None, Vec::new()) + } + + pub fn from_parts(parts: (Option, Vec)) -> Self { + Self { + keypair: parts.0, + peers: parts.1, + } + } + + pub fn take_parts(&mut self) -> (Option, Vec) { + (self.keypair.take(), self.peers.swap_with_default()) + } + + pub fn into_parts(mut self) -> (Option, Vec) { + self.take_parts() + } + + pub fn with_keypair(&mut self, keypair: Keypair) -> Result<&mut Self, KeypairAlreadySet> { + ensure_or(self.keypair.is_none(), KeypairAlreadySet)?; + self.keypair.insert(keypair).discard_result(); + Ok(self) + } + + pub fn with_added_peer(&mut self, psk: Option, pk: SPk) -> &mut Self { + // TODO: Check here already whether peer was already added + self.peers.push(PeerParams { psk, pk }); + self + } + + pub fn add_peer(&mut self, psk: Option, pk: SPk) -> PeerPtr { + let id = PeerPtr(self.peers.len()); + self.with_added_peer(psk, pk); + id + } + + pub fn emancipate(&mut self) -> Self { + Self::from_parts(self.take_parts()) + } +} diff --git a/rosenpass/src/protocol/mod.rs b/rosenpass/src/protocol/mod.rs new file mode 100644 index 0000000..68d48bb --- /dev/null +++ b/rosenpass/src/protocol/mod.rs @@ -0,0 +1,6 @@ +mod build_crypto_server; +#[allow(clippy::module_inception)] +mod protocol; + +pub use build_crypto_server::*; +pub use protocol::*; diff --git a/rosenpass/src/protocol.rs b/rosenpass/src/protocol/protocol.rs similarity index 100% rename from rosenpass/src/protocol.rs rename to rosenpass/src/protocol/protocol.rs diff --git a/util/src/build.rs b/util/src/build.rs new file mode 100644 index 0000000..b268c4c --- /dev/null +++ b/util/src/build.rs @@ -0,0 +1,169 @@ +use crate::{ + functional::ApplyExt, + mem::{SwapWithDefaultExt, SwapWithExt}, +}; + +#[derive(thiserror::Error, Debug)] +pub enum ConstructionSiteErectError { + #[error("Construction site is void")] + IsVoid, + #[error("Construction is already built")] + AlreadyBuilt, + #[error("Other construction site error {0:?}")] + Other(#[from] E), +} + +pub trait Build: Sized { + type Error; + fn build(self) -> Result; +} + +#[derive(Debug)] +pub enum ConstructionSite +where + Builder: Build, +{ + Void, + Builder(Builder), + Product(T), +} + +impl Default for ConstructionSite +where + Builder: Build, +{ + fn default() -> Self { + Self::Void + } +} + +impl ConstructionSite +where + Builder: Build, +{ + pub fn void() -> Self { + Self::Void + } + + pub fn new(builder: Builder) -> Self { + Self::Builder(builder) + } + + pub fn from_product(value: T) -> Self { + Self::Product(value) + } + + pub fn take(&mut self) -> Self { + self.swap_with_default() + } + + pub fn modify_taken_with_return(&mut self, f: F) -> R + where + F: FnOnce(Self) -> (Self, R), + { + let (site, res) = self.take().apply(f); + self.swap_with(site); + res + } + + pub fn modify_taken(&mut self, f: F) + where + F: FnOnce(Self) -> Self, + { + self.take().apply(f).swap_with_mut(self) + } + + #[allow(clippy::result_unit_err)] + pub fn erect(&mut self) -> Result<(), ConstructionSiteErectError> { + self.modify_taken_with_return(|site| { + let builder = match site { + site @ Self::Void => return (site, Err(ConstructionSiteErectError::IsVoid)), + site @ Self::Product(_) => { + return (site, Err(ConstructionSiteErectError::AlreadyBuilt)) + } + Self::Builder(builder) => builder, + }; + + let product = match builder.build() { + Err(e) => { + return (Self::void(), Err(ConstructionSiteErectError::Other(e))); + } + Ok(p) => p, + }; + + (Self::from_product(product), Ok(())) + }) + } + + /// Returns `true` if the construction site is [`Void`]. + /// + /// [`Void`]: ConstructionSite::Void + #[must_use] + pub fn is_void(&self) -> bool { + matches!(self, Self::Void) + } + + /// Returns `true` if the construction site is [`InProgress`]. + /// + /// [`InProgress`]: ConstructionSite::InProgress + #[must_use] + pub fn in_progess(&self) -> bool { + matches!(self, Self::Builder(..)) + } + + /// Returns `true` if the construction site is [`Done`]. + /// + /// [`Done`]: ConstructionSite::Done + #[must_use] + pub fn is_available(&self) -> bool { + matches!(self, Self::Product(..)) + } + + pub fn into_builder(self) -> Option { + use ConstructionSite as S; + match self { + S::Builder(v) => Some(v), + _ => None, + } + } + + pub fn builder_ref(&self) -> Option<&Builder> { + use ConstructionSite as S; + match self { + S::Builder(v) => Some(v), + _ => None, + } + } + + pub fn builder_mut(&mut self) -> Option<&mut Builder> { + use ConstructionSite as S; + match self { + S::Builder(v) => Some(v), + _ => None, + } + } + + pub fn into_product(self) -> Option { + use ConstructionSite as S; + match self { + S::Product(v) => Some(v), + _ => None, + } + } + + pub fn product_ref(&self) -> Option<&T> { + use ConstructionSite as S; + match self { + S::Product(v) => Some(v), + _ => None, + } + } + + pub fn product_mut(&mut self) -> Option<&mut T> { + use ConstructionSite as S; + match self { + S::Product(v) => Some(v), + _ => None, + } + } +} diff --git a/util/src/functional.rs b/util/src/functional.rs index 10c5f94..54ebe8c 100644 --- a/util/src/functional.rs +++ b/util/src/functional.rs @@ -6,6 +6,32 @@ where v } +pub trait MutatingExt { + fn mutating(self, f: F) -> Self + where + F: Fn(&mut Self); + fn mutating_mut(&mut self, f: F) -> &mut Self + where + F: Fn(&mut Self); +} + +impl MutatingExt for T { + fn mutating(self, f: F) -> Self + where + F: Fn(&mut Self), + { + mutating(self, f) + } + + fn mutating_mut(&mut self, f: F) -> &mut Self + where + F: Fn(&mut Self), + { + f(self); + self + } +} + pub fn sideeffect(v: T, f: F) -> T where F: Fn(&T), @@ -14,6 +40,58 @@ where v } +pub trait SideffectExt { + fn sideeffect(self, f: F) -> Self + where + F: Fn(&Self); + fn sideeffect_ref(&self, f: F) -> &Self + where + F: Fn(&Self); + fn sideeffect_mut(&mut self, f: F) -> &mut Self + where + F: Fn(&Self); +} + +impl SideffectExt for T { + fn sideeffect(self, f: F) -> Self + where + F: Fn(&Self), + { + sideeffect(self, f) + } + + fn sideeffect_ref(&self, f: F) -> &Self + where + F: Fn(&Self), + { + f(self); + self + } + + fn sideeffect_mut(&mut self, f: F) -> &mut Self + where + F: Fn(&Self), + { + f(self); + self + } +} + pub fn run R>(f: F) -> R { f() } + +pub trait ApplyExt: Sized { + fn apply(self, f: F) -> R + where + F: FnOnce(Self) -> R; +} + +impl ApplyExt for T { + fn apply(self, f: F) -> R + where + F: FnOnce(Self) -> R, + { + f(self) + } +} diff --git a/util/src/lib.rs b/util/src/lib.rs index c4d3f31..eeaa920 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -1,6 +1,7 @@ #![recursion_limit = "256"] pub mod b64; +pub mod build; pub mod fd; pub mod file; pub mod functional; diff --git a/util/src/mem.rs b/util/src/mem.rs index 620e3d8..2cc5785 100644 --- a/util/src/mem.rs +++ b/util/src/mem.rs @@ -92,3 +92,47 @@ impl Drop for Forgetting { forget(value) } } + +pub trait DiscardResultExt { + fn discard_result(self); +} + +impl DiscardResultExt for T { + fn discard_result(self) {} +} + +pub trait ForgetExt { + fn forget(self); +} + +impl ForgetExt for T { + fn forget(self) { + std::mem::forget(self) + } +} + +pub trait SwapWithExt { + fn swap_with(&mut self, other: Self) -> Self; + fn swap_with_mut(&mut self, other: &mut Self); +} + +impl SwapWithExt for T { + fn swap_with(&mut self, mut other: Self) -> Self { + self.swap_with_mut(&mut other); + other + } + + fn swap_with_mut(&mut self, other: &mut Self) { + std::mem::swap(self, other) + } +} + +pub trait SwapWithDefaultExt { + fn swap_with_default(&mut self) -> Self; +} + +impl SwapWithDefaultExt for T { + fn swap_with_default(&mut self) -> Self { + self.swap_with(Self::default()) + } +} diff --git a/util/src/result.rs b/util/src/result.rs index 8493857..041adbd 100644 --- a/util/src/result.rs +++ b/util/src/result.rs @@ -8,6 +8,16 @@ macro_rules! attempt { }; } +pub trait OkExt: Sized { + fn ok(self) -> Result; +} + +impl OkExt for T { + fn ok(self) -> Result { + Ok(self) + } +} + /// Trait for container types that guarantee successful unwrapping. /// /// The `.guaranteed()` function can be used over unwrap to show that