feat(wireguard-broker): merge from dev/broker-architecture, fixes, test

* wireguard-broker: merge from dev/broker-architecture
* use zerocopy instead of lenses
* Require use_broker feature flag to comile broker binaries
* Remove PhantomData from BrokerServer & BrokerClient
* Modify mio client rx to be non-recursive, add integration test

Co-authored-by: Karolin Varner <karo@cupdev.net>
Co-authored-by: Prabhpreet Dua <615318+prabhpreet@users.noreply.github.com>
This commit is contained in:
Prabhpreet Dua
2024-05-07 12:23:35 +05:30
committed by GitHub
parent e6d114c557
commit 2bac991305
18 changed files with 1373 additions and 50 deletions

245
Cargo.lock generated
View File

@@ -60,47 +60,48 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstream"
version = "0.6.13"
version = "0.6.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb"
checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.6"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc"
checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b"
[[package]]
name = "anstyle-parse"
version = "0.2.3"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c"
checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.0.2"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648"
checksum = "a64c907d4e79225ac72e2a354c9ce84d50ebb4586dee56c82b3ee73004f537f5"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.2"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7"
checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19"
dependencies = [
"anstyle",
"windows-sys 0.52.0",
@@ -146,9 +147,9 @@ dependencies = [
[[package]]
name = "autocfg"
version = "1.2.0"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "backtrace"
@@ -190,7 +191,7 @@ dependencies = [
"regex",
"rustc-hash",
"shlex",
"syn",
"syn 2.0.60",
"which",
]
@@ -259,9 +260,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
version = "1.0.96"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd"
checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4"
dependencies = [
"jobserver",
"libc",
@@ -399,7 +400,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -428,9 +429,9 @@ dependencies = [
[[package]]
name = "colorchoice"
version = "1.0.0"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]]
name = "cpufeatures"
@@ -559,7 +560,17 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
name = "darling"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f2c43f534ea4b0b049015d00269734195e6d3f0f6635cb692251aca6f9f8b3c"
dependencies = [
"darling_core 0.12.4",
"darling_macro 0.12.4",
]
[[package]]
@@ -568,8 +579,22 @@ version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
dependencies = [
"darling_core",
"darling_macro",
"darling_core 0.20.8",
"darling_macro 0.20.8",
]
[[package]]
name = "darling_core"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e91455b86830a1c21799d94524df0845183fa55bafd9aa137b01c7d1065fa36"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn 1.0.109",
]
[[package]]
@@ -583,7 +608,18 @@ dependencies = [
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn",
"syn 2.0.60",
]
[[package]]
name = "darling_macro"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29b5acf0dea37a7f66f7b25d2c5e93fd46f8f6968b1a5d7a3e02e97768afc95a"
dependencies = [
"darling_core 0.12.4",
"quote",
"syn 1.0.109",
]
[[package]]
@@ -592,9 +628,9 @@ version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
dependencies = [
"darling_core",
"darling_core 0.20.8",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -605,7 +641,16 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
name = "derive_builder"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d13202debe11181040ae9063d739fa32cfcaaebe2275fe387703460ae2365b30"
dependencies = [
"derive_builder_macro 0.10.2",
]
[[package]]
@@ -614,7 +659,19 @@ version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7"
dependencies = [
"derive_builder_macro",
"derive_builder_macro 0.20.0",
]
[[package]]
name = "derive_builder_core"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66e616858f6187ed828df7c64a6d71720d83767a7f19740b2d1b6fe6327b36e5"
dependencies = [
"darling 0.12.4",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
@@ -623,10 +680,20 @@ version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d"
dependencies = [
"darling",
"darling 0.20.8",
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
name = "derive_builder_macro"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58a94ace95092c5acb1e97a7e846b310cfbd499652f72297da7493f618a98d73"
dependencies = [
"derive_builder_core 0.10.2",
"syn 1.0.109",
]
[[package]]
@@ -635,8 +702,8 @@ version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
dependencies = [
"derive_builder_core",
"syn",
"derive_builder_core 0.20.0",
"syn 2.0.60",
]
[[package]]
@@ -765,7 +832,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -826,9 +893,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.14"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"libc",
@@ -973,6 +1040,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800"
[[package]]
name = "itertools"
version = "0.10.5"
@@ -1129,6 +1202,31 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "neli"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1805440578ced23f85145d00825c0a831e43c587132a90e100552172543ae30"
dependencies = [
"byteorder",
"libc",
"log",
"neli-proc-macros",
]
[[package]]
name = "neli-proc-macros"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c168194d373b1e134786274020dae7fc5513d565ea2ebb9bc9ff17ffb69106d4"
dependencies = [
"either",
"proc-macro2",
"quote",
"serde",
"syn 1.0.109",
]
[[package]]
name = "netlink-packet-core"
version = "0.7.0"
@@ -1256,9 +1354,9 @@ dependencies = [
[[package]]
name = "num-traits"
version = "0.2.18"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
@@ -1430,7 +1528,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550"
dependencies = [
"proc-macro2",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -1555,7 +1653,7 @@ dependencies = [
"anyhow",
"clap 4.5.4",
"criterion",
"derive_builder",
"derive_builder 0.20.0",
"env_logger",
"home",
"log",
@@ -1662,11 +1760,30 @@ version = "0.1.0"
dependencies = [
"anyhow",
"base64ct",
"rustix",
"static_assertions",
"typenum",
"zeroize",
]
[[package]]
name = "rosenpass-wireguard-broker"
version = "0.1.0"
dependencies = [
"anyhow",
"clap 4.5.4",
"env_logger",
"log",
"mio",
"rand",
"rosenpass-to",
"rosenpass-util",
"thiserror",
"tokio",
"wireguard-uapi",
"zerocopy",
]
[[package]]
name = "rp"
version = "0.2.1"
@@ -1762,9 +1879,9 @@ dependencies = [
[[package]]
name = "scc"
version = "2.1.0"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec96560eea317a9cc4e0bb1f6a2c93c09a19b8c4fc5cb3fcc0ec1c094cd783e2"
checksum = "76ad2bbb0ae5100a07b7a6f2ed7ab5fd0045551a4c507989b7a620046ea3efdc"
dependencies = [
"sdd",
]
@@ -1804,7 +1921,7 @@ checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -1849,7 +1966,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -1858,6 +1975,15 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
dependencies = [
"libc",
]
[[package]]
name = "slab"
version = "0.4.9"
@@ -1935,6 +2061,17 @@ version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.60"
@@ -1996,7 +2133,7 @@ checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -2020,7 +2157,9 @@ dependencies = [
"libc",
"mio",
"num_cpus",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys 0.48.0",
@@ -2034,7 +2173,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -2142,7 +2281,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
"wasm-bindgen-shared",
]
@@ -2164,7 +2303,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -2442,6 +2581,18 @@ dependencies = [
"memchr",
]
[[package]]
name = "wireguard-uapi"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ba4e9811befc20af3b6efb15924a7238ee5e8e8706a196576462a00b9f1af1"
dependencies = [
"derive_builder 0.10.2",
"libc",
"neli",
"thiserror",
]
[[package]]
name = "x25519-dalek"
version = "2.0.1"
@@ -2472,7 +2623,7 @@ checksum = "6f4b6c273f496d8fd4eaf18853e6b448760225dc030ff2c485a786859aea6393"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]
[[package]]
@@ -2492,5 +2643,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.60",
]

View File

@@ -12,11 +12,13 @@ members = [
"fuzz",
"secret-memory",
"rp",
"wireguard-broker"
]
default-members = [
"rosenpass",
"rp",
"wireguard-broker",
]
[workspace.metadata.release]
@@ -57,6 +59,7 @@ chacha20poly1305 = { version = "0.10.1", default-features = false, features = [
zerocopy = { version = "0.7.33", features = ["derive"] }
home = "0.5.9"
derive_builder = "0.20.0"
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
#Dev dependencies
serial_test = "3.1.1"
@@ -65,4 +68,9 @@ stacker = "0.1.15"
libfuzzer-sys = "0.4"
test_bin = "0.4.0"
criterion = "0.4.0"
allocator-api2-tests = "0.2.15"
allocator-api2-tests = "0.2.15"
#Broker dependencies (might need cleanup or changes)
wireguard-uapi = "3.0.0"
command-fds = "0.2.3"
rustix = { version = "0.38.27", features = ["net"] }

View File

@@ -21,7 +21,7 @@ rosenpass-cipher-traits = { workspace = true }
rosenpass-secret-memory = { workspace = true }
rosenpass-util = { workspace = true }
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
tokio = {workspace = true}
[target.'cfg(any(target_os = "linux", target_os = "freebsd"))'.dependencies]
ctrlc-async = "3.2"

View File

@@ -16,4 +16,5 @@ base64ct = { workspace = true }
anyhow = { workspace = true }
typenum = { workspace = true }
static_assertions = { workspace = true }
zeroize = {workspace = true}
rustix = {workspace = true}
zeroize = {workspace = true}

12
util/src/fd.rs Normal file
View File

@@ -0,0 +1,12 @@
use std::os::fd::{OwnedFd, RawFd};
/// Clone some file descriptor
///
/// If the file descriptor is invalid, an error will be raised.
pub fn claim_fd(fd: RawFd) -> anyhow::Result<OwnedFd> {
use rustix::{fd::BorrowedFd, io::dup};
// This is safe since [dup] will simply raise
let fd = unsafe { dup(BorrowedFd::borrow_raw(fd))? };
Ok(fd)
}

View File

@@ -1,6 +1,7 @@
#![recursion_limit = "256"]
pub mod b64;
pub mod fd;
pub mod file;
pub mod functional;
pub mod mem;

View File

@@ -0,0 +1,49 @@
[package]
name = "rosenpass-wireguard-broker"
authors = ["Karolin Varner <karo@cupdev.net>", "wucke13 <wucke13@gmail.com>"]
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Rosenpass internal broker that runs as root and supplies exchanged keys to the kernel."
homepage = "https://rosenpass.eu/"
repository = "https://github.com/rosenpass/rosenpass"
readme = "readme.md"
[dependencies]
thiserror = { workspace = true }
zerocopy = { workspace = true }
# Privileged only
wireguard-uapi = { workspace = true }
# Socket handler only
rosenpass-to = { workspace = true }
tokio = { version = "1.34.0", features = ["sync", "full", "mio"] }
anyhow = { workspace = true }
clap = { workspace = true }
env_logger = { workspace = true }
log = { workspace = true }
# Mio broker client
mio = { workspace = true }
rosenpass-util = { workspace = true }
[dev-dependencies]
rand = {workspace = true}
[features]
enable_broker=[]
[[bin]]
name = "rosenpass-wireguard-broker-privileged"
path = "src/bin/priviledged.rs"
test = false
doc = false
required-features=["enable_broker"]
[[bin]]
name = "rosenpass-wireguard-broker-socket-handler"
test = false
path = "src/bin/socket_handler.rs"
doc = false
required-features=["enable_broker"]

View File

@@ -0,0 +1,5 @@
# Rosenpass internal broker supplying WireGuard with keys.
This crate contains a small application purpose-built to supply WireGuard in the linux kernel with pre-shared keys.
This is an internal library; not guarantee is made about its API at this point in time.

View File

@@ -0,0 +1,142 @@
use std::borrow::BorrowMut;
use crate::{
api::msgs::{self, REQUEST_MSG_BUFFER_SIZE},
WireGuardBroker,
};
use super::msgs::{Envelope, SetPskResponse};
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum BrokerClientPollResponseError<RecvError> {
#[error(transparent)]
IoError(RecvError),
#[error("Invalid message.")]
InvalidMessage,
}
impl<RecvError> From<msgs::InvalidMessageTypeError> for BrokerClientPollResponseError<RecvError> {
fn from(value: msgs::InvalidMessageTypeError) -> Self {
let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type
BrokerClientPollResponseError::<RecvError>::InvalidMessage
}
}
fn io_poller<RecvError>(e: RecvError) -> BrokerClientPollResponseError<RecvError> {
BrokerClientPollResponseError::<RecvError>::IoError(e)
}
fn invalid_msg_poller<RecvError>() -> BrokerClientPollResponseError<RecvError> {
BrokerClientPollResponseError::<RecvError>::InvalidMessage
}
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum BrokerClientSetPskError<SendError> {
#[error("Error with encoding/decoding message")]
MsgError,
#[error(transparent)]
IoError(SendError),
#[error("Interface name out of bounds")]
IfaceOutOfBounds,
}
pub trait BrokerClientIo {
type SendError;
type RecvError;
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError>;
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError>;
}
#[derive(Debug)]
pub struct BrokerClient<Io>
where
Io: BrokerClientIo,
{
io: Io,
}
impl<Io> BrokerClient<Io>
where
Io: BrokerClientIo,
{
pub fn new(io: Io) -> Self {
Self { io }
}
pub fn io(&self) -> &Io {
&self.io
}
pub fn io_mut(&mut self) -> &mut Io {
&mut self.io
}
pub fn poll_response(
&mut self,
) -> Result<Option<msgs::SetPskResult>, BrokerClientPollResponseError<Io::RecvError>> {
let res: &[u8] = match self.io.borrow_mut().recv_msg().map_err(io_poller)? {
Some(r) => r,
None => return Ok(None),
};
let typ = res.get(0).ok_or(invalid_msg_poller())?;
let typ = msgs::MsgType::try_from(*typ)?;
let msgs::MsgType::SetPsk = typ; // Assert type
let res = zerocopy::Ref::<&[u8], Envelope<SetPskResponse>>::new(res)
.ok_or(invalid_msg_poller())?;
let res: &msgs::SetPskResponse = &res.payload;
let res: msgs::SetPskResponseReturnCode = res
.return_code
.try_into()
.map_err(|_| invalid_msg_poller())?;
let res: msgs::SetPskResult = res.into();
Ok(Some(res))
}
}
impl<Io> WireGuardBroker for BrokerClient<Io>
where
Io: BrokerClientIo,
{
type Error = BrokerClientSetPskError<Io::SendError>;
fn set_psk(
&mut self,
iface: &str,
peer_id: [u8; 32],
psk: [u8; 32],
) -> Result<(), Self::Error> {
use BrokerClientSetPskError::*;
const BUF_SIZE: usize = REQUEST_MSG_BUFFER_SIZE;
// Allocate message
let mut req = [0u8; BUF_SIZE];
// Construct message view
let mut req = zerocopy::Ref::<&mut [u8], Envelope<msgs::SetPskRequest>>::new(&mut req)
.ok_or(MsgError)?;
// Populate envelope
req.msg_type = msgs::MsgType::SetPsk as u8;
{
// Derived payload
let req = &mut req.payload;
// Populate payload
req.peer_id.copy_from_slice(&peer_id);
req.psk.copy_from_slice(&psk);
req.set_iface(iface).ok_or(IfaceOutOfBounds)?;
}
// Send message
self.io
.borrow_mut()
.send_msg(req.bytes())
.map_err(IoError)?;
Ok(())
}
}

View File

@@ -0,0 +1,231 @@
use std::collections::VecDeque;
use std::io::{ErrorKind, Read, Write};
use anyhow::{bail, ensure};
use crate::WireGuardBroker;
use super::client::{
BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError,
};
use super::msgs::{self, RESPONSE_MSG_BUFFER_SIZE};
#[derive(Debug)]
pub struct MioBrokerClient {
inner: BrokerClient<MioBrokerClientIo>,
}
const LEN_SIZE: usize = 8;
const RECV_BUF_SIZE: usize = RESPONSE_MSG_BUFFER_SIZE;
#[derive(Debug)]
struct MioBrokerClientIo {
socket: mio::net::UnixStream,
send_buf: VecDeque<u8>,
recv_state: RxState,
expected_state: RxState,
recv_buf: [u8; RECV_BUF_SIZE],
}
#[derive(Debug, Clone, Copy)]
enum RxState {
//Recieving size with buffer offset
RxSize(usize),
RxBuffer(usize),
}
impl MioBrokerClient {
pub fn new(socket: mio::net::UnixStream) -> Self {
let io = MioBrokerClientIo {
socket,
send_buf: VecDeque::new(),
recv_state: RxState::RxSize(0),
recv_buf: [0u8; RECV_BUF_SIZE],
expected_state: RxState::RxSize(LEN_SIZE),
};
let inner = BrokerClient::new(io);
Self { inner }
}
pub fn poll(&mut self) -> anyhow::Result<Option<msgs::SetPskResult>> {
self.inner.io_mut().flush()?;
// This sucks
match self.inner.poll_response() {
Ok(res) => {
return Ok(res);
}
Err(BrokerClientPollResponseError::IoError(e)) => {
return Err(e);
}
Err(BrokerClientPollResponseError::InvalidMessage) => {
bail!("Invalid message");
}
};
}
}
impl WireGuardBroker for MioBrokerClient {
type Error = anyhow::Error;
fn set_psk(&mut self, iface: &str, peer_id: [u8; 32], psk: [u8; 32]) -> anyhow::Result<()> {
use BrokerClientSetPskError::*;
let e = self.inner.set_psk(iface, peer_id, psk);
match e {
Ok(()) => Ok(()),
Err(IoError(e)) => Err(e),
Err(IfaceOutOfBounds) => bail!("Interface name size is out of bounds."),
Err(MsgError) => bail!("Error with encoding/decoding message."),
}
}
}
impl BrokerClientIo for MioBrokerClientIo {
type SendError = anyhow::Error;
type RecvError = anyhow::Error;
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> {
self.flush()?;
self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?;
self.send_or_buffer(&buf)?;
self.flush()?;
Ok(())
}
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError> {
loop {
match (self.recv_state, self.expected_state) {
//Stale Buffer state or recieved everything
(RxState::RxSize(x), RxState::RxSize(y))
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
if x == y =>
{
match self.recv_state {
RxState::RxSize(s) => {
let len: &[u8; LEN_SIZE] = self.recv_buf[0..s].try_into().unwrap();
let len: usize = u64::from_le_bytes(*len) as usize;
ensure!(
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in psk buffer response."
);
self.recv_state = RxState::RxBuffer(0);
self.expected_state = RxState::RxBuffer(len);
continue;
}
RxState::RxBuffer(s) => {
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
return Ok(Some(&self.recv_buf[0..s]));
}
}
}
//Recieve if x < y
(RxState::RxSize(x), RxState::RxSize(y))
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
if x < y =>
{
let bytes = raw_recv(&self.socket, &mut self.recv_buf[x..y])?;
if x + bytes == y {
return Ok(Some(&self.recv_buf[0..y]));
}
//We didn't recieve everything so let's assume something went wrong
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
bail!("Invalid state");
}
_ => {
//Reset states
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
bail!("Invalid state");
}
};
}
}
}
impl MioBrokerClientIo {
fn flush(&mut self) -> anyhow::Result<()> {
let (fst, snd) = self.send_buf.as_slices();
let (written, res) = match raw_send(&self.socket, fst) {
Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) {
Ok(w2) => (w1 + w2, Ok(())),
Err(e) => (w1, Err(e)),
},
Ok(w1) => (w1, Ok(())),
Err(e) => (0, Err(e)),
};
self.send_buf.drain(..written);
(&self.socket).try_io(|| (&self.socket).flush())?;
res
}
fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> {
let mut off = 0;
if self.send_buf.is_empty() {
off += raw_send(&self.socket, buf)?;
}
self.send_buf.extend((&buf[off..]).iter());
Ok(())
}
}
fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == data.len() {
return Ok(());
}
match socket.write(&data[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
return Ok(off);
}
fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == out.len() {
return Ok(());
}
match socket.read(&mut out[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
return Ok(off);
}

View File

@@ -0,0 +1,4 @@
pub mod client;
pub mod mio_client;
pub mod msgs;
pub mod server;

View File

@@ -0,0 +1,145 @@
use std::result::Result;
use std::str::{from_utf8, Utf8Error};
use zerocopy::{AsBytes, FromBytes, FromZeroes};
pub const ENVELOPE_OVERHEAD: usize = 1 + 3;
pub const REQUEST_MSG_BUFFER_SIZE: usize = ENVELOPE_OVERHEAD + 32 + 32 + 1 + 255;
pub const RESPONSE_MSG_BUFFER_SIZE: usize = ENVELOPE_OVERHEAD + 1;
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
pub struct Envelope<M: AsBytes + FromBytes> {
/// [MsgType] of this message
pub msg_type: u8,
/// Reserved for future use
pub reserved: [u8; 3],
/// The actual Paylod
pub payload: M,
}
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
pub struct SetPskRequest {
pub peer_id: [u8; 32],
pub psk: [u8; 32],
pub iface_size: u8, // TODO: We should have variable length strings in lenses
pub iface_buf: [u8; 255],
}
impl SetPskRequest {
pub fn iface_bin(&self) -> &[u8] {
let len = self.iface_size as usize;
&self.iface_buf[..len]
}
pub fn iface(&self) -> Result<&str, Utf8Error> {
from_utf8(self.iface_bin())
}
pub fn set_iface_bin(&mut self, iface: &[u8]) -> Option<()> {
(iface.len() < 256).then_some(())?; // Assert iface.len() < 256
self.iface_size = iface.len() as u8;
self.iface_buf = [0; 255];
(&mut self.iface_buf[..iface.len()]).copy_from_slice(iface);
Some(())
}
pub fn set_iface(&mut self, iface: &str) -> Option<()> {
self.set_iface_bin(iface.as_bytes())
}
}
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
pub struct SetPskResponse {
pub return_code: u8,
}
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum SetPskError {
#[error("The wireguard pre-shared-key assignment broker experienced an internal error.")]
InternalError,
#[error("The indicated wireguard interface does not exist")]
NoSuchInterface,
#[error("The indicated peer does not exist on the wireguard interface")]
NoSuchPeer,
}
pub type SetPskResult = Result<(), SetPskError>;
#[repr(u8)]
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
pub enum SetPskResponseReturnCode {
Success = 0x00,
InternalError = 0x01,
NoSuchInterface = 0x02,
NoSuchPeer = 0x03,
}
#[derive(Eq, PartialEq, Debug, Clone)]
pub struct InvalidSetPskResponseError;
impl TryFrom<u8> for SetPskResponseReturnCode {
type Error = InvalidSetPskResponseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
use SetPskResponseReturnCode::*;
match value {
0x00 => Ok(Success),
0x01 => Ok(InternalError),
0x02 => Ok(NoSuchInterface),
0x03 => Ok(NoSuchPeer),
_ => Err(InvalidSetPskResponseError),
}
}
}
impl From<SetPskResponseReturnCode> for SetPskResult {
fn from(value: SetPskResponseReturnCode) -> Self {
use SetPskError as E;
use SetPskResponseReturnCode as C;
match value {
C::Success => Ok(()),
C::InternalError => Err(E::InternalError),
C::NoSuchInterface => Err(E::NoSuchInterface),
C::NoSuchPeer => Err(E::NoSuchPeer),
}
}
}
impl From<SetPskResult> for SetPskResponseReturnCode {
fn from(value: SetPskResult) -> Self {
use SetPskError as E;
use SetPskResponseReturnCode as C;
match value {
Ok(()) => C::Success,
Err(E::InternalError) => C::InternalError,
Err(E::NoSuchInterface) => C::NoSuchInterface,
Err(E::NoSuchPeer) => C::NoSuchPeer,
}
}
}
#[repr(u8)]
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
pub enum MsgType {
SetPsk = 0x01,
}
#[derive(Eq, PartialEq, Debug, Clone)]
pub struct InvalidMessageTypeError;
impl TryFrom<u8> for MsgType {
type Error = InvalidMessageTypeError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(MsgType::SetPsk),
_ => Err(InvalidMessageTypeError),
}
}
}

View File

@@ -0,0 +1,79 @@
use std::borrow::BorrowMut;
use std::result::Result;
use crate::api::msgs::{self, Envelope, SetPskRequest, SetPskResponse};
use crate::WireGuardBroker;
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum BrokerServerError {
#[error("No such request type: {}", .0)]
NoSuchRequestType(u8),
#[error("Invalid message received.")]
InvalidMessage,
}
impl From<msgs::InvalidMessageTypeError> for BrokerServerError {
fn from(value: msgs::InvalidMessageTypeError) -> Self {
let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type
BrokerServerError::InvalidMessage
}
}
pub struct BrokerServer<Err, Inner>
where
msgs::SetPskError: From<Err>,
Inner: WireGuardBroker<Error = Err>,
{
inner: Inner,
}
impl<Err, Inner> BrokerServer<Err, Inner>
where
msgs::SetPskError: From<Err>,
Inner: WireGuardBroker<Error = Err>,
{
pub fn new(inner: Inner) -> Self {
Self { inner }
}
pub fn handle_message(
&mut self,
req: &[u8],
res: &mut [u8; msgs::RESPONSE_MSG_BUFFER_SIZE],
) -> Result<usize, BrokerServerError> {
use BrokerServerError::*;
let typ = req.get(0).ok_or(InvalidMessage)?;
let typ = msgs::MsgType::try_from(*typ)?;
let msgs::MsgType::SetPsk = typ; // Assert type
let req = zerocopy::Ref::<&[u8], Envelope<SetPskRequest>>::new(req)
.ok_or(BrokerServerError::InvalidMessage)?;
let mut res = zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::new(res)
.ok_or(BrokerServerError::InvalidMessage)?;
res.payload.return_code = msgs::MsgType::SetPsk as u8;
self.handle_set_psk(&req.payload, &mut res.payload)?;
Ok(res.bytes().len())
}
fn handle_set_psk(
&mut self,
req: &SetPskRequest,
res: &mut SetPskResponse,
) -> Result<(), BrokerServerError> {
// Using unwrap here since lenses can not return fixed-size arrays
// TODO: Slices should give access to fixed size arrays
let r: Result<(), Err> = self.inner.borrow_mut().set_psk(
req.iface()
.map_err(|_e| BrokerServerError::InvalidMessage)?,
req.peer_id.try_into().unwrap(),
req.psk.try_into().unwrap(),
);
let r: msgs::SetPskResult = r.map_err(|e| e.into());
let r: msgs::SetPskResponseReturnCode = r.into();
res.return_code = r as u8;
Ok(())
}
}

View File

@@ -0,0 +1,56 @@
use std::io::{stdin, stdout, Read, Write};
use std::result::Result;
use rosenpass_wireguard_broker::api::msgs;
use rosenpass_wireguard_broker::api::server::BrokerServer;
use rosenpass_wireguard_broker::netlink as wg;
#[derive(thiserror::Error, Debug)]
pub enum BrokerAppError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
WgConnectError(#[from] wg::ConnectError),
#[error(transparent)]
WgSetPskError(#[from] wg::SetPskError),
#[error("Oversized message {}; something about the request is fatally wrong", .0)]
OversizedMessage(u64),
}
fn main() -> Result<(), BrokerAppError> {
let mut broker = BrokerServer::new(wg::NetlinkWireGuardBroker::new()?);
let mut stdin = stdin().lock();
let mut stdout = stdout().lock();
loop {
// Read the message length
let mut len = [0u8; 8];
stdin.read_exact(&mut len)?;
// Parse the message length
let len = u64::from_le_bytes(len);
if (len as usize) > msgs::REQUEST_MSG_BUFFER_SIZE {
return Err(BrokerAppError::OversizedMessage(len));
}
// Read the message itself
let mut req_buf = [0u8; msgs::REQUEST_MSG_BUFFER_SIZE];
let req_buf = &mut req_buf[..(len as usize)];
stdin.read_exact(req_buf)?;
// Process the message
let mut res_buf = [0u8; msgs::RESPONSE_MSG_BUFFER_SIZE];
let res = match broker.handle_message(req_buf, &mut res_buf) {
Ok(len) => &res_buf[..len],
Err(e) => {
eprintln!("Error processing message for wireguard PSK broker: {e:?}");
continue;
}
};
// Write the response
stdout.write_all(&(res.len() as u64).to_le_bytes())?;
stdout.write_all(&res)?;
stdout.flush()?;
}
}

View File

@@ -0,0 +1,191 @@
use std::process::Stdio;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{UnixListener, UnixStream};
use tokio::process::Command;
use tokio::sync::{mpsc, oneshot};
use tokio::task;
use anyhow::{bail, ensure, Result};
use clap::{ArgGroup, Parser};
use rosenpass_util::fd::claim_fd;
use rosenpass_wireguard_broker::api::msgs;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[clap(group(
ArgGroup::new("socket")
.required(true)
.args(&["listen_path", "listen_fd", "stream_fd"]),
))]
struct Args {
/// Where in the file-system to create the unix socket this broker will be listening for
/// connections on
#[arg(long)]
listen_path: Option<String>,
/// When this broker is called from another process, the other process can open and bind the
/// unix socket to use themselves, passing it to this process. In Rust this can be achieved
/// using the [command-fds](https://docs.rs/command-fds/latest/command_fds/) crate.
#[arg(long)]
listen_fd: Option<i32>,
/// When this broker is called from another process, the other process can connect the unix socket
/// themselves, for instance using the `socketpair(2)` system call.
#[arg(long)]
stream_fd: Option<i32>,
/// The underlying broker, accepting commands through stdin and sending results through stdout.
#[arg(
last = true,
allow_hyphen_values = true,
default_value = "rosenpass-wireguard-broker-privileged"
)]
command: Vec<String>,
}
struct BrokerRequest {
reply_to: oneshot::Sender<BrokerResponse>,
request: Vec<u8>,
}
struct BrokerResponse {
response: Vec<u8>,
}
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
let args = Args::parse();
let (proc_tx, proc_rx) = mpsc::channel(100);
// Start the inner broker handler
task::spawn(async move {
if let Err(e) = direct_broker_process(proc_rx, args.command).await {
log::error!("Error in broker command handler: {e}");
panic!("Can not proceed without underlying broker process");
}
});
// Listen for incoming requests
if let Some(path) = args.listen_path {
let sock = UnixListener::bind(path)?;
listen_for_clients(proc_tx, sock).await
} else if let Some(fd) = args.listen_fd {
let sock = std::os::unix::net::UnixListener::from(claim_fd(fd)?);
sock.set_nonblocking(true)?;
listen_for_clients(proc_tx, UnixListener::from_std(sock)?).await
} else if let Some(fd) = args.stream_fd {
let stream = std::os::unix::net::UnixStream::from(claim_fd(fd)?);
stream.set_nonblocking(true)?;
on_accept(proc_tx, UnixStream::from_std(stream)?).await
} else {
unreachable!();
}
}
async fn direct_broker_process(
mut queue: mpsc::Receiver<BrokerRequest>,
cmd: Vec<String>,
) -> Result<()> {
let proc = Command::new(&cmd[0])
.args(&cmd[1..])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
let mut stdin = proc.stdin.unwrap();
let mut stdout = proc.stdout.unwrap();
loop {
let BrokerRequest { reply_to, request } = queue.recv().await.unwrap();
stdin
.write_all(&(request.len() as u64).to_le_bytes())
.await?;
stdin.write_all(&request[..]).await?;
// Read the response length
let mut len = [0u8; 8];
stdout.read_exact(&mut len).await?;
// Parse the response length
let len = u64::from_le_bytes(len) as usize;
ensure!(
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in broker stdout."
);
// Read the message itself
let mut res_buf = request; // Avoid allocating memory if we don't have to
res_buf.resize(len as usize, 0);
stdout.read_exact(&mut res_buf[..len]).await?;
// Return to the unix socket connection worker
reply_to
.send(BrokerResponse { response: res_buf })
.or_else(|_| bail!("Unable to send respnse to unix socket worker."))?;
}
}
async fn listen_for_clients(queue: mpsc::Sender<BrokerRequest>, sock: UnixListener) -> Result<()> {
loop {
let (stream, _addr) = sock.accept().await?;
let queue = queue.clone();
task::spawn(async move {
if let Err(e) = on_accept(queue, stream).await {
log::error!("Error during connection processing: {e}");
}
});
}
// NOTE: If loop can ever terminate we need to join the spawned tasks
}
async fn on_accept(queue: mpsc::Sender<BrokerRequest>, mut stream: UnixStream) -> Result<()> {
let mut req_buf = Vec::new();
loop {
stream.readable().await?;
// Read the message length
let mut len = [0u8; 8];
stream.read_exact(&mut len).await?;
// Parse the message length
let len = u64::from_le_bytes(len) as usize;
ensure!(
len <= msgs::REQUEST_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in unix socket input."
);
// Read the message itself
req_buf.resize(len as usize, 0);
stream.read_exact(&mut req_buf[..len]).await?;
// Handle the message
let (reply_tx, reply_rx) = oneshot::channel();
queue
.send(BrokerRequest {
reply_to: reply_tx,
request: req_buf,
})
.await?;
// Wait for the reply
let BrokerResponse { response } = reply_rx.await.unwrap();
// Write reply back to unix socket
stream
.write_all(&(response.len() as u64).to_le_bytes())
.await?;
stream.write_all(&response[..]).await?;
stream.flush().await?;
// Reuse the same memory for the next message
req_buf = response;
}
}

View File

@@ -0,0 +1,19 @@
#[cfg(feature = "enable_broker")]
use std::result::Result;
#[cfg(feature = "enable_broker")]
pub trait WireGuardBroker {
type Error;
fn set_psk(
&mut self,
interface: &str,
peer_id: [u8; 32],
psk: [u8; 32],
) -> Result<(), Self::Error>;
}
#[cfg(feature = "enable_broker")]
pub mod api;
#[cfg(feature = "enable_broker")]
pub mod netlink;

View File

@@ -0,0 +1,103 @@
use wireguard_uapi::linux as wg;
use crate::api::msgs;
use crate::WireGuardBroker;
#[derive(thiserror::Error, Debug)]
pub enum ConnectError {
#[error(transparent)]
ConnectError(#[from] wg::err::ConnectError),
}
#[derive(thiserror::Error, Debug)]
pub enum NetlinkError {
#[error(transparent)]
SetDevice(#[from] wg::err::SetDeviceError),
#[error(transparent)]
GetDevice(#[from] wg::err::GetDeviceError),
}
#[derive(thiserror::Error, Debug)]
pub enum SetPskError {
#[error("The indicated wireguard interface does not exist")]
NoSuchInterface,
#[error("The indicated peer does not exist on the wireguard interface")]
NoSuchPeer,
#[error(transparent)]
NetlinkError(#[from] NetlinkError),
}
impl From<wg::err::SetDeviceError> for SetPskError {
fn from(err: wg::err::SetDeviceError) -> Self {
NetlinkError::from(err).into()
}
}
impl From<wg::err::GetDeviceError> for SetPskError {
fn from(err: wg::err::GetDeviceError) -> Self {
NetlinkError::from(err).into()
}
}
use msgs::SetPskError as SetPskMsgsError;
use SetPskError as SetPskNetlinkError;
impl From<SetPskNetlinkError> for SetPskMsgsError {
fn from(err: SetPskError) -> Self {
match err {
SetPskNetlinkError::NoSuchPeer => SetPskMsgsError::NoSuchPeer,
_ => SetPskMsgsError::InternalError,
}
}
}
pub struct NetlinkWireGuardBroker {
sock: wg::WgSocket,
}
impl NetlinkWireGuardBroker {
pub fn new() -> Result<Self, ConnectError> {
let sock = wg::WgSocket::connect()?;
Ok(Self { sock })
}
}
impl WireGuardBroker for NetlinkWireGuardBroker {
type Error = SetPskError;
fn set_psk(
&mut self,
interface: &str,
peer_id: [u8; 32],
psk: [u8; 32],
) -> Result<(), Self::Error> {
// Ensure that the peer exists by querying the device configuration
// TODO: Use InvalidInterfaceError
let state = self
.sock
.get_device(wg::DeviceInterface::from_name(interface.to_owned()))?;
if state
.peers
.iter()
.find(|p| &p.public_key == &peer_id)
.is_none()
{
return Err(SetPskError::NoSuchPeer);
}
// Peer update description
let mut set_peer = wireguard_uapi::set::Peer::from_public_key(&peer_id);
set_peer
.flags
.push(wireguard_uapi::linux::set::WgPeerF::UpdateOnly);
set_peer.preshared_key = Some(&psk);
// Device update description
let mut set_dev = wireguard_uapi::set::Device::from_ifname(interface.to_owned());
set_dev.peers.push(set_peer);
self.sock.set_device(set_dev)?;
Ok(())
}
}

View File

@@ -0,0 +1,126 @@
#[cfg(feature = "enable_broker")]
#[cfg(test)]
mod integration_tests {
use rand::Rng;
use rosenpass_wireguard_broker::api::mio_client::MioBrokerClient;
use rosenpass_wireguard_broker::api::msgs::{
SetPskError, REQUEST_MSG_BUFFER_SIZE, RESPONSE_MSG_BUFFER_SIZE,
};
use rosenpass_wireguard_broker::api::server::{BrokerServer, BrokerServerError};
use rosenpass_wireguard_broker::WireGuardBroker;
use std::io::Read;
use std::sync::{Arc, Mutex};
#[derive(Default)]
struct MockServerBrokerInner {
psk: Option<[u8; 32]>,
peer_id: Option<[u8; 32]>,
interface: Option<String>,
}
struct MockServerBroker {
inner: Arc<Mutex<MockServerBrokerInner>>,
}
impl MockServerBroker {
fn new(inner: Arc<Mutex<MockServerBrokerInner>>) -> Self {
Self { inner }
}
}
impl WireGuardBroker for MockServerBroker {
type Error = SetPskError;
fn set_psk(
&mut self,
interface: &str,
peer_id: [u8; 32],
psk: [u8; 32],
) -> Result<(), Self::Error> {
loop {
let mut lock = self.inner.try_lock();
if let Ok(ref mut mutex) = lock {
**mutex = MockServerBrokerInner {
psk: Some(psk),
peer_id: Some(peer_id),
interface: Some(interface.to_string()),
};
break;
}
}
Ok(())
}
}
#[test]
fn test_psk_exchanges() {
const TEST_RUNS: usize = 100;
let server_broker_inner = Arc::new(Mutex::new(MockServerBrokerInner::default()));
// Create a mock BrokerServer
let server_broker = MockServerBroker::new(server_broker_inner.clone());
let mut server = BrokerServer::<SetPskError, MockServerBroker>::new(server_broker);
let (client_socket, mut server_socket) = mio::net::UnixStream::pair().unwrap();
// Spawn a new thread to connect to the unix socket
let handle = std::thread::spawn(move || {
for _ in 0..TEST_RUNS {
// Wait for 8 bytes of length to come in
let mut length_buffer = [0; 8];
while let Err(_err) = server_socket.read_exact(&mut length_buffer) {}
let length = u64::from_le_bytes(length_buffer) as usize;
// Read the amount of length bytes into a buffer
let mut data_buffer = [0; REQUEST_MSG_BUFFER_SIZE];
while let Err(_err) = server_socket.read_exact(&mut data_buffer[0..length]) {}
let mut response = [0; RESPONSE_MSG_BUFFER_SIZE];
server.handle_message(&data_buffer[0..length], &mut response)?;
}
Ok::<(), BrokerServerError>(())
});
// Create a MioBrokerClient and send a psk
let mut client = MioBrokerClient::new(client_socket);
for _ in 0..TEST_RUNS {
//Create psk of random 32 bytes
let mut psk: [u8; 32] = [0; 32];
rand::thread_rng().fill(&mut psk);
let mut peer_id: [u8; 32] = [0; 32];
rand::thread_rng().fill(&mut peer_id);
let interface = "test";
client.set_psk(&interface, peer_id, psk).unwrap();
//Sleep for a while to allow the server to process the message
std::thread::sleep(std::time::Duration::from_millis(
rand::thread_rng().gen_range(100..500),
));
loop {
let mut lock = server_broker_inner.try_lock();
if let Ok(ref mut inner) = lock {
// Check if the psk is received by the server
let received_psk = inner.psk;
assert_eq!(received_psk, Some(psk));
let recieved_peer_id = inner.peer_id;
assert_eq!(recieved_peer_id, Some(peer_id));
let target_interface = &inner.interface;
assert_eq!(target_interface.as_deref(), Some(interface));
break;
}
}
}
handle.join().unwrap().unwrap();
}
}