mirror of
https://github.com/rosenpass/rosenpass.git
synced 2026-02-28 06:23:08 -08:00
feat: First version of broker based WireGuard PSK interface
This allows us to run with minimal priviledges in the Rosenpass process itself
This commit is contained in:
152
wireguard-broker/src/api/client.rs
Normal file
152
wireguard-broker/src/api/client.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use std::{borrow::BorrowMut, marker::PhantomData};
|
||||
|
||||
use rosenpass_lenses::LenseView;
|
||||
|
||||
use crate::{
|
||||
api::msgs::{self, EnvelopeExt, SetPskRequestExt, SetPskResponseExt},
|
||||
WireGuardBroker,
|
||||
};
|
||||
|
||||
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
|
||||
pub enum BrokerClientPollResponseError<RecvError> {
|
||||
#[error(transparent)]
|
||||
IoError(RecvError),
|
||||
#[error("Invalid message.")]
|
||||
InvalidMessage,
|
||||
}
|
||||
|
||||
impl<RecvError> From<msgs::InvalidMessageTypeError> for BrokerClientPollResponseError<RecvError> {
|
||||
fn from(value: msgs::InvalidMessageTypeError) -> Self {
|
||||
let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type
|
||||
BrokerClientPollResponseError::<RecvError>::InvalidMessage
|
||||
}
|
||||
}
|
||||
|
||||
fn io_pollerr<RecvError>(e: RecvError) -> BrokerClientPollResponseError<RecvError> {
|
||||
BrokerClientPollResponseError::<RecvError>::IoError(e)
|
||||
}
|
||||
|
||||
fn invalid_msg_pollerr<RecvError>() -> BrokerClientPollResponseError<RecvError> {
|
||||
BrokerClientPollResponseError::<RecvError>::InvalidMessage
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
|
||||
pub enum BrokerClientSetPskError<SendError> {
|
||||
#[error(transparent)]
|
||||
IoError(SendError),
|
||||
#[error("Interface name out of bounds")]
|
||||
IfaceOutOfBounds,
|
||||
}
|
||||
|
||||
pub trait BrokerClientIo {
|
||||
type SendError;
|
||||
type RecvError;
|
||||
|
||||
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError>;
|
||||
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BrokerClient<'a, Io, IoRef>
|
||||
where
|
||||
Io: BrokerClientIo,
|
||||
IoRef: 'a + BorrowMut<Io>,
|
||||
{
|
||||
io: IoRef,
|
||||
_phantom_io: PhantomData<&'a mut Io>,
|
||||
}
|
||||
|
||||
impl<'a, Io, IoRef> BrokerClient<'a, Io, IoRef>
|
||||
where
|
||||
Io: BrokerClientIo,
|
||||
IoRef: 'a + BorrowMut<Io>,
|
||||
{
|
||||
pub fn new(io: IoRef) -> Self {
|
||||
Self {
|
||||
io,
|
||||
_phantom_io: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn io(&self) -> &IoRef {
|
||||
&self.io
|
||||
}
|
||||
|
||||
pub fn io_mut(&mut self) -> &mut IoRef {
|
||||
&mut self.io
|
||||
}
|
||||
|
||||
pub fn poll_response(
|
||||
&mut self,
|
||||
) -> Result<Option<msgs::SetPskResult>, BrokerClientPollResponseError<Io::RecvError>> {
|
||||
let res: &[u8] = match self.io.borrow_mut().recv_msg().map_err(io_pollerr)? {
|
||||
Some(r) => r,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let typ = res.get(0).ok_or(invalid_msg_pollerr())?;
|
||||
let typ = msgs::MsgType::try_from(*typ)?;
|
||||
let msgs::MsgType::SetPsk = typ; // Assert type
|
||||
|
||||
let res: msgs::Envelope<_, msgs::SetPskResponse<&[u8]>> = res
|
||||
.envelope_truncating()
|
||||
.map_err(|_| invalid_msg_pollerr())?;
|
||||
let res: msgs::SetPskResponse<&[u8]> = res
|
||||
.payload()
|
||||
.set_psk_response()
|
||||
.map_err(|_| invalid_msg_pollerr())?;
|
||||
let res: msgs::SetPskResponseReturnCode = res.return_code()[0]
|
||||
.try_into()
|
||||
.map_err(|_| invalid_msg_pollerr())?;
|
||||
let res: msgs::SetPskResult = res.into();
|
||||
|
||||
Ok(Some(res))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, Io, IoRef> WireGuardBroker for BrokerClient<'a, Io, IoRef>
|
||||
where
|
||||
Io: BrokerClientIo,
|
||||
IoRef: 'a + BorrowMut<Io>,
|
||||
{
|
||||
type Error = BrokerClientSetPskError<Io::SendError>;
|
||||
|
||||
fn set_psk(
|
||||
&mut self,
|
||||
iface: &str,
|
||||
peer_id: [u8; 32],
|
||||
psk: [u8; 32],
|
||||
) -> Result<(), Self::Error> {
|
||||
use BrokerClientSetPskError::*;
|
||||
const BUF_SIZE: usize = <msgs::Envelope<(), msgs::SetPskRequest<()>> as LenseView>::LEN;
|
||||
|
||||
// Allocate message
|
||||
let mut req = [0u8; BUF_SIZE];
|
||||
|
||||
// Construct message view
|
||||
let mut req: msgs::Envelope<_, msgs::SetPskRequest<&mut [u8]>> =
|
||||
(&mut req as &mut [u8]).envelope_truncating().unwrap();
|
||||
|
||||
// Populate envelope
|
||||
req.msg_type_mut()
|
||||
.copy_from_slice(&[msgs::MsgType::SetPsk as u8]);
|
||||
{
|
||||
// Derived payload
|
||||
let mut req: msgs::SetPskRequest<&mut [u8]> =
|
||||
req.payload_mut().set_psk_request().unwrap();
|
||||
|
||||
// Populate payload
|
||||
req.peer_id_mut().copy_from_slice(&peer_id);
|
||||
req.psk_mut().copy_from_slice(&psk);
|
||||
req.set_iface(iface).ok_or(IfaceOutOfBounds)?;
|
||||
}
|
||||
|
||||
// Send message
|
||||
self.io
|
||||
.borrow_mut()
|
||||
.send_msg(req.all_bytes())
|
||||
.map_err(IoError)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
204
wireguard-broker/src/api/mio_client.rs
Normal file
204
wireguard-broker/src/api/mio_client.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::io::{ErrorKind, Read, Write};
|
||||
|
||||
use anyhow::{bail, ensure};
|
||||
|
||||
use crate::WireGuardBroker;
|
||||
|
||||
use super::client::{
|
||||
BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError,
|
||||
};
|
||||
use super::msgs;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MioBrokerClient {
|
||||
inner: BrokerClient<'static, MioBrokerClientIo, MioBrokerClientIo>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MioBrokerClientIo {
|
||||
socket: mio::net::UnixStream,
|
||||
send_buf: VecDeque<u8>,
|
||||
receiving_size: bool,
|
||||
recv_buf: Vec<u8>,
|
||||
recv_off: usize,
|
||||
}
|
||||
|
||||
impl MioBrokerClient {
|
||||
pub fn new(socket: mio::net::UnixStream) -> Self {
|
||||
let io = MioBrokerClientIo {
|
||||
socket,
|
||||
send_buf: VecDeque::new(),
|
||||
receiving_size: false,
|
||||
recv_buf: Vec::new(),
|
||||
recv_off: 0,
|
||||
};
|
||||
let inner = BrokerClient::new(io);
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn poll(&mut self) -> anyhow::Result<Option<msgs::SetPskResult>> {
|
||||
self.inner.io_mut().flush()?;
|
||||
|
||||
// This sucks
|
||||
match self.inner.poll_response() {
|
||||
Ok(res) => {
|
||||
return Ok(res);
|
||||
}
|
||||
Err(BrokerClientPollResponseError::IoError(e)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Err(BrokerClientPollResponseError::InvalidMessage) => {
|
||||
bail!("Invalid message");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl WireGuardBroker for MioBrokerClient {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn set_psk(&mut self, iface: &str, peer_id: [u8; 32], psk: [u8; 32]) -> anyhow::Result<()> {
|
||||
use BrokerClientSetPskError::*;
|
||||
let e = self.inner.set_psk(iface, peer_id, psk);
|
||||
match e {
|
||||
Ok(()) => Ok(()),
|
||||
Err(IoError(e)) => Err(e),
|
||||
Err(IfaceOutOfBounds) => bail!("Interface name size is out of bounds."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BrokerClientIo for MioBrokerClientIo {
|
||||
type SendError = anyhow::Error;
|
||||
type RecvError = anyhow::Error;
|
||||
|
||||
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> {
|
||||
self.flush()?;
|
||||
self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?;
|
||||
self.send_or_buffer(&buf)?;
|
||||
self.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError> {
|
||||
// Stale message in receive buffer. Reset!
|
||||
if self.recv_off == self.recv_buf.len() {
|
||||
self.receiving_size = true;
|
||||
self.recv_off = 0;
|
||||
self.recv_buf.resize(8, 0);
|
||||
}
|
||||
|
||||
// Try filling the receive buffer
|
||||
self.recv_off += raw_recv(&self.socket, &mut self.recv_buf[self.recv_off..])?;
|
||||
if self.recv_off < self.recv_buf.len() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Received size, now start receiving
|
||||
if self.receiving_size {
|
||||
// Received the size
|
||||
// Parse the received length
|
||||
let len: &[u8; 8] = self.recv_buf[..].try_into().unwrap();
|
||||
let len: usize = u64::from_le_bytes(*len) as usize;
|
||||
|
||||
ensure!(
|
||||
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
|
||||
"Oversized buffer ({len}) in psk buffer response."
|
||||
);
|
||||
|
||||
// Prepare the message buffer for receiving an actual message of the given size
|
||||
self.receiving_size = false;
|
||||
self.recv_off = 0;
|
||||
self.recv_buf.resize(len, 0);
|
||||
|
||||
// Try to receive the message
|
||||
return self.recv_msg();
|
||||
}
|
||||
|
||||
// Received an actual message
|
||||
return Ok(Some(&self.recv_buf[..]));
|
||||
}
|
||||
}
|
||||
|
||||
impl MioBrokerClientIo {
|
||||
fn flush(&mut self) -> anyhow::Result<()> {
|
||||
let (fst, snd) = self.send_buf.as_slices();
|
||||
|
||||
let (written, res) = match raw_send(&self.socket, fst) {
|
||||
Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) {
|
||||
Ok(w2) => (w1 + w2, Ok(())),
|
||||
Err(e) => (w1, Err(e)),
|
||||
},
|
||||
Ok(w1) => (w1, Ok(())),
|
||||
Err(e) => (0, Err(e)),
|
||||
};
|
||||
|
||||
self.send_buf.drain(..written);
|
||||
|
||||
(&self.socket).try_io(|| (&self.socket).flush())?;
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> {
|
||||
let mut off = 0;
|
||||
|
||||
if self.send_buf.is_empty() {
|
||||
off += raw_send(&self.socket, buf)?;
|
||||
}
|
||||
|
||||
self.send_buf.extend((&buf[off..]).iter());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result<usize> {
|
||||
let mut off = 0;
|
||||
|
||||
socket.try_io(|| {
|
||||
loop {
|
||||
if off == data.len() {
|
||||
return Ok(());
|
||||
}
|
||||
match socket.write(&data[off..]) {
|
||||
Ok(n) => {
|
||||
off += n;
|
||||
}
|
||||
Err(e) if e.kind() == ErrorKind::Interrupted => {
|
||||
// pass – retry
|
||||
}
|
||||
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
return Ok(off);
|
||||
}
|
||||
|
||||
fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result<usize> {
|
||||
let mut off = 0;
|
||||
|
||||
socket.try_io(|| {
|
||||
loop {
|
||||
if off == out.len() {
|
||||
return Ok(());
|
||||
}
|
||||
match socket.read(&mut out[off..]) {
|
||||
Ok(n) => {
|
||||
off += n;
|
||||
}
|
||||
Err(e) if e.kind() == ErrorKind::Interrupted => {
|
||||
// pass – retry
|
||||
}
|
||||
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
return Ok(off);
|
||||
}
|
||||
4
wireguard-broker/src/api/mod.rs
Normal file
4
wireguard-broker/src/api/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod client;
|
||||
pub mod mio_client;
|
||||
pub mod msgs;
|
||||
pub mod server;
|
||||
140
wireguard-broker/src/api/msgs.rs
Normal file
140
wireguard-broker/src/api/msgs.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use std::result::Result;
|
||||
use std::str::{from_utf8, Utf8Error};
|
||||
|
||||
use rosenpass_lenses::{lense, LenseView};
|
||||
|
||||
pub const REQUEST_MSG_BUFFER_SIZE: usize = <Envelope<(), SetPskRequest<()>> as LenseView>::LEN;
|
||||
pub const RESPONSE_MSG_BUFFER_SIZE: usize = <Envelope<(), SetPskResponse<()>> as LenseView>::LEN;
|
||||
|
||||
lense! { Envelope<M> :=
|
||||
/// [MsgType] of this message
|
||||
msg_type: 1,
|
||||
/// Reserved for future use
|
||||
reserved: 3,
|
||||
/// The actual Paylod
|
||||
payload: M::LEN
|
||||
}
|
||||
|
||||
lense! { SetPskRequest :=
|
||||
peer_id: 32,
|
||||
psk: 32,
|
||||
iface_size: 1, // TODO: We should have variable length strings in lenses
|
||||
iface_buf: 255
|
||||
}
|
||||
|
||||
impl SetPskRequest<&[u8]> {
|
||||
pub fn iface_bin(&self) -> &[u8] {
|
||||
let len = self.iface_size()[0] as usize;
|
||||
&self.iface_buf()[..len]
|
||||
}
|
||||
|
||||
pub fn iface(&self) -> Result<&str, Utf8Error> {
|
||||
from_utf8(self.iface_bin())
|
||||
}
|
||||
}
|
||||
|
||||
impl SetPskRequest<&mut [u8]> {
|
||||
pub fn set_iface_bin(&mut self, iface: &[u8]) -> Option<()> {
|
||||
(iface.len() < 256).then_some(())?; // Assert iface.len() < 256
|
||||
|
||||
self.iface_size_mut()[0] = iface.len() as u8;
|
||||
|
||||
self.iface_buf_mut().fill(0);
|
||||
(&mut self.iface_buf_mut()[..iface.len()]).copy_from_slice(iface);
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
pub fn set_iface(&mut self, iface: &str) -> Option<()> {
|
||||
self.set_iface_bin(iface.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
lense! { SetPskResponse :=
|
||||
return_code: 1
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
|
||||
pub enum SetPskError {
|
||||
#[error("The wireguard pre-shared-key assignment broker experienced an internal error.")]
|
||||
InternalError,
|
||||
#[error("The indicated wireguard interface does not exist")]
|
||||
NoSuchInterface,
|
||||
#[error("The indicated peer does not exist on the wireguard interface")]
|
||||
NoSuchPeer,
|
||||
}
|
||||
|
||||
pub type SetPskResult = Result<(), SetPskError>;
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
|
||||
pub enum SetPskResponseReturnCode {
|
||||
Success = 0x00,
|
||||
InternalError = 0x01,
|
||||
NoSuchInterface = 0x02,
|
||||
NoSuchPeer = 0x03,
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Debug, Clone)]
|
||||
pub struct InvalidSetPskResponseError;
|
||||
|
||||
impl TryFrom<u8> for SetPskResponseReturnCode {
|
||||
type Error = InvalidSetPskResponseError;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
use SetPskResponseReturnCode::*;
|
||||
match value {
|
||||
0x00 => Ok(Success),
|
||||
0x01 => Ok(InternalError),
|
||||
0x02 => Ok(NoSuchInterface),
|
||||
0x03 => Ok(NoSuchPeer),
|
||||
_ => Err(InvalidSetPskResponseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SetPskResponseReturnCode> for SetPskResult {
|
||||
fn from(value: SetPskResponseReturnCode) -> Self {
|
||||
use SetPskError as E;
|
||||
use SetPskResponseReturnCode as C;
|
||||
match value {
|
||||
C::Success => Ok(()),
|
||||
C::InternalError => Err(E::InternalError),
|
||||
C::NoSuchInterface => Err(E::NoSuchInterface),
|
||||
C::NoSuchPeer => Err(E::NoSuchPeer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SetPskResult> for SetPskResponseReturnCode {
|
||||
fn from(value: SetPskResult) -> Self {
|
||||
use SetPskError as E;
|
||||
use SetPskResponseReturnCode as C;
|
||||
match value {
|
||||
Ok(()) => C::Success,
|
||||
Err(E::InternalError) => C::InternalError,
|
||||
Err(E::NoSuchInterface) => C::NoSuchInterface,
|
||||
Err(E::NoSuchPeer) => C::NoSuchPeer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
|
||||
pub enum MsgType {
|
||||
SetPsk = 0x01,
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Debug, Clone)]
|
||||
pub struct InvalidMessageTypeError;
|
||||
|
||||
impl TryFrom<u8> for MsgType {
|
||||
type Error = InvalidMessageTypeError;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0x01 => Ok(MsgType::SetPsk),
|
||||
_ => Err(InvalidMessageTypeError),
|
||||
}
|
||||
}
|
||||
}
|
||||
99
wireguard-broker/src/api/server.rs
Normal file
99
wireguard-broker/src/api/server.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use std::borrow::BorrowMut;
|
||||
use std::marker::PhantomData;
|
||||
use std::result::Result;
|
||||
|
||||
use rosenpass_lenses::LenseError;
|
||||
|
||||
use crate::api::msgs::{self, EnvelopeExt, SetPskRequestExt, SetPskResponseExt};
|
||||
use crate::WireGuardBroker;
|
||||
|
||||
#[derive(thiserror::Error, Debug, Clone, Eq, PartialEq)]
|
||||
pub enum BrokerServerError {
|
||||
#[error("No such request type: {}", .0)]
|
||||
NoSuchRequestType(u8),
|
||||
#[error("Invalid message received.")]
|
||||
InvalidMessage,
|
||||
}
|
||||
|
||||
impl From<LenseError> for BrokerServerError {
|
||||
fn from(value: LenseError) -> Self {
|
||||
use BrokerServerError as Be;
|
||||
use LenseError as Le;
|
||||
match value {
|
||||
Le::BufferSizeMismatch => Be::InvalidMessage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<msgs::InvalidMessageTypeError> for BrokerServerError {
|
||||
fn from(value: msgs::InvalidMessageTypeError) -> Self {
|
||||
let msgs::InvalidMessageTypeError = value; // Assert that this is a unit type
|
||||
BrokerServerError::InvalidMessage
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BrokerServer<'a, Err, Inner, Ref>
|
||||
where
|
||||
msgs::SetPskError: From<Err>,
|
||||
Inner: WireGuardBroker<Error = Err>,
|
||||
Ref: BorrowMut<Inner> + 'a,
|
||||
{
|
||||
inner: Ref,
|
||||
_phantom: PhantomData<&'a mut Inner>,
|
||||
}
|
||||
|
||||
impl<'a, Err, Inner, Ref> BrokerServer<'a, Err, Inner, Ref>
|
||||
where
|
||||
msgs::SetPskError: From<Err>,
|
||||
Inner: WireGuardBroker<Error = Err>,
|
||||
Ref: 'a + BorrowMut<Inner>,
|
||||
{
|
||||
pub fn new(inner: Ref) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_message(
|
||||
&mut self,
|
||||
req: &[u8],
|
||||
res: &mut [u8; msgs::RESPONSE_MSG_BUFFER_SIZE],
|
||||
) -> Result<usize, BrokerServerError> {
|
||||
use BrokerServerError::*;
|
||||
|
||||
let typ = req.get(0).ok_or(InvalidMessage)?;
|
||||
let typ = msgs::MsgType::try_from(*typ)?;
|
||||
let msgs::MsgType::SetPsk = typ; // Assert type
|
||||
|
||||
let req: msgs::Envelope<_, msgs::SetPskRequest<&[u8]>> = req.envelope_truncating()?;
|
||||
let mut res: msgs::Envelope<_, msgs::SetPskResponse<&mut [u8]>> =
|
||||
(res as &mut [u8]).envelope_truncating()?;
|
||||
(&mut res).msg_type_mut()[0] = msgs::MsgType::SetPsk as u8;
|
||||
self.handle_set_psk(
|
||||
req.payload().set_psk_request()?,
|
||||
res.payload_mut().set_psk_response()?,
|
||||
)?;
|
||||
Ok(res.all_bytes().len())
|
||||
}
|
||||
|
||||
fn handle_set_psk(
|
||||
&mut self,
|
||||
req: msgs::SetPskRequest<&[u8]>,
|
||||
mut res: msgs::SetPskResponse<&mut [u8]>,
|
||||
) -> Result<(), BrokerServerError> {
|
||||
// Using unwrap here since lenses can not return fixed-size arrays
|
||||
// TODO: Slices should give access to fixed size arrays
|
||||
let r: Result<(), Err> = self.inner.borrow_mut().set_psk(
|
||||
req.iface()
|
||||
.map_err(|_e| BrokerServerError::InvalidMessage)?,
|
||||
req.peer_id().try_into().unwrap(),
|
||||
req.psk().try_into().unwrap(),
|
||||
);
|
||||
let r: msgs::SetPskResult = r.map_err(|e| e.into());
|
||||
let r: msgs::SetPskResponseReturnCode = r.into();
|
||||
res.return_code_mut()[0] = r as u8;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
56
wireguard-broker/src/bin/priviledged.rs
Normal file
56
wireguard-broker/src/bin/priviledged.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use std::io::{stdin, stdout, Read, Write};
|
||||
use std::result::Result;
|
||||
|
||||
use rosenpass_wireguard_broker::api::msgs;
|
||||
use rosenpass_wireguard_broker::api::server::BrokerServer;
|
||||
use rosenpass_wireguard_broker::netlink as wg;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum BrokerAppError {
|
||||
#[error(transparent)]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error(transparent)]
|
||||
WgConnectError(#[from] wg::ConnectError),
|
||||
#[error(transparent)]
|
||||
WgSetPskError(#[from] wg::SetPskError),
|
||||
#[error("Oversized message {}; something about the request is fatally wrong", .0)]
|
||||
OversizedMessage(u64),
|
||||
}
|
||||
|
||||
fn main() -> Result<(), BrokerAppError> {
|
||||
let mut broker = BrokerServer::new(wg::NetlinkWireGuardBroker::new()?);
|
||||
|
||||
let mut stdin = stdin().lock();
|
||||
let mut stdout = stdout().lock();
|
||||
loop {
|
||||
// Read the message length
|
||||
let mut len = [0u8; 8];
|
||||
stdin.read_exact(&mut len)?;
|
||||
|
||||
// Parse the message length
|
||||
let len = u64::from_le_bytes(len);
|
||||
if (len as usize) > msgs::REQUEST_MSG_BUFFER_SIZE {
|
||||
return Err(BrokerAppError::OversizedMessage(len));
|
||||
}
|
||||
|
||||
// Read the message itself
|
||||
let mut req_buf = [0u8; msgs::REQUEST_MSG_BUFFER_SIZE];
|
||||
let req_buf = &mut req_buf[..(len as usize)];
|
||||
stdin.read_exact(req_buf)?;
|
||||
|
||||
// Process the message
|
||||
let mut res_buf = [0u8; msgs::RESPONSE_MSG_BUFFER_SIZE];
|
||||
let res = match broker.handle_message(req_buf, &mut res_buf) {
|
||||
Ok(len) => &res_buf[..len],
|
||||
Err(e) => {
|
||||
eprintln!("Error processing message for wireguard PSK broker: {e:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Write the response
|
||||
stdout.write_all(&(res.len() as u64).to_le_bytes())?;
|
||||
stdout.write_all(&res)?;
|
||||
stdout.flush()?;
|
||||
}
|
||||
}
|
||||
191
wireguard-broker/src/bin/socket_handler.rs
Normal file
191
wireguard-broker/src/bin/socket_handler.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use std::process::Stdio;
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::task;
|
||||
|
||||
use anyhow::{bail, ensure, Result};
|
||||
use clap::{ArgGroup, Parser};
|
||||
|
||||
use rosenpass_util::fd::claim_fd;
|
||||
use rosenpass_wireguard_broker::api::msgs;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
#[clap(group(
|
||||
ArgGroup::new("socket")
|
||||
.required(true)
|
||||
.args(&["listen_path", "listen_fd", "stream_fd"]),
|
||||
))]
|
||||
struct Args {
|
||||
/// Where in the file-system to create the unix socket this broker will be listening for
|
||||
/// connections on
|
||||
#[arg(long)]
|
||||
listen_path: Option<String>,
|
||||
|
||||
/// When this broker is called from another process, the other process can open and bind the
|
||||
/// unix socket to use themselves, passing it to this process. In Rust this can be achieved
|
||||
/// using the [command-fds](https://docs.rs/command-fds/latest/command_fds/) crate.
|
||||
#[arg(long)]
|
||||
listen_fd: Option<i32>,
|
||||
|
||||
/// When this broker is called from another process, the other process can connect the unix socket
|
||||
/// themselves, for instance using the `socketpair(2)` system call.
|
||||
#[arg(long)]
|
||||
stream_fd: Option<i32>,
|
||||
|
||||
/// The underlying broker, accepting commands through stdin and sending results through stdout.
|
||||
#[arg(
|
||||
last = true,
|
||||
allow_hyphen_values = true,
|
||||
default_value = "rosenpass-wireguard-broker-privileged"
|
||||
)]
|
||||
command: Vec<String>,
|
||||
}
|
||||
|
||||
struct BrokerRequest {
|
||||
reply_to: oneshot::Sender<BrokerResponse>,
|
||||
request: Vec<u8>,
|
||||
}
|
||||
|
||||
struct BrokerResponse {
|
||||
response: Vec<u8>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let (proc_tx, proc_rx) = mpsc::channel(100);
|
||||
|
||||
// Start the inner broker handler
|
||||
task::spawn(async move {
|
||||
if let Err(e) = direct_broker_process(proc_rx, args.command).await {
|
||||
log::error!("Error in broker command handler: {e}");
|
||||
panic!("Can not proceed without underlying broker process");
|
||||
}
|
||||
});
|
||||
|
||||
// Listen for incoming requests
|
||||
if let Some(path) = args.listen_path {
|
||||
let sock = UnixListener::bind(path)?;
|
||||
listen_for_clients(proc_tx, sock).await
|
||||
} else if let Some(fd) = args.listen_fd {
|
||||
let sock = std::os::unix::net::UnixListener::from(claim_fd(fd)?);
|
||||
sock.set_nonblocking(true)?;
|
||||
listen_for_clients(proc_tx, UnixListener::from_std(sock)?).await
|
||||
} else if let Some(fd) = args.stream_fd {
|
||||
let stream = std::os::unix::net::UnixStream::from(claim_fd(fd)?);
|
||||
stream.set_nonblocking(true)?;
|
||||
on_accept(proc_tx, UnixStream::from_std(stream)?).await
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
|
||||
async fn direct_broker_process(
|
||||
mut queue: mpsc::Receiver<BrokerRequest>,
|
||||
cmd: Vec<String>,
|
||||
) -> Result<()> {
|
||||
let proc = Command::new(&cmd[0])
|
||||
.args(&cmd[1..])
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = proc.stdin.unwrap();
|
||||
let mut stdout = proc.stdout.unwrap();
|
||||
|
||||
loop {
|
||||
let BrokerRequest { reply_to, request } = queue.recv().await.unwrap();
|
||||
|
||||
stdin
|
||||
.write_all(&(request.len() as u64).to_le_bytes())
|
||||
.await?;
|
||||
stdin.write_all(&request[..]).await?;
|
||||
|
||||
// Read the response length
|
||||
let mut len = [0u8; 8];
|
||||
stdout.read_exact(&mut len).await?;
|
||||
|
||||
// Parse the response length
|
||||
let len = u64::from_le_bytes(len) as usize;
|
||||
ensure!(
|
||||
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
|
||||
"Oversized buffer ({len}) in broker stdout."
|
||||
);
|
||||
|
||||
// Read the message itself
|
||||
let mut res_buf = request; // Avoid allocating memory if we don't have to
|
||||
res_buf.resize(len as usize, 0);
|
||||
stdout.read_exact(&mut res_buf[..len]).await?;
|
||||
|
||||
// Return to the unix socket connection worker
|
||||
reply_to
|
||||
.send(BrokerResponse { response: res_buf })
|
||||
.or_else(|_| bail!("Unable to send respnse to unix socket worker."))?;
|
||||
}
|
||||
}
|
||||
|
||||
async fn listen_for_clients(queue: mpsc::Sender<BrokerRequest>, sock: UnixListener) -> Result<()> {
|
||||
loop {
|
||||
let (stream, _addr) = sock.accept().await?;
|
||||
let queue = queue.clone();
|
||||
task::spawn(async move {
|
||||
if let Err(e) = on_accept(queue, stream).await {
|
||||
log::error!("Error during connection processing: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// NOTE: If loop can ever terminate we need to join the spawned tasks
|
||||
}
|
||||
|
||||
async fn on_accept(queue: mpsc::Sender<BrokerRequest>, mut stream: UnixStream) -> Result<()> {
|
||||
let mut req_buf = Vec::new();
|
||||
|
||||
loop {
|
||||
stream.readable().await?;
|
||||
|
||||
// Read the message length
|
||||
let mut len = [0u8; 8];
|
||||
stream.read_exact(&mut len).await?;
|
||||
|
||||
// Parse the message length
|
||||
let len = u64::from_le_bytes(len) as usize;
|
||||
ensure!(
|
||||
len <= msgs::REQUEST_MSG_BUFFER_SIZE,
|
||||
"Oversized buffer ({len}) in unix socket input."
|
||||
);
|
||||
|
||||
// Read the message itself
|
||||
req_buf.resize(len as usize, 0);
|
||||
stream.read_exact(&mut req_buf[..len]).await?;
|
||||
|
||||
// Handle the message
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
queue
|
||||
.send(BrokerRequest {
|
||||
reply_to: reply_tx,
|
||||
request: req_buf,
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Wait for the reply
|
||||
let BrokerResponse { response } = reply_rx.await.unwrap();
|
||||
|
||||
// Write reply back to unix socket
|
||||
stream
|
||||
.write_all(&(response.len() as u64).to_le_bytes())
|
||||
.await?;
|
||||
stream.write_all(&response[..]).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
// Reuse the same memory for the next message
|
||||
req_buf = response;
|
||||
}
|
||||
}
|
||||
15
wireguard-broker/src/lib.rs
Normal file
15
wireguard-broker/src/lib.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use std::result::Result;
|
||||
|
||||
pub trait WireGuardBroker {
|
||||
type Error;
|
||||
|
||||
fn set_psk(
|
||||
&mut self,
|
||||
interface: &str,
|
||||
peer_id: [u8; 32],
|
||||
psk: [u8; 32],
|
||||
) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
pub mod api;
|
||||
pub mod netlink;
|
||||
103
wireguard-broker/src/netlink.rs
Normal file
103
wireguard-broker/src/netlink.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use wireguard_uapi::linux as wg;
|
||||
|
||||
use crate::api::msgs;
|
||||
use crate::WireGuardBroker;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectError {
|
||||
#[error(transparent)]
|
||||
ConnectError(#[from] wg::err::ConnectError),
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum NetlinkError {
|
||||
#[error(transparent)]
|
||||
SetDevice(#[from] wg::err::SetDeviceError),
|
||||
#[error(transparent)]
|
||||
GetDevice(#[from] wg::err::GetDeviceError),
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum SetPskError {
|
||||
#[error("The indicated wireguard interface does not exist")]
|
||||
NoSuchInterface,
|
||||
#[error("The indicated peer does not exist on the wireguard interface")]
|
||||
NoSuchPeer,
|
||||
#[error(transparent)]
|
||||
NetlinkError(#[from] NetlinkError),
|
||||
}
|
||||
|
||||
impl From<wg::err::SetDeviceError> for SetPskError {
|
||||
fn from(err: wg::err::SetDeviceError) -> Self {
|
||||
NetlinkError::from(err).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<wg::err::GetDeviceError> for SetPskError {
|
||||
fn from(err: wg::err::GetDeviceError) -> Self {
|
||||
NetlinkError::from(err).into()
|
||||
}
|
||||
}
|
||||
|
||||
use msgs::SetPskError as SetPskMsgsError;
|
||||
use SetPskError as SetPskNetlinkError;
|
||||
impl From<SetPskNetlinkError> for SetPskMsgsError {
|
||||
fn from(err: SetPskError) -> Self {
|
||||
match err {
|
||||
SetPskNetlinkError::NoSuchPeer => SetPskMsgsError::NoSuchPeer,
|
||||
_ => SetPskMsgsError::InternalError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NetlinkWireGuardBroker {
|
||||
sock: wg::WgSocket,
|
||||
}
|
||||
|
||||
impl NetlinkWireGuardBroker {
|
||||
pub fn new() -> Result<Self, ConnectError> {
|
||||
let sock = wg::WgSocket::connect()?;
|
||||
Ok(Self { sock })
|
||||
}
|
||||
}
|
||||
|
||||
impl WireGuardBroker for NetlinkWireGuardBroker {
|
||||
type Error = SetPskError;
|
||||
|
||||
fn set_psk(
|
||||
&mut self,
|
||||
interface: &str,
|
||||
peer_id: [u8; 32],
|
||||
psk: [u8; 32],
|
||||
) -> Result<(), Self::Error> {
|
||||
// Ensure that the peer exists by querying the device configuration
|
||||
// TODO: Use InvalidInterfaceError
|
||||
let state = self
|
||||
.sock
|
||||
.get_device(wg::DeviceInterface::from_name(interface.to_owned()))?;
|
||||
|
||||
if state
|
||||
.peers
|
||||
.iter()
|
||||
.find(|p| &p.public_key == &peer_id)
|
||||
.is_none()
|
||||
{
|
||||
return Err(SetPskError::NoSuchPeer);
|
||||
}
|
||||
|
||||
// Peer update description
|
||||
let mut set_peer = wireguard_uapi::set::Peer::from_public_key(&peer_id);
|
||||
set_peer
|
||||
.flags
|
||||
.push(wireguard_uapi::linux::set::WgPeerF::UpdateOnly);
|
||||
set_peer.preshared_key = Some(&psk);
|
||||
|
||||
// Device update description
|
||||
let mut set_dev = wireguard_uapi::set::Device::from_ifname(interface.to_owned());
|
||||
set_dev.peers.push(set_peer);
|
||||
|
||||
self.sock.set_device(set_dev)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user