feat(rosenpass): Add wireguard-broker interface in AppServer (#303)

Dynamically dispatch WireguardBrokerMio trait in AppServer. Also allows for mio event registration and poll processing, logic from dev/broker-architecture branch

Co-authored-by: Prabhpreet Dua <615318+prabhpreet@users.noreply.github.com>
Co-authored-by: Karolin Varner <karo@cupdev.net>
This commit is contained in:
Prabhpreet Dua
2024-05-20 18:12:42 +05:30
committed by GitHub
parent ae7577c641
commit c1abfbfd14
26 changed files with 693 additions and 177 deletions

View File

@@ -0,0 +1,259 @@
use anyhow::{bail, ensure};
use mio::Interest;
use std::collections::VecDeque;
use std::io::{ErrorKind, Read, Write};
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio};
use crate::api::client::{
BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError,
};
use crate::api::msgs::{self, RESPONSE_MSG_BUFFER_SIZE};
#[derive(Debug)]
pub struct MioBrokerClient {
inner: BrokerClient<MioBrokerClientIo>,
}
const LEN_SIZE: usize = 8;
const RECV_BUF_SIZE: usize = RESPONSE_MSG_BUFFER_SIZE;
#[derive(Debug)]
struct MioBrokerClientIo {
socket: mio::net::UnixStream,
send_buf: VecDeque<u8>,
recv_state: RxState,
expected_state: RxState,
recv_buf: [u8; RECV_BUF_SIZE],
}
#[derive(Debug, Clone, Copy)]
enum RxState {
//Recieving size with buffer offset
RxSize(usize),
RxBuffer(usize),
}
impl MioBrokerClient {
pub fn new(socket: mio::net::UnixStream) -> Self {
let io = MioBrokerClientIo {
socket,
send_buf: VecDeque::new(),
recv_state: RxState::RxSize(0),
recv_buf: [0u8; RECV_BUF_SIZE],
expected_state: RxState::RxSize(LEN_SIZE),
};
let inner = BrokerClient::new(io);
Self { inner }
}
fn poll(&mut self) -> anyhow::Result<Option<msgs::SetPskResult>> {
self.inner.io_mut().flush()?;
// This sucks
match self.inner.poll_response() {
Ok(res) => {
return Ok(res);
}
Err(BrokerClientPollResponseError::IoError(e)) => {
return Err(e);
}
Err(BrokerClientPollResponseError::InvalidMessage) => {
bail!("Invalid message");
}
};
}
}
impl WireGuardBroker for MioBrokerClient {
type Error = anyhow::Error;
fn set_psk<'a>(&mut self, config: SerializedBrokerConfig<'a>) -> anyhow::Result<()> {
use BrokerClientSetPskError::*;
let e = self.inner.set_psk(config);
match e {
Ok(()) => Ok(()),
Err(IoError(e)) => Err(e),
Err(IfaceOutOfBounds) => bail!("Interface name size is out of bounds."),
Err(MsgError) => bail!("Error with encoding/decoding message."),
Err(BrokerError(e)) => bail!("Broker error: {:?}", e),
}
}
}
impl WireguardBrokerMio for MioBrokerClient {
type MioError = anyhow::Error;
fn register(
&mut self,
registry: &mio::Registry,
token: mio::Token,
) -> Result<(), Self::MioError> {
registry.register(
&mut self.inner.io_mut().socket,
token,
Interest::READABLE | Interest::WRITABLE,
)?;
Ok(())
}
fn process_poll(&mut self) -> Result<(), Self::MioError> {
self.poll()?;
Ok(())
}
fn unregister(&mut self, registry: &mio::Registry) -> Result<(), Self::MioError> {
registry.deregister(&mut self.inner.io_mut().socket)?;
Ok(())
}
}
impl BrokerClientIo for MioBrokerClientIo {
type SendError = anyhow::Error;
type RecvError = anyhow::Error;
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> {
self.flush()?;
self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?;
self.send_or_buffer(&buf)?;
self.flush()?;
Ok(())
}
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError> {
loop {
match (self.recv_state, self.expected_state) {
//Stale Buffer state or recieved everything
(RxState::RxSize(x), RxState::RxSize(y))
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
if x == y =>
{
match self.recv_state {
RxState::RxSize(s) => {
let len: &[u8; LEN_SIZE] = self.recv_buf[0..s].try_into().unwrap();
let len: usize = u64::from_le_bytes(*len) as usize;
ensure!(
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in psk buffer response."
);
self.recv_state = RxState::RxBuffer(0);
self.expected_state = RxState::RxBuffer(len);
continue;
}
RxState::RxBuffer(s) => {
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
return Ok(Some(&self.recv_buf[0..s]));
}
}
}
//Recieve if x < y
(RxState::RxSize(x), RxState::RxSize(y))
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
if x < y =>
{
let bytes = raw_recv(&self.socket, &mut self.recv_buf[x..y])?;
if x + bytes == y {
return Ok(Some(&self.recv_buf[0..y]));
}
//We didn't recieve everything so let's assume something went wrong
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
bail!("Invalid state");
}
_ => {
//Reset states
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
bail!("Invalid state");
}
};
}
}
}
impl MioBrokerClientIo {
fn flush(&mut self) -> anyhow::Result<()> {
let (fst, snd) = self.send_buf.as_slices();
let (written, res) = match raw_send(&self.socket, fst) {
Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) {
Ok(w2) => (w1 + w2, Ok(())),
Err(e) => (w1, Err(e)),
},
Ok(w1) => (w1, Ok(())),
Err(e) => (0, Err(e)),
};
self.send_buf.drain(..written);
(&self.socket).try_io(|| (&self.socket).flush())?;
res
}
fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> {
let mut off = 0;
if self.send_buf.is_empty() {
off += raw_send(&self.socket, buf)?;
}
self.send_buf.extend((&buf[off..]).iter());
Ok(())
}
}
fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == data.len() {
return Ok(());
}
match socket.write(&data[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
return Ok(off);
}
fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == out.len() {
return Ok(());
}
match socket.read(&mut out[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
return Ok(off);
}

View File

@@ -0,0 +1,6 @@
#[cfg(feature = "enable_broker_api")]
pub mod mio_client;
#[cfg(feature = "enable_broker_api")]
pub mod netlink;
pub mod native_unix;

View File

@@ -0,0 +1,177 @@
use std::fmt::Debug;
use std::process::{Command, Stdio};
use std::thread;
use derive_builder::Builder;
use log::{debug, error};
use postcard::{from_bytes, to_allocvec};
use rosenpass_secret_memory::{Public, Secret};
use rosenpass_util::b64::b64_decode;
use rosenpass_util::{b64::B64Display, file::StoreValueB64Writer};
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerCfg, WireguardBrokerMio};
use crate::{WG_KEY_LEN, WG_PEER_LEN};
const MAX_B64_KEY_SIZE: usize = WG_KEY_LEN * 5 / 3;
const MAX_B64_PEER_ID_SIZE: usize = WG_PEER_LEN * 5 / 3;
#[derive(Debug)]
pub struct NativeUnixBroker {}
impl Default for NativeUnixBroker {
fn default() -> Self {
Self::new()
}
}
impl NativeUnixBroker {
pub fn new() -> Self {
Self {}
}
}
impl WireGuardBroker for NativeUnixBroker {
type Error = anyhow::Error;
fn set_psk(&mut self, config: SerializedBrokerConfig<'_>) -> Result<(), Self::Error> {
let config: NativeUnixBrokerConfig = config.try_into()?;
let peer_id = format!("{}", config.peer_id.fmt_b64::<MAX_B64_PEER_ID_SIZE>());
let mut child = match Command::new("wg")
.arg("set")
.arg(config.interface)
.arg("peer")
.arg(peer_id)
.arg("preshared-key")
.arg("/dev/stdin")
.stdin(Stdio::piped())
.args(config.extra_params)
.spawn()
{
Ok(x) => x,
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
anyhow::bail!("Could not find wg command");
} else {
return Err(anyhow::Error::new(e));
}
}
};
if let Err(e) = config
.psk
.store_b64_writer::<MAX_B64_KEY_SIZE, _>(child.stdin.take().unwrap())
{
error!("could not write psk to wg: {:?}", e);
}
thread::spawn(move || {
let status = child.wait();
if let Ok(status) = status {
if status.success() {
debug!("successfully passed psk to wg")
} else {
error!("could not pass psk to wg {:?}", status)
}
} else {
error!("wait failed: {:?}", status)
}
});
Ok(())
}
}
impl WireguardBrokerMio for NativeUnixBroker {
type MioError = anyhow::Error;
fn register(
&mut self,
_registry: &mio::Registry,
_token: mio::Token,
) -> Result<(), Self::MioError> {
Ok(())
}
fn process_poll(&mut self) -> Result<(), Self::MioError> {
Ok(())
}
fn unregister(&mut self, _registry: &mio::Registry) -> Result<(), Self::MioError> {
Ok(())
}
}
#[derive(Debug, Builder)]
#[builder(pattern = "mutable")]
pub struct NativeUnixBrokerConfigBase {
pub interface: String,
pub peer_id: Public<WG_PEER_LEN>,
#[builder(private)]
pub extra_params: Vec<u8>,
}
impl NativeUnixBrokerConfigBaseBuilder {
pub fn peer_id_b64(
&mut self,
peer_id: &str,
) -> Result<&mut Self, NativeUnixBrokerConfigBaseBuilderError> {
let mut peer_id_b64 = Public::<WG_PEER_LEN>::zero();
b64_decode(peer_id.as_bytes(), &mut peer_id_b64.value).map_err(|_e| {
NativeUnixBrokerConfigBaseBuilderError::ValidationError(
"Failed to parse peer id b64".to_string(),
)
})?;
Ok(self.peer_id(peer_id_b64))
}
pub fn extra_params_ser(
&mut self,
extra_params: &Vec<String>,
) -> Result<&mut Self, NativeUnixBrokerConfigBuilderError> {
let params = to_allocvec(extra_params).map_err(|_e| {
NativeUnixBrokerConfigBuilderError::ValidationError(
"Failed to parse extra params".to_string(),
)
})?;
Ok(self.extra_params(params))
}
}
impl WireguardBrokerCfg for NativeUnixBrokerConfigBase {
fn create_config<'a>(&'a self, psk: &'a Secret<WG_KEY_LEN>) -> SerializedBrokerConfig<'a> {
SerializedBrokerConfig {
interface: self.interface.as_bytes(),
peer_id: &self.peer_id,
psk,
additional_params: &self.extra_params,
}
}
}
#[derive(Debug, Builder)]
#[builder(pattern = "mutable")]
pub struct NativeUnixBrokerConfig<'a> {
pub interface: &'a str,
pub peer_id: &'a Public<WG_PEER_LEN>,
pub psk: &'a Secret<WG_KEY_LEN>,
pub extra_params: Vec<String>,
}
impl<'a> TryFrom<SerializedBrokerConfig<'a>> for NativeUnixBrokerConfig<'a> {
type Error = anyhow::Error;
fn try_from(value: SerializedBrokerConfig<'a>) -> Result<Self, Self::Error> {
let iface = std::str::from_utf8(value.interface)
.map_err(|_| anyhow::Error::msg("Interface UTF8 decoding error"))?;
let extra_params: Vec<String> =
from_bytes(value.additional_params).map_err(anyhow::Error::new)?;
Ok(Self {
interface: iface,
peer_id: value.peer_id,
psk: value.psk,
extra_params,
})
}
}

View File

@@ -0,0 +1,112 @@
use std::fmt::Debug;
use wireguard_uapi::linux as wg;
use crate::api::config::NetworkBrokerConfig;
use crate::api::msgs;
use crate::{SerializedBrokerConfig, WireGuardBroker};
#[derive(thiserror::Error, Debug)]
pub enum ConnectError {
#[error(transparent)]
ConnectError(#[from] wg::err::ConnectError),
}
#[derive(thiserror::Error, Debug)]
pub enum NetlinkError {
#[error(transparent)]
SetDevice(#[from] wg::err::SetDeviceError),
#[error(transparent)]
GetDevice(#[from] wg::err::GetDeviceError),
}
#[derive(thiserror::Error, Debug)]
pub enum SetPskError {
#[error("The indicated wireguard interface does not exist")]
NoSuchInterface,
#[error("The indicated peer does not exist on the wireguard interface")]
NoSuchPeer,
#[error(transparent)]
NetlinkError(#[from] NetlinkError),
}
impl From<wg::err::SetDeviceError> for SetPskError {
fn from(err: wg::err::SetDeviceError) -> Self {
NetlinkError::from(err).into()
}
}
impl From<wg::err::GetDeviceError> for SetPskError {
fn from(err: wg::err::GetDeviceError) -> Self {
NetlinkError::from(err).into()
}
}
use msgs::SetPskError as SetPskMsgsError;
use SetPskError as SetPskNetlinkError;
impl From<SetPskNetlinkError> for SetPskMsgsError {
fn from(err: SetPskError) -> Self {
match err {
SetPskNetlinkError::NoSuchPeer => SetPskMsgsError::NoSuchPeer,
_ => SetPskMsgsError::InternalError,
}
}
}
pub struct NetlinkWireGuardBroker {
sock: wg::WgSocket,
}
impl NetlinkWireGuardBroker {
pub fn new() -> Result<Self, ConnectError> {
let sock = wg::WgSocket::connect()?;
Ok(Self { sock })
}
}
impl Debug for NetlinkWireGuardBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
//TODO: Add useful info in Debug
f.debug_struct("NetlinkWireGuardBroker").finish()
}
}
impl WireGuardBroker for NetlinkWireGuardBroker {
type Error = SetPskError;
fn set_psk(&mut self, config: SerializedBrokerConfig) -> Result<(), Self::Error> {
let config: NetworkBrokerConfig = config
.try_into()
.map_err(|e| SetPskError::NoSuchInterface)?;
// Ensure that the peer exists by querying the device configuration
// TODO: Use InvalidInterfaceError
let state = self
.sock
.get_device(wg::DeviceInterface::from_name(config.iface))?;
if state
.peers
.iter()
.find(|p| &p.public_key == &config.peer_id.value)
.is_none()
{
return Err(SetPskError::NoSuchPeer);
}
// Peer update description
let mut set_peer = wireguard_uapi::set::Peer::from_public_key(&config.peer_id);
set_peer
.flags
.push(wireguard_uapi::linux::set::WgPeerF::UpdateOnly);
set_peer.preshared_key = Some(&config.psk.secret());
// Device update description
let mut set_dev = wireguard_uapi::set::Device::from_ifname(config.iface);
set_dev.peers.push(set_peer);
self.sock.set_device(set_dev)?;
Ok(())
}
}