Compare commits

...

53 Commits

Author SHA1 Message Date
Karolin Varner
3e111fa7ad stash 2025-10-25 13:02:01 +02:00
Karolin Varner
5ec845f5d0 stash 2025-10-25 12:52:35 +02:00
Karolin Varner
4736c40d84 stash 2025-10-19 15:31:25 +02:00
Karolin Varner
a8ef9dc3e5 stash 2025-10-19 15:24:55 +02:00
Karolin Varner
4a0e34b1fc stash 2025-10-19 15:21:19 +02:00
Karolin Varner
b6229b8d33 stash 2025-10-19 15:09:14 +02:00
Karolin Varner
ecc17dea44 stash 2025-10-19 14:44:37 +02:00
Karolin Varner
b499a9ba5b stash 2025-10-19 13:03:09 +02:00
Karolin Varner
91382e189e stash 2025-10-19 13:02:21 +02:00
Karolin Varner
c9db8cfec7 stash 2025-10-19 13:01:37 +02:00
Karolin Varner
52b903c8c0 stash 2025-10-19 12:43:25 +02:00
Karolin Varner
7f464de421 stash 2025-10-19 12:41:49 +02:00
Karolin Varner
d856116a44 stash 2025-10-19 12:39:42 +02:00
Karolin Varner
ea708bca90 stash 2025-10-19 12:39:33 +02:00
Karolin Varner
f1746bd067 stash 2025-10-19 12:04:19 +02:00
Karolin Varner
afea7d0a2e stash 2025-10-19 11:43:45 +02:00
Karolin Varner
62c974f636 stash 2025-10-19 11:03:08 +02:00
Karolin Varner
62e337c6a1 stash 2025-10-18 14:37:30 +02:00
Karolin Varner
aebfdfa966 stash 2025-10-18 13:47:50 +02:00
Karolin Varner
f0a43932c1 stash 2025-10-18 13:45:03 +02:00
Karolin Varner
e98ee3fb6c stash 2025-10-18 13:35:21 +02:00
Karolin Varner
3b8c082365 stash 2025-10-17 15:04:35 +02:00
Karolin Varner
adda409edc stash 2025-10-12 19:14:35 +02:00
Karolin Varner
e0fa817988 stash 2025-10-12 17:15:15 +02:00
Karolin Varner
777e94eb64 stash 2025-10-12 15:52:55 +02:00
Karolin Varner
61e30006ec stash 2025-10-11 13:00:27 +02:00
Karolin Varner
daabb8c9ee stash 2025-10-08 17:39:07 +02:00
Karolin Varner
cc2656bfaa stash 2025-10-08 17:17:11 +02:00
Karolin Varner
2a2d78b448 stash 2025-10-07 19:06:06 +02:00
Karolin Varner
d00f91f7dd stash 2025-10-07 18:54:44 +02:00
Karolin Varner
ba3d22829e stash 2025-10-07 15:56:47 +02:00
Karolin Varner
9e94759dd2 stash 2025-10-06 20:55:25 +02:00
Karolin Varner
13ff005d27 stash 2025-10-06 19:38:26 +02:00
Karolin Varner
5b1ee307a4 stash 2025-10-06 19:37:33 +02:00
Karolin Varner
64205d4342 stash 2025-10-06 17:12:04 +02:00
Karolin Varner
8d7d4ab643 stash 2025-10-06 15:39:40 +02:00
Karolin Varner
9d10d11453 stash 2025-10-06 15:36:53 +02:00
Karolin Varner
e15aa94b0c stash 2025-10-06 14:31:23 +02:00
Karolin Varner
0b4bfe4d3e stash 2025-09-20 17:26:10 +02:00
Karolin Varner
32427ad049 stash 2025-09-20 17:26:10 +02:00
Karolin Varner
8f847f59e3 stash 2025-09-20 17:26:10 +02:00
Karolin Varner
d900bb3875 stash 2025-09-20 17:26:10 +02:00
Karolin Varner
30e30b6dfd stash 2025-09-20 17:26:10 +02:00
Karolin Varner
a40e654d08 stash 2025-09-20 17:26:10 +02:00
Karolin Varner
38e997554e stash 2025-09-20 17:26:10 +02:00
Karolin Varner
adbe0bd431 stash 2025-09-20 17:26:10 +02:00
Karolin Varner
e0976986fd stash 2025-09-20 17:26:10 +02:00
Karolin Varner
d932144451 stash 2025-09-20 17:26:10 +02:00
Karolin Varner
fe90220748 stash 2025-09-20 17:26:10 +02:00
Karolin Varner
8d57dde61a stash 2025-09-20 17:26:10 +02:00
Karolin Varner
1a51478e89 chore: Split rosenpass_util::rustix into multiple files 2025-09-20 17:26:10 +02:00
Karolin Varner
5b14ef8065 chore: Rename rosenpass_util::{fd -> rustix} 2025-09-20 17:26:10 +02:00
Karolin Varner
a796bdd2e7 stash 2025-09-20 17:26:10 +02:00
43 changed files with 4283 additions and 315 deletions

34
Cargo.lock generated
View File

@@ -212,7 +212,7 @@ version = "0.68.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "726e4313eb6ec35d2730258ad4e15b547ee75d6afaa1361a922e78e59b7d8078"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"cexpr",
"clang-sys",
"lazy_static",
@@ -237,9 +237,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.8.0"
version = "2.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
[[package]]
name = "blake2"
@@ -818,9 +818,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "errno"
version = "0.3.10"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
dependencies = [
"libc",
"windows-sys 0.59.0",
@@ -1193,7 +1193,7 @@ version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"cfg-if",
"libc",
]
@@ -1453,7 +1453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.48.5",
"windows-targets 0.52.6",
]
[[package]]
@@ -1673,7 +1673,7 @@ version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"cfg-if",
"libc",
]
@@ -2032,7 +2032,7 @@ version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
]
[[package]]
@@ -2257,13 +2257,19 @@ version = "0.1.0"
dependencies = [
"anyhow",
"base64ct",
"bitflags 2.9.3",
"errno",
"libc",
"libcrux-test-utils",
"log",
"mio",
"num-traits",
"rosenpass-to",
"rustix",
"static_assertions",
"tempfile",
"thiserror 1.0.69",
"tinyvec",
"tokio",
"typenum",
"uds",
@@ -2340,7 +2346,7 @@ version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"errno",
"libc",
"linux-raw-sys",
@@ -2710,6 +2716,12 @@ dependencies = [
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
[[package]]
name = "tokio"
version = "1.47.0"
@@ -3276,7 +3288,7 @@ version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
]
[[package]]

View File

@@ -80,6 +80,7 @@ hex-literal = { version = "0.4.1" }
hex = { version = "0.4.3" }
heck = { version = "0.5.0" }
libc = { version = "0.2" }
errno = { version = "0.3.13" }
uds = { git = "https://github.com/rosenpass/uds" }
lazy_static = "1.5"
@@ -95,6 +96,7 @@ criterion = "0.5.1"
allocator-api2-tests = "0.2.15"
procspawn = { version = "1.0.1", features = ["test-support"] }
serde_json = { version = "1.0.140" }
bitflags = "2.9.3"
#Broker dependencies (might need cleanup or changes)
wireguard-uapi = { version = "3.0.0", features = ["xplatform"] }

View File

@@ -6,12 +6,12 @@ use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd};
use anyhow::Context;
use rosenpass_to::{ops::copy_slice, To};
use rosenpass_util::{
fd::FdIo,
functional::{run, ApplyExt},
io::ReadExt,
mem::DiscardResultExt,
mio::UnixStreamExt,
result::OkExt,
rustix::FdIo,
};
use rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient;
@@ -243,7 +243,7 @@ where
.context("Invalid request socket missing.")?;
// TODO: We need to have this outside linux
#[cfg(target_os = "linux")]
rosenpass_util::fd::GetSocketProtocol::demand_udp_socket(&sock)?;
rosenpass_util::rustix::GetSocketProtocol::demand_udp_socket(&sock)?;
let sock = std::net::UdpSocket::from(sock);
sock.set_nonblocking(true)?;
mio::net::UdpSocket::from_std(sock).ok()

View File

@@ -26,7 +26,7 @@ use {
command_fds::{CommandFdExt, FdMapping},
log::{error, info},
mio::net::UnixStream,
rosenpass_util::fd::claim_fd,
rosenpass_util::rustix::claim_fd,
rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient,
rosenpass_wireguard_broker::WireguardBrokerMio,
rustix::net::{socketpair, AddressFamily, SocketFlags, SocketType},

View File

@@ -13,11 +13,14 @@ rust-version = "1.77.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rosenpass-to = { workspace = true }
base64ct = { workspace = true }
anyhow = { workspace = true }
typenum = { workspace = true }
static_assertions = { workspace = true }
rustix = { workspace = true }
rustix = { workspace = true, features = ["net", "fs", "process", "mm"] }
libc = { workspace = true }
errno = { workspace = true }
zeroize = { workspace = true }
zerocopy = { workspace = true }
thiserror = { workspace = true }
@@ -32,6 +35,9 @@ tokio = { workspace = true, optional = true, features = [
"time",
] }
log = { workspace = true }
bitflags = { workspace = true }
tinyvec = "1.10.0"
num-traits = "0.2.19"
[features]
experiment_file_descriptor_passing = ["uds"]

85
util/src/convert.rs Normal file
View File

@@ -0,0 +1,85 @@
//! Additional helpers to [std::convert]: Traits for conversions between types.
/// Variant of [std::convert::Into] with an explicitly specified type.
///
/// This facilitates method chaining.
///
/// # Examples
///
/// We can create a nicer implementation of the following function using this extension trait
///
/// ```
/// use rosenpass_util::convert::IntoTypeExt;
///
/// fn encode_char_u32_be_1(c: char) -> [u8; 4] {
/// let i : u32 = c.into();
/// i.to_be_bytes()
/// }
///
/// fn encode_char_u32_be_2(c: char) -> [u8; 4] {
/// c.into_type::<u32>().to_be_bytes()
/// }
///
/// assert_eq!(encode_char_u32_be_1('X'), [0x00, 0x00, 0x00, 0x58]);
/// assert_eq!(encode_char_u32_be_1('X'), encode_char_u32_be_2('X'));
///
/// ```
pub trait IntoTypeExt {
/// Variant of [std::convert::Into] with explicitly specified type.
///
/// # Examples
///
/// See [IntoType].
fn into_type<T>(self) -> T
where
Self: Into<T>,
{
self.into()
}
}
impl<T> IntoTypeExt for T {}
/// Variant of [std::convert::TryInto] with an explicitly specified type.
///
/// This facilitates method chaining.
///
/// # Examples
///
/// We can create a nicer implementation of the following function using this extension trait
///
/// ```
/// use rosenpass_util::convert::TryIntoTypeExt;
/// use rosenpass_util::result::OkExt;
///
/// fn encode_char_u16_be_1(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
/// let i : u16 = c.try_into()?;
/// Ok(i.to_be_bytes())
/// }
///
/// fn encode_char_u16_be_2(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
/// c.try_into_type::<u16>()?.to_be_bytes().ok()
/// }
///
/// assert_eq!(encode_char_u16_be_1('X'), Ok([0x00, 0x58]));
/// assert_eq!(encode_char_u16_be_1('X'), encode_char_u16_be_2('X'));
///
/// const HEART_CAT : char = '😻'; // 1F63B
/// assert!(encode_char_u16_be_1(HEART_CAT).is_err());
/// assert_eq!(encode_char_u16_be_1(HEART_CAT), encode_char_u16_be_2(HEART_CAT));
/// ```
pub trait TryIntoTypeExt {
/// Variant of [std::convert::TryInto] with explicitly specified type.
///
/// # Examples
///
/// See [TryIntoType].
fn try_into_type<T>(self) -> Result<T, <Self as TryInto<T>>::Error>
where
Self: TryInto<T>,
{
self.try_into()
}
}
impl<T> TryIntoTypeExt for T {}

4
util/src/int/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Helpers for working with integer types
pub mod modular;
pub mod u64uint;

275
util/src/int/modular.rs Normal file
View File

@@ -0,0 +1,275 @@
//! Numeric types for modular arithmetic
use num_traits::{
ops::overflowing::OverflowingAdd, CheckedMul, Euclid, Num, Unsigned, WrappingNeg, Zero,
};
/// Summary-trait for numeric types that can serve as the basis for ModuleBase
pub trait ModuleBase: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
impl<T> ModuleBase for T where T: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
/// Represents a modulus; i.e. the range of values some number type is allowed to use.
///
/// This is based on some inner representation.
///
/// This is not just a value of the underlying representation, because it also supports the modulus
/// [Self::new_full_range()], which indicates that the full range of the underlying type is to be
/// supported.
///
/// Note that zero is not a valid modulus
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct Modulus<T: ModuleBase> {
/// Inner representation of the type
modulus: T,
}
impl<T: ModuleBase> Modulus<T> {
/// Create a new [Self] without any checks
fn raw_new(modulus: T) -> Self {
Self { modulus }
}
/// Create a new [Self] that indicates that the full range of the underlying type is to be used
pub fn new_full_range() -> Self {
Self::raw_new(T::zero())
}
/// Try to create a new [Self]. Will return None only if `modulus == 0`
pub fn try_new(modulus: T) -> Option<Self> {
match modulus == T::zero() {
true => None,
false => Some(Self::raw_new(modulus)),
}
}
/// Like [Self::try_new] but will panic if `modulus == 0`
///
/// # Panic
///
/// Will panic if `modulus == 0`
pub fn new_or_panic(modulus: T) -> Self {
match Self::try_new(modulus) {
None => panic!("Can not create Module with modulus zero!"),
Some(me) => me,
}
}
/// Check if this [Self] represents the full range of the underlying type
pub fn is_full_range(&self) -> bool {
self.modulus == T::zero()
}
/// Get the raw modulus. I.e. the value of the underlying type that can be given to
/// a modulo operation to implement modular arithmetic.
///
/// Will return None if [Self::is_full_range].
pub fn modulus(&self) -> Option<T> {
match self.is_full_range() {
true => None,
false => Some(self.modulus),
}
}
/// Check if the given type is contained in the range represented by this [Self]
pub fn contains(&self, v: T) -> bool {
match self.is_full_range() {
true => true,
false => v < self.modulus,
}
}
/// Double the modulus.
///
/// Correctly handles the case that `v.double().is_full_range()`.
///
/// # Examples
///
/// ```rust
/// use rosenpass_util::int::modular::Modulus;
///
/// fn m(v: u8) -> Modulus<u8> {
/// Modulus::new_or_panic(v)
/// }
///
/// assert_eq!(m(100).double(), m(200));
/// assert_eq!(m(128).double(), Modulus::new_full_range());
/// ```
pub fn double(&self) -> Option<Self> {
let s = self.modulus()?;
match s.overflowing_add(&s) {
(d, true) if d > T::zero() => None,
(d, _) => Some(Self::raw_new(d)),
}
}
/// Create a new [ModularArithmetic] by taking the value modulo the modulus
pub fn new_number<U: Into<T>>(self, value: U) -> ModularArithmetic<T>
where
T: ModularArithmeticBase,
{
ModularArithmetic::modular_new(value.into(), self)
}
/// Apply [Self::new_number] to each of the parameters, return whatever result the closure
/// produces
pub fn with_converted<U, const N: usize, R, F>(&self, params: [U; N], f: F) -> R
where
Self: Copy,
T: std::fmt::Debug + ModularArithmeticBase,
U: Into<T>,
F: FnOnce([ModularArithmetic<T>; N]) -> R,
{
let params = params.map(|v| self.new_number(v));
f(params)
}
/// Apply [Self::new_number] to each of the parameters, converting the result to the underlying
/// representation
pub fn formula<U, const N: usize, F>(&self, params: [U; N], f: F) -> T
where
Self: Copy,
T: std::fmt::Debug + ModularArithmeticBase,
U: Into<T>,
F: FnOnce([ModularArithmetic<T>; N]) -> ModularArithmetic<T>,
{
self.with_converted(params, f).value()
}
}
/// Summary trait for types that can serve as the basis for [ModularArithmetic]
pub trait ModularArithmeticBase:
ModuleBase
+ std::fmt::Debug
+ Num
+ PartialOrd
+ Ord
+ Copy
+ Unsigned
+ CheckedMul
+ Euclid
+ WrappingNeg
{
}
impl<T> ModularArithmeticBase for T where
T: ModuleBase
+ std::fmt::Debug
+ Num
+ PartialOrd
+ Ord
+ Copy
+ Unsigned
+ CheckedMul
+ Euclid
+ WrappingNeg
{
}
/// Modular arithmetic with an arbitrary modulus
#[derive(Debug, Copy, Clone)]
pub struct ModularArithmetic<T: ModularArithmeticBase> {
/// The modulus
modulus: Modulus<T>,
/// The value inside the modulus
///
/// Note that `self.modulus.contains(self.value)` must always hold.
value: T,
}
impl<T: ModularArithmeticBase> ModularArithmetic<T> {
/// Construct a new [Self].
///
/// Will return none unless `module.`[contains](Modulus::contains)`(value)`.
pub fn try_new(value: T, module: Modulus<T>) -> Option<Self> {
module.contains(value).then_some(Self {
value,
modulus: module,
})
}
/// Construct a new [Self].
///
/// # Panic
///
/// Will punic unless `module.`[contains](Modulus::contains)`(value)`.
pub fn modular_new(value: T, module: Modulus<T>) -> Self {
let value = match module.modulus() {
Some(m) => value.rem_euclid(&m),
None => value,
};
Self {
modulus: module,
value,
}
}
/// Return the modulus
pub fn modulus(&self) -> &Modulus<T> {
&self.modulus
}
/// The inner value
pub fn value(&self) -> T {
self.value
}
}
impl<T: ModularArithmeticBase> PartialEq for ModularArithmetic<T> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl<T: ModularArithmeticBase> Eq for ModularArithmetic<T> {}
impl<T: ModularArithmeticBase> PartialOrd for ModularArithmetic<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T: ModularArithmeticBase> Ord for ModularArithmetic<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.value.cmp(&other.value)
}
}
impl<T: ModularArithmeticBase> std::ops::Neg for ModularArithmetic<T> {
type Output = Self;
fn neg(self) -> Self::Output {
let ret = match self.modulus().modulus() {
None => self.value().wrapping_neg(),
Some(modulus) => modulus - self.value(),
};
Self {
value: ret,
modulus: self.modulus,
}
}
}
impl<T: ModularArithmeticBase> std::ops::Sub for ModularArithmetic<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.modulus(), rhs.modulus());
if self < rhs {
return -(rhs - self);
}
Self {
value: self.value() - rhs.value(),
modulus: self.modulus,
}
}
}
impl<T: ModularArithmeticBase> std::ops::Add for ModularArithmetic<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
self - (-rhs)
}
}

View File

@@ -0,0 +1,12 @@
//! Constants for u64 <-> usize conversion
/// Largest numeric value that can be safely represented both as
/// a u64 and usize
pub const MAX_USIZE_IN_U64: usize = match u64::BITS >= usize::BITS {
true => usize::MAX,
false => u64::MAX as usize,
};
/// Largest numeric value that can be safely represented both as
/// a u64 and usize
pub const MAX_U64_IN_USIZE: u64 = MAX_USIZE_IN_U64 as u64;

View File

@@ -0,0 +1,21 @@
//! Making sure size representations fit both into [usize] and [u64]
use static_assertions::const_assert;
mod constants;
pub use constants::*;
#[allow(clippy::module_inception)]
mod u64uint;
pub use u64uint::*;
mod range;
pub use range::*;
/// Safe conversion from usize to u64
///
/// TODO: Deprecate this
pub const fn usize_to_u64(v: usize) -> u64 {
const_assert!(u64::BITS >= usize::BITS);
v as u64
}

View File

@@ -0,0 +1,29 @@
//! Working with ranges of [super::U64Uint]
use std::ops::Range;
use super::U64USize;
/// Extensions for working with [std::ops::Range] of [U64USize]
pub trait U64USizeRangeExt {
/// Convert to a usize based range
fn usize(self) -> Range<usize>;
/// Convert to a u64 based range
fn u64(self) -> Range<u64>;
}
impl U64USizeRangeExt for Range<U64USize> {
fn usize(self) -> Range<usize> {
Range {
start: self.start.usize(),
end: self.end.usize(),
}
}
fn u64(self) -> Range<u64> {
Range {
start: self.start.u64(),
end: self.end.u64(),
}
}
}

View File

@@ -0,0 +1,192 @@
//! The [U64UInt type]
use super::MAX_U64_IN_USIZE;
/// Error produced by [U64USize::try_new]
#[derive(Debug, thiserror::Error)]
pub enum U64USizeConversionError<T: std::fmt::Debug> {
/// Value can not be represented as u64
#[error("Value can not be represented as a u64 (max = {}) value: {:?}", u64::MAX, .0)]
NoU64Repr(T),
/// Value can not be represented as usize
#[error("Value can not be represented as a usize (max = {}) value: {:?}", usize::MAX, .0)]
NoUSizeRepr(T),
/// Value can not be represented as usize or u64
#[error("Value can not be represented as a u64 (max = {}) or usize (max = {}) value: {:?}", u64::MAX, usize::MAX, .0)]
NoU64OrUSizeRepr(T),
}
/// A number that can be represented as both a usize and a u64.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct U64USize {
/// Enclosed data
storage: u64,
}
impl U64USize {
/// Cast another number to a number that can be represented as usize and u64;
///
/// This is the internal version which does not use [TryInto]; we use this to implement
/// [TryInto].
fn try_new_internal<T>(v: T) -> Result<Self, U64USizeConversionError<T>>
where
T: Copy + TryInto<usize> + TryInto<u64> + std::fmt::Debug,
{
use U64USizeConversionError as E;
let v_u64: Result<u64, _> = v.try_into();
let v_usize: Result<usize, _> = v.try_into();
match (v_u64, v_usize) {
(Ok(storage), Ok(_)) => Ok(Self { storage }),
(Err(_), Ok(_)) => Err(E::NoU64Repr(v)),
(Ok(_), Err(_)) => Err(E::NoUSizeRepr(v)),
(Err(_), Err(_)) => Err(E::NoU64OrUSizeRepr(v)),
}
}
/// Cast another number to a number that can be represented as usize and u64
pub fn try_new<T>(v: T) -> Result<U64USize, <T as TryInto<Self>>::Error>
where
T: TryInto<Self>,
{
v.try_into()
}
/// Like [Self::try_new], but panics
pub fn new_or_panic<T>(v: T) -> Self
where
T: TryInto<Self>,
<T as TryInto<Self>>::Error: std::fmt::Debug,
{
match Self::try_new(v) {
Ok(v) => v,
Err(e) => panic!(
"Could not construct {}: {e:?}",
std::any::type_name::<Self>()
),
}
}
/// Return this value as a usize
pub fn usize(&self) -> usize {
self.storage as usize
}
/// Return this value as a u64
pub fn u64(&self) -> u64 {
self.storage
}
}
impl std::ops::Sub for U64USize {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self::new_or_panic(self.u64() - rhs.u64())
}
}
impl std::ops::Add for U64USize {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self::new_or_panic(self.u64() + rhs.u64())
}
}
/// Facilitates creation of [U64USize] in cases where truncation to the largest representable value is permissible
pub trait TruncateIntoU64USize {
/// Check whether calling [TruncateIntoU64Usize::truncate_to_u64usize] would truncate or return
/// the value as-is
fn fits_into_u64usize(&self) -> bool;
/// Turn [Self] into a [U64USize]. If the value is representable as a usize and a u64, then
/// the value will be returned as is. Otherwise, the maximum representable value [MAX_U64_IN_USIZE]
/// will be returned.
fn truncate_to_u64usize(&self) -> U64USize;
}
/// Create instances of From for SafeUSize
macro_rules! derive_TruncateIntoU64Usize {
($($T:ty),*) => {
$(
impl TruncateIntoU64USize for $T {
fn fits_into_u64usize(&self) -> bool {
U64USize::try_new(*self).is_ok()
}
fn truncate_to_u64usize(&self) -> U64USize {
U64USize::try_new(*self).unwrap_or(U64USize::new_or_panic(MAX_U64_IN_USIZE))
}
}
)*
}
}
derive_TruncateIntoU64Usize!(
U64USize, usize, isize, bool, u8, u16, u32, u64, u128, i8, i16, i32, i64, i128
);
/// Create instances of From for SafeUSize
macro_rules! U64USize_derive_from {
($($T:ty),*) => {
$(
impl From<$T> for U64USize {
fn from(value: $T) -> Self {
U64USize::new_or_panic::<$T>(value)
}
}
)*
}
}
U64USize_derive_from!(bool, u8, u16);
/// Create instances of TryFrom for SafeUSize
macro_rules! U64USize_derive_try_from {
($($T:ty),*) => {
$(
impl TryFrom<$T> for U64USize {
type Error = U64USizeConversionError<$T>;
fn try_from(value: $T) -> Result<Self, Self::Error> {
U64USize::try_new_internal::<$T>(value)
}
}
)*
}
}
U64USize_derive_try_from!(usize, isize, u32, u64, u128, i8, i16, i32, i64, i128);
/// Create instances of Into for SafeUSize
macro_rules! U64USize_derive_into {
($($T:ty),*) => {
$(
impl From<U64USize> for $T {
fn from(val: U64USize) -> Self {
val.u64() as $T
}
}
)*
}
}
U64USize_derive_into!(usize, u64, u128, i128);
/// Create instances of TryInto for SafeUSize
macro_rules! U64USize_derive_try_into {
($($T:ty),*) => {
$(
impl TryFrom<U64USize> for $T {
type Error = <$T as TryFrom<u64>>::Error;
fn try_from(val: U64USize) -> Result<Self, Self::Error> {
val.u64().try_into()
}
}
)*
}
}
U64USize_derive_try_into!(isize, u8, u16, u32, i8, i16, i32, i64);

3
util/src/ipc/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
//! Inter-process communication related resources
pub mod shm;

6
util/src/ipc/shm/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
//! Resources for working with shared-memory
mod shared_memory_segment;
pub use shared_memory_segment::*;
pub mod ringbuf;

View File

@@ -0,0 +1,40 @@
//! Local shared memory ring buffers mostly for testing
use std::sync::Arc;
use crate::ipc::shm::SharedMemorySegment;
use crate::ringbuf::concurrent::framework::{ConcurrentPipeReader, ConcurrentPipeWriter};
use super::{ShmPipeCore, ShmPipeVariables};
/// A process-local shared memory pipe reader
///
/// See [shm_pipe()].
pub type LocalShmPipeReader = ConcurrentPipeReader<ShmPipeCore<Arc<ShmPipeVariables>>>;
/// A process-local shared memory pipe writer
///
/// See [shm_pipe()].
pub type LocalShmPipeWriter = ConcurrentPipeWriter<ShmPipeCore<Arc<ShmPipeVariables>>>;
/// Creates a process-local shared-memory pipe.
///
/// See [ConcurrentPipeWriter]/[ConcurrentPipeReader]. The types [LocalShmPipeWriter]/[LocalShmPipeReader] are just aliases for these.
///
/// # Safety
///
/// Mind the comments in the safety section of [super::super::SharedMemorySegment]; the issues
/// described in there *exactly* affect this implementation.
pub fn shm_pipe(len: usize) -> anyhow::Result<(LocalShmPipeWriter, LocalShmPipeReader)> {
let shared = Arc::new(ShmPipeVariables::new());
let (seg_fd, seg_buf_1) = SharedMemorySegment::create(len)?;
let seg_buf_2 = unsafe { SharedMemorySegment::from_fd(seg_fd, len) }?;
let writer = ShmPipeCore::new(shared.clone(), seg_buf_1);
let reader = ShmPipeCore::new(shared, seg_buf_2);
Ok((
ConcurrentPipeWriter::from_core(writer),
ConcurrentPipeReader::from_core(reader),
))
}

View File

@@ -0,0 +1,78 @@
//! Shared-memory ring buffer implementations (main part of the enclosing module)
use std::{borrow::Borrow, sync::atomic::AtomicU64};
use crate::{
ipc::shm::SharedMemorySegment,
ringbuf::concurrent::framework::{
ConcurrentPipeCore, ConcurrentPipeReader, ConcurrentPipeWriter,
},
};
/// Synchronization variables for a [ShmPipeWriter]/[ShmPipeReader].
///
/// These values must be shared between the reader/writer in such a way that access to the inner
/// variables is synchronized and atomic between the two parties.
#[derive(Debug, Default)]
pub struct ShmPipeVariables {
/// See [crate::ringbuf::sched::RingBufferScheduler::items_read()]
pub items_read: AtomicU64,
/// See [crate::ringbuf::sched::RingBufferScheduler::items_written()]
pub items_written: AtomicU64,
}
impl ShmPipeVariables {
/// Constructor
pub fn new() -> Self {
Self {
items_read: 0.into(),
items_written: 0.into(),
}
}
}
/// The [ConcurrentPipeCore] for a shared memory pipe.
#[derive(Debug)]
pub struct ShmPipeCore<Variables: Borrow<ShmPipeVariables>> {
/// The synchronization variables
variables: Variables,
/// The memory buffer
buf: SharedMemorySegment,
}
impl<Variables: Borrow<ShmPipeVariables>> ShmPipeCore<Variables> {
/// Constructor
pub fn new(variables: Variables, buf: SharedMemorySegment) -> Self {
Self { variables, buf }
}
}
/// A shared memory pipe reader
pub type ShmPipeReader<Variables> = ConcurrentPipeReader<ShmPipeCore<Variables>>;
/// A shared memory pipe writer
pub type ShmPipeWriter<Variables> = ConcurrentPipeWriter<ShmPipeCore<Variables>>;
impl<Variables: Borrow<ShmPipeVariables>> ConcurrentPipeCore for ShmPipeCore<Variables> {
type AtomicType = AtomicU64;
fn buf_len(&self) -> u64 {
self.buf.len() as u64
}
fn items_read(&self) -> &AtomicU64 {
&self.variables.borrow().items_read
}
fn items_written(&self) -> &AtomicU64 {
&self.variables.borrow().items_written
}
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64) {
self.buf.volatile_read(dst, off as usize)
}
fn write_to_buffer(&mut self, off: u64, src: &[u8]) {
self.buf.volatile_write(off as usize, src);
}
}

View File

@@ -0,0 +1,7 @@
//! Shared-memory ring buffers
mod main;
pub use main::*;
mod local;
pub use local::*;

View File

@@ -0,0 +1,367 @@
//! Accessing data in a shared memory segment
use std::{
borrow::Borrow,
os::fd::{AsFd, OwnedFd},
};
use crate::{
int::u64uint::usize_to_u64,
ptr::{ReadMemVolatile, WriteMemVolatile},
result::OkExt,
secret_memory::{
fd::SecretMemfdConfig,
mmap::{MMapError, MapFdConfig, MappedSegment},
},
};
/// Safe creation of shared memory segments
///
/// This is a slightly more convenient API than using [MapFdConfig] directly.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SharedMemorySegmentBuilder {
/// Configuration for allocating memory file descriptors and their secrecy level
pub secret_memfd_cfg: SecretMemfdConfig,
/// Configuration for mapping file descriptors into memory
pub map_fd_cfg: MapFdConfig,
}
impl SharedMemorySegmentBuilder {
/// Create a new secret memory segment builder.
///
/// Note that this always sets [MapFdConfig::set_shared()], but you can overwrite this behavior
/// by un-setting the flag again.
///
/// # Safety & panic
///
/// This function can panic when [u64] is unable to represent the given size
/// value. This might be the case in future computers impressing me with their excessive size.
pub const fn new(len: usize) -> Self {
let secret_memfd_cfg = SecretMemfdConfig::new();
let map_fd_cfg = MapFdConfig::new()
.set_shared()
.resize_on_mmap(usize_to_u64(len));
Self {
secret_memfd_cfg,
map_fd_cfg,
}
}
/// Create a secret memory segment using the configuration stored here
pub fn create_segment(&self) -> anyhow::Result<(OwnedFd, SharedMemorySegment)> {
let fd = self.secret_memfd_cfg.create()?;
let seg = self.map_fd_cfg.mappable_fd(&fd).mmap()?;
let seg = unsafe { SharedMemorySegment::from_mapped_segment(seg) };
Ok((fd, seg))
}
}
/// Safe creation of and access to shared memory segments
///
/// # Safety
///
/// Any means to create a [Self] must guarantee, that further calls to [Self::volatile_write] and
/// [Self::volatile_read] are also safe.
///
/// The API of this struct is specifically designed so creating multiple memory mappings of the same
/// shared memory segment is impossible without unsafe code.
///
/// We recognize that the sole purpose of shared memory segments is that multiple mappings of them
/// are created, there just is no way to do so using safe rust.
///
/// The reason for this is that by the Rust documentation any concurrent memory access leads to
/// undefined behavior, even volatile accesses caused solely by an adversarial application on the
/// other end of a shared memory communication channel.
///
/// For this reason, we force users to use unsafe code to create shared memory mappings from an
/// existing file descriptor, as this could potentially lead to adversarial data races (and thus to
/// undefined behavior).
///
/// In practice, we believe using concurrent memory access using volatile operations is going to
/// lead to nothing worse than garbled data being transferred. The user should treat any data
/// received through a shared-memory ring buffer as untrusted and validate this data any way, so
/// garbled data should be caught.
///
/// This means that [Self::from_fd()] is unsafe in theory, but most likely safe in practice.
///
/// ## Concurrent, untrusted shared memory is technically undefined behavior
///
/// Its even worse than having to use unsafe: technically speaking, it may be impossible to
/// use shared memory soundly in rust unless all parties with access to the segment are
/// *trusted*. If these parties are not trusted (or buggy) they can always cause undefined
/// behavior:
///
/// From the [std::sync::atomic] documentation:
///
/// > **The most important aspect of this model is that data races are undefined behavior.** A data race
/// > is defined as conflicting non-synchronized accesses where at least one of the accesses is non-atomic.
/// > Here, accesses are conflicting if they affect overlapping regions of memory and at least one of them
/// > is a write. (A compare_exchange or compare_exchange_weak that does not succeed is not considered a
/// > write.) They are non-synchronized if neither of them happens-before the other, according to the
/// > happens-before order of the memory model.
///
/// The fact that this API uses volatile [reads](Self::volatile_read) and [writes](Self::volatile_write),
/// and the the fact that we use mmap(2) for allocation does not mitigate this issue; from the
/// documentation of [std::ptr::write_volatile()]:
///
/// > When a volatile operation is used for memory inside an allocation, it behaves exactly like write,
/// > except for the additional guarantee that it wont be elided or reordered (see above). This implies
/// > that the operation will actually access memory and not e.g. be lowered to a register access. Other
/// > than that, all the usual rules for memory accesses apply (including provenance). In particular, just
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
/// > access from multiple threads. Volatile accesses behave exactly like non-atomic accesses in that regard.
///
/// An allocation is defined as follows (taken from [std::ptr]):
///
/// > An allocation is a subset of program memory which is addressable from Rust, and within which pointer
/// > arithmetic is possible. Examples of allocations include heap allocations, stack-allocated variables,
/// > statics, and consts. The safety preconditions of some Rust operations - such as offset and field
/// > projections (expr.field) - are defined in terms of the allocations on which they operate.
///
/// This definition clearly applies to mmap(2) allocated regions.
///
/// What might mitigate this issue is mapping the region just once per process:
///
/// > In particular, just
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
/// > access from **multiple threads**.
///
/// We could argue that a process is not a thread, and thus concurrent access from two processes is
/// fine, but concurrent access from two threads is not (unless guarded by an atomic value or a
/// mutex or some primitive actually designed for synchronization).
///
/// There is no wording in the spec explicitly allowing raceful, concurrent access from multiple processes.
///
/// The problem with basing our safety-argument on the claim that "processes are not threads" is
/// that the line between processes and threads is drawn in the sand. For linux, read the man page
/// of clone(2):
///
/// > By contrast with fork(2), these [clone, __clone2, clone3] system calls provide more precise control over what pieces of execution
/// > context are shared between the calling process and the child process. For example, using these system
/// > calls, the caller can control whether or not the two processes share the virtual address space, the ta
/// > ble of file descriptors, and the table of signal handlers. These system calls also allow the new child
/// > process to be placed in separate namespaces(7).
/// >
/// > […]
/// >
/// > ## CLONE_THREAD (since Linux 2.4.0)
/// >
/// > If CLONE_THREAD is set, the child is placed in the same thread group as the calling process. To
/// > make the remainder of the discussion of CLONE_THREAD more readable, the term "thread" is used to
/// > refer to the processes within a thread group.
///
/// According to the man page, "the term 'thread' is used to refer to the processes within a thread group.".
///
/// The Rust (transitively, from the C++11 Atomic) specification tells us that there must be no
/// concurrent memory access between threads whether this access is volatile or not. The linux man
/// pages tell us that "thread" is just a special type of "process".
///
/// **The most robust interpretation of these specifications is that shared memory must not be used for
/// communication with an untrusted party across thread or process boundaries, or else the other
/// process/thread can cause undefined behavior in our process.**
///
/// ## In practice
///
/// Realistically, using volatile reads/writes on valid, mapped memory might cause garbled values in
/// case of a data race, but it should crash the program or do anything worse than create garbled
/// values.
///
/// Mind that we do not mind garbled values here; we are implementing a shared memory communication
/// interface, so our application must always assume, that the data it receives may be garbled. It
/// has to be validated. We just don't want the other application to be able to do anything worse
/// that garble the data it is sending (or receiving), so lets estimate what can *realistically*
/// happen here if the other application maliciously causes a race.
///
/// The worst any of the assembly sequences below should do is cause tearing in case of a data
/// race.
///
/// This leads me to the conclusion that what what we are dealing here with is not an
/// implementation that is faulty/insecure, instead it is a definition-gap in the compiler
/// semantics for volatile memory access for use in security-critical applications.
///
/// Godbolt link: https://rust.godbolt.org/z/GGjsGsc33
/// Compiler: `rustc 1.90.0`
///
/// Rust code:
///
/// ```rust
/// #[unsafe(no_mangle)]
/// pub fn read_volatile(num: &[u128]) -> u128 {
/// let ptr = num.as_ptr();
/// unsafe { ptr.read_volatile() }
/// }
///
/// #[unsafe(no_mangle)]
/// pub fn write_volatile(num: &mut [u128]) {
/// let ptr = num.as_mut_ptr();
/// unsafe { ptr.write_volatile(42u128) };
/// }
/// ```
///
/// x86_64 (`--target=x86_64-unknown-linux-gnu -O`):
///
/// ```asm
/// read_volatile:
/// mov rax, qword ptr [rdi]
/// mov rdx, qword ptr [rdi + 8]
/// ret
///
/// write_volatile:
/// mov qword ptr [rdi + 8], 0
/// mov qword ptr [rdi], 42
/// ret
/// ```
///
/// arm64 (`--target=aarch64-unknown-linux-gnu -O`):
///
/// ```asm
/// read_volatile:
/// ldp x0, x1, [x0]
/// ret
///
/// write_volatile:
/// mov w8, #42
/// stp x8, xzr, [x0]
/// ret
/// ```
///
/// armv7 (`--target=armv7-unknown-linux-gnueabihf -O`)
///
/// ```
/// read_volatile:
/// push {r4, r5, r11, lr}
/// ldrd r2, r3, [r1]
/// ldrd r4, r5, [r1, #8]
/// stm r0, {r2, r3, r4, r5}
/// pop {r4, r5, r11, pc}
///
/// write_volatile:
/// push {r4, r5, r11, lr}
/// mov r2, #0
/// mov r4, #42
/// mov r3, r2
/// mov r5, r2
/// strd r2, r3, [r0, #8]
/// strd r4, r5, [r0]
/// pop {r4, r5, r11, pc}
/// ```
///
/// risc64 (`--target=riscv64gc-unknown-linux-gnu`):
///
/// ```
/// read_volatile:
/// ld a1, 8(a0)
/// ld a0, 0(a0)
/// ret
///
/// write_volatile:
/// sd zero, 8(a0)
/// li a1, 42
/// sd a1, 0(a0)
/// ret
/// ```
///
#[derive(Debug)]
pub struct SharedMemorySegment {
/// The underlying mapped segment
inner: MappedSegment,
}
impl SharedMemorySegment {
/// Create a new shared memory segment.
pub fn create(len: usize) -> anyhow::Result<(OwnedFd, Self)> {
SharedMemorySegmentBuilder::new(len).create_segment()
}
/// Create a shared memory segment from a file descriptor
///
/// # Safety
///
/// See the comments in [Self].
pub unsafe fn from_fd<Fd: AsFd>(fd: Fd, size: usize) -> Result<Self, MMapError> {
let cfg = MapFdConfig::new()
.set_shared()
.expected_size(usize_to_u64(size));
unsafe { Self::from_fd_with_config(fd, cfg) }
}
/// Create a shared memory segment from a file descriptor
///
/// # Safety
///
/// See the comments in [Self].
pub unsafe fn from_fd_with_config<Fd: AsFd>(
fd: Fd,
cfg: MapFdConfig,
) -> Result<Self, MMapError> {
let segment = cfg.mappable_fd(&fd).mmap()?;
unsafe { Self::from_mapped_segment(segment).ok() }
}
/// Create a shared memory segment from an existing mapped segment
///
/// # Safety
///
/// See the comments in [Self].
pub unsafe fn from_mapped_segment(inner: MappedSegment) -> Self {
Self { inner }
}
/// The underlying mapped segment
pub fn mapped_segment(&self) -> &MappedSegment {
self.inner.borrow()
}
/// A pointer to the underlying mapped segment
pub fn ptr(&self) -> *mut u8 {
self.mapped_segment().ptr()
}
/// The length of the underlying mapped segment
pub fn len(&self) -> usize {
self.mapped_segment().len()
}
/// Whether `self.`[len()](Self::len)` == 0`
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Read data from the ring buffer
///
/// # Safety
///
/// See comments in [Self].
pub fn volatile_read(&self, dst: &mut [u8], off: usize) {
let end = off + dst.len();
assert!(end <= self.len());
unsafe { self.ptr().add(off).read_mem_volatile(dst) }
}
/// Read data from the ring buffer
///
/// # Safety
///
/// See comments in [Self].
pub fn volatile_write(&self, off: usize, src: &[u8]) {
let end = off + src.len();
assert!(end <= self.len());
unsafe { self.ptr().add(off).write_mem_volatile(src) }
}
}
impl From<SharedMemorySegment> for MappedSegment {
fn from(val: SharedMemorySegment) -> Self {
val.inner
}
}
impl Borrow<MappedSegment> for SharedMemorySegment {
fn borrow(&self) -> &MappedSegment {
self.mapped_segment()
}
}

View File

@@ -10,15 +10,16 @@ pub mod b64;
pub mod build;
/// Control flow abstractions and utilities.
pub mod controlflow;
/// File descriptor utilities.
pub mod fd;
pub mod convert;
/// File system operations and handling.
pub mod file;
pub mod fmt;
/// Functional programming utilities.
pub mod functional;
pub mod int;
/// Input/output operations.
pub mod io;
pub mod ipc;
/// Length prefix encoding schemes implementation.
pub mod length_prefix_encoding;
/// Memory manipulation and allocation utilities.
@@ -27,8 +28,13 @@ pub mod mem;
pub mod mio;
/// Extended Option type functionality.
pub mod option;
pub mod ptr;
/// Extended Result type functionality.
pub mod result;
pub mod ringbuf;
pub mod rustix;
pub mod secret_memory;
pub mod sync;
/// Time and duration utilities.
pub mod time;
#[cfg(feature = "tokio")]

View File

@@ -280,6 +280,108 @@ impl<T: Sized> MoveExt for T {
}
}
/// Explicitly copy a value
///
/// This is sometimes useful as an alternative to the deref operator
/// (`x.copy()` instead of `*x`) to achieve neater method chains.
///
/// # Examples
///
/// ```
/// use rosenpass_util::mem::CopyExt;
/// use rosenpass_util::result::OkExt;
///
/// // Without CopyExt
/// fn boilerplate_1<T: Copy>(v: &T) -> Result<T, ()> {
/// (*v).ok()
/// }
///
/// // With CopyExt
/// fn boilerplate_2<T: Copy>(v: &T) -> Result<T, ()> {
/// v.copy().ok()
/// }
///
/// assert_eq!(boilerplate_1(&32), Ok(32));
/// assert_eq!(boilerplate_1(&32), boilerplate_2(&32));
/// ```
pub trait CopyExt: Copy {
/// Copy a value
fn copy(&self) -> Self {
*self
}
}
impl<T: Copy> CopyExt for T {}
/// Helper trait for applying a mutating closure/function to a mutable reference
///
/// # Examples
///
/// ```rust
/// fn plus_two_mul_three(v: u64) -> u64 {
/// (v + 2) * 3
/// }
///
/// fn inconvenient_example(v: &mut u64) {
/// *v = plus_two_mul_three(*v);
/// }
///
/// fn convenient_example(v: &mut u64) {
/// v.mutate(plus_two_mul_three);
/// }
///
/// let mut x = 0;
/// inconvenient_example(&mut x);
/// assert_eq!(x, 6);
/// convenient_example(&mut x);
/// assert_eq!(x, 24);
///
/// x = 0;
///
/// x.mutate(plus_two_mul_three);
/// assert_eq!(x, 6);
/// x.mutate(|v| (v + 2) * 3);
/// assert_eq!(x, 24);
///
/// x = 0;
///
/// let y = &mut x;
/// y.mutate(plus_two_mul_three);
/// assert_eq!(x, 6);
///
/// struct Cont {
/// x: u64,
/// }
///
/// impl Cont {
/// fn x_mut(&mut self) -> &mut u64 {
/// &mut self.x
/// }
/// }
///
/// let mut s = Cont { x: 0 };
/// s.x_mut().mutate(|v| (v + 2) * 3);
/// ```
pub trait MutateRefExt: DerefMut
where
Self::Target: Sized,
{
/// Directly mutate a reference
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(self, f: F) -> Self;
}
impl<T> MutateRefExt for T
where
T: DerefMut,
<T as Deref>::Target: Copy,
{
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(mut self, f: F) -> Self {
let v = self.deref_mut();
*v = f(*v);
self
}
}
#[cfg(test)]
mod test_forgetting {
use crate::mem::Forgetting;

View File

@@ -2,8 +2,8 @@ use mio::net::{UnixListener, UnixStream};
use std::os::fd::{OwnedFd, RawFd};
use crate::{
fd::{claim_fd, claim_fd_inplace},
result::OkExt,
rustix::{claim_fd, claim_fd_inplace},
};
/// Module containing I/O interest flags for Unix operations (see also: [mio::Interest])
@@ -93,7 +93,7 @@ impl UnixStreamExt for UnixStream {
fn from_fd(fd: OwnedFd) -> anyhow::Result<Self> {
use std::os::unix::net::UnixStream as StdUnixStream;
#[cfg(target_os = "linux")] // TODO: We should support this on other plattforms
crate::fd::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
crate::rustix::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
let sock = StdUnixStream::from(fd);
sock.set_nonblocking(true)?;
UnixStream::from_std(sock).ok()

View File

@@ -7,7 +7,7 @@ use std::{
};
use uds::UnixStreamExt as FdPassingExt;
use crate::fd::{claim_fd_inplace, IntoStdioErr};
use crate::rustix::{claim_fd_inplace, IntoStdioErr};
/// A wrapper around a socket that combines reading from the socket with tracking
/// received file descriptors. Limits the maximum number of file descriptors that

4
util/src/ptr/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Utilities for working with pointers
mod volatile;
pub use volatile::*;

51
util/src/ptr/volatile.rs Normal file
View File

@@ -0,0 +1,51 @@
//! Utilities relating to volatile reads/writes on pointers
/// Read from a memory location using
/// [pointer::read_volatile] in a loop and store the
/// results in the given slice
pub trait ReadMemVolatile<T> {
/// Read from a memory location using
/// [pointer::read_volatile] in a loop and store the
/// results in an array
///
/// # Safety
///
/// Refer to [pointer::read_volatile]
unsafe fn read_mem_volatile(self, dst: &mut [T]);
}
impl<T> ReadMemVolatile<T> for *const T {
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
for (idx, dst) in dst.iter_mut().enumerate() {
*dst = unsafe { self.add(idx).read_volatile() };
}
}
}
impl<T> ReadMemVolatile<T> for *mut T {
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
unsafe { self.cast_const().read_mem_volatile(dst) }
}
}
/// Write to a memory location using
/// [pointer::write_volatile] in a loop
/// and store the resulting values in the given slice
pub trait WriteMemVolatile<T>: ReadMemVolatile<T> {
/// Write to a memory location using
/// [pointer::write_volatile] in a loop
/// and store the resulting values in the given slice
///
/// # Safety
///
/// Refer to [pointer::write_volatile]
unsafe fn write_mem_volatile(self, src: &[T]);
}
impl<T: Copy> WriteMemVolatile<T> for *mut T {
unsafe fn write_mem_volatile(self, src: &[T]) {
for (idx, src) in src.iter().enumerate() {
unsafe { self.add(idx).write_volatile(*src) }
}
}
}

View File

@@ -0,0 +1,3 @@
//! Concurrent ring buffers
pub mod framework;

View File

@@ -0,0 +1,202 @@
//! An implementation of a concurrent ring buffer with the
//! core (IO/memory access) operations in a detached trait. Everything around the core is
//! implemented but before being able to use this, callers must implement an appropriate IO core.
use crate::int::u64uint::U64USizeRangeExt;
use crate::result::OkExt;
use crate::ringbuf::sched::{
Diff, OperationType, RingBufferFromCountersError, RingBufferScheduler, ScheduledOperations,
};
use crate::sync::atomic::abstract_atomic::AbstractAtomic;
pub trait ConcurrentPipeCore {
type AtomicType: AbstractAtomic<u64>;
fn buf_len(&self) -> u64;
fn items_read(&self) -> &Self::AtomicType;
fn items_written(&self) -> &Self::AtomicType;
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64);
fn write_to_buffer(&mut self, off: u64, src: &[u8]);
}
/// Raised by [ConcurrentPipeReader::read()]/[ConcurrentPipeWriter::write()] if the underlying
/// ring buffer is in an inconsistent state.
///
/// When used in shared memory applications, this error may have been locally caused, but it may
/// also be caused by **the other threads or processes** putting the ring buffer into an
/// inconsistent state.
///
/// For this reason, you must treat this error as an unrecoverable breakdown of the communication channel. You
/// must close and no longer use ring buffer after receiving this error, but you can handle it
/// gracefully by reopening a new ring buffer in its place.
///
/// # API Stability
///
/// Treat this error type as opaque; do not rely on the enum values remaining the same
#[derive(Debug, thiserror::Error, Clone)]
pub enum InconsistentRingBufferStateError {
/// Failed to update an inner counter; this probably indicates a data race (forbidden
/// concurrent read/write)
#[error("Concurrent access to the ring buffer {:?}", self)]
ConcurrrentAccess {
/// Whether this was a read or write
operation_type: OperationType,
/// Operations performed before updaing the ring buffer state
scheduled_ops: ScheduledOperations,
/// The scheduler that scheduled our operation
scheduler_state: RingBufferScheduler,
/// The counter value before compare and exchange
expected_counter_value: u64,
/// The counter value we actually found in the counter
actual_counter_value: u64,
/// The counter value we tried to set
new_counter_value_tried_to_set: u64,
},
/// Could not construct ring buffer scheduler
#[error("Inconsistent ring buffer state: {:?}", .0)]
InconsistentCounterState(#[from] RingBufferFromCountersError),
}
#[derive(Debug)]
enum ConcurrentPipeOperation<'a> {
Read(&'a mut [u8]),
Write(&'a [u8]),
}
impl<'a> ConcurrentPipeOperation<'a> {
pub fn inner_buf(&'a self) -> &'a [u8] {
match self {
ConcurrentPipeOperation::Read(items) => items,
ConcurrentPipeOperation::Write(items) => items,
}
}
pub fn len(&self) -> usize {
self.inner_buf().len()
}
pub fn scheduler_op(&self) -> OperationType {
match self {
ConcurrentPipeOperation::Read(_) => OperationType::Read,
ConcurrentPipeOperation::Write(_) => OperationType::Write,
}
}
}
struct ConcurrentPipeImpl<Core: ConcurrentPipeCore> {
core: Core,
}
impl<Core: ConcurrentPipeCore> ConcurrentPipeImpl<Core> {
fn from_core(core: Core) -> Self {
Self { core }
}
fn read_or_write(
&mut self,
mut op: ConcurrentPipeOperation,
) -> Result<usize, InconsistentRingBufferStateError> {
use std::sync::atomic::Ordering as O;
// Figure out which counter to store the result of the operation in and the orderings to
// use for load/store operations
let (ord_r, ord_w, ord_store_succ, ord_store_fail) = match op {
ConcurrentPipeOperation::Read(_) => (O::Relaxed, O::Acquire, O::Relaxed, O::Relaxed),
ConcurrentPipeOperation::Write(_) => (O::Relaxed, O::Relaxed, O::Release, O::Relaxed),
};
// Construct a ring buffer scheduler from the current state
let sched = RingBufferScheduler::try_from_counters(
self.core.buf_len(),
self.core.items_written().load(ord_w),
self.core.items_read().load(ord_r),
)?;
// Have the scheduler schedule the operations
let ops = sched.schedule_contigous_operations(op.scheduler_op(), op.len());
// Actually perform the operations
for (buf_slice, ring_op) in ops.with_outside_buffer_range() {
match &mut op {
ConcurrentPipeOperation::Read(dst) => self
.core
.read_from_buffer(&mut dst[buf_slice.usize()], ring_op.off),
ConcurrentPipeOperation::Write(src) => self
.core
.write_to_buffer(ring_op.off, &src[buf_slice.usize()]),
}
}
// Take a differential between the scheduler and the scheduler after the operations where
// applied. Make sure only one counter was updated
let diff = sched
.register_operation(&ops)
.diff_old(&sched)
.expect_op_only(op.scheduler_op())
.unwrap();
let store_to_ctr = match op {
ConcurrentPipeOperation::Read(_) => self.core.items_read(),
ConcurrentPipeOperation::Write(_) => self.core.items_written(),
};
// Update the counters, assuming there was any change at all
if let Diff::Different(old, new) = diff {
store_to_ctr
.compare_exchange(old, new, ord_store_succ, ord_store_fail)
.map_err(
|actual| InconsistentRingBufferStateError::ConcurrrentAccess {
operation_type: op.scheduler_op(),
scheduled_ops: ops,
scheduler_state: sched,
expected_counter_value: old,
actual_counter_value: actual,
new_counter_value_tried_to_set: new,
},
)?;
}
ops.cumulative_operation_length().usize().ok()
}
}
pub struct ConcurrentPipeReader<Core: ConcurrentPipeCore> {
inner: ConcurrentPipeImpl<Core>,
}
impl<Core: ConcurrentPipeCore> ConcurrentPipeWriter<Core> {
pub fn from_core(core: Core) -> Self {
Self {
inner: ConcurrentPipeImpl::from_core(core),
}
}
pub fn buf_len(&self) -> u64 {
self.inner.core.buf_len()
}
pub fn write(&mut self, src: &[u8]) -> Result<usize, InconsistentRingBufferStateError> {
self.inner
.read_or_write(ConcurrentPipeOperation::Write(src))
}
}
pub struct ConcurrentPipeWriter<Core: ConcurrentPipeCore> {
inner: ConcurrentPipeImpl<Core>,
}
impl<Core: ConcurrentPipeCore> ConcurrentPipeReader<Core> {
pub fn from_core(core: Core) -> Self {
Self {
inner: ConcurrentPipeImpl::from_core(core),
}
}
pub fn buf_len(&self) -> u64 {
self.inner.core.buf_len()
}
pub fn read(&mut self, dst: &mut [u8]) -> Result<usize, InconsistentRingBufferStateError> {
self.inner.read_or_write(ConcurrentPipeOperation::Read(dst))
}
}

View File

@@ -0,0 +1,3 @@
//! Concurrent ring buffers
pub mod framework;

4
util/src/ringbuf/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Ring buffers and utilities for working with them
pub mod concurrent;
pub mod sched;

1376
util/src/ringbuf/sched.rs Normal file

File diff suppressed because it is too large Load Diff

117
util/src/rustix/error.rs Normal file
View File

@@ -0,0 +1,117 @@
//! Rustix extensions for error handling
/// Provides access to the last system error number
///
/// > The integer variable errno is set by system calls and some library functions in the event of an error to indicate what went wrong.
///
/// -- `man 3 errno`
///
/// # Panics
///
/// This function panics if ther
///
/// # Examples
///
/// ```rust
///
/// use rustix::io::Errno as E;
/// use rosenpass_util::rustix::{errno, try_errno, last_os_result};
///
/// let res = unsafe { libc::mkdir(c"/tmp/baz".as_ptr(), 0) };
/// assert_eq!(res, -1);
/// assert_eq!(errno(), E::EXIST);
/// assert_eq!(try_errno(), Some(E::EXIST));
/// assert_eq!(last_os_result(), Err(E::EXIST));
///
/// // Deliberately clear the system error
/// unsafe { libc::__errno_location().write(0) };
/// // assert_eq!(errno(), _); // PANICS
/// assert_eq!(try_errno(), None);
/// assert_eq!(last_os_result(), Ok(()));
/// ```
///
/// Calling errno() when there is no error causes a panic:
///
/// ```rust,should_panic
///
/// use rustix::io::Errno as E;
/// use rosenpass_util::rustix::errno;
///
/// // Deliberately clear the system error
/// unsafe { libc::__errno_location().write(0) };
/// errno(); // PANICS
/// ```
pub fn errno() -> rustix::io::Errno {
match try_errno() {
None => panic!("Tried to retrieve last system error, but there was no system error (the system error number, errno = 0)"),
Some(errno) => errno,
}
}
/// Provides access to the last system error number.
///
/// Variant of [errno] that will return None if there was no system error.
///
/// # Examples
///
/// See [errno].
pub fn try_errno() -> Option<rustix::io::Errno> {
let raw = unsafe { libc::__errno_location().read() };
match raw {
0 => None,
_ => Some(rustix::io::Errno::from_raw_os_error(raw)),
}
}
/// Provides access to the last system error number.
///
/// Variant of [errno] that will return `Err(errno)` if there
/// was a system error and `Ok(())` otherwise.
///
/// # Examples
///
/// See [errno].
pub fn last_os_result() -> Result<(), rustix::io::Errno> {
match try_errno() {
None => Ok(()),
Some(errno) => Err(errno),
}
}
/// Convert low level errors into std::io::Error
///
/// # Examples
///
/// ```
/// use std::io::ErrorKind as EK;
/// use rustix::io::Errno;
/// use rosenpass_util::rustix::IntoStdioErr;
///
/// let e = Errno::INTR.into_stdio_err();
/// assert!(matches!(e.kind(), EK::Interrupted));
///
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
/// ```
pub trait IntoStdioErr {
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
type Target;
/// Convert low level errors to
fn into_stdio_err(self) -> Self::Target;
}
impl IntoStdioErr for rustix::io::Errno {
type Target = std::io::Error;
fn into_stdio_err(self) -> Self::Target {
std::io::Error::from_raw_os_error(self.raw_os_error())
}
}
impl<T> IntoStdioErr for rustix::io::Result<T> {
type Target = std::io::Result<T>;
fn into_stdio_err(self) -> Self::Target {
self.map_err(IntoStdioErr::into_stdio_err)
}
}

255
util/src/rustix/fd.rs Normal file
View File

@@ -0,0 +1,255 @@
//! Basic utilities for working with file descriptors
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
use rustix::io::fcntl_dupfd_cloexec;
use super::IntoStdioErr;
use crate::mem::Forgetting;
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
/// in case the given file descriptor is still used somewhere
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::{IntoRawFd, AsRawFd};
/// use tempfile::tempdir;
/// use rosenpass_util::rustix::{claim_fd, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let orig = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut claimed = FdIo(claim_fd(orig)?);
///
/// // A different file descriptor is used
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
///
/// // Write some data
/// claimed.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
mask_fd(fd)?;
Ok(new)
}
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid.
///
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::IntoRawFd;
/// use tempfile::tempdir;
/// use rosenpass_util::rustix::{claim_fd_inplace, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let fd = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
///
/// // Write some data
/// fd.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
let tmp = clone_fd_cloexec(&new)?;
clone_fd_to_cloexec(tmp, &mut new)?;
Ok(new)
}
/// Will close the given file descriptor and overwrite
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Example
/// ```
/// # use std::fs::File;
/// # use std::io::Read;
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
/// # use std::os::fd::IntoRawFd;
/// # use rustix::fd::AsFd;
/// # use rosenpass_util::rustix::mask_fd;
///
/// // Open a temporary file
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
/// assert!(fd >= 0);
///
/// // Mask the file descriptor
/// mask_fd(fd).unwrap();
///
/// // Verify the file descriptor now points to `/dev/null`
/// // Reading from `/dev/null` always returns 0 bytes
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
/// let mut buffer = [0u8; 4];
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
/// assert_eq!(bytes_read, 0);
/// ```
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
}
/// Duplicate a file descriptor, setting the close on exec flag
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
/// Avoid stdin, stdout, and stderr
const MINFD: RawFd = 3;
fcntl_dupfd_cloexec(fd, MINFD)
}
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
#[cfg(target_os = "linux")]
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup3, DupFlags};
dup3(fd, new, DupFlags::CLOEXEC)
}
#[cfg(not(target_os = "linux"))]
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup2, fcntl_setfd, FdFlags};
dup2(&fd, new)?;
fcntl_setfd(&new, FdFlags::CLOEXEC)
}
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
/// writing.
///
/// # Safety
///
/// The behavior of the file descriptor when being written to or from is undefined.
///
/// # Examples
///
/// ```
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
/// use rustix::fd::FromRawFd;
/// use rosenpass_util::rustix::open_nullfd;
///
/// let nullfd = open_nullfd().unwrap();
/// ```
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
use rustix::fs::{open, Mode, OFlags};
// TODO: Add tests showing that this will throw errors on use
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
}
/// Read and write directly from a file descriptor
///
/// # Examples
///
/// See [claim_fd].
pub struct FdIo<Fd: AsFd>(pub Fd);
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
rustix::io::read(&self.0, buf).into_stdio_err()
}
}
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
rustix::io::write(&self.0, buf).into_stdio_err()
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
#[test]
#[should_panic]
fn test_claim_fd_invalid_neg() {
let _ = claim_fd(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_invalid_max() {
let _ = claim_fd(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_neg() {
let _ = claim_fd_inplace(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_max() {
let _ = claim_fd_inplace(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_neg() {
let _ = mask_fd(-1);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_max() {
let _ = mask_fd(i64::MAX as RawFd);
}
#[test]
fn test_open_nullfd() -> anyhow::Result<()> {
let mut file = FdIo(open_nullfd()?);
let mut buf = [0; 10];
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
assert!(file.write(&buf).is_err());
Ok(())
}
#[test]
fn test_nullfd_read_write() {
let nullfd = open_nullfd().unwrap();
let mut buf = vec![0u8; 16];
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
assert!(rustix::io::write(&nullfd, b"test").is_err());
}
}

86
util/src/rustix/memfd.rs Normal file
View File

@@ -0,0 +1,86 @@
//! Utilities for working with memory based file descriptors
use std::os::fd::OwnedFd;
use rustix::fs::MemfdFlags;
use rustix::io::Errno;
use rustix::path::Arg as Path;
use bitflags::bitflags;
use crate::convert::IntoTypeExt;
use super::SyscallResult;
/// Create an anonymous file
/// using the memfd_create(2) syscall
///
/// Just forwards to [rustix::fs::memfd_create]
pub fn memfd_create<P: Path>(name: P, flags: MemfdFlags) -> rustix::io::Result<OwnedFd> {
rustix::fs::memfd_create(name, flags)
}
bitflags! {
/// `FD_*` constants for use with [memfd_secret].
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct MemfdSecretFlags: std::ffi::c_uint {
/// FD_CLOEXEC
const CLOEXEC = libc::FD_CLOEXEC as std::ffi::c_uint;
}
}
/// Errors for [create_memfd_secret()]
#[derive(Copy, Clone, Debug, thiserror::Error)]
pub enum MemfdSecretError {
/// memfd_secret(2) not supported on system
#[error("Could not create secret memory segment using memfd_secret(2): not supported on your system")]
NotSupported,
/// Other error
#[error("Could not create secret memory segment using memfd_secret(2): underlying system error: {:?}", .0)]
SystemError(Errno),
}
impl From<Errno> for MemfdSecretError {
fn from(value: Errno) -> Self {
match value {
Errno::NOSYS => Self::NotSupported,
e => Self::SystemError(e),
}
}
}
/// Create an anonymous RAM-based file to access secret memory regions
/// using the memfd_secret(2) syscall
///
/// # Examples
///
/// ```
/// use rustix::io::Errno;
/// use rustix::fs::ftruncate;
///
/// use rosenpass_util::rustix::{memfd_secret, MemfdSecretFlags, IntoStdioErr};
/// use rosenpass_util::io::handle_interrupted;
///
/// let res = memfd_secret(MemfdSecretFlags::empty());
/// let fd = match res {
/// Ok(fd) => fd,
/// // The system might not have memfd_secret enabled; abort the test
/// Err(Errno::NOSYS) => return Ok(()),
/// Err(err) => return Err(err)?,
/// };
///
/// handle_interrupted(|| { ftruncate(&fd, 8192).into_stdio_err() })?;
///
/// Ok::<(), anyhow::Error>(())
/// ```
pub fn memfd_secret(flags: MemfdSecretFlags) -> Result<rustix::fd::OwnedFd, MemfdSecretError> {
let res = unsafe {
use libc::{syscall, SYS_memfd_secret};
syscall(SYS_memfd_secret, flags)
.into_type::<SyscallResult>()
.claim_fd()
};
res.map_err(MemfdSecretError::from)
}

16
util/src/rustix/mod.rs Normal file
View File

@@ -0,0 +1,16 @@
//! Extensions to the rustix crate for memory safe operating system interfaces
mod error;
pub use error::*;
mod fd;
pub use fd::*;
mod stat;
pub use stat::*;
mod syscall;
pub use syscall::*;
mod memfd;
pub use memfd::*;

View File

@@ -1,235 +1,12 @@
//! Utilities for working with file descriptors
//! Rustix extensions for getting information about file descriptors`
use std::os::fd::AsFd;
use anyhow::bail;
use rustix::io::fcntl_dupfd_cloexec;
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
use crate::{mem::Forgetting, result::OkExt};
use super::IntoStdioErr;
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
/// in case the given file descriptor is still used somewhere
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::{IntoRawFd, AsRawFd};
/// use tempfile::tempdir;
/// use rosenpass_util::fd::{claim_fd, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let orig = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut claimed = FdIo(claim_fd(orig)?);
///
/// // A different file descriptor is used
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
///
/// // Write some data
/// claimed.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
mask_fd(fd)?;
Ok(new)
}
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid.
///
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::IntoRawFd;
/// use tempfile::tempdir;
/// use rosenpass_util::fd::{claim_fd_inplace, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let fd = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
///
/// // Write some data
/// fd.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
let tmp = clone_fd_cloexec(&new)?;
clone_fd_to_cloexec(tmp, &mut new)?;
Ok(new)
}
/// Will close the given file descriptor and overwrite
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Example
/// ```
/// # use std::fs::File;
/// # use std::io::Read;
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
/// # use std::os::fd::IntoRawFd;
/// # use rustix::fd::AsFd;
/// # use rosenpass_util::fd::mask_fd;
///
/// // Open a temporary file
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
/// assert!(fd >= 0);
///
/// // Mask the file descriptor
/// mask_fd(fd).unwrap();
///
/// // Verify the file descriptor now points to `/dev/null`
/// // Reading from `/dev/null` always returns 0 bytes
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
/// let mut buffer = [0u8; 4];
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
/// assert_eq!(bytes_read, 0);
/// ```
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
}
/// Duplicate a file descriptor, setting the close on exec flag
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
/// Avoid stdin, stdout, and stderr
const MINFD: RawFd = 3;
fcntl_dupfd_cloexec(fd, MINFD)
}
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
#[cfg(target_os = "linux")]
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup3, DupFlags};
dup3(fd, new, DupFlags::CLOEXEC)
}
#[cfg(not(target_os = "linux"))]
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup2, fcntl_setfd, FdFlags};
dup2(&fd, new)?;
fcntl_setfd(&new, FdFlags::CLOEXEC)
}
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
/// writing.
///
/// # Safety
///
/// The behavior of the file descriptor when being written to or from is undefined.
///
/// # Examples
///
/// ```
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
/// use rustix::fd::FromRawFd;
/// use rosenpass_util::fd::open_nullfd;
///
/// let nullfd = open_nullfd().unwrap();
/// ```
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
use rustix::fs::{open, Mode, OFlags};
// TODO: Add tests showing that this will throw errors on use
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
}
/// Convert low level errors into std::io::Error
///
/// # Examples
///
/// ```
/// use std::io::ErrorKind as EK;
/// use rustix::io::Errno;
/// use rosenpass_util::fd::IntoStdioErr;
///
/// let e = Errno::INTR.into_stdio_err();
/// assert!(matches!(e.kind(), EK::Interrupted));
///
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
/// ```
pub trait IntoStdioErr {
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
type Target;
/// Convert low level errors to
fn into_stdio_err(self) -> Self::Target;
}
impl IntoStdioErr for rustix::io::Errno {
type Target = std::io::Error;
fn into_stdio_err(self) -> Self::Target {
std::io::Error::from_raw_os_error(self.raw_os_error())
}
}
impl<T> IntoStdioErr for rustix::io::Result<T> {
type Target = std::io::Result<T>;
fn into_stdio_err(self) -> Self::Target {
self.map_err(IntoStdioErr::into_stdio_err)
}
}
/// Read and write directly from a file descriptor
///
/// # Examples
///
/// See [claim_fd].
pub struct FdIo<Fd: AsFd>(pub Fd);
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
rustix::io::read(&self.0, buf).into_stdio_err()
}
}
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
rustix::io::write(&self.0, buf).into_stdio_err()
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
use crate::result::OkExt;
/// Helpers for accessing stat(2) information
pub trait StatExt {
@@ -238,7 +15,7 @@ pub trait StatExt {
/// # Examples
///
/// ```
/// use rosenpass_util::fd::StatExt;
/// use rosenpass_util::rustix::StatExt;
/// assert!(rustix::fs::stat("/")?.is_socket() == false);
/// Ok::<(), rustix::io::Errno>(())
/// ````
@@ -263,7 +40,7 @@ pub trait TryStatExt {
/// # Examples
///
/// ```
/// use rosenpass_util::fd::TryStatExt;
/// use rosenpass_util::rustix::TryStatExt;
/// let fd = rustix::fs::open("/", rustix::fs::OFlags::empty(), rustix::fs::Mode::empty())?;
/// assert!(matches!(fd.is_socket(), Ok(false)));
/// Ok::<(), rustix::io::Errno>(())
@@ -358,7 +135,7 @@ pub trait GetUnixSocketType {
/// # use std::os::fd::{AsFd, BorrowedFd};
/// # use std::os::unix::net::UnixListener;
/// # use tempfile::NamedTempFile;
/// # use rosenpass_util::fd::GetUnixSocketType;
/// # use rosenpass_util::rustix::GetUnixSocketType;
/// let f = {
/// // Generate a temp file and take its path
/// // Remove the temp file
@@ -378,7 +155,7 @@ pub trait GetUnixSocketType {
/// # use std::os::fd::{AsFd, BorrowedFd};
/// # use std::os::unix::net::{UnixDatagram, UnixListener};
/// # use tempfile::NamedTempFile;
/// # use rosenpass_util::fd::GetUnixSocketType;
/// # use rosenpass_util::rustix::GetUnixSocketType;
/// let f = {
/// // Generate a temp file and take its path
/// // Remove the temp file
@@ -445,7 +222,7 @@ pub trait GetSocketProtocol {
/// ```
/// # use std::net::UdpSocket;
/// # use std::os::fd::{AsFd, AsRawFd};
/// # use rosenpass_util::fd::GetSocketProtocol;
/// # use rosenpass_util::rustix::GetSocketProtocol;
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
/// assert_eq!(socket.as_fd().socket_protocol().unwrap().unwrap(), rustix::net::ipproto::UDP);
/// # Ok::<(), std::io::Error>(())
@@ -458,7 +235,7 @@ pub trait GetSocketProtocol {
/// # use std::net::UdpSocket;
/// # use std::net::TcpListener;
/// # use std::os::fd::{AsFd, AsRawFd};
/// # use rosenpass_util::fd::GetSocketProtocol;
/// # use rosenpass_util::rustix::GetSocketProtocol;
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
/// assert!(socket.as_fd().is_udp_socket().unwrap());
///
@@ -484,7 +261,7 @@ pub trait GetSocketProtocol {
/// # use std::net::UdpSocket;
/// # use std::net::TcpListener;
/// # use std::os::fd::{AsFd, AsRawFd};
/// # use rosenpass_util::fd::GetSocketProtocol;
/// # use rosenpass_util::rustix::GetSocketProtocol;
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
/// assert!(matches!(socket.as_fd().demand_udp_socket(), Ok(())));
///
@@ -514,62 +291,3 @@ where
rustix::net::sockopt::get_socket_protocol(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
#[test]
#[should_panic]
fn test_claim_fd_invalid_neg() {
let _ = claim_fd(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_invalid_max() {
let _ = claim_fd(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_neg() {
let _ = claim_fd_inplace(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_max() {
let _ = claim_fd_inplace(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_neg() {
let _ = mask_fd(-1);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_max() {
let _ = mask_fd(i64::MAX as RawFd);
}
#[test]
fn test_open_nullfd() -> anyhow::Result<()> {
let mut file = FdIo(open_nullfd()?);
let mut buf = [0; 10];
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
assert!(file.write(&buf).is_err());
Ok(())
}
#[test]
fn test_nullfd_read_write() {
let nullfd = open_nullfd().unwrap();
let mut buf = vec![0u8; 16];
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
assert!(rustix::io::write(&nullfd, b"test").is_err());
}
}

View File

@@ -0,0 +1,33 @@
//! Helpers for performing system calls
use std::os::fd::FromRawFd;
use super::errno;
#[repr(C, packed)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct SyscallResult(pub libc::c_long);
impl SyscallResult {
pub fn raw_value(&self) -> libc::c_long {
self.0
}
/// # Safety
///
/// TODO…
pub unsafe fn claim_fd(&self) -> Result<rustix::fd::OwnedFd, rustix::io::Errno> {
let fde = self.0;
match fde {
e if e < 0 => Err(errno()),
fd if fd > i32::MAX.into() => panic!("File descriptor `{fd}` is out of bounds: "),
fd => Ok(unsafe { rustix::fd::OwnedFd::from_raw_fd(fd as i32) }),
}
}
}
impl From<libc::c_long> for SyscallResult {
fn from(value: libc::c_long) -> Self {
Self(value)
}
}

View File

@@ -0,0 +1,200 @@
//! Creation of secret memory file descriptors.
//!
//! This essentially provides a higher_level API for [memfd_secret()] and [memfd_create()]
// Tests: This uses nix-based integration tests
use std::{os::fd::OwnedFd, sync::OnceLock};
use rustix::{fs::MemfdFlags, io::Errno};
use crate::rustix::{memfd_create, memfd_secret, MemfdSecretError, MemfdSecretFlags};
use crate::{mem::CopyExt, result::OkExt};
/// Cache for [memfd_secret_supported]
static MEMFD_SECRET_SUPPORTED: OnceLock<bool> = OnceLock::new();
/// Check whether support for memfd_secret is available
pub fn memfd_secret_supported() -> Result<bool, Errno> {
match MEMFD_SECRET_SUPPORTED.get() {
Some(v) => return Ok(*v),
_ => {} // Continue
};
use MemfdSecretError as E;
let is_supported = match memfd_secret(MemfdSecretFlags::empty()) {
Ok(_) => true,
Err(E::NotSupported) => false,
Err(E::SystemError(e)) => return Err(e),
};
// We are deliberately using get_or_init here to make sure that the entire application
// never sees different values here
MEMFD_SECRET_SUPPORTED
.get_or_init(|| is_supported)
.copy()
.ok()
}
/// How secure memory file descriptors should be allocated
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
pub enum SecretMemfdPolicy {
/// Use memfd_secret(2) if available, otherwise fall back to less
/// secure options
Opportunistic,
/// Enforce the use of memfd_secret(2)
UseMemfdSecret,
/// Never use memfd_secret(2)
DisableMemfdSecret,
}
impl Default for SecretMemfdPolicy {
fn default() -> Self {
Self::Opportunistic
}
}
impl SecretMemfdPolicy {
/// Create a SecretMemfdPolicy with the default policy
///
/// Currently [Self::Opportunistic]
pub const fn default_const() -> Self {
Self::Opportunistic
}
/// Enforce the use of the highest security configuration available
///
/// This might not work on some systems, which is why this is not used
/// by default.
///
/// Currently [Self::UseMemfdSecret]
pub const fn enforce_high_security() -> Self {
Self::UseMemfdSecret
}
}
/// Which mechanism to us use when allocating secret memory file descriptors
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
pub enum SecretMemfdMechanism {
/// The less secure memfd_create(2) will be used
MemfdCreate,
/// The more secure memfd_secret(2) will be used
MemfdSecret,
}
impl SecretMemfdMechanism {
/// Decide which mechanism to use, based on the given [SecretMemfdPolicy]
/// and [memfd_secret_supported()].
///
/// If [SecretMemfdPolicy::UseMemfdSecret] is used, then [SecretMemfdMechanism::MemfdSecret]
/// will be returned, regardless of whether it is supported.
///
/// Likewise, if [SecretMemfdPolicy::DisableMemfdSecret] is used, then
/// [SecretMemfdMechanism::MemfdCreate] will be used unconditionally.
pub fn decide_with_policy(policy: SecretMemfdPolicy) -> Result<Self, Errno> {
use SecretMemfdMechanism as M;
use SecretMemfdPolicy as P;
match policy {
P::UseMemfdSecret => return Ok(M::MemfdSecret),
P::DisableMemfdSecret => return Ok(M::MemfdCreate),
P::Opportunistic => {}
};
match memfd_secret_supported()? {
true => Ok(M::MemfdSecret),
false => Ok(M::MemfdCreate),
}
}
}
/// Errors for [create_memfd_secret()]
#[derive(Copy, Clone, Debug, thiserror::Error)]
pub enum SecretMemfdWithConfigError {
/// Call to [memfd_secret()] failed
#[error("{:?}", .0)]
MemfdSecretError(#[from] MemfdSecretError),
/// Call to [memfd()] failed
#[error("Could not create secret memory segment using memfd_create(2) due to underlying system error: {:?}", .0)]
MemfdCreateError(Errno),
#[error("Failed to determine whether memfd_secret(2) is supported due to underlying system error: {:?}", .0)]
FailedToDetectSupport(Errno),
}
/// Robustly configure and create secret memory file descriptors
///
/// Whereas [memfd_secret] will always use memfd_secret(2), this construction allows multiple
/// file descriptor back ends to be used to support different usage scenarios.
///
/// This is necessary, because older systems do not support memfd_secret(2) and using it might not
/// always be desirable, as memfd_secret for instance also inhibits hibernation.
#[derive(Default, Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
pub struct SecretMemfdConfig {
/// Enable the close-on-exec flag for the new file descriptor.
pub close_on_exec: bool,
/// Security mechanism to use
pub policy: SecretMemfdPolicy,
}
impl SecretMemfdConfig {
/// Create a new, default [Self]
pub const fn new() -> Self {
let close_on_exec = false;
let policy = SecretMemfdPolicy::default_const();
Self {
close_on_exec,
policy,
}
}
/// Set the [Self::close_on_exec] flag to true
pub const fn cloexec(&self) -> Self {
let mut r = *self;
r.close_on_exec = true;
r
}
/// Set `self.policy = SecretMemoryFdPolicy::UseMemfdSecret`
pub const fn enforce_high_security(&self) -> Self {
let mut r = *self;
r.policy = SecretMemfdPolicy::enforce_high_security();
r
}
/// Whether memfd_secret will be used by [Self::create()]
pub fn mechanism(&self) -> Result<SecretMemfdMechanism, Errno> {
SecretMemfdMechanism::decide_with_policy(self.policy)
}
/// Allocate a secret file descriptor based on the configuration
pub fn create(&self) -> Result<OwnedFd, SecretMemfdWithConfigError> {
use SecretMemfdWithConfigError as E;
let mech = self.mechanism().map_err(E::FailedToDetectSupport)?;
use SecretMemfdMechanism as M;
match (mech, self.close_on_exec) {
(M::MemfdCreate, cloexec) => {
let flags = match cloexec {
true => MemfdFlags::CLOEXEC,
false => MemfdFlags::empty(),
};
memfd_create("rosenpass secret memory segment", flags).map_err(E::MemfdCreateError)
}
(M::MemfdSecret, cloexec) => {
let flags = match cloexec {
true => MemfdSecretFlags::CLOEXEC,
false => MemfdSecretFlags::empty(),
};
memfd_secret(flags).map_err(E::MemfdSecretError)
}
}
}
}
/// Create a secret memory file descriptor using the default policy
///
/// Shorthand for
/// [`SecretMemfdConfig`](SecretMemfdConfig)[`::new()`](SecretMemfdConfig::new)[`.create()`](SecretMemfdConfig::create)
pub fn memfd_for_secrets_with_default_policy() -> Result<OwnedFd, SecretMemfdWithConfigError> {
SecretMemfdConfig::new().create()
}

View File

@@ -0,0 +1,409 @@
//! This module takes care of allocating memory segments for
//! file descriptors created with [super::fd::memfd_for_secrets_with_default_policy]
//! and anonymous memory segments
// Tests: Nix based integration tests
#![deny(unsafe_op_in_unsafe_fn)]
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::ptr::null_mut;
use crate::mem::CopyExt;
/// Size of the memory mapping for [MappableFd]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum MMapSizePolicy {
/// Size is assumed to be this particular value; [MappableFd::mmap] will simply
/// use this value without checking whether its matches the size of the underlying
/// data
Assumed(u64),
/// Size is assumed to be this particular value; [MappableFd::mmap] will check the
/// size of the underlying data and raise an error if the size of data and this value
/// do not match
Checked(u64),
/// Size is defined to be this particular value; [MappableFd::mmap] will explicitly resize
/// the underlying file descriptor to be this particular size when called.
Resize(u64),
}
impl MMapSizePolicy {
/// The numeric value of the size
pub fn size_value(&self) -> u64 {
match *self {
Self::Assumed(v) => v,
Self::Checked(v) => v,
Self::Resize(v) => v,
}
}
}
/// Configuration for [MappableFd::mmap]
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct MapFdConfig {
/// The memory region can not be read from
pub unreadable: bool,
/// The memory region can not be written to
pub immutable: bool,
/// The memory region can be executed
pub executable: bool,
/// The memory region is shared-memory; other mappings of the same
/// region within and outside this process can see the modifications
/// to the memory region (as long as they also set the shared flag)
pub shared: bool,
/// How [MappableFd::mmap] will determine the size to be used fo
///
/// You should usually set this value through [Self::set_size_policy], [Self::assume_size_without_checking],
/// [Self::expected_size], or [Self::resize_on_mmap].
pub size_policy: Option<MMapSizePolicy>,
}
impl MapFdConfig {
/// New MapFdConfig with all settings turned off
///
/// You still must set [Self::size_policy], otherwise [MappableFd::mmap] will raise
/// an error when called
pub const fn new() -> Self {
MapFdConfig {
unreadable: false,
immutable: false,
executable: false,
shared: false,
size_policy: None,
}
}
/// New MappableFdConfig with shared memory turned on
pub const fn shared_memory() -> Self {
Self::new().set_shared()
}
/// Set the [Self::unreadable()] flag
pub const fn set_unreadable(&self) -> Self {
let mut r = *self;
r.unreadable = true;
r
}
/// Set the [Self::immutable()] flag
pub const fn set_immutable(&self) -> Self {
let mut r = *self;
r.immutable = true;
r
}
/// Set the [Self::shared()] flag
pub const fn set_shared(&self) -> Self {
let mut r = *self;
r.shared = true;
r
}
/// Create a [MappableFd] instance with this configuration
pub fn mappable_fd<Fd: AsFd>(&self, fd: Fd) -> MappableFd<Fd> {
MappableFd::new(fd, self.copy())
}
/// Calculate [rustix::mm::ProtFlags] for this configuration
pub const fn mmap_prot(&self) -> rustix::mm::ProtFlags {
use rustix::mm::ProtFlags as P;
let p_read = match self.unreadable {
true => P::empty(),
false => P::READ,
};
let p_write = match self.immutable {
true => P::empty(),
false => P::WRITE,
};
let p_exec = match self.executable {
true => P::EXEC,
false => P::empty(),
};
p_read.union(p_write).union(p_exec)
}
/// Calculate [rustix::mm::MapFlags] for this configuration
pub const fn mmap_flags(&self) -> rustix::mm::MapFlags {
use rustix::mm::MapFlags as M;
match self.shared {
true => M::SHARED,
false => M::empty(),
}
}
/// Set [Self::size_policy] to the given value
pub const fn set_size_policy(&self, size_policy: MMapSizePolicy) -> Self {
let mut r = *self;
r.size_policy = Some(size_policy);
r
}
/// Set [Self::size_policy] to [MMapSizePolicy::Assumed] with the given value
pub const fn assume_size_without_checking(&self, size: u64) -> Self {
self.set_size_policy(MMapSizePolicy::Assumed(size))
}
/// Set [Self::size_policy] to [MMapSizePolicy::Checked] with the given value
pub const fn expected_size(&self, size: u64) -> Self {
self.set_size_policy(MMapSizePolicy::Checked(size))
}
/// Set [Self::size_policy] to [MMapSizePolicy::Resize] with the given value
pub const fn resize_on_mmap(&self, size: u64) -> Self {
self.set_size_policy(MMapSizePolicy::Resize(size))
}
}
/// Error returned by MappableFd::mmap
#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
pub enum MMapError {
/// Error converting between usize & f64
#[error("Requested memory map of size {requested_len} but maximum supported size is {max_supported_len}. \
This is a low level error that usually arises on architectures where the integer type usize ({} bytes) can not \
represent all values in u64 ({} bytes). Are you possibly on 32 bit CPU architecture requesting a buffer bigger than 4 GB?\n\
Error: {err:?}",
(usize::BITS as f64)/8f64, (u64::BITS as f64)/8f64
)]
OutOfBounds {
/// Underlying error
err: <u64 as TryInto<usize>>::Error,
/// The size of the memory map requested
requested_len: u64,
/// Maximum supported size
max_supported_len: usize,
},
/// Tried to map a file descriptor into memory, but the size policy was never set. Developer
/// error.
#[error("Tried to map a file descriptor into memory, but the size policy was never set. This is a developer error.")]
MissingSizePolicy,
/// fseek(3)/ftell(3) system call failed
#[error("Tried to map file descriptor into memory, but failed to determine the size of the file descriptor: {:?}", .0)]
CouldNotDetermineSize(rustix::io::Errno),
/// Mismatch between expected and actual size of the file descriptor
#[error("Tried to map file descriptor into memory with expected size {expected:?}, instead we found that the true size is {actual:?}")]
IncorrectSize {
/// Expected file descriptor size (given by caller)
expected: u64,
/// Actual file descriptor size
actual: u64,
},
/// Negative size reported by fstat(2)
#[error("Tried to determine the size of the underlying file descriptor, but failed to determine the size of the file descriptor because the size ({size}) is negative: {err:?}")]
InvalidSize {
/// Underlying error
err: <i64 as TryInto<u64>>::Error,
/// Reported size of the file descriptor
size: i64,
},
/// ftruncate(2) system call failed
#[error("Tried to resize the underlying file descriptor, but failed: {:?}", .0)]
ResizeError(rustix::io::Errno),
/// mmap(2) system call failed
#[error("Tried to map file descriptor into memory, but mmap(2) system call failed: {:?}", .0)]
MMapError(rustix::io::Errno),
}
/// Handle mapping a file descriptor into memory
pub struct MappableFd<Fd: AsFd> {
/// The file descriptor this struct refers to
fd: Fd,
/// The configuration for mmap
config: MapFdConfig,
}
impl<Fd: AsFd> AsFd for MappableFd<Fd> {
fn as_fd(&self) -> BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl<Fd: AsFd> AsRawFd for MappableFd<Fd> {
fn as_raw_fd(&self) -> RawFd {
self.as_fd().as_raw_fd()
}
}
impl<Fd: AsFd> MappableFd<Fd> {
/// Create a [MappableFd] using the default configuration ([MapFdConfig])
pub fn from_fd(fd: Fd) -> Self {
Self::new(fd, MapFdConfig::default())
}
/// Create a new [MappableFd] for an existing file descriptor
pub fn new(fd: Fd, config: MapFdConfig) -> Self {
Self { fd, config }
}
/// Extract the underlying file descriptor
pub fn into_fd(self) -> Fd {
self.fd
}
/// Access the [MapFdConfig] associated with Self
pub fn config(&self) -> MapFdConfig {
self.config
}
/// Access the [MapFdConfig] associated with Self
pub fn config_mut(&mut self) -> &mut MapFdConfig {
&mut self.config
}
/// Modify the [MapFdConfig] associated with Self, chainable
pub fn with_config(mut self, config: MapFdConfig) -> Self {
self.config = config;
self
}
/// Determine the size of the data associated with the file descriptor
pub fn size_of_underlying_data(&self) -> Result<u64, MMapError> {
use MMapError as E;
let size = rustix::fs::fstat(self)
.map_err(E::CouldNotDetermineSize)?
.st_size;
size.try_into().map_err(|err| E::InvalidSize { err, size })
}
/// Map the file into memory
///
/// # Determining the size of the mapping.
///
/// Before calling this function, [MapFdConfig::size_policy] must be set.
///
/// Note that [Self::from_fd] and [Self::new] still allow you to create a [Self] with
/// [MapFdConfig::size_policy] set to [None], so you can use [Self::size_of_underlying_data]
/// to auto-detect the size of the mapping.
///
/// This functionality is not implemented by default, as its crucial to validate the size of
/// the data being mapped into memory somehow; otherwise, the party that created the file
/// descriptor can trigger a denial-of-service attack against our process by allocating an
/// excessively large, sparse file. If you implement size auto-detection facilities, you should
/// still enforce some bounds on the size.
///
/// # Safety
///
/// If there exist any Rust references referring to the memory region, or if you subsequently create a Rust reference referring to the resulting region, it is your responsibility to ensure that the Rust reference invariants are preserved, including ensuring that the memory is not mutated in a way that a Rust reference would not expect.
pub fn mmap(&self) -> Result<MappedSegment, MMapError> {
use rustix::mm::mmap;
use MMapError as E;
let prot = self.config().mmap_prot();
let flags = self.config().mmap_flags();
// Determine the size of the mapping to be used as u64
let requested_size = match self.config().size_policy {
None => return Err(E::MissingSizePolicy),
Some(MMapSizePolicy::Assumed(size)) => size,
Some(MMapSizePolicy::Resize(size)) => {
rustix::fs::ftruncate(self, size).map_err(E::ResizeError)?;
size
}
Some(MMapSizePolicy::Checked(expected)) => {
let actual = self.size_of_underlying_data()?;
if expected != actual {
return Err(E::IncorrectSize { expected, actual });
}
expected
}
};
// Cast the size of the mapping to be used to usize, raising an error if the
// requested_size can not be represented as usize (this should never happen in general,
// but it could conceivably be thrown on 32 bit systems when very large mappings (>= 4GB)
// are requested
let len = requested_size.try_into().map_err(|err| {
let max_supported_len = usize::MAX;
E::OutOfBounds {
err,
requested_len: requested_size,
max_supported_len,
}
})?;
let ptr = unsafe { mmap(null_mut(), len, prot, flags, self, 0) };
let ptr = ptr.map_err(E::MMapError)?;
let ptr = unsafe { MappedSegment::from_raw_parts(ptr.cast(), len) };
Ok(ptr)
}
}
/// Represents exclusive ownership of a memory segment mapped into memory
///
/// Automatically unmaps the memory segment as this goes out of scope
///
/// # Panic
///
/// If the munmap(2) call fails, the destructor will panic. You can avoid this and use explicit
/// error handling by calling [Self::unmap] instead
#[derive(Debug)]
pub struct MappedSegment {
/// The location of the memory segment
ptr: *mut u8,
/// Length of the segment in bytes
len: usize,
}
unsafe impl Send for MappedSegment {}
impl MappedSegment {
/// Construct a new [Self] from a pointer and a length
///
/// `ptr` is the address of the memory segment and `len` is its length in bytes
///
/// # Safety
///
/// It is the responsibility of the caller to make sure that any pointer P such that `P = ptr.add(n)`,
/// where `n < len` is a valid pointer. See #safety in [std::ptr].
pub unsafe fn from_raw_parts(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len }
}
/// Decompose a MappedSegment into its raw components: Pointer and length
pub fn into_raw_parts(self) -> (*mut u8, usize) {
let r = (self.ptr(), self.len());
std::mem::forget(self);
r
}
/// The location of the memory segment
pub fn ptr(&self) -> *mut u8 {
self.ptr
}
/// Length of the segment in bytes
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.len
}
/// Release the memory segment
///
/// Compared to using the destructor which panics if unmapping fails, this allows explicit error handling to be used.
///
/// If this returns an error, the memory segment has not been freed. The values from
/// [Self::into_raw_parts()] are returned as part of the error value so the caller has
/// some chance to free the memory some other way (or to leak it if they so choose)
pub fn unmap(self) -> Result<(), (rustix::io::Errno, *mut u8, usize)> {
let (ptr, len) = self.into_raw_parts();
let res = unsafe { rustix::mm::munmap(ptr.cast(), len) };
res.map_err(|errno| (errno, ptr, len))
}
}
impl Drop for MappedSegment {
fn drop(&mut self) {
let mut owned = MappedSegment {
ptr: null_mut(),
len: 0,
};
std::mem::swap(self, &mut owned);
if let Err((errno, _ptr, _len)) = owned.unmap() {
panic!("Failed to unmap MappedSegment: {errno:?}")
}
}
}

View File

@@ -0,0 +1,4 @@
//! Utilities for allocating secret memory
pub mod fd;
pub mod mmap;

View File

@@ -0,0 +1,107 @@
//! Traits for types with atomic access semantics
use std::sync::atomic::Ordering;
/// A trait for types with atomic access semantics
///
/// Using this trait allows us to achieve two goals:
///
/// 1. We can implement atomic semantics for types where
/// there is no platform support for atomic semantics
/// (e.g. by using a Mutex)
/// 2. We can reuse implementations of concurrent data structures
/// efficiently for the non-concurrent case. E.g. we could
/// build a thread-local ring buffer using
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeWriter]/
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeReader]
/// by supplying some sort `Immediate<u64>` type for the atomics
/// in [crate::ringbuf::concurrent::framework::ConcurrentPipeCore] that implements
/// this trait by using a [std::cell::Cell]. It may seem counter
/// intuitive, but this setup implements perfectly fine atomic-appearing
/// semantics just as long as the cell is thread-local.
pub trait AbstractAtomic<T> {
/// Like [std::sync::atomic::AtomicU64::load()]
fn load(&self, order: Ordering) -> T;
/// Like [std::sync::atomic::AtomicU64::compare_exchange_weak()].
///
/// The default implementation just calls [AbstractAtomic::compare_exchange()].
fn compare_exchange_weak(
&self,
current: T,
new: T,
success: Ordering,
failure: Ordering,
) -> Result<T, T> {
self.compare_exchange(current, new, success, failure)
}
/// Like [std::sync::atomic::AtomicU64::compare_exchange()].
fn compare_exchange(
&self,
current: T,
new: T,
success: Ordering,
failure: Ordering,
) -> Result<T, T>;
}
/// Implements a default type for [AbstractAtomic]
///
/// This in essence is to [AbstractAtomic], as [std::ops::Deref] is to [std::borrow::Borrow];
/// the same functionality, except with a
pub trait AbstractAtomicType: AbstractAtomic<Self::ValueType> {
/// The underlying atomic value
type ValueType;
}
/// Implementing [AbstractAtomic] and [AbstractAtomicType] for standard atomics
macro_rules! impl_abstract_atomic_for_atomic {
($($Atomic:ty : $Value:ty),*) => {
$(
impl AbstractAtomicType for $Atomic {
type ValueType = $Value;
}
impl AbstractAtomic<$Value> for $Atomic {
fn load(&self, order: Ordering) -> $Value {
<$Atomic>::load(&self, order)
}
fn compare_exchange_weak(
&self,
current: $Value,
new: $Value,
success: Ordering,
failure: Ordering,
) -> Result<$Value, $Value> {
<$Atomic>::compare_exchange_weak(&self, current, new, success, failure)
}
fn compare_exchange(
&self,
current: $Value,
new: $Value,
success: Ordering,
failure: Ordering,
) -> Result<$Value, $Value> {
<$Atomic>::compare_exchange(&self, current, new, success, failure)
}
}
)*
};
}
impl_abstract_atomic_for_atomic! {
std::sync::atomic::AtomicBool: bool,
std::sync::atomic::AtomicI8: i8,
std::sync::atomic::AtomicI16: i16,
std::sync::atomic::AtomicI32: i32,
std::sync::atomic::AtomicI64: i64,
std::sync::atomic::AtomicIsize: isize,
std::sync::atomic::AtomicU8: u8,
std::sync::atomic::AtomicU16: u16,
std::sync::atomic::AtomicU32: u32,
std::sync::atomic::AtomicU64: u64,
std::sync::atomic::AtomicUsize: usize
}

View File

@@ -0,0 +1,5 @@
//! Synchronization using atomic semantics for multi-threaded programs.
//!
//! Analogous to [std::sync::atomic].
pub mod abstract_atomic;

5
util/src/sync/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
//! Synchronization for multi-threaded programs.
//!
//! Analogous to [std::sync].
pub mod atomic;

123
util/tests/pipe.rs Normal file
View File

@@ -0,0 +1,123 @@
use std::{
borrow::Borrow,
sync::{atomic::AtomicU64, Arc},
thread,
};
use rosenpass_util::{
int::u64uint::U64USizeRangeExt,
ipc::shm::{ringbuf::shm_pipe, SharedMemorySegment},
ringbuf::sched::{Diff, OperationType, RingBufferScheduler},
};
macro_rules! dbg_print {
($($arg:tt)*) => {{
use std::io::Write;
let stderr = std::io::stderr();
let mut stderr = stderr.lock();
// writeln!(stderr, $($arg)*).unwrap()
}};
}
#[test]
fn pipe_test() -> anyhow::Result<()> {
let (fd, reg1) = SharedMemorySegment::create(1024)?;
let reg2 = unsafe { SharedMemorySegment::from_fd(fd, 1024) }?;
dbg_print!("Regions {:?} {:?}", reg1.ptr(), reg2.ptr());
let mut buf = b"______________________________________".to_owned();
reg1.volatile_read(&mut buf, 0);
dbg_print!(
"Region 1 read: `{:?}` `{:?}`",
String::from_utf8_lossy(&buf),
&buf
);
let mut buf = b"______________________________________".to_owned();
reg2.volatile_read(&mut buf, 0);
dbg_print!(
"Region 1 read: `{:?}` `{:?}`",
String::from_utf8_lossy(&buf),
&buf
);
dbg_print!("Write to region 1");
reg1.volatile_write(0, b"Hello World");
let mut buf = b"______________________________________".to_owned();
reg1.volatile_read(&mut buf, 0);
dbg_print!(
"Region 1 read: `{:?}` `{:?}`",
String::from_utf8_lossy(&buf),
&buf
);
let mut buf = b"______________________________________".to_owned();
reg2.volatile_read(&mut buf, 0);
dbg_print!(
"Region 1 read: `{:?}` `{:?}`",
String::from_utf8_lossy(&buf),
&buf
);
let (mut writer, mut reader) = shm_pipe(1024)?;
const MSG: &[u8] = b"Hello World\0";
const MSG_COUNT: usize = 100000;
let t = thread::spawn(move || {
for _ in 0..MSG_COUNT {
let mut buf = MSG;
while !buf.is_empty() {
let n = writer.write(buf).unwrap();
buf = &buf[n..];
}
}
});
let mut buf = [0u8; 1000];
let mut buf_off = 0;
let mut msg_no = 0usize;
'read_data: while msg_no < MSG_COUNT {
let mut old_off = buf_off;
// Read the data from the shared memory buffer
buf_off += reader.read(&mut buf[buf_off..])?;
'scan_again: loop {
// Scan the available data for the zero terminator
let msg_len = &buf[old_off..buf_off]
.iter()
.copied()
.enumerate()
.find(|(_off, c)| *c == 0x0)
.map(|(off, _c)| off + old_off + 1);
// Next iteration, unless the terminator was found
let msg_len = match *msg_len {
Some(l) => l,
None => continue 'read_data,
};
// Register the newly read message
msg_no += 1;
// Check that the message is correctly transferred
let msg = &buf[0..msg_len];
dbg_print!("CONT {:?}", &buf[..buf_off]);
dbg_print!("RECV {msg:?}");
assert_eq!(msg, MSG);
// Move any extra data to the beginning of the buffer and adjust the offsets accordingly
buf.copy_within(msg_len..buf_off, 0);
old_off = 0;
buf_off -= msg_len;
}
}
t.join().unwrap();
Ok(())
}

View File

@@ -12,7 +12,7 @@ use tokio::task;
use anyhow::{bail, ensure, Result};
use clap::{ArgGroup, Parser};
use rosenpass_util::fd::claim_fd;
use rosenpass_util::rustix::claim_fd;
use rosenpass_wireguard_broker::api::msgs;
/// Command-line arguments for configuring the socket handler