mirror of
https://github.com/rosenpass/rosenpass.git
synced 2026-02-28 14:33:37 -08:00
feat(rosenpass): Add wireguard-broker interface in AppServer (#303)
Dynamically dispatch WireguardBrokerMio trait in AppServer. Also allows for mio event registration and poll processing, logic from dev/broker-architecture branch Co-authored-by: Prabhpreet Dua <615318+prabhpreet@users.noreply.github.com> Co-authored-by: Karolin Varner <karo@cupdev.net>
This commit is contained in:
259
wireguard-broker/src/brokers/mio_client.rs
Normal file
259
wireguard-broker/src/brokers/mio_client.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use anyhow::{bail, ensure};
|
||||
use mio::Interest;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::{ErrorKind, Read, Write};
|
||||
|
||||
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio};
|
||||
|
||||
use crate::api::client::{
|
||||
BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError,
|
||||
};
|
||||
use crate::api::msgs::{self, RESPONSE_MSG_BUFFER_SIZE};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MioBrokerClient {
|
||||
inner: BrokerClient<MioBrokerClientIo>,
|
||||
}
|
||||
|
||||
const LEN_SIZE: usize = 8;
|
||||
const RECV_BUF_SIZE: usize = RESPONSE_MSG_BUFFER_SIZE;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MioBrokerClientIo {
|
||||
socket: mio::net::UnixStream,
|
||||
send_buf: VecDeque<u8>,
|
||||
recv_state: RxState,
|
||||
expected_state: RxState,
|
||||
recv_buf: [u8; RECV_BUF_SIZE],
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum RxState {
|
||||
//Recieving size with buffer offset
|
||||
RxSize(usize),
|
||||
RxBuffer(usize),
|
||||
}
|
||||
|
||||
impl MioBrokerClient {
|
||||
pub fn new(socket: mio::net::UnixStream) -> Self {
|
||||
let io = MioBrokerClientIo {
|
||||
socket,
|
||||
send_buf: VecDeque::new(),
|
||||
recv_state: RxState::RxSize(0),
|
||||
recv_buf: [0u8; RECV_BUF_SIZE],
|
||||
expected_state: RxState::RxSize(LEN_SIZE),
|
||||
};
|
||||
let inner = BrokerClient::new(io);
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
fn poll(&mut self) -> anyhow::Result<Option<msgs::SetPskResult>> {
|
||||
self.inner.io_mut().flush()?;
|
||||
|
||||
// This sucks
|
||||
match self.inner.poll_response() {
|
||||
Ok(res) => {
|
||||
return Ok(res);
|
||||
}
|
||||
Err(BrokerClientPollResponseError::IoError(e)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Err(BrokerClientPollResponseError::InvalidMessage) => {
|
||||
bail!("Invalid message");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl WireGuardBroker for MioBrokerClient {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn set_psk<'a>(&mut self, config: SerializedBrokerConfig<'a>) -> anyhow::Result<()> {
|
||||
use BrokerClientSetPskError::*;
|
||||
let e = self.inner.set_psk(config);
|
||||
match e {
|
||||
Ok(()) => Ok(()),
|
||||
Err(IoError(e)) => Err(e),
|
||||
Err(IfaceOutOfBounds) => bail!("Interface name size is out of bounds."),
|
||||
Err(MsgError) => bail!("Error with encoding/decoding message."),
|
||||
Err(BrokerError(e)) => bail!("Broker error: {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WireguardBrokerMio for MioBrokerClient {
|
||||
type MioError = anyhow::Error;
|
||||
|
||||
fn register(
|
||||
&mut self,
|
||||
registry: &mio::Registry,
|
||||
token: mio::Token,
|
||||
) -> Result<(), Self::MioError> {
|
||||
registry.register(
|
||||
&mut self.inner.io_mut().socket,
|
||||
token,
|
||||
Interest::READABLE | Interest::WRITABLE,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_poll(&mut self) -> Result<(), Self::MioError> {
|
||||
self.poll()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unregister(&mut self, registry: &mio::Registry) -> Result<(), Self::MioError> {
|
||||
registry.deregister(&mut self.inner.io_mut().socket)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl BrokerClientIo for MioBrokerClientIo {
|
||||
type SendError = anyhow::Error;
|
||||
type RecvError = anyhow::Error;
|
||||
|
||||
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> {
|
||||
self.flush()?;
|
||||
self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?;
|
||||
self.send_or_buffer(&buf)?;
|
||||
self.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError> {
|
||||
loop {
|
||||
match (self.recv_state, self.expected_state) {
|
||||
//Stale Buffer state or recieved everything
|
||||
(RxState::RxSize(x), RxState::RxSize(y))
|
||||
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
|
||||
if x == y =>
|
||||
{
|
||||
match self.recv_state {
|
||||
RxState::RxSize(s) => {
|
||||
let len: &[u8; LEN_SIZE] = self.recv_buf[0..s].try_into().unwrap();
|
||||
let len: usize = u64::from_le_bytes(*len) as usize;
|
||||
|
||||
ensure!(
|
||||
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
|
||||
"Oversized buffer ({len}) in psk buffer response."
|
||||
);
|
||||
|
||||
self.recv_state = RxState::RxBuffer(0);
|
||||
self.expected_state = RxState::RxBuffer(len);
|
||||
continue;
|
||||
}
|
||||
RxState::RxBuffer(s) => {
|
||||
self.recv_state = RxState::RxSize(0);
|
||||
self.expected_state = RxState::RxSize(LEN_SIZE);
|
||||
return Ok(Some(&self.recv_buf[0..s]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Recieve if x < y
|
||||
(RxState::RxSize(x), RxState::RxSize(y))
|
||||
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
|
||||
if x < y =>
|
||||
{
|
||||
let bytes = raw_recv(&self.socket, &mut self.recv_buf[x..y])?;
|
||||
|
||||
if x + bytes == y {
|
||||
return Ok(Some(&self.recv_buf[0..y]));
|
||||
}
|
||||
//We didn't recieve everything so let's assume something went wrong
|
||||
self.recv_state = RxState::RxSize(0);
|
||||
self.expected_state = RxState::RxSize(LEN_SIZE);
|
||||
bail!("Invalid state");
|
||||
}
|
||||
_ => {
|
||||
//Reset states
|
||||
self.recv_state = RxState::RxSize(0);
|
||||
self.expected_state = RxState::RxSize(LEN_SIZE);
|
||||
bail!("Invalid state");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MioBrokerClientIo {
|
||||
fn flush(&mut self) -> anyhow::Result<()> {
|
||||
let (fst, snd) = self.send_buf.as_slices();
|
||||
|
||||
let (written, res) = match raw_send(&self.socket, fst) {
|
||||
Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) {
|
||||
Ok(w2) => (w1 + w2, Ok(())),
|
||||
Err(e) => (w1, Err(e)),
|
||||
},
|
||||
Ok(w1) => (w1, Ok(())),
|
||||
Err(e) => (0, Err(e)),
|
||||
};
|
||||
|
||||
self.send_buf.drain(..written);
|
||||
|
||||
(&self.socket).try_io(|| (&self.socket).flush())?;
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> {
|
||||
let mut off = 0;
|
||||
|
||||
if self.send_buf.is_empty() {
|
||||
off += raw_send(&self.socket, buf)?;
|
||||
}
|
||||
|
||||
self.send_buf.extend((&buf[off..]).iter());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result<usize> {
|
||||
let mut off = 0;
|
||||
|
||||
socket.try_io(|| {
|
||||
loop {
|
||||
if off == data.len() {
|
||||
return Ok(());
|
||||
}
|
||||
match socket.write(&data[off..]) {
|
||||
Ok(n) => {
|
||||
off += n;
|
||||
}
|
||||
Err(e) if e.kind() == ErrorKind::Interrupted => {
|
||||
// pass – retry
|
||||
}
|
||||
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
return Ok(off);
|
||||
}
|
||||
|
||||
fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result<usize> {
|
||||
let mut off = 0;
|
||||
|
||||
socket.try_io(|| {
|
||||
loop {
|
||||
if off == out.len() {
|
||||
return Ok(());
|
||||
}
|
||||
match socket.read(&mut out[off..]) {
|
||||
Ok(n) => {
|
||||
off += n;
|
||||
}
|
||||
Err(e) if e.kind() == ErrorKind::Interrupted => {
|
||||
// pass – retry
|
||||
}
|
||||
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
return Ok(off);
|
||||
}
|
||||
6
wireguard-broker/src/brokers/mod.rs
Normal file
6
wireguard-broker/src/brokers/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
#[cfg(feature = "enable_broker_api")]
|
||||
pub mod mio_client;
|
||||
#[cfg(feature = "enable_broker_api")]
|
||||
pub mod netlink;
|
||||
|
||||
pub mod native_unix;
|
||||
177
wireguard-broker/src/brokers/native_unix.rs
Normal file
177
wireguard-broker/src/brokers/native_unix.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::fmt::Debug;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::thread;
|
||||
|
||||
use derive_builder::Builder;
|
||||
use log::{debug, error};
|
||||
use postcard::{from_bytes, to_allocvec};
|
||||
use rosenpass_secret_memory::{Public, Secret};
|
||||
use rosenpass_util::b64::b64_decode;
|
||||
use rosenpass_util::{b64::B64Display, file::StoreValueB64Writer};
|
||||
|
||||
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerCfg, WireguardBrokerMio};
|
||||
use crate::{WG_KEY_LEN, WG_PEER_LEN};
|
||||
|
||||
const MAX_B64_KEY_SIZE: usize = WG_KEY_LEN * 5 / 3;
|
||||
const MAX_B64_PEER_ID_SIZE: usize = WG_PEER_LEN * 5 / 3;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NativeUnixBroker {}
|
||||
|
||||
impl Default for NativeUnixBroker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl NativeUnixBroker {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl WireGuardBroker for NativeUnixBroker {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn set_psk(&mut self, config: SerializedBrokerConfig<'_>) -> Result<(), Self::Error> {
|
||||
let config: NativeUnixBrokerConfig = config.try_into()?;
|
||||
|
||||
let peer_id = format!("{}", config.peer_id.fmt_b64::<MAX_B64_PEER_ID_SIZE>());
|
||||
|
||||
let mut child = match Command::new("wg")
|
||||
.arg("set")
|
||||
.arg(config.interface)
|
||||
.arg("peer")
|
||||
.arg(peer_id)
|
||||
.arg("preshared-key")
|
||||
.arg("/dev/stdin")
|
||||
.stdin(Stdio::piped())
|
||||
.args(config.extra_params)
|
||||
.spawn()
|
||||
{
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
anyhow::bail!("Could not find wg command");
|
||||
} else {
|
||||
return Err(anyhow::Error::new(e));
|
||||
}
|
||||
}
|
||||
};
|
||||
if let Err(e) = config
|
||||
.psk
|
||||
.store_b64_writer::<MAX_B64_KEY_SIZE, _>(child.stdin.take().unwrap())
|
||||
{
|
||||
error!("could not write psk to wg: {:?}", e);
|
||||
}
|
||||
|
||||
thread::spawn(move || {
|
||||
let status = child.wait();
|
||||
|
||||
if let Ok(status) = status {
|
||||
if status.success() {
|
||||
debug!("successfully passed psk to wg")
|
||||
} else {
|
||||
error!("could not pass psk to wg {:?}", status)
|
||||
}
|
||||
} else {
|
||||
error!("wait failed: {:?}", status)
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl WireguardBrokerMio for NativeUnixBroker {
|
||||
type MioError = anyhow::Error;
|
||||
|
||||
fn register(
|
||||
&mut self,
|
||||
_registry: &mio::Registry,
|
||||
_token: mio::Token,
|
||||
) -> Result<(), Self::MioError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_poll(&mut self) -> Result<(), Self::MioError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unregister(&mut self, _registry: &mio::Registry) -> Result<(), Self::MioError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Builder)]
|
||||
#[builder(pattern = "mutable")]
|
||||
pub struct NativeUnixBrokerConfigBase {
|
||||
pub interface: String,
|
||||
pub peer_id: Public<WG_PEER_LEN>,
|
||||
#[builder(private)]
|
||||
pub extra_params: Vec<u8>,
|
||||
}
|
||||
|
||||
impl NativeUnixBrokerConfigBaseBuilder {
|
||||
pub fn peer_id_b64(
|
||||
&mut self,
|
||||
peer_id: &str,
|
||||
) -> Result<&mut Self, NativeUnixBrokerConfigBaseBuilderError> {
|
||||
let mut peer_id_b64 = Public::<WG_PEER_LEN>::zero();
|
||||
b64_decode(peer_id.as_bytes(), &mut peer_id_b64.value).map_err(|_e| {
|
||||
NativeUnixBrokerConfigBaseBuilderError::ValidationError(
|
||||
"Failed to parse peer id b64".to_string(),
|
||||
)
|
||||
})?;
|
||||
Ok(self.peer_id(peer_id_b64))
|
||||
}
|
||||
|
||||
pub fn extra_params_ser(
|
||||
&mut self,
|
||||
extra_params: &Vec<String>,
|
||||
) -> Result<&mut Self, NativeUnixBrokerConfigBuilderError> {
|
||||
let params = to_allocvec(extra_params).map_err(|_e| {
|
||||
NativeUnixBrokerConfigBuilderError::ValidationError(
|
||||
"Failed to parse extra params".to_string(),
|
||||
)
|
||||
})?;
|
||||
Ok(self.extra_params(params))
|
||||
}
|
||||
}
|
||||
|
||||
impl WireguardBrokerCfg for NativeUnixBrokerConfigBase {
|
||||
fn create_config<'a>(&'a self, psk: &'a Secret<WG_KEY_LEN>) -> SerializedBrokerConfig<'a> {
|
||||
SerializedBrokerConfig {
|
||||
interface: self.interface.as_bytes(),
|
||||
peer_id: &self.peer_id,
|
||||
psk,
|
||||
additional_params: &self.extra_params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Builder)]
|
||||
#[builder(pattern = "mutable")]
|
||||
pub struct NativeUnixBrokerConfig<'a> {
|
||||
pub interface: &'a str,
|
||||
pub peer_id: &'a Public<WG_PEER_LEN>,
|
||||
pub psk: &'a Secret<WG_KEY_LEN>,
|
||||
pub extra_params: Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<SerializedBrokerConfig<'a>> for NativeUnixBrokerConfig<'a> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: SerializedBrokerConfig<'a>) -> Result<Self, Self::Error> {
|
||||
let iface = std::str::from_utf8(value.interface)
|
||||
.map_err(|_| anyhow::Error::msg("Interface UTF8 decoding error"))?;
|
||||
|
||||
let extra_params: Vec<String> =
|
||||
from_bytes(value.additional_params).map_err(anyhow::Error::new)?;
|
||||
Ok(Self {
|
||||
interface: iface,
|
||||
peer_id: value.peer_id,
|
||||
psk: value.psk,
|
||||
extra_params,
|
||||
})
|
||||
}
|
||||
}
|
||||
112
wireguard-broker/src/brokers/netlink.rs
Normal file
112
wireguard-broker/src/brokers/netlink.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use wireguard_uapi::linux as wg;
|
||||
|
||||
use crate::api::config::NetworkBrokerConfig;
|
||||
use crate::api::msgs;
|
||||
use crate::{SerializedBrokerConfig, WireGuardBroker};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectError {
|
||||
#[error(transparent)]
|
||||
ConnectError(#[from] wg::err::ConnectError),
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum NetlinkError {
|
||||
#[error(transparent)]
|
||||
SetDevice(#[from] wg::err::SetDeviceError),
|
||||
#[error(transparent)]
|
||||
GetDevice(#[from] wg::err::GetDeviceError),
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum SetPskError {
|
||||
#[error("The indicated wireguard interface does not exist")]
|
||||
NoSuchInterface,
|
||||
#[error("The indicated peer does not exist on the wireguard interface")]
|
||||
NoSuchPeer,
|
||||
#[error(transparent)]
|
||||
NetlinkError(#[from] NetlinkError),
|
||||
}
|
||||
|
||||
impl From<wg::err::SetDeviceError> for SetPskError {
|
||||
fn from(err: wg::err::SetDeviceError) -> Self {
|
||||
NetlinkError::from(err).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<wg::err::GetDeviceError> for SetPskError {
|
||||
fn from(err: wg::err::GetDeviceError) -> Self {
|
||||
NetlinkError::from(err).into()
|
||||
}
|
||||
}
|
||||
|
||||
use msgs::SetPskError as SetPskMsgsError;
|
||||
use SetPskError as SetPskNetlinkError;
|
||||
impl From<SetPskNetlinkError> for SetPskMsgsError {
|
||||
fn from(err: SetPskError) -> Self {
|
||||
match err {
|
||||
SetPskNetlinkError::NoSuchPeer => SetPskMsgsError::NoSuchPeer,
|
||||
_ => SetPskMsgsError::InternalError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NetlinkWireGuardBroker {
|
||||
sock: wg::WgSocket,
|
||||
}
|
||||
|
||||
impl NetlinkWireGuardBroker {
|
||||
pub fn new() -> Result<Self, ConnectError> {
|
||||
let sock = wg::WgSocket::connect()?;
|
||||
Ok(Self { sock })
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for NetlinkWireGuardBroker {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
//TODO: Add useful info in Debug
|
||||
f.debug_struct("NetlinkWireGuardBroker").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl WireGuardBroker for NetlinkWireGuardBroker {
|
||||
type Error = SetPskError;
|
||||
|
||||
fn set_psk(&mut self, config: SerializedBrokerConfig) -> Result<(), Self::Error> {
|
||||
let config: NetworkBrokerConfig = config
|
||||
.try_into()
|
||||
.map_err(|e| SetPskError::NoSuchInterface)?;
|
||||
// Ensure that the peer exists by querying the device configuration
|
||||
// TODO: Use InvalidInterfaceError
|
||||
|
||||
let state = self
|
||||
.sock
|
||||
.get_device(wg::DeviceInterface::from_name(config.iface))?;
|
||||
|
||||
if state
|
||||
.peers
|
||||
.iter()
|
||||
.find(|p| &p.public_key == &config.peer_id.value)
|
||||
.is_none()
|
||||
{
|
||||
return Err(SetPskError::NoSuchPeer);
|
||||
}
|
||||
|
||||
// Peer update description
|
||||
let mut set_peer = wireguard_uapi::set::Peer::from_public_key(&config.peer_id);
|
||||
set_peer
|
||||
.flags
|
||||
.push(wireguard_uapi::linux::set::WgPeerF::UpdateOnly);
|
||||
set_peer.preshared_key = Some(&config.psk.secret());
|
||||
|
||||
// Device update description
|
||||
let mut set_dev = wireguard_uapi::set::Device::from_ifname(config.iface);
|
||||
set_dev.peers.push(set_peer);
|
||||
|
||||
self.sock.set_device(set_dev)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user