feat: First version of broker based WireGuard PSK interface

This allows us to run with minimal priviledges in the Rosenpass process itself
This commit is contained in:
Karolin Varner
2023-12-09 19:42:15 +01:00
parent 3a0ebd2cbc
commit f3590645e9
19 changed files with 1478 additions and 76 deletions

View File

@@ -0,0 +1,42 @@
[package]
name = "rosenpass-wireguard-broker"
authors = ["Karolin Varner <karo@cupdev.net>", "wucke13 <wucke13@gmail.com>"]
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Rosenpass internal broker that runs as root and supplies exchanged keys to the kernel."
homepage = "https://rosenpass.eu/"
repository = "https://github.com/rosenpass/rosenpass"
readme = "readme.md"
[dependencies]
thiserror = { workspace = true }
rosenpass-lenses = { workspace = true }
paste = { workspace = true } # TODO: Using lenses should not necessitate importing paste
# Privileged only
wireguard-uapi = { workspace = true }
# Socket handler only
rosenpass-to = { workspace = true }
tokio = { workspace = true }
anyhow = { workspace = true }
clap = { workspace = true }
env_logger = { workspace = true }
log = { workspace = true }
# Mio broker client
mio = { workspace = true }
rosenpass-util = { workspace = true }
[[bin]]
name = "rosenpass-wireguard-broker-privileged"
path = "src/bin/priviledged.rs"
test = false
doc = false
[[bin]]
name = "rosenpass-wireguard-broker-socket-handler"
test = false
path = "src/bin/socket_handler.rs"
doc = false

View File

@@ -0,0 +1,5 @@
# Rosenpass internal broker supplying WireGuard with keys.
This crate contains a small application purpose-built to supply WireGuard in the linux kernel with pre-shared keys.
This is an internal library; not guarantee is made about its API at this point in time.

View File

@@ -0,0 +1,152 @@
use std::{borrow::BorrowMut, marker::PhantomData};
use rosenpass_lenses::LenseView;
use crate::{
api::msgs::{self, EnvelopeExt, SetPskRequestExt, SetPskResponseExt},
WireGuardBroker,
};
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum BrokerClientPollResponseError<RecvError> {
#[error(transparent)]
IoError(RecvError),
#[error("Invalid message.")]
InvalidMessage,
}
impl<RecvError> From<msgs::InvalidMessageTypeError> for BrokerClientPollResponseError<RecvError> {
fn from(value: msgs::InvalidMessageTypeError) -> Self {
let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type
BrokerClientPollResponseError::<RecvError>::InvalidMessage
}
}
fn io_pollerr<RecvError>(e: RecvError) -> BrokerClientPollResponseError<RecvError> {
BrokerClientPollResponseError::<RecvError>::IoError(e)
}
fn invalid_msg_pollerr<RecvError>() -> BrokerClientPollResponseError<RecvError> {
BrokerClientPollResponseError::<RecvError>::InvalidMessage
}
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum BrokerClientSetPskError<SendError> {
#[error(transparent)]
IoError(SendError),
#[error("Interface name out of bounds")]
IfaceOutOfBounds,
}
pub trait BrokerClientIo {
type SendError;
type RecvError;
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError>;
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError>;
}
#[derive(Debug)]
pub struct BrokerClient<'a, Io, IoRef>
where
Io: BrokerClientIo,
IoRef: 'a + BorrowMut<Io>,
{
io: IoRef,
_phantom_io: PhantomData<&'a mut Io>,
}
impl<'a, Io, IoRef> BrokerClient<'a, Io, IoRef>
where
Io: BrokerClientIo,
IoRef: 'a + BorrowMut<Io>,
{
pub fn new(io: IoRef) -> Self {
Self {
io,
_phantom_io: PhantomData,
}
}
pub fn io(&self) -> &IoRef {
&self.io
}
pub fn io_mut(&mut self) -> &mut IoRef {
&mut self.io
}
pub fn poll_response(
&mut self,
) -> Result<Option<msgs::SetPskResult>, BrokerClientPollResponseError<Io::RecvError>> {
let res: &[u8] = match self.io.borrow_mut().recv_msg().map_err(io_pollerr)? {
Some(r) => r,
None => return Ok(None),
};
let typ = res.get(0).ok_or(invalid_msg_pollerr())?;
let typ = msgs::MsgType::try_from(*typ)?;
let msgs::MsgType::SetPsk = typ; // Assert type
let res: msgs::Envelope<_, msgs::SetPskResponse<&[u8]>> = res
.envelope_truncating()
.map_err(|_| invalid_msg_pollerr())?;
let res: msgs::SetPskResponse<&[u8]> = res
.payload()
.set_psk_response()
.map_err(|_| invalid_msg_pollerr())?;
let res: msgs::SetPskResponseReturnCode = res.return_code()[0]
.try_into()
.map_err(|_| invalid_msg_pollerr())?;
let res: msgs::SetPskResult = res.into();
Ok(Some(res))
}
}
impl<'a, Io, IoRef> WireGuardBroker for BrokerClient<'a, Io, IoRef>
where
Io: BrokerClientIo,
IoRef: 'a + BorrowMut<Io>,
{
type Error = BrokerClientSetPskError<Io::SendError>;
fn set_psk(
&mut self,
iface: &str,
peer_id: [u8; 32],
psk: [u8; 32],
) -> Result<(), Self::Error> {
use BrokerClientSetPskError::*;
const BUF_SIZE: usize = <msgs::Envelope<(), msgs::SetPskRequest<()>> as LenseView>::LEN;
// Allocate message
let mut req = [0u8; BUF_SIZE];
// Construct message view
let mut req: msgs::Envelope<_, msgs::SetPskRequest<&mut [u8]>> =
(&mut req as &mut [u8]).envelope_truncating().unwrap();
// Populate envelope
req.msg_type_mut()
.copy_from_slice(&[msgs::MsgType::SetPsk as u8]);
{
// Derived payload
let mut req: msgs::SetPskRequest<&mut [u8]> =
req.payload_mut().set_psk_request().unwrap();
// Populate payload
req.peer_id_mut().copy_from_slice(&peer_id);
req.psk_mut().copy_from_slice(&psk);
req.set_iface(iface).ok_or(IfaceOutOfBounds)?;
}
// Send message
self.io
.borrow_mut()
.send_msg(req.all_bytes())
.map_err(IoError)?;
Ok(())
}
}

View File

@@ -0,0 +1,204 @@
use std::collections::VecDeque;
use std::io::{ErrorKind, Read, Write};
use anyhow::{bail, ensure};
use crate::WireGuardBroker;
use super::client::{
BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError,
};
use super::msgs;
#[derive(Debug)]
pub struct MioBrokerClient {
inner: BrokerClient<'static, MioBrokerClientIo, MioBrokerClientIo>,
}
#[derive(Debug)]
struct MioBrokerClientIo {
socket: mio::net::UnixStream,
send_buf: VecDeque<u8>,
receiving_size: bool,
recv_buf: Vec<u8>,
recv_off: usize,
}
impl MioBrokerClient {
pub fn new(socket: mio::net::UnixStream) -> Self {
let io = MioBrokerClientIo {
socket,
send_buf: VecDeque::new(),
receiving_size: false,
recv_buf: Vec::new(),
recv_off: 0,
};
let inner = BrokerClient::new(io);
Self { inner }
}
pub fn poll(&mut self) -> anyhow::Result<Option<msgs::SetPskResult>> {
self.inner.io_mut().flush()?;
// This sucks
match self.inner.poll_response() {
Ok(res) => {
return Ok(res);
}
Err(BrokerClientPollResponseError::IoError(e)) => {
return Err(e);
}
Err(BrokerClientPollResponseError::InvalidMessage) => {
bail!("Invalid message");
}
};
}
}
impl WireGuardBroker for MioBrokerClient {
type Error = anyhow::Error;
fn set_psk(&mut self, iface: &str, peer_id: [u8; 32], psk: [u8; 32]) -> anyhow::Result<()> {
use BrokerClientSetPskError::*;
let e = self.inner.set_psk(iface, peer_id, psk);
match e {
Ok(()) => Ok(()),
Err(IoError(e)) => Err(e),
Err(IfaceOutOfBounds) => bail!("Interface name size is out of bounds."),
}
}
}
impl BrokerClientIo for MioBrokerClientIo {
type SendError = anyhow::Error;
type RecvError = anyhow::Error;
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> {
self.flush()?;
self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?;
self.send_or_buffer(&buf)?;
self.flush()?;
Ok(())
}
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError> {
// Stale message in receive buffer. Reset!
if self.recv_off == self.recv_buf.len() {
self.receiving_size = true;
self.recv_off = 0;
self.recv_buf.resize(8, 0);
}
// Try filling the receive buffer
self.recv_off += raw_recv(&self.socket, &mut self.recv_buf[self.recv_off..])?;
if self.recv_off < self.recv_buf.len() {
return Ok(None);
}
// Received size, now start receiving
if self.receiving_size {
// Received the size
// Parse the received length
let len: &[u8; 8] = self.recv_buf[..].try_into().unwrap();
let len: usize = u64::from_le_bytes(*len) as usize;
ensure!(
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in psk buffer response."
);
// Prepare the message buffer for receiving an actual message of the given size
self.receiving_size = false;
self.recv_off = 0;
self.recv_buf.resize(len, 0);
// Try to receive the message
return self.recv_msg();
}
// Received an actual message
return Ok(Some(&self.recv_buf[..]));
}
}
impl MioBrokerClientIo {
fn flush(&mut self) -> anyhow::Result<()> {
let (fst, snd) = self.send_buf.as_slices();
let (written, res) = match raw_send(&self.socket, fst) {
Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) {
Ok(w2) => (w1 + w2, Ok(())),
Err(e) => (w1, Err(e)),
},
Ok(w1) => (w1, Ok(())),
Err(e) => (0, Err(e)),
};
self.send_buf.drain(..written);
(&self.socket).try_io(|| (&self.socket).flush())?;
res
}
fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> {
let mut off = 0;
if self.send_buf.is_empty() {
off += raw_send(&self.socket, buf)?;
}
self.send_buf.extend((&buf[off..]).iter());
Ok(())
}
}
fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == data.len() {
return Ok(());
}
match socket.write(&data[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
return Ok(off);
}
fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == out.len() {
return Ok(());
}
match socket.read(&mut out[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
return Ok(off);
}

View File

@@ -0,0 +1,4 @@
pub mod client;
pub mod mio_client;
pub mod msgs;
pub mod server;

View File

@@ -0,0 +1,140 @@
use std::result::Result;
use std::str::{from_utf8, Utf8Error};
use rosenpass_lenses::{lense, LenseView};
pub const REQUEST_MSG_BUFFER_SIZE: usize = <Envelope<(), SetPskRequest<()>> as LenseView>::LEN;
pub const RESPONSE_MSG_BUFFER_SIZE: usize = <Envelope<(), SetPskResponse<()>> as LenseView>::LEN;
lense! { Envelope<M> :=
/// [MsgType] of this message
msg_type: 1,
/// Reserved for future use
reserved: 3,
/// The actual Paylod
payload: M::LEN
}
lense! { SetPskRequest :=
peer_id: 32,
psk: 32,
iface_size: 1, // TODO: We should have variable length strings in lenses
iface_buf: 255
}
impl SetPskRequest<&[u8]> {
pub fn iface_bin(&self) -> &[u8] {
let len = self.iface_size()[0] as usize;
&self.iface_buf()[..len]
}
pub fn iface(&self) -> Result<&str, Utf8Error> {
from_utf8(self.iface_bin())
}
}
impl SetPskRequest<&mut [u8]> {
pub fn set_iface_bin(&mut self, iface: &[u8]) -> Option<()> {
(iface.len() < 256).then_some(())?; // Assert iface.len() < 256
self.iface_size_mut()[0] = iface.len() as u8;
self.iface_buf_mut().fill(0);
(&mut self.iface_buf_mut()[..iface.len()]).copy_from_slice(iface);
Some(())
}
pub fn set_iface(&mut self, iface: &str) -> Option<()> {
self.set_iface_bin(iface.as_bytes())
}
}
lense! { SetPskResponse :=
return_code: 1
}
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum SetPskError {
#[error("The wireguard pre-shared-key assignment broker experienced an internal error.")]
InternalError,
#[error("The indicated wireguard interface does not exist")]
NoSuchInterface,
#[error("The indicated peer does not exist on the wireguard interface")]
NoSuchPeer,
}
pub type SetPskResult = Result<(), SetPskError>;
#[repr(u8)]
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
pub enum SetPskResponseReturnCode {
Success = 0x00,
InternalError = 0x01,
NoSuchInterface = 0x02,
NoSuchPeer = 0x03,
}
#[derive(Eq, PartialEq, Debug, Clone)]
pub struct InvalidSetPskResponseError;
impl TryFrom<u8> for SetPskResponseReturnCode {
type Error = InvalidSetPskResponseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
use SetPskResponseReturnCode::*;
match value {
0x00 => Ok(Success),
0x01 => Ok(InternalError),
0x02 => Ok(NoSuchInterface),
0x03 => Ok(NoSuchPeer),
_ => Err(InvalidSetPskResponseError),
}
}
}
impl From<SetPskResponseReturnCode> for SetPskResult {
fn from(value: SetPskResponseReturnCode) -> Self {
use SetPskError as E;
use SetPskResponseReturnCode as C;
match value {
C::Success => Ok(()),
C::InternalError => Err(E::InternalError),
C::NoSuchInterface => Err(E::NoSuchInterface),
C::NoSuchPeer => Err(E::NoSuchPeer),
}
}
}
impl From<SetPskResult> for SetPskResponseReturnCode {
fn from(value: SetPskResult) -> Self {
use SetPskError as E;
use SetPskResponseReturnCode as C;
match value {
Ok(()) => C::Success,
Err(E::InternalError) => C::InternalError,
Err(E::NoSuchInterface) => C::NoSuchInterface,
Err(E::NoSuchPeer) => C::NoSuchPeer,
}
}
}
#[repr(u8)]
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
pub enum MsgType {
SetPsk = 0x01,
}
#[derive(Eq, PartialEq, Debug, Clone)]
pub struct InvalidMessageTypeError;
impl TryFrom<u8> for MsgType {
type Error = InvalidMessageTypeError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(MsgType::SetPsk),
_ => Err(InvalidMessageTypeError),
}
}
}

View File

@@ -0,0 +1,99 @@
use std::borrow::BorrowMut;
use std::marker::PhantomData;
use std::result::Result;
use rosenpass_lenses::LenseError;
use crate::api::msgs::{self, EnvelopeExt, SetPskRequestExt, SetPskResponseExt};
use crate::WireGuardBroker;
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
pub enum BrokerServerError {
#[error("No such request type: {}", .0)]
NoSuchRequestType(u8),
#[error("Invalid message received.")]
InvalidMessage,
}
impl From<LenseError> for BrokerServerError {
fn from(value: LenseError) -> Self {
use BrokerServerError as Be;
use LenseError as Le;
match value {
Le::BufferSizeMismatch => Be::InvalidMessage,
}
}
}
impl From<msgs::InvalidMessageTypeError> for BrokerServerError {
fn from(value: msgs::InvalidMessageTypeError) -> Self {
let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type
BrokerServerError::InvalidMessage
}
}
pub struct BrokerServer<'a, Err, Inner, Ref>
where
msgs::SetPskError: From<Err>,
Inner: WireGuardBroker<Error = Err>,
Ref: BorrowMut<Inner> + 'a,
{
inner: Ref,
_phantom: PhantomData<&'a mut Inner>,
}
impl<'a, Err, Inner, Ref> BrokerServer<'a, Err, Inner, Ref>
where
msgs::SetPskError: From<Err>,
Inner: WireGuardBroker<Error = Err>,
Ref: 'a + BorrowMut<Inner>,
{
pub fn new(inner: Ref) -> Self {
Self {
inner,
_phantom: PhantomData,
}
}
pub fn handle_message(
&mut self,
req: &[u8],
res: &mut [u8; msgs::RESPONSE_MSG_BUFFER_SIZE],
) -> Result<usize, BrokerServerError> {
use BrokerServerError::*;
let typ = req.get(0).ok_or(InvalidMessage)?;
let typ = msgs::MsgType::try_from(*typ)?;
let msgs::MsgType::SetPsk = typ; // Assert type
let req: msgs::Envelope<_, msgs::SetPskRequest<&[u8]>> = req.envelope_truncating()?;
let mut res: msgs::Envelope<_, msgs::SetPskResponse<&mut [u8]>> =
(res as &mut [u8]).envelope_truncating()?;
(&mut res).msg_type_mut()[0] = msgs::MsgType::SetPsk as u8;
self.handle_set_psk(
req.payload().set_psk_request()?,
res.payload_mut().set_psk_response()?,
)?;
Ok(res.all_bytes().len())
}
fn handle_set_psk(
&mut self,
req: msgs::SetPskRequest<&[u8]>,
mut res: msgs::SetPskResponse<&mut [u8]>,
) -> Result<(), BrokerServerError> {
// Using unwrap here since lenses can not return fixed-size arrays
// TODO: Slices should give access to fixed size arrays
let r: Result<(), Err> = self.inner.borrow_mut().set_psk(
req.iface()
.map_err(|_e| BrokerServerError::InvalidMessage)?,
req.peer_id().try_into().unwrap(),
req.psk().try_into().unwrap(),
);
let r: msgs::SetPskResult = r.map_err(|e| e.into());
let r: msgs::SetPskResponseReturnCode = r.into();
res.return_code_mut()[0] = r as u8;
Ok(())
}
}

View File

@@ -0,0 +1,56 @@
use std::io::{stdin, stdout, Read, Write};
use std::result::Result;
use rosenpass_wireguard_broker::api::msgs;
use rosenpass_wireguard_broker::api::server::BrokerServer;
use rosenpass_wireguard_broker::netlink as wg;
#[derive(thiserror::Error, Debug)]
pub enum BrokerAppError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
WgConnectError(#[from] wg::ConnectError),
#[error(transparent)]
WgSetPskError(#[from] wg::SetPskError),
#[error("Oversized message {}; something about the request is fatally wrong", .0)]
OversizedMessage(u64),
}
fn main() -> Result<(), BrokerAppError> {
let mut broker = BrokerServer::new(wg::NetlinkWireGuardBroker::new()?);
let mut stdin = stdin().lock();
let mut stdout = stdout().lock();
loop {
// Read the message length
let mut len = [0u8; 8];
stdin.read_exact(&mut len)?;
// Parse the message length
let len = u64::from_le_bytes(len);
if (len as usize) > msgs::REQUEST_MSG_BUFFER_SIZE {
return Err(BrokerAppError::OversizedMessage(len));
}
// Read the message itself
let mut req_buf = [0u8; msgs::REQUEST_MSG_BUFFER_SIZE];
let req_buf = &mut req_buf[..(len as usize)];
stdin.read_exact(req_buf)?;
// Process the message
let mut res_buf = [0u8; msgs::RESPONSE_MSG_BUFFER_SIZE];
let res = match broker.handle_message(req_buf, &mut res_buf) {
Ok(len) => &res_buf[..len],
Err(e) => {
eprintln!("Error processing message for wireguard PSK broker: {e:?}");
continue;
}
};
// Write the response
stdout.write_all(&(res.len() as u64).to_le_bytes())?;
stdout.write_all(&res)?;
stdout.flush()?;
}
}

View File

@@ -0,0 +1,191 @@
use std::process::Stdio;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{UnixListener, UnixStream};
use tokio::process::Command;
use tokio::sync::{mpsc, oneshot};
use tokio::task;
use anyhow::{bail, ensure, Result};
use clap::{ArgGroup, Parser};
use rosenpass_util::fd::claim_fd;
use rosenpass_wireguard_broker::api::msgs;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[clap(group(
ArgGroup::new("socket")
.required(true)
.args(&["listen_path", "listen_fd", "stream_fd"]),
))]
struct Args {
/// Where in the file-system to create the unix socket this broker will be listening for
/// connections on
#[arg(long)]
listen_path: Option<String>,
/// When this broker is called from another process, the other process can open and bind the
/// unix socket to use themselves, passing it to this process. In Rust this can be achieved
/// using the [command-fds](https://docs.rs/command-fds/latest/command_fds/) crate.
#[arg(long)]
listen_fd: Option<i32>,
/// When this broker is called from another process, the other process can connect the unix socket
/// themselves, for instance using the `socketpair(2)` system call.
#[arg(long)]
stream_fd: Option<i32>,
/// The underlying broker, accepting commands through stdin and sending results through stdout.
#[arg(
last = true,
allow_hyphen_values = true,
default_value = "rosenpass-wireguard-broker-privileged"
)]
command: Vec<String>,
}
struct BrokerRequest {
reply_to: oneshot::Sender<BrokerResponse>,
request: Vec<u8>,
}
struct BrokerResponse {
response: Vec<u8>,
}
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
let args = Args::parse();
let (proc_tx, proc_rx) = mpsc::channel(100);
// Start the inner broker handler
task::spawn(async move {
if let Err(e) = direct_broker_process(proc_rx, args.command).await {
log::error!("Error in broker command handler: {e}");
panic!("Can not proceed without underlying broker process");
}
});
// Listen for incoming requests
if let Some(path) = args.listen_path {
let sock = UnixListener::bind(path)?;
listen_for_clients(proc_tx, sock).await
} else if let Some(fd) = args.listen_fd {
let sock = std::os::unix::net::UnixListener::from(claim_fd(fd)?);
sock.set_nonblocking(true)?;
listen_for_clients(proc_tx, UnixListener::from_std(sock)?).await
} else if let Some(fd) = args.stream_fd {
let stream = std::os::unix::net::UnixStream::from(claim_fd(fd)?);
stream.set_nonblocking(true)?;
on_accept(proc_tx, UnixStream::from_std(stream)?).await
} else {
unreachable!();
}
}
async fn direct_broker_process(
mut queue: mpsc::Receiver<BrokerRequest>,
cmd: Vec<String>,
) -> Result<()> {
let proc = Command::new(&cmd[0])
.args(&cmd[1..])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
let mut stdin = proc.stdin.unwrap();
let mut stdout = proc.stdout.unwrap();
loop {
let BrokerRequest { reply_to, request } = queue.recv().await.unwrap();
stdin
.write_all(&(request.len() as u64).to_le_bytes())
.await?;
stdin.write_all(&request[..]).await?;
// Read the response length
let mut len = [0u8; 8];
stdout.read_exact(&mut len).await?;
// Parse the response length
let len = u64::from_le_bytes(len) as usize;
ensure!(
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in broker stdout."
);
// Read the message itself
let mut res_buf = request; // Avoid allocating memory if we don't have to
res_buf.resize(len as usize, 0);
stdout.read_exact(&mut res_buf[..len]).await?;
// Return to the unix socket connection worker
reply_to
.send(BrokerResponse { response: res_buf })
.or_else(|_| bail!("Unable to send respnse to unix socket worker."))?;
}
}
async fn listen_for_clients(queue: mpsc::Sender<BrokerRequest>, sock: UnixListener) -> Result<()> {
loop {
let (stream, _addr) = sock.accept().await?;
let queue = queue.clone();
task::spawn(async move {
if let Err(e) = on_accept(queue, stream).await {
log::error!("Error during connection processing: {e}");
}
});
}
// NOTE: If loop can ever terminate we need to join the spawned tasks
}
async fn on_accept(queue: mpsc::Sender<BrokerRequest>, mut stream: UnixStream) -> Result<()> {
let mut req_buf = Vec::new();
loop {
stream.readable().await?;
// Read the message length
let mut len = [0u8; 8];
stream.read_exact(&mut len).await?;
// Parse the message length
let len = u64::from_le_bytes(len) as usize;
ensure!(
len <= msgs::REQUEST_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in unix socket input."
);
// Read the message itself
req_buf.resize(len as usize, 0);
stream.read_exact(&mut req_buf[..len]).await?;
// Handle the message
let (reply_tx, reply_rx) = oneshot::channel();
queue
.send(BrokerRequest {
reply_to: reply_tx,
request: req_buf,
})
.await?;
// Wait for the reply
let BrokerResponse { response } = reply_rx.await.unwrap();
// Write reply back to unix socket
stream
.write_all(&(response.len() as u64).to_le_bytes())
.await?;
stream.write_all(&response[..]).await?;
stream.flush().await?;
// Reuse the same memory for the next message
req_buf = response;
}
}

View File

@@ -0,0 +1,15 @@
use std::result::Result;
pub trait WireGuardBroker {
type Error;
fn set_psk(
&mut self,
interface: &str,
peer_id: [u8; 32],
psk: [u8; 32],
) -> Result<(), Self::Error>;
}
pub mod api;
pub mod netlink;

View File

@@ -0,0 +1,103 @@
use wireguard_uapi::linux as wg;
use crate::api::msgs;
use crate::WireGuardBroker;
#[derive(thiserror::Error, Debug)]
pub enum ConnectError {
#[error(transparent)]
ConnectError(#[from] wg::err::ConnectError),
}
#[derive(thiserror::Error, Debug)]
pub enum NetlinkError {
#[error(transparent)]
SetDevice(#[from] wg::err::SetDeviceError),
#[error(transparent)]
GetDevice(#[from] wg::err::GetDeviceError),
}
#[derive(thiserror::Error, Debug)]
pub enum SetPskError {
#[error("The indicated wireguard interface does not exist")]
NoSuchInterface,
#[error("The indicated peer does not exist on the wireguard interface")]
NoSuchPeer,
#[error(transparent)]
NetlinkError(#[from] NetlinkError),
}
impl From<wg::err::SetDeviceError> for SetPskError {
fn from(err: wg::err::SetDeviceError) -> Self {
NetlinkError::from(err).into()
}
}
impl From<wg::err::GetDeviceError> for SetPskError {
fn from(err: wg::err::GetDeviceError) -> Self {
NetlinkError::from(err).into()
}
}
use msgs::SetPskError as SetPskMsgsError;
use SetPskError as SetPskNetlinkError;
impl From<SetPskNetlinkError> for SetPskMsgsError {
fn from(err: SetPskError) -> Self {
match err {
SetPskNetlinkError::NoSuchPeer => SetPskMsgsError::NoSuchPeer,
_ => SetPskMsgsError::InternalError,
}
}
}
pub struct NetlinkWireGuardBroker {
sock: wg::WgSocket,
}
impl NetlinkWireGuardBroker {
pub fn new() -> Result<Self, ConnectError> {
let sock = wg::WgSocket::connect()?;
Ok(Self { sock })
}
}
impl WireGuardBroker for NetlinkWireGuardBroker {
type Error = SetPskError;
fn set_psk(
&mut self,
interface: &str,
peer_id: [u8; 32],
psk: [u8; 32],
) -> Result<(), Self::Error> {
// Ensure that the peer exists by querying the device configuration
// TODO: Use InvalidInterfaceError
let state = self
.sock
.get_device(wg::DeviceInterface::from_name(interface.to_owned()))?;
if state
.peers
.iter()
.find(|p| &p.public_key == &peer_id)
.is_none()
{
return Err(SetPskError::NoSuchPeer);
}
// Peer update description
let mut set_peer = wireguard_uapi::set::Peer::from_public_key(&peer_id);
set_peer
.flags
.push(wireguard_uapi::linux::set::WgPeerF::UpdateOnly);
set_peer.preshared_key = Some(&psk);
// Device update description
let mut set_dev = wireguard_uapi::set::Device::from_ifname(interface.to_owned());
set_dev.peers.push(set_peer);
self.sock.set_device(set_dev)?;
Ok(())
}
}