mirror of
https://github.com/rosenpass/rosenpass.git
synced 2026-03-02 07:23:10 -08:00
feat(wireguard-broker): merge from dev/broker-architecture, fixes, test
* wireguard-broker: merge from dev/broker-architecture * use zerocopy instead of lenses * Require use_broker feature flag to comile broker binaries * Remove PhantomData from BrokerServer & BrokerClient * Modify mio client rx to be non-recursive, add integration test Co-authored-by: Karolin Varner <karo@cupdev.net> Co-authored-by: Prabhpreet Dua <615318+prabhpreet@users.noreply.github.com>
This commit is contained in:
56
wireguard-broker/src/bin/priviledged.rs
Normal file
56
wireguard-broker/src/bin/priviledged.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use std::io::{stdin, stdout, Read, Write};
|
||||
use std::result::Result;
|
||||
|
||||
use rosenpass_wireguard_broker::api::msgs;
|
||||
use rosenpass_wireguard_broker::api::server::BrokerServer;
|
||||
use rosenpass_wireguard_broker::netlink as wg;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum BrokerAppError {
|
||||
#[error(transparent)]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error(transparent)]
|
||||
WgConnectError(#[from] wg::ConnectError),
|
||||
#[error(transparent)]
|
||||
WgSetPskError(#[from] wg::SetPskError),
|
||||
#[error("Oversized message {}; something about the request is fatally wrong", .0)]
|
||||
OversizedMessage(u64),
|
||||
}
|
||||
|
||||
fn main() -> Result<(), BrokerAppError> {
|
||||
let mut broker = BrokerServer::new(wg::NetlinkWireGuardBroker::new()?);
|
||||
|
||||
let mut stdin = stdin().lock();
|
||||
let mut stdout = stdout().lock();
|
||||
loop {
|
||||
// Read the message length
|
||||
let mut len = [0u8; 8];
|
||||
stdin.read_exact(&mut len)?;
|
||||
|
||||
// Parse the message length
|
||||
let len = u64::from_le_bytes(len);
|
||||
if (len as usize) > msgs::REQUEST_MSG_BUFFER_SIZE {
|
||||
return Err(BrokerAppError::OversizedMessage(len));
|
||||
}
|
||||
|
||||
// Read the message itself
|
||||
let mut req_buf = [0u8; msgs::REQUEST_MSG_BUFFER_SIZE];
|
||||
let req_buf = &mut req_buf[..(len as usize)];
|
||||
stdin.read_exact(req_buf)?;
|
||||
|
||||
// Process the message
|
||||
let mut res_buf = [0u8; msgs::RESPONSE_MSG_BUFFER_SIZE];
|
||||
let res = match broker.handle_message(req_buf, &mut res_buf) {
|
||||
Ok(len) => &res_buf[..len],
|
||||
Err(e) => {
|
||||
eprintln!("Error processing message for wireguard PSK broker: {e:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Write the response
|
||||
stdout.write_all(&(res.len() as u64).to_le_bytes())?;
|
||||
stdout.write_all(&res)?;
|
||||
stdout.flush()?;
|
||||
}
|
||||
}
|
||||
191
wireguard-broker/src/bin/socket_handler.rs
Normal file
191
wireguard-broker/src/bin/socket_handler.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
use std::process::Stdio;
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::task;
|
||||
|
||||
use anyhow::{bail, ensure, Result};
|
||||
use clap::{ArgGroup, Parser};
|
||||
|
||||
use rosenpass_util::fd::claim_fd;
|
||||
use rosenpass_wireguard_broker::api::msgs;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
#[clap(group(
|
||||
ArgGroup::new("socket")
|
||||
.required(true)
|
||||
.args(&["listen_path", "listen_fd", "stream_fd"]),
|
||||
))]
|
||||
struct Args {
|
||||
/// Where in the file-system to create the unix socket this broker will be listening for
|
||||
/// connections on
|
||||
#[arg(long)]
|
||||
listen_path: Option<String>,
|
||||
|
||||
/// When this broker is called from another process, the other process can open and bind the
|
||||
/// unix socket to use themselves, passing it to this process. In Rust this can be achieved
|
||||
/// using the [command-fds](https://docs.rs/command-fds/latest/command_fds/) crate.
|
||||
#[arg(long)]
|
||||
listen_fd: Option<i32>,
|
||||
|
||||
/// When this broker is called from another process, the other process can connect the unix socket
|
||||
/// themselves, for instance using the `socketpair(2)` system call.
|
||||
#[arg(long)]
|
||||
stream_fd: Option<i32>,
|
||||
|
||||
/// The underlying broker, accepting commands through stdin and sending results through stdout.
|
||||
#[arg(
|
||||
last = true,
|
||||
allow_hyphen_values = true,
|
||||
default_value = "rosenpass-wireguard-broker-privileged"
|
||||
)]
|
||||
command: Vec<String>,
|
||||
}
|
||||
|
||||
struct BrokerRequest {
|
||||
reply_to: oneshot::Sender<BrokerResponse>,
|
||||
request: Vec<u8>,
|
||||
}
|
||||
|
||||
struct BrokerResponse {
|
||||
response: Vec<u8>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let (proc_tx, proc_rx) = mpsc::channel(100);
|
||||
|
||||
// Start the inner broker handler
|
||||
task::spawn(async move {
|
||||
if let Err(e) = direct_broker_process(proc_rx, args.command).await {
|
||||
log::error!("Error in broker command handler: {e}");
|
||||
panic!("Can not proceed without underlying broker process");
|
||||
}
|
||||
});
|
||||
|
||||
// Listen for incoming requests
|
||||
if let Some(path) = args.listen_path {
|
||||
let sock = UnixListener::bind(path)?;
|
||||
listen_for_clients(proc_tx, sock).await
|
||||
} else if let Some(fd) = args.listen_fd {
|
||||
let sock = std::os::unix::net::UnixListener::from(claim_fd(fd)?);
|
||||
sock.set_nonblocking(true)?;
|
||||
listen_for_clients(proc_tx, UnixListener::from_std(sock)?).await
|
||||
} else if let Some(fd) = args.stream_fd {
|
||||
let stream = std::os::unix::net::UnixStream::from(claim_fd(fd)?);
|
||||
stream.set_nonblocking(true)?;
|
||||
on_accept(proc_tx, UnixStream::from_std(stream)?).await
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
|
||||
async fn direct_broker_process(
|
||||
mut queue: mpsc::Receiver<BrokerRequest>,
|
||||
cmd: Vec<String>,
|
||||
) -> Result<()> {
|
||||
let proc = Command::new(&cmd[0])
|
||||
.args(&cmd[1..])
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = proc.stdin.unwrap();
|
||||
let mut stdout = proc.stdout.unwrap();
|
||||
|
||||
loop {
|
||||
let BrokerRequest { reply_to, request } = queue.recv().await.unwrap();
|
||||
|
||||
stdin
|
||||
.write_all(&(request.len() as u64).to_le_bytes())
|
||||
.await?;
|
||||
stdin.write_all(&request[..]).await?;
|
||||
|
||||
// Read the response length
|
||||
let mut len = [0u8; 8];
|
||||
stdout.read_exact(&mut len).await?;
|
||||
|
||||
// Parse the response length
|
||||
let len = u64::from_le_bytes(len) as usize;
|
||||
ensure!(
|
||||
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
|
||||
"Oversized buffer ({len}) in broker stdout."
|
||||
);
|
||||
|
||||
// Read the message itself
|
||||
let mut res_buf = request; // Avoid allocating memory if we don't have to
|
||||
res_buf.resize(len as usize, 0);
|
||||
stdout.read_exact(&mut res_buf[..len]).await?;
|
||||
|
||||
// Return to the unix socket connection worker
|
||||
reply_to
|
||||
.send(BrokerResponse { response: res_buf })
|
||||
.or_else(|_| bail!("Unable to send respnse to unix socket worker."))?;
|
||||
}
|
||||
}
|
||||
|
||||
async fn listen_for_clients(queue: mpsc::Sender<BrokerRequest>, sock: UnixListener) -> Result<()> {
|
||||
loop {
|
||||
let (stream, _addr) = sock.accept().await?;
|
||||
let queue = queue.clone();
|
||||
task::spawn(async move {
|
||||
if let Err(e) = on_accept(queue, stream).await {
|
||||
log::error!("Error during connection processing: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// NOTE: If loop can ever terminate we need to join the spawned tasks
|
||||
}
|
||||
|
||||
async fn on_accept(queue: mpsc::Sender<BrokerRequest>, mut stream: UnixStream) -> Result<()> {
|
||||
let mut req_buf = Vec::new();
|
||||
|
||||
loop {
|
||||
stream.readable().await?;
|
||||
|
||||
// Read the message length
|
||||
let mut len = [0u8; 8];
|
||||
stream.read_exact(&mut len).await?;
|
||||
|
||||
// Parse the message length
|
||||
let len = u64::from_le_bytes(len) as usize;
|
||||
ensure!(
|
||||
len <= msgs::REQUEST_MSG_BUFFER_SIZE,
|
||||
"Oversized buffer ({len}) in unix socket input."
|
||||
);
|
||||
|
||||
// Read the message itself
|
||||
req_buf.resize(len as usize, 0);
|
||||
stream.read_exact(&mut req_buf[..len]).await?;
|
||||
|
||||
// Handle the message
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
queue
|
||||
.send(BrokerRequest {
|
||||
reply_to: reply_tx,
|
||||
request: req_buf,
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Wait for the reply
|
||||
let BrokerResponse { response } = reply_rx.await.unwrap();
|
||||
|
||||
// Write reply back to unix socket
|
||||
stream
|
||||
.write_all(&(response.len() as u64).to_le_bytes())
|
||||
.await?;
|
||||
stream.write_all(&response[..]).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
// Reuse the same memory for the next message
|
||||
req_buf = response;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user