diff --git a/Cargo.lock b/Cargo.lock index b4f59a0..0e7fd41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -179,7 +179,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn", + "syn 2.0.39", "which", ] @@ -216,6 +216,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + [[package]] name = "cast" version = "0.3.0" @@ -328,7 +334,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.39", ] [[package]] @@ -361,6 +367,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "command-fds" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f190f3c954f7bca3c6296d0ec561c739bdbe6c7e990294ed168d415f6e1b5b01" +dependencies = [ + "nix", + "thiserror", +] + [[package]] name = "core2" version = "0.4.0" @@ -448,6 +464,41 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "darling" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f2c43f534ea4b0b049015d00269734195e6d3f0f6635cb692251aca6f9f8b3c" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e91455b86830a1c21799d94524df0845183fa55bafd9aa137b01c7d1065fa36" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29b5acf0dea37a7f66f7b25d2c5e93fd46f8f6968b1a5d7a3e02e97768afc95a" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + [[package]] name = "dary_heap" version = "0.3.6" @@ -462,7 +513,38 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.39", +] + +[[package]] +name = "derive_builder" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d13202debe11181040ae9063d739fa32cfcaaebe2275fe387703460ae2365b30" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66e616858f6187ed828df7c64a6d71720d83767a7f19740b2d1b6fe6327b36e5" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58a94ace95092c5acb1e97a7e846b310cfbd499652f72297da7493f618a98d73" +dependencies = [ + "derive_builder_core", + "syn 1.0.109", ] [[package]] @@ -514,7 +596,7 @@ checksum = "d4029edd3e734da6fe05b6cd7bd2960760a616bd2ddd0d59a0124746d6272af0" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.3.5", "windows-sys 0.48.0", ] @@ -528,6 +610,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -623,6 +711,12 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -783,6 +877,16 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.20" @@ -837,6 +941,42 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "neli" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1805440578ced23f85145d00825c0a831e43c587132a90e100552172543ae30" +dependencies = [ + "byteorder", + "libc", + "log", + "neli-proc-macros", +] + +[[package]] +name = "neli-proc-macros" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c168194d373b1e134786274020dae7fc5513d565ea2ebb9bc9ff17ffb69106d4" +dependencies = [ + "either", + "proc-macro2", + "quote", + "serde", + "syn 1.0.109", +] + +[[package]] +name = "nix" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" +dependencies = [ + "bitflags 2.4.1", + "cfg-if", + "libc", +] + [[package]] name = "nom" version = "7.1.3" @@ -856,6 +996,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi 0.3.3", + "libc", +] + [[package]] name = "object" version = "0.32.1" @@ -895,6 +1045,29 @@ version = "6.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.4.1", + "smallvec", + "windows-targets 0.48.5", +] + [[package]] name = "paste" version = "1.0.14" @@ -913,6 +1086,12 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + [[package]] name = "pkg-config" version = "0.3.27" @@ -960,7 +1139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.39", ] [[package]] @@ -1049,6 +1228,15 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "regex" version = "1.10.2" @@ -1104,6 +1292,7 @@ version = "0.2.1" dependencies = [ "anyhow", "clap 4.4.10", + "command-fds", "criterion", "env_logger", "libsodium-sys-stable", @@ -1120,6 +1309,8 @@ dependencies = [ "rosenpass-sodium", "rosenpass-to", "rosenpass-util", + "rosenpass-wireguard-broker", + "rustix", "serde", "stacker", "static_assertions", @@ -1225,6 +1416,25 @@ version = "0.1.0" dependencies = [ "anyhow", "base64", + "rustix", +] + +[[package]] +name = "rosenpass-wireguard-broker" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap 4.4.10", + "env_logger", + "log", + "mio", + "paste", + "rosenpass-lenses", + "rosenpass-to", + "rosenpass-util", + "thiserror", + "tokio", + "wireguard-uapi", ] [[package]] @@ -1241,9 +1451,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.26" +version = "0.38.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" +checksum = "bfeae074e687625746172d639330f1de242a178bf3189b51e35a7a21573513ac" dependencies = [ "bitflags 2.4.1", "errno", @@ -1322,7 +1532,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.39", ] [[package]] @@ -1351,6 +1561,31 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7cee0529a6d40f580e7a5e6c495c8fbfe21b7b52795ed4bb5e62cdf92bc6380" +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "smallvec" +version = "1.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" + +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.9.8" @@ -1382,6 +1617,17 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.39" @@ -1442,7 +1688,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.39", ] [[package]] @@ -1470,6 +1716,36 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokio" +version = "1.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.39", +] + [[package]] name = "toml" version = "0.7.8" @@ -1612,7 +1888,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.39", "wasm-bindgen-shared", ] @@ -1634,7 +1910,7 @@ checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.39", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1845,6 +2121,18 @@ dependencies = [ "memchr", ] +[[package]] +name = "wireguard-uapi" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ba4e9811befc20af3b6efb15924a7238ee5e8e8706a196576462a00b9f1af1" +dependencies = [ + "derive_builder", + "libc", + "neli", + "thiserror", +] + [[package]] name = "xattr" version = "1.0.1" @@ -1871,7 +2159,7 @@ checksum = "c2f140bda219a26ccc0cdb03dba58af72590c53b22642577d88a927bc5c87d6b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.39", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3ab316f..6945e6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,12 @@ members = [ "fuzz", "secret-memory", "lenses", + "wireguard-broker", ] default-members = [ - "rosenpass" + "rosenpass", + "wireguard-broker" ] [workspace.metadata.release] @@ -34,6 +36,7 @@ rosenpass-to = { path = "to" } rosenpass-secret-memory = { path = "secret-memory" } rosenpass-oqs = { path = "oqs" } rosenpass-lenses = { path = "lenses" } +rosenpass-wireguard-broker = { path = "wireguard-broker" } criterion = "0.4.0" test_bin = "0.4.0" libfuzzer-sys = "0.4" @@ -50,6 +53,10 @@ toml = "0.7.8" static_assertions = "1.1.0" allocator-api2 = "0.2.16" rand = "0.8.5" +wireguard-uapi = "3.0.0" +command-fds = "0.2.3" +rustix = { version = "0.38.27", features = ["net"] } +tokio = { version = "1.34.0", features = ["sync", "full", "mio"] } log = { version = "0.4.20" } clap = { version = "4.4.10", features = ["derive"] } serde = { version = "1.0.193", features = ["derive"] } diff --git a/rosenpass/Cargo.toml b/rosenpass/Cargo.toml index 2cad94d..39aa4bb 100644 --- a/rosenpass/Cargo.toml +++ b/rosenpass/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "rosenpass" +description = "Build post-quantum-secure VPNs with WireGuard!" version = "0.2.1" authors = ["Karolin Varner ", "wucke13 "] edition = "2021" license = "MIT OR Apache-2.0" -description = "Build post-quantum-secure VPNs with WireGuard!" homepage = "https://rosenpass.eu/" repository = "https://github.com/rosenpass/rosenpass" readme = "readme.md" @@ -22,6 +22,7 @@ rosenpass-cipher-traits = { workspace = true } rosenpass-to = { workspace = true } rosenpass-secret-memory = { workspace = true } rosenpass-lenses = { workspace = true } +rosenpass-wireguard-broker = { workspace = true } anyhow = { workspace = true } static_assertions = { workspace = true } memoffset = { workspace = true } @@ -35,6 +36,8 @@ toml = { workspace = true } clap = { workspace = true } mio = { workspace = true } rand = { workspace = true } +command-fds = { workspace = true } +rustix = { workspace = true } [build-dependencies] anyhow = { workspace = true } diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index d047f9b..44c377f 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -1,38 +1,26 @@ -use anyhow::bail; - -use anyhow::Result; -use log::{debug, error, info, warn}; -use mio::Interest; -use mio::Token; -use rosenpass_util::file::fopen_w; - -use std::cell::Cell; -use std::io::Write; - -use std::io::ErrorKind; -use std::net::Ipv4Addr; -use std::net::Ipv6Addr; -use std::net::SocketAddr; -use std::net::SocketAddrV4; -use std::net::SocketAddrV6; -use std::net::ToSocketAddrs; +use std::cell::{Cell, RefCell}; +use std::io::{ErrorKind, Write}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; +use std::os::unix::net::UnixStream; use std::path::PathBuf; -use std::process::Command; -use std::process::Stdio; use std::slice; -use std::thread; use std::time::Duration; -use crate::{ - config::Verbosity, - protocol::{CryptoServer, MsgBuf, PeerPtr, SPk, SSk, SymKey, Timing}, -}; -use rosenpass_util::attempt; +use anyhow::{bail, Result}; +use log::{error, info, warn}; +use mio::{Interest, Token}; + +use rosenpass_secret_memory::Public; use rosenpass_util::b64::{b64_writer, fmt_b64}; +use rosenpass_util::{attempt, file::fopen_w}; +use rosenpass_wireguard_broker::api::mio_client::MioBrokerClient as PskBroker; +use rosenpass_wireguard_broker::WireGuardBroker; + +use crate::config::Verbosity; +use crate::protocol::{CryptoServer, MsgBuf, PeerPtr, SPk, SSk, SymKey, Timing}; const IPV4_ANY_ADDR: Ipv4Addr = Ipv4Addr::new(0, 0, 0, 0); const IPV6_ANY_ADDR: Ipv6Addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0); - fn ipv4_any_binding() -> SocketAddr { // addr, port SocketAddr::V4(SocketAddrV4::new(IPV4_ANY_ADDR, 0)) @@ -43,6 +31,19 @@ fn ipv6_any_binding() -> SocketAddr { SocketAddr::V6(SocketAddrV6::new(IPV6_ANY_ADDR, 0, 0, 0)) } +#[derive(Default)] +struct MioTokenDispenser { + counter: usize, +} + +impl MioTokenDispenser { + fn get_token(&mut self) -> Token { + let r = self.counter; + self.counter += 1; + Token(r) + } +} + #[derive(Default, Debug)] pub struct AppPeer { pub outfile: Option, @@ -59,14 +60,24 @@ impl AppPeer { } } -#[derive(Default, Debug)] +#[derive(Debug)] pub struct WireguardOut { // impl KeyOutput pub dev: String, - pub pk: String, + pub pk: Public<32>, pub extra_params: Vec, } +impl Default for WireguardOut { + fn default() -> Self { + Self { + dev: Default::default(), + pk: Public::zero(), + extra_params: Default::default(), + } + } +} + /// Holds the state of the application, namely the external IO /// /// Responsible for file IO, network IO @@ -77,6 +88,7 @@ pub struct AppServer { pub sockets: Vec, pub events: mio::Events, pub mio_poll: mio::Poll, + pub psk_broker: RefCell, pub peers: Vec, pub verbosity: Verbosity, pub all_sockets_drained: bool, @@ -341,11 +353,24 @@ impl AppServer { sk: SSk, pk: SPk, addrs: Vec, + psk_broker_socket: UnixStream, verbosity: Verbosity, ) -> anyhow::Result { // setup mio let mio_poll = mio::Poll::new()?; let events = mio::Events::with_capacity(8); + let mut dispenser = MioTokenDispenser::default(); + + // Create the Wireguard broker connection + let psk_broker = { + let mut sock = mio::net::UnixStream::from_std(psk_broker_socket); + mio_poll.registry().register( + &mut sock, + dispenser.get_token(), + Interest::READABLE | Interest::WRITABLE, + )?; + PskBroker::new(sock) + }; // bind each SocketAddr to a socket let maybe_sockets: Result, _> = @@ -430,6 +455,7 @@ impl AppServer { Ok(Self { crypt: CryptoServer::new(sk, pk), peers: Vec::new(), + psk_broker: RefCell::new(psk_broker), verbosity, sockets, events, @@ -624,31 +650,9 @@ impl AppServer { } if let Some(owg) = ap.outwg.as_ref() { - let mut child = Command::new("wg") - .arg("set") - .arg(&owg.dev) - .arg("peer") - .arg(&owg.pk) - .arg("preshared-key") - .arg("/dev/stdin") - .stdin(Stdio::piped()) - .args(&owg.extra_params) - .spawn()?; - b64_writer(child.stdin.take().unwrap()).write_all(key.secret())?; - - thread::spawn(move || { - let status = child.wait(); - - if let Ok(status) = status { - if status.success() { - debug!("successfully passed psk to wg") - } else { - error!("could not pass psk to wg {:?}", status) - } - } else { - error!("wait failed: {:?}", status) - } - }); + self.psk_broker + .borrow_mut() + .set_psk(&owg.dev, owg.pk.value, *key.secret())?; } Ok(()) @@ -706,9 +710,16 @@ impl AppServer { // only poll if we drained all sockets before if self.all_sockets_drained { - self.mio_poll.poll(&mut self.events, Some(timeout))?; + self.mio_poll + .poll(&mut self.events, Some(timeout)) + .or_else(|e| match e.kind() { + ErrorKind::Interrupted | ErrorKind::WouldBlock => Ok(()), + _ => Err(e), + })?; } + self.psk_broker.get_mut().poll()?; + let mut would_block_count = 0; for (sock_no, socket) in self.sockets.iter_mut().enumerate() { match socket.recv_from(buf) { diff --git a/rosenpass/src/cli.rs b/rosenpass/src/cli.rs index dc6b596..75b8212 100644 --- a/rosenpass/src/cli.rs +++ b/rosenpass/src/cli.rs @@ -1,10 +1,22 @@ -use anyhow::{bail, ensure}; +use std::io::{BufReader, Read}; +use std::os::unix::net::UnixStream; +use std::path::PathBuf; +use std::process::Command; +use std::thread; + +use anyhow::{bail, ensure, Context}; use clap::Parser; +use command_fds::{CommandFdExt, FdMapping}; +use log::{error, info}; +use rustix::fd::AsRawFd; +use rustix::net::{socketpair, AddressFamily, SocketFlags, SocketType}; + use rosenpass_cipher_traits::Kem; use rosenpass_ciphers::kem::StaticKem; use rosenpass_secret_memory::file::StoreSecret; +use rosenpass_secret_memory::Public; +use rosenpass_util::b64::b64_reader; use rosenpass_util::file::{LoadValue, LoadValueB64}; -use std::path::PathBuf; use crate::app_server; use crate::app_server::AppServer; @@ -62,6 +74,7 @@ pub enum Cli { config_file: PathBuf, /// Forcefully overwrite existing config file + /// - [ ] Janepie #[clap(short, long)] force: bool, }, @@ -220,11 +233,53 @@ impl Cli { let sk = SSk::load(&config.secret_key)?; let pk = SPk::load(&config.public_key)?; + // Spawn the psk broker and use socketpair(2) to connect with them + let psk_broker_socket = { + let (ours, theirs) = socketpair( + AddressFamily::UNIX, + SocketType::STREAM, + SocketFlags::empty(), + None, + )?; + + // Setup our end of the socketpair + let ours = UnixStream::from(ours); + ours.set_nonblocking(true)?; + + // Start the PSK broker + let mut child = Command::new("rosenpass-wireguard-broker-socket-handler") + .args(["--stream-fd", "3"]) + .fd_mappings(vec![FdMapping { + parent_fd: theirs.as_raw_fd(), + child_fd: 3, + }])? + .spawn()?; + + // Handle the PSK broker crashing + thread::spawn(move || { + let status = child.wait(); + + if let Ok(status) = status { + if status.success() { + // Maybe they are doing double forking? + info!("PSK broker exited."); + } else { + error!("PSK broker exited with an error ({status:?})"); + } + } else { + error!("Wait on PSK broker process failed ({status:?})"); + } + }); + + ours + }; + // start an application server let mut srv = std::boxed::Box::::new(AppServer::new( sk, pk, config.listen, + psk_broker_socket, config.verbosity, )?); @@ -234,11 +289,24 @@ impl Cli { cfg_peer.pre_shared_key.map(SymKey::load_b64).transpose()?, SPk::load(&cfg_peer.public_key)?, cfg_peer.key_out, - cfg_peer.wg.map(|cfg| app_server::WireguardOut { - dev: cfg.device, - pk: cfg.peer, - extra_params: cfg.extra_params, - }), + cfg_peer + .wg + .map(|cfg| -> anyhow::Result<_> { + let b64pk = &cfg.peer; + let mut pk = Public::zero(); + b64_reader(BufReader::new(b64pk.as_bytes())) + .read_exact(&mut pk.value) + .with_context(|| { + format!("Could not decode base64 public key: '{b64pk}'") + })?; + + Ok(app_server::WireguardOut { + pk, + dev: cfg.device, + extra_params: cfg.extra_params, + }) + }) + .transpose()?, cfg_peer.endpoint.clone(), )?; } diff --git a/util/Cargo.toml b/util/Cargo.toml index d1df0f1..b4dd70b 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -14,3 +14,4 @@ readme = "readme.md" [dependencies] base64 = { workspace = true } anyhow = { workspace = true } +rustix = { workspace = true } diff --git a/util/src/fd.rs b/util/src/fd.rs new file mode 100644 index 0000000..cf5eaf9 --- /dev/null +++ b/util/src/fd.rs @@ -0,0 +1,12 @@ +use std::os::fd::{OwnedFd, RawFd}; + +/// Clone some file descriptor +/// +/// If the file descriptor is invalid, an error will be raised. +pub fn claim_fd(fd: RawFd) -> anyhow::Result { + use rustix::{fd::BorrowedFd, io::dup}; + + // This is safe since [dup] will simply raise + let fd = unsafe { dup(BorrowedFd::borrow_raw(fd))? }; + Ok(fd) +} diff --git a/util/src/lib.rs b/util/src/lib.rs index 49ded9c..121cbf4 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -1,4 +1,5 @@ pub mod b64; +pub mod fd; pub mod file; pub mod functional; pub mod mem; diff --git a/wireguard-broker/Cargo.toml b/wireguard-broker/Cargo.toml new file mode 100644 index 0000000..a91046a --- /dev/null +++ b/wireguard-broker/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "rosenpass-wireguard-broker" +authors = ["Karolin Varner ", "wucke13 "] +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Rosenpass internal broker that runs as root and supplies exchanged keys to the kernel." +homepage = "https://rosenpass.eu/" +repository = "https://github.com/rosenpass/rosenpass" +readme = "readme.md" + +[dependencies] +thiserror = { workspace = true } +rosenpass-lenses = { workspace = true } +paste = { workspace = true } # TODO: Using lenses should not necessitate importing paste + +# Privileged only +wireguard-uapi = { workspace = true } + +# Socket handler only +rosenpass-to = { workspace = true } +tokio = { workspace = true } +anyhow = { workspace = true } +clap = { workspace = true } +env_logger = { workspace = true } +log = { workspace = true } + +# Mio broker client +mio = { workspace = true } +rosenpass-util = { workspace = true } + +[[bin]] +name = "rosenpass-wireguard-broker-privileged" +path = "src/bin/priviledged.rs" +test = false +doc = false + +[[bin]] +name = "rosenpass-wireguard-broker-socket-handler" +test = false +path = "src/bin/socket_handler.rs" +doc = false diff --git a/wireguard-broker/readme.md b/wireguard-broker/readme.md new file mode 100644 index 0000000..9da4b8c --- /dev/null +++ b/wireguard-broker/readme.md @@ -0,0 +1,5 @@ +# Rosenpass internal broker supplying WireGuard with keys. + +This crate contains a small application purpose-built to supply WireGuard in the linux kernel with pre-shared keys. + +This is an internal library; not guarantee is made about its API at this point in time. diff --git a/wireguard-broker/src/api/client.rs b/wireguard-broker/src/api/client.rs new file mode 100644 index 0000000..a5898dc --- /dev/null +++ b/wireguard-broker/src/api/client.rs @@ -0,0 +1,152 @@ +use std::{borrow::BorrowMut, marker::PhantomData}; + +use rosenpass_lenses::LenseView; + +use crate::{ + api::msgs::{self, EnvelopeExt, SetPskRequestExt, SetPskResponseExt}, + WireGuardBroker, +}; + +#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)] +pub enum BrokerClientPollResponseError { + #[error(transparent)] + IoError(RecvError), + #[error("Invalid message.")] + InvalidMessage, +} + +impl From for BrokerClientPollResponseError { + fn from(value: msgs::InvalidMessageTypeError) -> Self { + let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type + BrokerClientPollResponseError::::InvalidMessage + } +} + +fn io_pollerr(e: RecvError) -> BrokerClientPollResponseError { + BrokerClientPollResponseError::::IoError(e) +} + +fn invalid_msg_pollerr() -> BrokerClientPollResponseError { + BrokerClientPollResponseError::::InvalidMessage +} + +#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)] +pub enum BrokerClientSetPskError { + #[error(transparent)] + IoError(SendError), + #[error("Interface name out of bounds")] + IfaceOutOfBounds, +} + +pub trait BrokerClientIo { + type SendError; + type RecvError; + + fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError>; + fn recv_msg(&mut self) -> Result, Self::RecvError>; +} + +#[derive(Debug)] +pub struct BrokerClient<'a, Io, IoRef> +where + Io: BrokerClientIo, + IoRef: 'a + BorrowMut, +{ + io: IoRef, + _phantom_io: PhantomData<&'a mut Io>, +} + +impl<'a, Io, IoRef> BrokerClient<'a, Io, IoRef> +where + Io: BrokerClientIo, + IoRef: 'a + BorrowMut, +{ + pub fn new(io: IoRef) -> Self { + Self { + io, + _phantom_io: PhantomData, + } + } + + pub fn io(&self) -> &IoRef { + &self.io + } + + pub fn io_mut(&mut self) -> &mut IoRef { + &mut self.io + } + + pub fn poll_response( + &mut self, + ) -> Result, BrokerClientPollResponseError> { + let res: &[u8] = match self.io.borrow_mut().recv_msg().map_err(io_pollerr)? { + Some(r) => r, + None => return Ok(None), + }; + + let typ = res.get(0).ok_or(invalid_msg_pollerr())?; + let typ = msgs::MsgType::try_from(*typ)?; + let msgs::MsgType::SetPsk = typ; // Assert type + + let res: msgs::Envelope<_, msgs::SetPskResponse<&[u8]>> = res + .envelope_truncating() + .map_err(|_| invalid_msg_pollerr())?; + let res: msgs::SetPskResponse<&[u8]> = res + .payload() + .set_psk_response() + .map_err(|_| invalid_msg_pollerr())?; + let res: msgs::SetPskResponseReturnCode = res.return_code()[0] + .try_into() + .map_err(|_| invalid_msg_pollerr())?; + let res: msgs::SetPskResult = res.into(); + + Ok(Some(res)) + } +} + +impl<'a, Io, IoRef> WireGuardBroker for BrokerClient<'a, Io, IoRef> +where + Io: BrokerClientIo, + IoRef: 'a + BorrowMut, +{ + type Error = BrokerClientSetPskError; + + fn set_psk( + &mut self, + iface: &str, + peer_id: [u8; 32], + psk: [u8; 32], + ) -> Result<(), Self::Error> { + use BrokerClientSetPskError::*; + const BUF_SIZE: usize = > as LenseView>::LEN; + + // Allocate message + let mut req = [0u8; BUF_SIZE]; + + // Construct message view + let mut req: msgs::Envelope<_, msgs::SetPskRequest<&mut [u8]>> = + (&mut req as &mut [u8]).envelope_truncating().unwrap(); + + // Populate envelope + req.msg_type_mut() + .copy_from_slice(&[msgs::MsgType::SetPsk as u8]); + { + // Derived payload + let mut req: msgs::SetPskRequest<&mut [u8]> = + req.payload_mut().set_psk_request().unwrap(); + + // Populate payload + req.peer_id_mut().copy_from_slice(&peer_id); + req.psk_mut().copy_from_slice(&psk); + req.set_iface(iface).ok_or(IfaceOutOfBounds)?; + } + + // Send message + self.io + .borrow_mut() + .send_msg(req.all_bytes()) + .map_err(IoError)?; + + Ok(()) + } +} diff --git a/wireguard-broker/src/api/mio_client.rs b/wireguard-broker/src/api/mio_client.rs new file mode 100644 index 0000000..1e1be6c --- /dev/null +++ b/wireguard-broker/src/api/mio_client.rs @@ -0,0 +1,204 @@ +use std::collections::VecDeque; +use std::io::{ErrorKind, Read, Write}; + +use anyhow::{bail, ensure}; + +use crate::WireGuardBroker; + +use super::client::{ + BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError, +}; +use super::msgs; + +#[derive(Debug)] +pub struct MioBrokerClient { + inner: BrokerClient<'static, MioBrokerClientIo, MioBrokerClientIo>, +} + +#[derive(Debug)] +struct MioBrokerClientIo { + socket: mio::net::UnixStream, + send_buf: VecDeque, + receiving_size: bool, + recv_buf: Vec, + recv_off: usize, +} + +impl MioBrokerClient { + pub fn new(socket: mio::net::UnixStream) -> Self { + let io = MioBrokerClientIo { + socket, + send_buf: VecDeque::new(), + receiving_size: false, + recv_buf: Vec::new(), + recv_off: 0, + }; + let inner = BrokerClient::new(io); + Self { inner } + } + + pub fn poll(&mut self) -> anyhow::Result> { + self.inner.io_mut().flush()?; + + // This sucks + match self.inner.poll_response() { + Ok(res) => { + return Ok(res); + } + Err(BrokerClientPollResponseError::IoError(e)) => { + return Err(e); + } + Err(BrokerClientPollResponseError::InvalidMessage) => { + bail!("Invalid message"); + } + }; + } +} + +impl WireGuardBroker for MioBrokerClient { + type Error = anyhow::Error; + + fn set_psk(&mut self, iface: &str, peer_id: [u8; 32], psk: [u8; 32]) -> anyhow::Result<()> { + use BrokerClientSetPskError::*; + let e = self.inner.set_psk(iface, peer_id, psk); + match e { + Ok(()) => Ok(()), + Err(IoError(e)) => Err(e), + Err(IfaceOutOfBounds) => bail!("Interface name size is out of bounds."), + } + } +} + +impl BrokerClientIo for MioBrokerClientIo { + type SendError = anyhow::Error; + type RecvError = anyhow::Error; + + fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> { + self.flush()?; + self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?; + self.send_or_buffer(&buf)?; + self.flush()?; + + Ok(()) + } + + fn recv_msg(&mut self) -> Result, Self::RecvError> { + // Stale message in receive buffer. Reset! + if self.recv_off == self.recv_buf.len() { + self.receiving_size = true; + self.recv_off = 0; + self.recv_buf.resize(8, 0); + } + + // Try filling the receive buffer + self.recv_off += raw_recv(&self.socket, &mut self.recv_buf[self.recv_off..])?; + if self.recv_off < self.recv_buf.len() { + return Ok(None); + } + + // Received size, now start receiving + if self.receiving_size { + // Received the size + // Parse the received length + let len: &[u8; 8] = self.recv_buf[..].try_into().unwrap(); + let len: usize = u64::from_le_bytes(*len) as usize; + + ensure!( + len <= msgs::RESPONSE_MSG_BUFFER_SIZE, + "Oversized buffer ({len}) in psk buffer response." + ); + + // Prepare the message buffer for receiving an actual message of the given size + self.receiving_size = false; + self.recv_off = 0; + self.recv_buf.resize(len, 0); + + // Try to receive the message + return self.recv_msg(); + } + + // Received an actual message + return Ok(Some(&self.recv_buf[..])); + } +} + +impl MioBrokerClientIo { + fn flush(&mut self) -> anyhow::Result<()> { + let (fst, snd) = self.send_buf.as_slices(); + + let (written, res) = match raw_send(&self.socket, fst) { + Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) { + Ok(w2) => (w1 + w2, Ok(())), + Err(e) => (w1, Err(e)), + }, + Ok(w1) => (w1, Ok(())), + Err(e) => (0, Err(e)), + }; + + self.send_buf.drain(..written); + + (&self.socket).try_io(|| (&self.socket).flush())?; + + res + } + + fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> { + let mut off = 0; + + if self.send_buf.is_empty() { + off += raw_send(&self.socket, buf)?; + } + + self.send_buf.extend((&buf[off..]).iter()); + + Ok(()) + } +} + +fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result { + let mut off = 0; + + socket.try_io(|| { + loop { + if off == data.len() { + return Ok(()); + } + match socket.write(&data[off..]) { + Ok(n) => { + off += n; + } + Err(e) if e.kind() == ErrorKind::Interrupted => { + // pass – retry + } + Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()), + Err(e) => return Err(e), + } + } + })?; + + return Ok(off); +} + +fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result { + let mut off = 0; + + socket.try_io(|| { + loop { + if off == out.len() { + return Ok(()); + } + match socket.read(&mut out[off..]) { + Ok(n) => { + off += n; + } + Err(e) if e.kind() == ErrorKind::Interrupted => { + // pass – retry + } + Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()), + Err(e) => return Err(e), + } + } + })?; + + return Ok(off); +} diff --git a/wireguard-broker/src/api/mod.rs b/wireguard-broker/src/api/mod.rs new file mode 100644 index 0000000..386eb7d --- /dev/null +++ b/wireguard-broker/src/api/mod.rs @@ -0,0 +1,4 @@ +pub mod client; +pub mod mio_client; +pub mod msgs; +pub mod server; diff --git a/wireguard-broker/src/api/msgs.rs b/wireguard-broker/src/api/msgs.rs new file mode 100644 index 0000000..efdb97f --- /dev/null +++ b/wireguard-broker/src/api/msgs.rs @@ -0,0 +1,140 @@ +use std::result::Result; +use std::str::{from_utf8, Utf8Error}; + +use rosenpass_lenses::{lense, LenseView}; + +pub const REQUEST_MSG_BUFFER_SIZE: usize = > as LenseView>::LEN; +pub const RESPONSE_MSG_BUFFER_SIZE: usize = > as LenseView>::LEN; + +lense! { Envelope := + /// [MsgType] of this message + msg_type: 1, + /// Reserved for future use + reserved: 3, + /// The actual Paylod + payload: M::LEN +} + +lense! { SetPskRequest := + peer_id: 32, + psk: 32, + iface_size: 1, // TODO: We should have variable length strings in lenses + iface_buf: 255 +} + +impl SetPskRequest<&[u8]> { + pub fn iface_bin(&self) -> &[u8] { + let len = self.iface_size()[0] as usize; + &self.iface_buf()[..len] + } + + pub fn iface(&self) -> Result<&str, Utf8Error> { + from_utf8(self.iface_bin()) + } +} + +impl SetPskRequest<&mut [u8]> { + pub fn set_iface_bin(&mut self, iface: &[u8]) -> Option<()> { + (iface.len() < 256).then_some(())?; // Assert iface.len() < 256 + + self.iface_size_mut()[0] = iface.len() as u8; + + self.iface_buf_mut().fill(0); + (&mut self.iface_buf_mut()[..iface.len()]).copy_from_slice(iface); + + Some(()) + } + + pub fn set_iface(&mut self, iface: &str) -> Option<()> { + self.set_iface_bin(iface.as_bytes()) + } +} + +lense! { SetPskResponse := + return_code: 1 +} + +#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)] +pub enum SetPskError { + #[error("The wireguard pre-shared-key assignment broker experienced an internal error.")] + InternalError, + #[error("The indicated wireguard interface does not exist")] + NoSuchInterface, + #[error("The indicated peer does not exist on the wireguard interface")] + NoSuchPeer, +} + +pub type SetPskResult = Result<(), SetPskError>; + +#[repr(u8)] +#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)] +pub enum SetPskResponseReturnCode { + Success = 0x00, + InternalError = 0x01, + NoSuchInterface = 0x02, + NoSuchPeer = 0x03, +} + +#[derive(Eq, PartialEq, Debug, Clone)] +pub struct InvalidSetPskResponseError; + +impl TryFrom for SetPskResponseReturnCode { + type Error = InvalidSetPskResponseError; + + fn try_from(value: u8) -> Result { + use SetPskResponseReturnCode::*; + match value { + 0x00 => Ok(Success), + 0x01 => Ok(InternalError), + 0x02 => Ok(NoSuchInterface), + 0x03 => Ok(NoSuchPeer), + _ => Err(InvalidSetPskResponseError), + } + } +} + +impl From for SetPskResult { + fn from(value: SetPskResponseReturnCode) -> Self { + use SetPskError as E; + use SetPskResponseReturnCode as C; + match value { + C::Success => Ok(()), + C::InternalError => Err(E::InternalError), + C::NoSuchInterface => Err(E::NoSuchInterface), + C::NoSuchPeer => Err(E::NoSuchPeer), + } + } +} + +impl From for SetPskResponseReturnCode { + fn from(value: SetPskResult) -> Self { + use SetPskError as E; + use SetPskResponseReturnCode as C; + match value { + Ok(()) => C::Success, + Err(E::InternalError) => C::InternalError, + Err(E::NoSuchInterface) => C::NoSuchInterface, + Err(E::NoSuchPeer) => C::NoSuchPeer, + } + } +} + +#[repr(u8)] +#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)] +pub enum MsgType { + SetPsk = 0x01, +} + +#[derive(Eq, PartialEq, Debug, Clone)] +pub struct InvalidMessageTypeError; + +impl TryFrom for MsgType { + type Error = InvalidMessageTypeError; + + fn try_from(value: u8) -> Result { + match value { + 0x01 => Ok(MsgType::SetPsk), + _ => Err(InvalidMessageTypeError), + } + } +} diff --git a/wireguard-broker/src/api/server.rs b/wireguard-broker/src/api/server.rs new file mode 100644 index 0000000..725b03f --- /dev/null +++ b/wireguard-broker/src/api/server.rs @@ -0,0 +1,99 @@ +use std::borrow::BorrowMut; +use std::marker::PhantomData; +use std::result::Result; + +use rosenpass_lenses::LenseError; + +use crate::api::msgs::{self, EnvelopeExt, SetPskRequestExt, SetPskResponseExt}; +use crate::WireGuardBroker; + +#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)] +pub enum BrokerServerError { + #[error("No such request type: {}", .0)] + NoSuchRequestType(u8), + #[error("Invalid message received.")] + InvalidMessage, +} + +impl From for BrokerServerError { + fn from(value: LenseError) -> Self { + use BrokerServerError as Be; + use LenseError as Le; + match value { + Le::BufferSizeMismatch => Be::InvalidMessage, + } + } +} + +impl From for BrokerServerError { + fn from(value: msgs::InvalidMessageTypeError) -> Self { + let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type + BrokerServerError::InvalidMessage + } +} + +pub struct BrokerServer<'a, Err, Inner, Ref> +where + msgs::SetPskError: From, + Inner: WireGuardBroker, + Ref: BorrowMut + 'a, +{ + inner: Ref, + _phantom: PhantomData<&'a mut Inner>, +} + +impl<'a, Err, Inner, Ref> BrokerServer<'a, Err, Inner, Ref> +where + msgs::SetPskError: From, + Inner: WireGuardBroker, + Ref: 'a + BorrowMut, +{ + pub fn new(inner: Ref) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } + + pub fn handle_message( + &mut self, + req: &[u8], + res: &mut [u8; msgs::RESPONSE_MSG_BUFFER_SIZE], + ) -> Result { + use BrokerServerError::*; + + let typ = req.get(0).ok_or(InvalidMessage)?; + let typ = msgs::MsgType::try_from(*typ)?; + let msgs::MsgType::SetPsk = typ; // Assert type + + let req: msgs::Envelope<_, msgs::SetPskRequest<&[u8]>> = req.envelope_truncating()?; + let mut res: msgs::Envelope<_, msgs::SetPskResponse<&mut [u8]>> = + (res as &mut [u8]).envelope_truncating()?; + (&mut res).msg_type_mut()[0] = msgs::MsgType::SetPsk as u8; + self.handle_set_psk( + req.payload().set_psk_request()?, + res.payload_mut().set_psk_response()?, + )?; + Ok(res.all_bytes().len()) + } + + fn handle_set_psk( + &mut self, + req: msgs::SetPskRequest<&[u8]>, + mut res: msgs::SetPskResponse<&mut [u8]>, + ) -> Result<(), BrokerServerError> { + // Using unwrap here since lenses can not return fixed-size arrays + // TODO: Slices should give access to fixed size arrays + let r: Result<(), Err> = self.inner.borrow_mut().set_psk( + req.iface() + .map_err(|_e| BrokerServerError::InvalidMessage)?, + req.peer_id().try_into().unwrap(), + req.psk().try_into().unwrap(), + ); + let r: msgs::SetPskResult = r.map_err(|e| e.into()); + let r: msgs::SetPskResponseReturnCode = r.into(); + res.return_code_mut()[0] = r as u8; + + Ok(()) + } +} diff --git a/wireguard-broker/src/bin/priviledged.rs b/wireguard-broker/src/bin/priviledged.rs new file mode 100644 index 0000000..e1fff72 --- /dev/null +++ b/wireguard-broker/src/bin/priviledged.rs @@ -0,0 +1,56 @@ +use std::io::{stdin, stdout, Read, Write}; +use std::result::Result; + +use rosenpass_wireguard_broker::api::msgs; +use rosenpass_wireguard_broker::api::server::BrokerServer; +use rosenpass_wireguard_broker::netlink as wg; + +#[derive(thiserror::Error, Debug)] +pub enum BrokerAppError { + #[error(transparent)] + IoError(#[from] std::io::Error), + #[error(transparent)] + WgConnectError(#[from] wg::ConnectError), + #[error(transparent)] + WgSetPskError(#[from] wg::SetPskError), + #[error("Oversized message {}; something about the request is fatally wrong", .0)] + OversizedMessage(u64), +} + +fn main() -> Result<(), BrokerAppError> { + let mut broker = BrokerServer::new(wg::NetlinkWireGuardBroker::new()?); + + let mut stdin = stdin().lock(); + let mut stdout = stdout().lock(); + loop { + // Read the message length + let mut len = [0u8; 8]; + stdin.read_exact(&mut len)?; + + // Parse the message length + let len = u64::from_le_bytes(len); + if (len as usize) > msgs::REQUEST_MSG_BUFFER_SIZE { + return Err(BrokerAppError::OversizedMessage(len)); + } + + // Read the message itself + let mut req_buf = [0u8; msgs::REQUEST_MSG_BUFFER_SIZE]; + let req_buf = &mut req_buf[..(len as usize)]; + stdin.read_exact(req_buf)?; + + // Process the message + let mut res_buf = [0u8; msgs::RESPONSE_MSG_BUFFER_SIZE]; + let res = match broker.handle_message(req_buf, &mut res_buf) { + Ok(len) => &res_buf[..len], + Err(e) => { + eprintln!("Error processing message for wireguard PSK broker: {e:?}"); + continue; + } + }; + + // Write the response + stdout.write_all(&(res.len() as u64).to_le_bytes())?; + stdout.write_all(&res)?; + stdout.flush()?; + } +} diff --git a/wireguard-broker/src/bin/socket_handler.rs b/wireguard-broker/src/bin/socket_handler.rs new file mode 100644 index 0000000..e1693c5 --- /dev/null +++ b/wireguard-broker/src/bin/socket_handler.rs @@ -0,0 +1,191 @@ +use std::process::Stdio; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{UnixListener, UnixStream}; +use tokio::process::Command; +use tokio::sync::{mpsc, oneshot}; +use tokio::task; + +use anyhow::{bail, ensure, Result}; +use clap::{ArgGroup, Parser}; + +use rosenpass_util::fd::claim_fd; +use rosenpass_wireguard_broker::api::msgs; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +#[clap(group( + ArgGroup::new("socket") + .required(true) + .args(&["listen_path", "listen_fd", "stream_fd"]), + ))] +struct Args { + /// Where in the file-system to create the unix socket this broker will be listening for + /// connections on + #[arg(long)] + listen_path: Option, + + /// When this broker is called from another process, the other process can open and bind the + /// unix socket to use themselves, passing it to this process. In Rust this can be achieved + /// using the [command-fds](https://docs.rs/command-fds/latest/command_fds/) crate. + #[arg(long)] + listen_fd: Option, + + /// When this broker is called from another process, the other process can connect the unix socket + /// themselves, for instance using the `socketpair(2)` system call. + #[arg(long)] + stream_fd: Option, + + /// The underlying broker, accepting commands through stdin and sending results through stdout. + #[arg( + last = true, + allow_hyphen_values = true, + default_value = "rosenpass-wireguard-broker-privileged" + )] + command: Vec, +} + +struct BrokerRequest { + reply_to: oneshot::Sender, + request: Vec, +} + +struct BrokerResponse { + response: Vec, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + + let args = Args::parse(); + + let (proc_tx, proc_rx) = mpsc::channel(100); + + // Start the inner broker handler + task::spawn(async move { + if let Err(e) = direct_broker_process(proc_rx, args.command).await { + log::error!("Error in broker command handler: {e}"); + panic!("Can not proceed without underlying broker process"); + } + }); + + // Listen for incoming requests + if let Some(path) = args.listen_path { + let sock = UnixListener::bind(path)?; + listen_for_clients(proc_tx, sock).await + } else if let Some(fd) = args.listen_fd { + let sock = std::os::unix::net::UnixListener::from(claim_fd(fd)?); + sock.set_nonblocking(true)?; + listen_for_clients(proc_tx, UnixListener::from_std(sock)?).await + } else if let Some(fd) = args.stream_fd { + let stream = std::os::unix::net::UnixStream::from(claim_fd(fd)?); + stream.set_nonblocking(true)?; + on_accept(proc_tx, UnixStream::from_std(stream)?).await + } else { + unreachable!(); + } +} + +async fn direct_broker_process( + mut queue: mpsc::Receiver, + cmd: Vec, +) -> Result<()> { + let proc = Command::new(&cmd[0]) + .args(&cmd[1..]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn()?; + + let mut stdin = proc.stdin.unwrap(); + let mut stdout = proc.stdout.unwrap(); + + loop { + let BrokerRequest { reply_to, request } = queue.recv().await.unwrap(); + + stdin + .write_all(&(request.len() as u64).to_le_bytes()) + .await?; + stdin.write_all(&request[..]).await?; + + // Read the response length + let mut len = [0u8; 8]; + stdout.read_exact(&mut len).await?; + + // Parse the response length + let len = u64::from_le_bytes(len) as usize; + ensure!( + len <= msgs::RESPONSE_MSG_BUFFER_SIZE, + "Oversized buffer ({len}) in broker stdout." + ); + + // Read the message itself + let mut res_buf = request; // Avoid allocating memory if we don't have to + res_buf.resize(len as usize, 0); + stdout.read_exact(&mut res_buf[..len]).await?; + + // Return to the unix socket connection worker + reply_to + .send(BrokerResponse { response: res_buf }) + .or_else(|_| bail!("Unable to send respnse to unix socket worker."))?; + } +} + +async fn listen_for_clients(queue: mpsc::Sender, sock: UnixListener) -> Result<()> { + loop { + let (stream, _addr) = sock.accept().await?; + let queue = queue.clone(); + task::spawn(async move { + if let Err(e) = on_accept(queue, stream).await { + log::error!("Error during connection processing: {e}"); + } + }); + } + + // NOTE: If loop can ever terminate we need to join the spawned tasks +} + +async fn on_accept(queue: mpsc::Sender, mut stream: UnixStream) -> Result<()> { + let mut req_buf = Vec::new(); + + loop { + stream.readable().await?; + + // Read the message length + let mut len = [0u8; 8]; + stream.read_exact(&mut len).await?; + + // Parse the message length + let len = u64::from_le_bytes(len) as usize; + ensure!( + len <= msgs::REQUEST_MSG_BUFFER_SIZE, + "Oversized buffer ({len}) in unix socket input." + ); + + // Read the message itself + req_buf.resize(len as usize, 0); + stream.read_exact(&mut req_buf[..len]).await?; + + // Handle the message + let (reply_tx, reply_rx) = oneshot::channel(); + queue + .send(BrokerRequest { + reply_to: reply_tx, + request: req_buf, + }) + .await?; + + // Wait for the reply + let BrokerResponse { response } = reply_rx.await.unwrap(); + + // Write reply back to unix socket + stream + .write_all(&(response.len() as u64).to_le_bytes()) + .await?; + stream.write_all(&response[..]).await?; + stream.flush().await?; + + // Reuse the same memory for the next message + req_buf = response; + } +} diff --git a/wireguard-broker/src/lib.rs b/wireguard-broker/src/lib.rs new file mode 100644 index 0000000..6c13dfa --- /dev/null +++ b/wireguard-broker/src/lib.rs @@ -0,0 +1,15 @@ +use std::result::Result; + +pub trait WireGuardBroker { + type Error; + + fn set_psk( + &mut self, + interface: &str, + peer_id: [u8; 32], + psk: [u8; 32], + ) -> Result<(), Self::Error>; +} + +pub mod api; +pub mod netlink; diff --git a/wireguard-broker/src/netlink.rs b/wireguard-broker/src/netlink.rs new file mode 100644 index 0000000..c8ebdab --- /dev/null +++ b/wireguard-broker/src/netlink.rs @@ -0,0 +1,103 @@ +use wireguard_uapi::linux as wg; + +use crate::api::msgs; +use crate::WireGuardBroker; + +#[derive(thiserror::Error, Debug)] +pub enum ConnectError { + #[error(transparent)] + ConnectError(#[from] wg::err::ConnectError), +} + +#[derive(thiserror::Error, Debug)] +pub enum NetlinkError { + #[error(transparent)] + SetDevice(#[from] wg::err::SetDeviceError), + #[error(transparent)] + GetDevice(#[from] wg::err::GetDeviceError), +} + +#[derive(thiserror::Error, Debug)] +pub enum SetPskError { + #[error("The indicated wireguard interface does not exist")] + NoSuchInterface, + #[error("The indicated peer does not exist on the wireguard interface")] + NoSuchPeer, + #[error(transparent)] + NetlinkError(#[from] NetlinkError), +} + +impl From for SetPskError { + fn from(err: wg::err::SetDeviceError) -> Self { + NetlinkError::from(err).into() + } +} + +impl From for SetPskError { + fn from(err: wg::err::GetDeviceError) -> Self { + NetlinkError::from(err).into() + } +} + +use msgs::SetPskError as SetPskMsgsError; +use SetPskError as SetPskNetlinkError; +impl From for SetPskMsgsError { + fn from(err: SetPskError) -> Self { + match err { + SetPskNetlinkError::NoSuchPeer => SetPskMsgsError::NoSuchPeer, + _ => SetPskMsgsError::InternalError, + } + } +} + +pub struct NetlinkWireGuardBroker { + sock: wg::WgSocket, +} + +impl NetlinkWireGuardBroker { + pub fn new() -> Result { + let sock = wg::WgSocket::connect()?; + Ok(Self { sock }) + } +} + +impl WireGuardBroker for NetlinkWireGuardBroker { + type Error = SetPskError; + + fn set_psk( + &mut self, + interface: &str, + peer_id: [u8; 32], + psk: [u8; 32], + ) -> Result<(), Self::Error> { + // Ensure that the peer exists by querying the device configuration + // TODO: Use InvalidInterfaceError + let state = self + .sock + .get_device(wg::DeviceInterface::from_name(interface.to_owned()))?; + + if state + .peers + .iter() + .find(|p| &p.public_key == &peer_id) + .is_none() + { + return Err(SetPskError::NoSuchPeer); + } + + // Peer update description + let mut set_peer = wireguard_uapi::set::Peer::from_public_key(&peer_id); + set_peer + .flags + .push(wireguard_uapi::linux::set::WgPeerF::UpdateOnly); + set_peer.preshared_key = Some(&psk); + + // Device update description + let mut set_dev = wireguard_uapi::set::Device::from_ifname(interface.to_owned()); + set_dev.peers.push(set_peer); + + self.sock.set_device(set_dev)?; + + Ok(()) + } +}