mirror of
https://github.com/rosenpass/rosenpass.git
synced 2025-12-05 20:40:02 -08:00
Compare commits
53 Commits
dependabot
...
3e111fa7ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e111fa7ad | ||
|
|
5ec845f5d0 | ||
|
|
4736c40d84 | ||
|
|
a8ef9dc3e5 | ||
|
|
4a0e34b1fc | ||
|
|
b6229b8d33 | ||
|
|
ecc17dea44 | ||
|
|
b499a9ba5b | ||
|
|
91382e189e | ||
|
|
c9db8cfec7 | ||
|
|
52b903c8c0 | ||
|
|
7f464de421 | ||
|
|
d856116a44 | ||
|
|
ea708bca90 | ||
|
|
f1746bd067 | ||
|
|
afea7d0a2e | ||
|
|
62c974f636 | ||
|
|
62e337c6a1 | ||
|
|
aebfdfa966 | ||
|
|
f0a43932c1 | ||
|
|
e98ee3fb6c | ||
|
|
3b8c082365 | ||
|
|
adda409edc | ||
|
|
e0fa817988 | ||
|
|
777e94eb64 | ||
|
|
61e30006ec | ||
|
|
daabb8c9ee | ||
|
|
cc2656bfaa | ||
|
|
2a2d78b448 | ||
|
|
d00f91f7dd | ||
|
|
ba3d22829e | ||
|
|
9e94759dd2 | ||
|
|
13ff005d27 | ||
|
|
5b1ee307a4 | ||
|
|
64205d4342 | ||
|
|
8d7d4ab643 | ||
|
|
9d10d11453 | ||
|
|
e15aa94b0c | ||
|
|
0b4bfe4d3e | ||
|
|
32427ad049 | ||
|
|
8f847f59e3 | ||
|
|
d900bb3875 | ||
|
|
30e30b6dfd | ||
|
|
a40e654d08 | ||
|
|
38e997554e | ||
|
|
adbe0bd431 | ||
|
|
e0976986fd | ||
|
|
d932144451 | ||
|
|
fe90220748 | ||
|
|
8d57dde61a | ||
|
|
1a51478e89 | ||
|
|
5b14ef8065 | ||
|
|
a796bdd2e7 |
34
Cargo.lock
generated
34
Cargo.lock
generated
@@ -212,7 +212,7 @@ version = "0.68.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "726e4313eb6ec35d2730258ad4e15b547ee75d6afaa1361a922e78e59b7d8078"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"lazy_static",
|
||||
@@ -237,9 +237,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.8.0"
|
||||
version = "2.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
|
||||
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
|
||||
|
||||
[[package]]
|
||||
name = "blake2"
|
||||
@@ -818,9 +818,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.10"
|
||||
version = "0.3.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
|
||||
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
@@ -1193,7 +1193,7 @@ version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
@@ -1453,7 +1453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.48.5",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1673,7 +1673,7 @@ version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
@@ -2032,7 +2032,7 @@ version = "0.5.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2257,13 +2257,19 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64ct",
|
||||
"bitflags 2.9.3",
|
||||
"errno",
|
||||
"libc",
|
||||
"libcrux-test-utils",
|
||||
"log",
|
||||
"mio",
|
||||
"num-traits",
|
||||
"rosenpass-to",
|
||||
"rustix",
|
||||
"static_assertions",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
"tinyvec",
|
||||
"tokio",
|
||||
"typenum",
|
||||
"uds",
|
||||
@@ -2340,7 +2346,7 @@ version = "0.38.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
@@ -2710,6 +2716,12 @@ dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.47.0"
|
||||
@@ -3276,7 +3288,7 @@ version = "0.33.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -80,6 +80,7 @@ hex-literal = { version = "0.4.1" }
|
||||
hex = { version = "0.4.3" }
|
||||
heck = { version = "0.5.0" }
|
||||
libc = { version = "0.2" }
|
||||
errno = { version = "0.3.13" }
|
||||
uds = { git = "https://github.com/rosenpass/uds" }
|
||||
lazy_static = "1.5"
|
||||
|
||||
@@ -95,6 +96,7 @@ criterion = "0.5.1"
|
||||
allocator-api2-tests = "0.2.15"
|
||||
procspawn = { version = "1.0.1", features = ["test-support"] }
|
||||
serde_json = { version = "1.0.140" }
|
||||
bitflags = "2.9.3"
|
||||
|
||||
#Broker dependencies (might need cleanup or changes)
|
||||
wireguard-uapi = { version = "3.0.0", features = ["xplatform"] }
|
||||
|
||||
@@ -6,12 +6,12 @@ use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd};
|
||||
use anyhow::Context;
|
||||
use rosenpass_to::{ops::copy_slice, To};
|
||||
use rosenpass_util::{
|
||||
fd::FdIo,
|
||||
functional::{run, ApplyExt},
|
||||
io::ReadExt,
|
||||
mem::DiscardResultExt,
|
||||
mio::UnixStreamExt,
|
||||
result::OkExt,
|
||||
rustix::FdIo,
|
||||
};
|
||||
use rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient;
|
||||
|
||||
@@ -243,7 +243,7 @@ where
|
||||
.context("Invalid request – socket missing.")?;
|
||||
// TODO: We need to have this outside linux
|
||||
#[cfg(target_os = "linux")]
|
||||
rosenpass_util::fd::GetSocketProtocol::demand_udp_socket(&sock)?;
|
||||
rosenpass_util::rustix::GetSocketProtocol::demand_udp_socket(&sock)?;
|
||||
let sock = std::net::UdpSocket::from(sock);
|
||||
sock.set_nonblocking(true)?;
|
||||
mio::net::UdpSocket::from_std(sock).ok()
|
||||
|
||||
@@ -26,7 +26,7 @@ use {
|
||||
command_fds::{CommandFdExt, FdMapping},
|
||||
log::{error, info},
|
||||
mio::net::UnixStream,
|
||||
rosenpass_util::fd::claim_fd,
|
||||
rosenpass_util::rustix::claim_fd,
|
||||
rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient,
|
||||
rosenpass_wireguard_broker::WireguardBrokerMio,
|
||||
rustix::net::{socketpair, AddressFamily, SocketFlags, SocketType},
|
||||
|
||||
@@ -13,11 +13,14 @@ rust-version = "1.77.0"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
rosenpass-to = { workspace = true }
|
||||
base64ct = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
typenum = { workspace = true }
|
||||
static_assertions = { workspace = true }
|
||||
rustix = { workspace = true }
|
||||
rustix = { workspace = true, features = ["net", "fs", "process", "mm"] }
|
||||
libc = { workspace = true }
|
||||
errno = { workspace = true }
|
||||
zeroize = { workspace = true }
|
||||
zerocopy = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
@@ -32,6 +35,9 @@ tokio = { workspace = true, optional = true, features = [
|
||||
"time",
|
||||
] }
|
||||
log = { workspace = true }
|
||||
bitflags = { workspace = true }
|
||||
tinyvec = "1.10.0"
|
||||
num-traits = "0.2.19"
|
||||
|
||||
[features]
|
||||
experiment_file_descriptor_passing = ["uds"]
|
||||
|
||||
85
util/src/convert.rs
Normal file
85
util/src/convert.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Additional helpers to [std::convert]: Traits for conversions between types.
|
||||
|
||||
/// Variant of [std::convert::Into] with an explicitly specified type.
|
||||
///
|
||||
/// This facilitates method chaining.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// We can create a nicer implementation of the following function using this extension trait
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::convert::IntoTypeExt;
|
||||
///
|
||||
/// fn encode_char_u32_be_1(c: char) -> [u8; 4] {
|
||||
/// let i : u32 = c.into();
|
||||
/// i.to_be_bytes()
|
||||
/// }
|
||||
///
|
||||
/// fn encode_char_u32_be_2(c: char) -> [u8; 4] {
|
||||
/// c.into_type::<u32>().to_be_bytes()
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(encode_char_u32_be_1('X'), [0x00, 0x00, 0x00, 0x58]);
|
||||
/// assert_eq!(encode_char_u32_be_1('X'), encode_char_u32_be_2('X'));
|
||||
///
|
||||
/// ```
|
||||
pub trait IntoTypeExt {
|
||||
/// Variant of [std::convert::Into] with explicitly specified type.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [IntoType].
|
||||
fn into_type<T>(self) -> T
|
||||
where
|
||||
Self: Into<T>,
|
||||
{
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoTypeExt for T {}
|
||||
|
||||
/// Variant of [std::convert::TryInto] with an explicitly specified type.
|
||||
///
|
||||
/// This facilitates method chaining.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// We can create a nicer implementation of the following function using this extension trait
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::convert::TryIntoTypeExt;
|
||||
/// use rosenpass_util::result::OkExt;
|
||||
///
|
||||
/// fn encode_char_u16_be_1(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
|
||||
/// let i : u16 = c.try_into()?;
|
||||
/// Ok(i.to_be_bytes())
|
||||
/// }
|
||||
///
|
||||
/// fn encode_char_u16_be_2(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
|
||||
/// c.try_into_type::<u16>()?.to_be_bytes().ok()
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(encode_char_u16_be_1('X'), Ok([0x00, 0x58]));
|
||||
/// assert_eq!(encode_char_u16_be_1('X'), encode_char_u16_be_2('X'));
|
||||
///
|
||||
/// const HEART_CAT : char = '😻'; // 1F63B
|
||||
/// assert!(encode_char_u16_be_1(HEART_CAT).is_err());
|
||||
/// assert_eq!(encode_char_u16_be_1(HEART_CAT), encode_char_u16_be_2(HEART_CAT));
|
||||
/// ```
|
||||
pub trait TryIntoTypeExt {
|
||||
/// Variant of [std::convert::TryInto] with explicitly specified type.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [TryIntoType].
|
||||
fn try_into_type<T>(self) -> Result<T, <Self as TryInto<T>>::Error>
|
||||
where
|
||||
Self: TryInto<T>,
|
||||
{
|
||||
self.try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TryIntoTypeExt for T {}
|
||||
4
util/src/int/mod.rs
Normal file
4
util/src/int/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Helpers for working with integer types
|
||||
|
||||
pub mod modular;
|
||||
pub mod u64uint;
|
||||
275
util/src/int/modular.rs
Normal file
275
util/src/int/modular.rs
Normal file
@@ -0,0 +1,275 @@
|
||||
//! Numeric types for modular arithmetic
|
||||
|
||||
use num_traits::{
|
||||
ops::overflowing::OverflowingAdd, CheckedMul, Euclid, Num, Unsigned, WrappingNeg, Zero,
|
||||
};
|
||||
|
||||
/// Summary-trait for numeric types that can serve as the basis for ModuleBase
|
||||
pub trait ModuleBase: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
|
||||
impl<T> ModuleBase for T where T: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
|
||||
|
||||
/// Represents a modulus; i.e. the range of values some number type is allowed to use.
|
||||
///
|
||||
/// This is based on some inner representation.
|
||||
///
|
||||
/// This is not just a value of the underlying representation, because it also supports the modulus
|
||||
/// [Self::new_full_range()], which indicates that the full range of the underlying type is to be
|
||||
/// supported.
|
||||
///
|
||||
/// Note that zero is not a valid modulus
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub struct Modulus<T: ModuleBase> {
|
||||
/// Inner representation of the type
|
||||
modulus: T,
|
||||
}
|
||||
|
||||
impl<T: ModuleBase> Modulus<T> {
|
||||
/// Create a new [Self] without any checks
|
||||
fn raw_new(modulus: T) -> Self {
|
||||
Self { modulus }
|
||||
}
|
||||
|
||||
/// Create a new [Self] that indicates that the full range of the underlying type is to be used
|
||||
pub fn new_full_range() -> Self {
|
||||
Self::raw_new(T::zero())
|
||||
}
|
||||
|
||||
/// Try to create a new [Self]. Will return None only if `modulus == 0`
|
||||
pub fn try_new(modulus: T) -> Option<Self> {
|
||||
match modulus == T::zero() {
|
||||
true => None,
|
||||
false => Some(Self::raw_new(modulus)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Like [Self::try_new] but will panic if `modulus == 0`
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// Will panic if `modulus == 0`
|
||||
pub fn new_or_panic(modulus: T) -> Self {
|
||||
match Self::try_new(modulus) {
|
||||
None => panic!("Can not create Module with modulus zero!"),
|
||||
Some(me) => me,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this [Self] represents the full range of the underlying type
|
||||
pub fn is_full_range(&self) -> bool {
|
||||
self.modulus == T::zero()
|
||||
}
|
||||
|
||||
/// Get the raw modulus. I.e. the value of the underlying type that can be given to
|
||||
/// a modulo operation to implement modular arithmetic.
|
||||
///
|
||||
/// Will return None if [Self::is_full_range].
|
||||
pub fn modulus(&self) -> Option<T> {
|
||||
match self.is_full_range() {
|
||||
true => None,
|
||||
false => Some(self.modulus),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the given type is contained in the range represented by this [Self]
|
||||
pub fn contains(&self, v: T) -> bool {
|
||||
match self.is_full_range() {
|
||||
true => true,
|
||||
false => v < self.modulus,
|
||||
}
|
||||
}
|
||||
/// Double the modulus.
|
||||
///
|
||||
/// Correctly handles the case that `v.double().is_full_range()`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use rosenpass_util::int::modular::Modulus;
|
||||
///
|
||||
/// fn m(v: u8) -> Modulus<u8> {
|
||||
/// Modulus::new_or_panic(v)
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(m(100).double(), m(200));
|
||||
/// assert_eq!(m(128).double(), Modulus::new_full_range());
|
||||
/// ```
|
||||
pub fn double(&self) -> Option<Self> {
|
||||
let s = self.modulus()?;
|
||||
match s.overflowing_add(&s) {
|
||||
(d, true) if d > T::zero() => None,
|
||||
(d, _) => Some(Self::raw_new(d)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new [ModularArithmetic] by taking the value modulo the modulus
|
||||
pub fn new_number<U: Into<T>>(self, value: U) -> ModularArithmetic<T>
|
||||
where
|
||||
T: ModularArithmeticBase,
|
||||
{
|
||||
ModularArithmetic::modular_new(value.into(), self)
|
||||
}
|
||||
|
||||
/// Apply [Self::new_number] to each of the parameters, return whatever result the closure
|
||||
/// produces
|
||||
pub fn with_converted<U, const N: usize, R, F>(&self, params: [U; N], f: F) -> R
|
||||
where
|
||||
Self: Copy,
|
||||
T: std::fmt::Debug + ModularArithmeticBase,
|
||||
U: Into<T>,
|
||||
F: FnOnce([ModularArithmetic<T>; N]) -> R,
|
||||
{
|
||||
let params = params.map(|v| self.new_number(v));
|
||||
f(params)
|
||||
}
|
||||
|
||||
/// Apply [Self::new_number] to each of the parameters, converting the result to the underlying
|
||||
/// representation
|
||||
pub fn formula<U, const N: usize, F>(&self, params: [U; N], f: F) -> T
|
||||
where
|
||||
Self: Copy,
|
||||
T: std::fmt::Debug + ModularArithmeticBase,
|
||||
U: Into<T>,
|
||||
F: FnOnce([ModularArithmetic<T>; N]) -> ModularArithmetic<T>,
|
||||
{
|
||||
self.with_converted(params, f).value()
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary trait for types that can serve as the basis for [ModularArithmetic]
|
||||
pub trait ModularArithmeticBase:
|
||||
ModuleBase
|
||||
+ std::fmt::Debug
|
||||
+ Num
|
||||
+ PartialOrd
|
||||
+ Ord
|
||||
+ Copy
|
||||
+ Unsigned
|
||||
+ CheckedMul
|
||||
+ Euclid
|
||||
+ WrappingNeg
|
||||
{
|
||||
}
|
||||
impl<T> ModularArithmeticBase for T where
|
||||
T: ModuleBase
|
||||
+ std::fmt::Debug
|
||||
+ Num
|
||||
+ PartialOrd
|
||||
+ Ord
|
||||
+ Copy
|
||||
+ Unsigned
|
||||
+ CheckedMul
|
||||
+ Euclid
|
||||
+ WrappingNeg
|
||||
{
|
||||
}
|
||||
|
||||
/// Modular arithmetic with an arbitrary modulus
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct ModularArithmetic<T: ModularArithmeticBase> {
|
||||
/// The modulus
|
||||
modulus: Modulus<T>,
|
||||
/// The value inside the modulus
|
||||
///
|
||||
/// Note that `self.modulus.contains(self.value)` must always hold.
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> ModularArithmetic<T> {
|
||||
/// Construct a new [Self].
|
||||
///
|
||||
/// Will return none unless `module.`[contains](Modulus::contains)`(value)`.
|
||||
pub fn try_new(value: T, module: Modulus<T>) -> Option<Self> {
|
||||
module.contains(value).then_some(Self {
|
||||
value,
|
||||
modulus: module,
|
||||
})
|
||||
}
|
||||
|
||||
/// Construct a new [Self].
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// Will punic unless `module.`[contains](Modulus::contains)`(value)`.
|
||||
pub fn modular_new(value: T, module: Modulus<T>) -> Self {
|
||||
let value = match module.modulus() {
|
||||
Some(m) => value.rem_euclid(&m),
|
||||
None => value,
|
||||
};
|
||||
|
||||
Self {
|
||||
modulus: module,
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the modulus
|
||||
pub fn modulus(&self) -> &Modulus<T> {
|
||||
&self.modulus
|
||||
}
|
||||
|
||||
/// The inner value
|
||||
pub fn value(&self) -> T {
|
||||
self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> PartialEq for ModularArithmetic<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.value == other.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> Eq for ModularArithmetic<T> {}
|
||||
|
||||
impl<T: ModularArithmeticBase> PartialOrd for ModularArithmetic<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> Ord for ModularArithmetic<T> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.value.cmp(&other.value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> std::ops::Neg for ModularArithmetic<T> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let ret = match self.modulus().modulus() {
|
||||
None => self.value().wrapping_neg(),
|
||||
Some(modulus) => modulus - self.value(),
|
||||
};
|
||||
|
||||
Self {
|
||||
value: ret,
|
||||
modulus: self.modulus,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> std::ops::Sub for ModularArithmetic<T> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
assert_eq!(self.modulus(), rhs.modulus());
|
||||
|
||||
if self < rhs {
|
||||
return -(rhs - self);
|
||||
}
|
||||
|
||||
Self {
|
||||
value: self.value() - rhs.value(),
|
||||
modulus: self.modulus,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> std::ops::Add for ModularArithmetic<T> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
self - (-rhs)
|
||||
}
|
||||
}
|
||||
12
util/src/int/u64uint/constants.rs
Normal file
12
util/src/int/u64uint/constants.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Constants for u64 <-> usize conversion
|
||||
|
||||
/// Largest numeric value that can be safely represented both as
|
||||
/// a u64 and usize
|
||||
pub const MAX_USIZE_IN_U64: usize = match u64::BITS >= usize::BITS {
|
||||
true => usize::MAX,
|
||||
false => u64::MAX as usize,
|
||||
};
|
||||
|
||||
/// Largest numeric value that can be safely represented both as
|
||||
/// a u64 and usize
|
||||
pub const MAX_U64_IN_USIZE: u64 = MAX_USIZE_IN_U64 as u64;
|
||||
21
util/src/int/u64uint/mod.rs
Normal file
21
util/src/int/u64uint/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
//! Making sure size representations fit both into [usize] and [u64]
|
||||
|
||||
use static_assertions::const_assert;
|
||||
|
||||
mod constants;
|
||||
pub use constants::*;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
mod u64uint;
|
||||
pub use u64uint::*;
|
||||
|
||||
mod range;
|
||||
pub use range::*;
|
||||
|
||||
/// Safe conversion from usize to u64
|
||||
///
|
||||
/// TODO: Deprecate this
|
||||
pub const fn usize_to_u64(v: usize) -> u64 {
|
||||
const_assert!(u64::BITS >= usize::BITS);
|
||||
v as u64
|
||||
}
|
||||
29
util/src/int/u64uint/range.rs
Normal file
29
util/src/int/u64uint/range.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
//! Working with ranges of [super::U64Uint]
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
use super::U64USize;
|
||||
|
||||
/// Extensions for working with [std::ops::Range] of [U64USize]
|
||||
pub trait U64USizeRangeExt {
|
||||
/// Convert to a usize based range
|
||||
fn usize(self) -> Range<usize>;
|
||||
/// Convert to a u64 based range
|
||||
fn u64(self) -> Range<u64>;
|
||||
}
|
||||
|
||||
impl U64USizeRangeExt for Range<U64USize> {
|
||||
fn usize(self) -> Range<usize> {
|
||||
Range {
|
||||
start: self.start.usize(),
|
||||
end: self.end.usize(),
|
||||
}
|
||||
}
|
||||
|
||||
fn u64(self) -> Range<u64> {
|
||||
Range {
|
||||
start: self.start.u64(),
|
||||
end: self.end.u64(),
|
||||
}
|
||||
}
|
||||
}
|
||||
192
util/src/int/u64uint/u64uint.rs
Normal file
192
util/src/int/u64uint/u64uint.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! The [U64UInt type]
|
||||
|
||||
use super::MAX_U64_IN_USIZE;
|
||||
|
||||
/// Error produced by [U64USize::try_new]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum U64USizeConversionError<T: std::fmt::Debug> {
|
||||
/// Value can not be represented as u64
|
||||
#[error("Value can not be represented as a u64 (max = {}) value: {:?}", u64::MAX, .0)]
|
||||
NoU64Repr(T),
|
||||
/// Value can not be represented as usize
|
||||
#[error("Value can not be represented as a usize (max = {}) value: {:?}", usize::MAX, .0)]
|
||||
NoUSizeRepr(T),
|
||||
/// Value can not be represented as usize or u64
|
||||
#[error("Value can not be represented as a u64 (max = {}) or usize (max = {}) value: {:?}", u64::MAX, usize::MAX, .0)]
|
||||
NoU64OrUSizeRepr(T),
|
||||
}
|
||||
|
||||
/// A number that can be represented as both a usize and a u64.
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct U64USize {
|
||||
/// Enclosed data
|
||||
storage: u64,
|
||||
}
|
||||
|
||||
impl U64USize {
|
||||
/// Cast another number to a number that can be represented as usize and u64;
|
||||
///
|
||||
/// This is the internal version which does not use [TryInto]; we use this to implement
|
||||
/// [TryInto].
|
||||
fn try_new_internal<T>(v: T) -> Result<Self, U64USizeConversionError<T>>
|
||||
where
|
||||
T: Copy + TryInto<usize> + TryInto<u64> + std::fmt::Debug,
|
||||
{
|
||||
use U64USizeConversionError as E;
|
||||
|
||||
let v_u64: Result<u64, _> = v.try_into();
|
||||
let v_usize: Result<usize, _> = v.try_into();
|
||||
match (v_u64, v_usize) {
|
||||
(Ok(storage), Ok(_)) => Ok(Self { storage }),
|
||||
(Err(_), Ok(_)) => Err(E::NoU64Repr(v)),
|
||||
(Ok(_), Err(_)) => Err(E::NoUSizeRepr(v)),
|
||||
(Err(_), Err(_)) => Err(E::NoU64OrUSizeRepr(v)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast another number to a number that can be represented as usize and u64
|
||||
pub fn try_new<T>(v: T) -> Result<U64USize, <T as TryInto<Self>>::Error>
|
||||
where
|
||||
T: TryInto<Self>,
|
||||
{
|
||||
v.try_into()
|
||||
}
|
||||
|
||||
/// Like [Self::try_new], but panics
|
||||
pub fn new_or_panic<T>(v: T) -> Self
|
||||
where
|
||||
T: TryInto<Self>,
|
||||
<T as TryInto<Self>>::Error: std::fmt::Debug,
|
||||
{
|
||||
match Self::try_new(v) {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!(
|
||||
"Could not construct {}: {e:?}",
|
||||
std::any::type_name::<Self>()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return this value as a usize
|
||||
pub fn usize(&self) -> usize {
|
||||
self.storage as usize
|
||||
}
|
||||
|
||||
/// Return this value as a u64
|
||||
pub fn u64(&self) -> u64 {
|
||||
self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub for U64USize {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self::new_or_panic(self.u64() - rhs.u64())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add for U64USize {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self::new_or_panic(self.u64() + rhs.u64())
|
||||
}
|
||||
}
|
||||
|
||||
/// Facilitates creation of [U64USize] in cases where truncation to the largest representable value is permissible
|
||||
pub trait TruncateIntoU64USize {
|
||||
/// Check whether calling [TruncateIntoU64Usize::truncate_to_u64usize] would truncate or return
|
||||
/// the value as-is
|
||||
fn fits_into_u64usize(&self) -> bool;
|
||||
|
||||
/// Turn [Self] into a [U64USize]. If the value is representable as a usize and a u64, then
|
||||
/// the value will be returned as is. Otherwise, the maximum representable value [MAX_U64_IN_USIZE]
|
||||
/// will be returned.
|
||||
fn truncate_to_u64usize(&self) -> U64USize;
|
||||
}
|
||||
|
||||
/// Create instances of From for SafeUSize
|
||||
macro_rules! derive_TruncateIntoU64Usize {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl TruncateIntoU64USize for $T {
|
||||
fn fits_into_u64usize(&self) -> bool {
|
||||
U64USize::try_new(*self).is_ok()
|
||||
}
|
||||
|
||||
fn truncate_to_u64usize(&self) -> U64USize {
|
||||
U64USize::try_new(*self).unwrap_or(U64USize::new_or_panic(MAX_U64_IN_USIZE))
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
derive_TruncateIntoU64Usize!(
|
||||
U64USize, usize, isize, bool, u8, u16, u32, u64, u128, i8, i16, i32, i64, i128
|
||||
);
|
||||
|
||||
/// Create instances of From for SafeUSize
|
||||
macro_rules! U64USize_derive_from {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl From<$T> for U64USize {
|
||||
fn from(value: $T) -> Self {
|
||||
U64USize::new_or_panic::<$T>(value)
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_from!(bool, u8, u16);
|
||||
|
||||
/// Create instances of TryFrom for SafeUSize
|
||||
macro_rules! U64USize_derive_try_from {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl TryFrom<$T> for U64USize {
|
||||
type Error = U64USizeConversionError<$T>;
|
||||
|
||||
fn try_from(value: $T) -> Result<Self, Self::Error> {
|
||||
U64USize::try_new_internal::<$T>(value)
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_try_from!(usize, isize, u32, u64, u128, i8, i16, i32, i64, i128);
|
||||
|
||||
/// Create instances of Into for SafeUSize
|
||||
macro_rules! U64USize_derive_into {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl From<U64USize> for $T {
|
||||
fn from(val: U64USize) -> Self {
|
||||
val.u64() as $T
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_into!(usize, u64, u128, i128);
|
||||
|
||||
/// Create instances of TryInto for SafeUSize
|
||||
macro_rules! U64USize_derive_try_into {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl TryFrom<U64USize> for $T {
|
||||
type Error = <$T as TryFrom<u64>>::Error;
|
||||
|
||||
fn try_from(val: U64USize) -> Result<Self, Self::Error> {
|
||||
val.u64().try_into()
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_try_into!(isize, u8, u16, u32, i8, i16, i32, i64);
|
||||
3
util/src/ipc/mod.rs
Normal file
3
util/src/ipc/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Inter-process communication related resources
|
||||
|
||||
pub mod shm;
|
||||
6
util/src/ipc/shm/mod.rs
Normal file
6
util/src/ipc/shm/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
//! Resources for working with shared-memory
|
||||
|
||||
mod shared_memory_segment;
|
||||
pub use shared_memory_segment::*;
|
||||
|
||||
pub mod ringbuf;
|
||||
40
util/src/ipc/shm/ringbuf/local.rs
Normal file
40
util/src/ipc/shm/ringbuf/local.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
//! Local shared memory ring buffers – mostly for testing
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::ipc::shm::SharedMemorySegment;
|
||||
use crate::ringbuf::concurrent::framework::{ConcurrentPipeReader, ConcurrentPipeWriter};
|
||||
|
||||
use super::{ShmPipeCore, ShmPipeVariables};
|
||||
|
||||
/// A process-local shared memory pipe reader
|
||||
///
|
||||
/// See [shm_pipe()].
|
||||
pub type LocalShmPipeReader = ConcurrentPipeReader<ShmPipeCore<Arc<ShmPipeVariables>>>;
|
||||
|
||||
/// A process-local shared memory pipe writer
|
||||
///
|
||||
/// See [shm_pipe()].
|
||||
pub type LocalShmPipeWriter = ConcurrentPipeWriter<ShmPipeCore<Arc<ShmPipeVariables>>>;
|
||||
|
||||
/// Creates a process-local shared-memory pipe.
|
||||
///
|
||||
/// See [ConcurrentPipeWriter]/[ConcurrentPipeReader]. The types [LocalShmPipeWriter]/[LocalShmPipeReader] are just aliases for these.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Mind the comments in the safety section of [super::super::SharedMemorySegment]; the issues
|
||||
/// described in there *exactly* affect this implementation.
|
||||
pub fn shm_pipe(len: usize) -> anyhow::Result<(LocalShmPipeWriter, LocalShmPipeReader)> {
|
||||
let shared = Arc::new(ShmPipeVariables::new());
|
||||
let (seg_fd, seg_buf_1) = SharedMemorySegment::create(len)?;
|
||||
let seg_buf_2 = unsafe { SharedMemorySegment::from_fd(seg_fd, len) }?;
|
||||
|
||||
let writer = ShmPipeCore::new(shared.clone(), seg_buf_1);
|
||||
let reader = ShmPipeCore::new(shared, seg_buf_2);
|
||||
|
||||
Ok((
|
||||
ConcurrentPipeWriter::from_core(writer),
|
||||
ConcurrentPipeReader::from_core(reader),
|
||||
))
|
||||
}
|
||||
78
util/src/ipc/shm/ringbuf/main.rs
Normal file
78
util/src/ipc/shm/ringbuf/main.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
//! Shared-memory ring buffer implementations (main part of the enclosing module)
|
||||
|
||||
use std::{borrow::Borrow, sync::atomic::AtomicU64};
|
||||
|
||||
use crate::{
|
||||
ipc::shm::SharedMemorySegment,
|
||||
ringbuf::concurrent::framework::{
|
||||
ConcurrentPipeCore, ConcurrentPipeReader, ConcurrentPipeWriter,
|
||||
},
|
||||
};
|
||||
|
||||
/// Synchronization variables for a [ShmPipeWriter]/[ShmPipeReader].
|
||||
///
|
||||
/// These values must be shared between the reader/writer in such a way that access to the inner
|
||||
/// variables is synchronized and atomic between the two parties.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ShmPipeVariables {
|
||||
/// See [crate::ringbuf::sched::RingBufferScheduler::items_read()]
|
||||
pub items_read: AtomicU64,
|
||||
/// See [crate::ringbuf::sched::RingBufferScheduler::items_written()]
|
||||
pub items_written: AtomicU64,
|
||||
}
|
||||
|
||||
impl ShmPipeVariables {
|
||||
/// Constructor
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items_read: 0.into(),
|
||||
items_written: 0.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The [ConcurrentPipeCore] for a shared memory pipe.
|
||||
#[derive(Debug)]
|
||||
pub struct ShmPipeCore<Variables: Borrow<ShmPipeVariables>> {
|
||||
/// The synchronization variables
|
||||
variables: Variables,
|
||||
/// The memory buffer
|
||||
buf: SharedMemorySegment,
|
||||
}
|
||||
|
||||
impl<Variables: Borrow<ShmPipeVariables>> ShmPipeCore<Variables> {
|
||||
/// Constructor
|
||||
pub fn new(variables: Variables, buf: SharedMemorySegment) -> Self {
|
||||
Self { variables, buf }
|
||||
}
|
||||
}
|
||||
|
||||
/// A shared memory pipe reader
|
||||
pub type ShmPipeReader<Variables> = ConcurrentPipeReader<ShmPipeCore<Variables>>;
|
||||
|
||||
/// A shared memory pipe writer
|
||||
pub type ShmPipeWriter<Variables> = ConcurrentPipeWriter<ShmPipeCore<Variables>>;
|
||||
|
||||
impl<Variables: Borrow<ShmPipeVariables>> ConcurrentPipeCore for ShmPipeCore<Variables> {
|
||||
type AtomicType = AtomicU64;
|
||||
|
||||
fn buf_len(&self) -> u64 {
|
||||
self.buf.len() as u64
|
||||
}
|
||||
|
||||
fn items_read(&self) -> &AtomicU64 {
|
||||
&self.variables.borrow().items_read
|
||||
}
|
||||
|
||||
fn items_written(&self) -> &AtomicU64 {
|
||||
&self.variables.borrow().items_written
|
||||
}
|
||||
|
||||
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64) {
|
||||
self.buf.volatile_read(dst, off as usize)
|
||||
}
|
||||
|
||||
fn write_to_buffer(&mut self, off: u64, src: &[u8]) {
|
||||
self.buf.volatile_write(off as usize, src);
|
||||
}
|
||||
}
|
||||
7
util/src/ipc/shm/ringbuf/mod.rs
Normal file
7
util/src/ipc/shm/ringbuf/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Shared-memory ring buffers
|
||||
|
||||
mod main;
|
||||
pub use main::*;
|
||||
|
||||
mod local;
|
||||
pub use local::*;
|
||||
367
util/src/ipc/shm/shared_memory_segment.rs
Normal file
367
util/src/ipc/shm/shared_memory_segment.rs
Normal file
@@ -0,0 +1,367 @@
|
||||
//! Accessing data in a shared memory segment
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
os::fd::{AsFd, OwnedFd},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
int::u64uint::usize_to_u64,
|
||||
ptr::{ReadMemVolatile, WriteMemVolatile},
|
||||
result::OkExt,
|
||||
secret_memory::{
|
||||
fd::SecretMemfdConfig,
|
||||
mmap::{MMapError, MapFdConfig, MappedSegment},
|
||||
},
|
||||
};
|
||||
|
||||
/// Safe creation of shared memory segments
|
||||
///
|
||||
/// This is a slightly more convenient API than using [MapFdConfig] directly.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SharedMemorySegmentBuilder {
|
||||
/// Configuration for allocating memory file descriptors and their secrecy level
|
||||
pub secret_memfd_cfg: SecretMemfdConfig,
|
||||
/// Configuration for mapping file descriptors into memory
|
||||
pub map_fd_cfg: MapFdConfig,
|
||||
}
|
||||
|
||||
impl SharedMemorySegmentBuilder {
|
||||
/// Create a new secret memory segment builder.
|
||||
///
|
||||
/// Note that this always sets [MapFdConfig::set_shared()], but you can overwrite this behavior
|
||||
/// by un-setting the flag again.
|
||||
///
|
||||
/// # Safety & panic
|
||||
///
|
||||
/// This function can panic when [u64] is unable to represent the given size
|
||||
/// value. This might be the case in future computers impressing me with their excessive size.
|
||||
pub const fn new(len: usize) -> Self {
|
||||
let secret_memfd_cfg = SecretMemfdConfig::new();
|
||||
let map_fd_cfg = MapFdConfig::new()
|
||||
.set_shared()
|
||||
.resize_on_mmap(usize_to_u64(len));
|
||||
Self {
|
||||
secret_memfd_cfg,
|
||||
map_fd_cfg,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a secret memory segment using the configuration stored here
|
||||
pub fn create_segment(&self) -> anyhow::Result<(OwnedFd, SharedMemorySegment)> {
|
||||
let fd = self.secret_memfd_cfg.create()?;
|
||||
let seg = self.map_fd_cfg.mappable_fd(&fd).mmap()?;
|
||||
let seg = unsafe { SharedMemorySegment::from_mapped_segment(seg) };
|
||||
|
||||
Ok((fd, seg))
|
||||
}
|
||||
}
|
||||
|
||||
/// Safe creation of and access to shared memory segments
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Any means to create a [Self] must guarantee, that further calls to [Self::volatile_write] and
|
||||
/// [Self::volatile_read] are also safe.
|
||||
///
|
||||
/// The API of this struct is specifically designed so creating multiple memory mappings of the same
|
||||
/// shared memory segment is impossible without unsafe code.
|
||||
///
|
||||
/// We recognize that the sole purpose of shared memory segments is that multiple mappings of them
|
||||
/// are created, there just is no way to do so using safe rust.
|
||||
///
|
||||
/// The reason for this is – that by the Rust documentation – any concurrent memory access leads to
|
||||
/// undefined behavior, even volatile accesses caused solely by an adversarial application on the
|
||||
/// other end of a shared memory communication channel.
|
||||
///
|
||||
/// For this reason, we force users to use unsafe code to create shared memory mappings from an
|
||||
/// existing file descriptor, as this could potentially lead to adversarial data races (and thus to
|
||||
/// undefined behavior).
|
||||
///
|
||||
/// In practice, we believe using concurrent memory access using volatile operations is going to
|
||||
/// lead to nothing worse than garbled data being transferred. The user should treat any data
|
||||
/// received through a shared-memory ring buffer as untrusted and validate this data any way, so
|
||||
/// garbled data should be caught.
|
||||
///
|
||||
/// This means that [Self::from_fd()] is unsafe in theory, but most likely safe in practice.
|
||||
///
|
||||
/// ## Concurrent, untrusted shared memory is technically undefined behavior
|
||||
///
|
||||
/// Its even worse than having to use unsafe: technically speaking, it may be impossible to
|
||||
/// use shared memory soundly in rust unless all parties with access to the segment are
|
||||
/// *trusted*. If these parties are not trusted (or buggy) they can always cause undefined
|
||||
/// behavior:
|
||||
///
|
||||
/// From the [std::sync::atomic] documentation:
|
||||
///
|
||||
/// > **The most important aspect of this model is that data races are undefined behavior.** A data race
|
||||
/// > is defined as conflicting non-synchronized accesses where at least one of the accesses is non-atomic.
|
||||
/// > Here, accesses are conflicting if they affect overlapping regions of memory and at least one of them
|
||||
/// > is a write. (A compare_exchange or compare_exchange_weak that does not succeed is not considered a
|
||||
/// > write.) They are non-synchronized if neither of them happens-before the other, according to the
|
||||
/// > happens-before order of the memory model.
|
||||
///
|
||||
/// The fact that this API uses volatile [reads](Self::volatile_read) and [writes](Self::volatile_write),
|
||||
/// and the the fact that we use mmap(2) for allocation does not mitigate this issue; from the
|
||||
/// documentation of [std::ptr::write_volatile()]:
|
||||
///
|
||||
/// > When a volatile operation is used for memory inside an allocation, it behaves exactly like write,
|
||||
/// > except for the additional guarantee that it won’t be elided or reordered (see above). This implies
|
||||
/// > that the operation will actually access memory and not e.g. be lowered to a register access. Other
|
||||
/// > than that, all the usual rules for memory accesses apply (including provenance). In particular, just
|
||||
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
|
||||
/// > access from multiple threads. Volatile accesses behave exactly like non-atomic accesses in that regard.
|
||||
///
|
||||
/// An allocation is defined as follows (taken from [std::ptr]):
|
||||
///
|
||||
/// > An allocation is a subset of program memory which is addressable from Rust, and within which pointer
|
||||
/// > arithmetic is possible. Examples of allocations include heap allocations, stack-allocated variables,
|
||||
/// > statics, and consts. The safety preconditions of some Rust operations - such as offset and field
|
||||
/// > projections (expr.field) - are defined in terms of the allocations on which they operate.
|
||||
///
|
||||
/// This definition clearly applies to mmap(2) allocated regions.
|
||||
///
|
||||
/// What might mitigate this issue is mapping the region just once per process:
|
||||
///
|
||||
/// > In particular, just
|
||||
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
|
||||
/// > access from **multiple threads**.
|
||||
///
|
||||
/// We could argue that a process is not a thread, and thus concurrent access from two processes is
|
||||
/// fine, but concurrent access from two threads is not (unless guarded by an atomic value or a
|
||||
/// mutex or some primitive actually designed for synchronization).
|
||||
///
|
||||
/// There is no wording in the spec explicitly allowing raceful, concurrent access from multiple processes.
|
||||
///
|
||||
/// The problem with basing our safety-argument on the claim that "processes are not threads" is
|
||||
/// that the line between processes and threads is drawn in the sand. For linux, read the man page
|
||||
/// of clone(2):
|
||||
///
|
||||
/// > By contrast with fork(2), these [clone, __clone2, clone3] system calls provide more precise control over what pieces of execution
|
||||
/// > context are shared between the calling process and the child process. For example, using these system
|
||||
/// > calls, the caller can control whether or not the two processes share the virtual address space, the ta‐
|
||||
/// > ble of file descriptors, and the table of signal handlers. These system calls also allow the new child
|
||||
/// > process to be placed in separate namespaces(7).
|
||||
/// >
|
||||
/// > […]
|
||||
/// >
|
||||
/// > ## CLONE_THREAD (since Linux 2.4.0)
|
||||
/// >
|
||||
/// > If CLONE_THREAD is set, the child is placed in the same thread group as the calling process. To
|
||||
/// > make the remainder of the discussion of CLONE_THREAD more readable, the term "thread" is used to
|
||||
/// > refer to the processes within a thread group.
|
||||
///
|
||||
/// According to the man page, "the term 'thread' is used to refer to the processes within a thread group.".
|
||||
///
|
||||
/// The Rust (transitively, from the C++11 Atomic) specification tells us that there must be no
|
||||
/// concurrent memory access between threads whether this access is volatile or not. The linux man
|
||||
/// pages tell us that "thread" is just a special type of "process".
|
||||
///
|
||||
/// **The most robust interpretation of these specifications is that shared memory must not be used for
|
||||
/// communication with an untrusted party across thread or process boundaries, or else the other
|
||||
/// process/thread can cause undefined behavior in our process.**
|
||||
///
|
||||
/// ## In practice
|
||||
///
|
||||
/// Realistically, using volatile reads/writes on valid, mapped memory might cause garbled values in
|
||||
/// case of a data race, but it should crash the program or do anything worse than create garbled
|
||||
/// values.
|
||||
///
|
||||
/// Mind that we do not mind garbled values here; we are implementing a shared memory communication
|
||||
/// interface, so our application must always assume, that the data it receives may be garbled. It
|
||||
/// has to be validated. We just don't want the other application to be able to do anything worse
|
||||
/// that garble the data it is sending (or receiving), so lets estimate what can *realistically*
|
||||
/// happen here if the other application maliciously causes a race.
|
||||
///
|
||||
/// The worst any of the assembly sequences below should do is cause tearing in case of a data
|
||||
/// race.
|
||||
///
|
||||
/// This leads me to the conclusion that what what we are dealing here with is not an
|
||||
/// implementation that is faulty/insecure, instead it is a definition-gap in the compiler
|
||||
/// semantics for volatile memory access for use in security-critical applications.
|
||||
///
|
||||
/// Godbolt link: https://rust.godbolt.org/z/GGjsGsc33
|
||||
/// Compiler: `rustc 1.90.0`
|
||||
///
|
||||
/// Rust code:
|
||||
///
|
||||
/// ```rust
|
||||
/// #[unsafe(no_mangle)]
|
||||
/// pub fn read_volatile(num: &[u128]) -> u128 {
|
||||
/// let ptr = num.as_ptr();
|
||||
/// unsafe { ptr.read_volatile() }
|
||||
/// }
|
||||
///
|
||||
/// #[unsafe(no_mangle)]
|
||||
/// pub fn write_volatile(num: &mut [u128]) {
|
||||
/// let ptr = num.as_mut_ptr();
|
||||
/// unsafe { ptr.write_volatile(42u128) };
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// x86_64 (`--target=x86_64-unknown-linux-gnu -O`):
|
||||
///
|
||||
/// ```asm
|
||||
/// read_volatile:
|
||||
/// mov rax, qword ptr [rdi]
|
||||
/// mov rdx, qword ptr [rdi + 8]
|
||||
/// ret
|
||||
///
|
||||
/// write_volatile:
|
||||
/// mov qword ptr [rdi + 8], 0
|
||||
/// mov qword ptr [rdi], 42
|
||||
/// ret
|
||||
/// ```
|
||||
///
|
||||
/// arm64 (`--target=aarch64-unknown-linux-gnu -O`):
|
||||
///
|
||||
/// ```asm
|
||||
/// read_volatile:
|
||||
/// ldp x0, x1, [x0]
|
||||
/// ret
|
||||
///
|
||||
/// write_volatile:
|
||||
/// mov w8, #42
|
||||
/// stp x8, xzr, [x0]
|
||||
/// ret
|
||||
/// ```
|
||||
///
|
||||
/// armv7 (`--target=armv7-unknown-linux-gnueabihf -O`)
|
||||
///
|
||||
/// ```
|
||||
/// read_volatile:
|
||||
/// push {r4, r5, r11, lr}
|
||||
/// ldrd r2, r3, [r1]
|
||||
/// ldrd r4, r5, [r1, #8]
|
||||
/// stm r0, {r2, r3, r4, r5}
|
||||
/// pop {r4, r5, r11, pc}
|
||||
///
|
||||
/// write_volatile:
|
||||
/// push {r4, r5, r11, lr}
|
||||
/// mov r2, #0
|
||||
/// mov r4, #42
|
||||
/// mov r3, r2
|
||||
/// mov r5, r2
|
||||
/// strd r2, r3, [r0, #8]
|
||||
/// strd r4, r5, [r0]
|
||||
/// pop {r4, r5, r11, pc}
|
||||
/// ```
|
||||
///
|
||||
/// risc64 (`--target=riscv64gc-unknown-linux-gnu`):
|
||||
///
|
||||
/// ```
|
||||
/// read_volatile:
|
||||
/// ld a1, 8(a0)
|
||||
/// ld a0, 0(a0)
|
||||
/// ret
|
||||
///
|
||||
/// write_volatile:
|
||||
/// sd zero, 8(a0)
|
||||
/// li a1, 42
|
||||
/// sd a1, 0(a0)
|
||||
/// ret
|
||||
/// ```
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub struct SharedMemorySegment {
|
||||
/// The underlying mapped segment
|
||||
inner: MappedSegment,
|
||||
}
|
||||
|
||||
impl SharedMemorySegment {
|
||||
/// Create a new shared memory segment.
|
||||
pub fn create(len: usize) -> anyhow::Result<(OwnedFd, Self)> {
|
||||
SharedMemorySegmentBuilder::new(len).create_segment()
|
||||
}
|
||||
|
||||
/// Create a shared memory segment from a file descriptor
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See the comments in [Self].
|
||||
pub unsafe fn from_fd<Fd: AsFd>(fd: Fd, size: usize) -> Result<Self, MMapError> {
|
||||
let cfg = MapFdConfig::new()
|
||||
.set_shared()
|
||||
.expected_size(usize_to_u64(size));
|
||||
unsafe { Self::from_fd_with_config(fd, cfg) }
|
||||
}
|
||||
|
||||
/// Create a shared memory segment from a file descriptor
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See the comments in [Self].
|
||||
pub unsafe fn from_fd_with_config<Fd: AsFd>(
|
||||
fd: Fd,
|
||||
cfg: MapFdConfig,
|
||||
) -> Result<Self, MMapError> {
|
||||
let segment = cfg.mappable_fd(&fd).mmap()?;
|
||||
unsafe { Self::from_mapped_segment(segment).ok() }
|
||||
}
|
||||
|
||||
|
||||
/// Create a shared memory segment from an existing mapped segment
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See the comments in [Self].
|
||||
pub unsafe fn from_mapped_segment(inner: MappedSegment) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
/// The underlying mapped segment
|
||||
pub fn mapped_segment(&self) -> &MappedSegment {
|
||||
self.inner.borrow()
|
||||
}
|
||||
|
||||
/// A pointer to the underlying mapped segment
|
||||
pub fn ptr(&self) -> *mut u8 {
|
||||
self.mapped_segment().ptr()
|
||||
}
|
||||
|
||||
/// The length of the underlying mapped segment
|
||||
pub fn len(&self) -> usize {
|
||||
self.mapped_segment().len()
|
||||
}
|
||||
|
||||
/// Whether `self.`[len()](Self::len)` == 0`
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Read data from the ring buffer
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See comments in [Self].
|
||||
pub fn volatile_read(&self, dst: &mut [u8], off: usize) {
|
||||
let end = off + dst.len();
|
||||
assert!(end <= self.len());
|
||||
|
||||
unsafe { self.ptr().add(off).read_mem_volatile(dst) }
|
||||
}
|
||||
|
||||
/// Read data from the ring buffer
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See comments in [Self].
|
||||
pub fn volatile_write(&self, off: usize, src: &[u8]) {
|
||||
let end = off + src.len();
|
||||
assert!(end <= self.len());
|
||||
|
||||
unsafe { self.ptr().add(off).write_mem_volatile(src) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SharedMemorySegment> for MappedSegment {
|
||||
fn from(val: SharedMemorySegment) -> Self {
|
||||
val.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Borrow<MappedSegment> for SharedMemorySegment {
|
||||
fn borrow(&self) -> &MappedSegment {
|
||||
self.mapped_segment()
|
||||
}
|
||||
}
|
||||
@@ -10,15 +10,16 @@ pub mod b64;
|
||||
pub mod build;
|
||||
/// Control flow abstractions and utilities.
|
||||
pub mod controlflow;
|
||||
/// File descriptor utilities.
|
||||
pub mod fd;
|
||||
pub mod convert;
|
||||
/// File system operations and handling.
|
||||
pub mod file;
|
||||
pub mod fmt;
|
||||
/// Functional programming utilities.
|
||||
pub mod functional;
|
||||
pub mod int;
|
||||
/// Input/output operations.
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
/// Length prefix encoding schemes implementation.
|
||||
pub mod length_prefix_encoding;
|
||||
/// Memory manipulation and allocation utilities.
|
||||
@@ -27,8 +28,13 @@ pub mod mem;
|
||||
pub mod mio;
|
||||
/// Extended Option type functionality.
|
||||
pub mod option;
|
||||
pub mod ptr;
|
||||
/// Extended Result type functionality.
|
||||
pub mod result;
|
||||
pub mod ringbuf;
|
||||
pub mod rustix;
|
||||
pub mod secret_memory;
|
||||
pub mod sync;
|
||||
/// Time and duration utilities.
|
||||
pub mod time;
|
||||
#[cfg(feature = "tokio")]
|
||||
|
||||
102
util/src/mem.rs
102
util/src/mem.rs
@@ -280,6 +280,108 @@ impl<T: Sized> MoveExt for T {
|
||||
}
|
||||
}
|
||||
|
||||
/// Explicitly copy a value
|
||||
///
|
||||
/// This is sometimes useful as an alternative to the deref operator
|
||||
/// (`x.copy()` instead of `*x`) to achieve neater method chains.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::mem::CopyExt;
|
||||
/// use rosenpass_util::result::OkExt;
|
||||
///
|
||||
/// // Without CopyExt
|
||||
/// fn boilerplate_1<T: Copy>(v: &T) -> Result<T, ()> {
|
||||
/// (*v).ok()
|
||||
/// }
|
||||
///
|
||||
/// // With CopyExt
|
||||
/// fn boilerplate_2<T: Copy>(v: &T) -> Result<T, ()> {
|
||||
/// v.copy().ok()
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(boilerplate_1(&32), Ok(32));
|
||||
/// assert_eq!(boilerplate_1(&32), boilerplate_2(&32));
|
||||
/// ```
|
||||
pub trait CopyExt: Copy {
|
||||
/// Copy a value
|
||||
fn copy(&self) -> Self {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy> CopyExt for T {}
|
||||
|
||||
/// Helper trait for applying a mutating closure/function to a mutable reference
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// fn plus_two_mul_three(v: u64) -> u64 {
|
||||
/// (v + 2) * 3
|
||||
/// }
|
||||
///
|
||||
/// fn inconvenient_example(v: &mut u64) {
|
||||
/// *v = plus_two_mul_three(*v);
|
||||
/// }
|
||||
///
|
||||
/// fn convenient_example(v: &mut u64) {
|
||||
/// v.mutate(plus_two_mul_three);
|
||||
/// }
|
||||
///
|
||||
/// let mut x = 0;
|
||||
/// inconvenient_example(&mut x);
|
||||
/// assert_eq!(x, 6);
|
||||
/// convenient_example(&mut x);
|
||||
/// assert_eq!(x, 24);
|
||||
///
|
||||
/// x = 0;
|
||||
///
|
||||
/// x.mutate(plus_two_mul_three);
|
||||
/// assert_eq!(x, 6);
|
||||
/// x.mutate(|v| (v + 2) * 3);
|
||||
/// assert_eq!(x, 24);
|
||||
///
|
||||
/// x = 0;
|
||||
///
|
||||
/// let y = &mut x;
|
||||
/// y.mutate(plus_two_mul_three);
|
||||
/// assert_eq!(x, 6);
|
||||
///
|
||||
/// struct Cont {
|
||||
/// x: u64,
|
||||
/// }
|
||||
///
|
||||
/// impl Cont {
|
||||
/// fn x_mut(&mut self) -> &mut u64 {
|
||||
/// &mut self.x
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let mut s = Cont { x: 0 };
|
||||
/// s.x_mut().mutate(|v| (v + 2) * 3);
|
||||
/// ```
|
||||
pub trait MutateRefExt: DerefMut
|
||||
where
|
||||
Self::Target: Sized,
|
||||
{
|
||||
/// Directly mutate a reference
|
||||
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(self, f: F) -> Self;
|
||||
}
|
||||
|
||||
impl<T> MutateRefExt for T
|
||||
where
|
||||
T: DerefMut,
|
||||
<T as Deref>::Target: Copy,
|
||||
{
|
||||
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(mut self, f: F) -> Self {
|
||||
let v = self.deref_mut();
|
||||
*v = f(*v);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_forgetting {
|
||||
use crate::mem::Forgetting;
|
||||
|
||||
@@ -2,8 +2,8 @@ use mio::net::{UnixListener, UnixStream};
|
||||
use std::os::fd::{OwnedFd, RawFd};
|
||||
|
||||
use crate::{
|
||||
fd::{claim_fd, claim_fd_inplace},
|
||||
result::OkExt,
|
||||
rustix::{claim_fd, claim_fd_inplace},
|
||||
};
|
||||
|
||||
/// Module containing I/O interest flags for Unix operations (see also: [mio::Interest])
|
||||
@@ -93,7 +93,7 @@ impl UnixStreamExt for UnixStream {
|
||||
fn from_fd(fd: OwnedFd) -> anyhow::Result<Self> {
|
||||
use std::os::unix::net::UnixStream as StdUnixStream;
|
||||
#[cfg(target_os = "linux")] // TODO: We should support this on other plattforms
|
||||
crate::fd::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
|
||||
crate::rustix::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
|
||||
let sock = StdUnixStream::from(fd);
|
||||
sock.set_nonblocking(true)?;
|
||||
UnixStream::from_std(sock).ok()
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::{
|
||||
};
|
||||
use uds::UnixStreamExt as FdPassingExt;
|
||||
|
||||
use crate::fd::{claim_fd_inplace, IntoStdioErr};
|
||||
use crate::rustix::{claim_fd_inplace, IntoStdioErr};
|
||||
|
||||
/// A wrapper around a socket that combines reading from the socket with tracking
|
||||
/// received file descriptors. Limits the maximum number of file descriptors that
|
||||
|
||||
4
util/src/ptr/mod.rs
Normal file
4
util/src/ptr/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Utilities for working with pointers
|
||||
|
||||
mod volatile;
|
||||
pub use volatile::*;
|
||||
51
util/src/ptr/volatile.rs
Normal file
51
util/src/ptr/volatile.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
//! Utilities relating to volatile reads/writes on pointers
|
||||
|
||||
/// Read from a memory location using
|
||||
/// [pointer::read_volatile] in a loop and store the
|
||||
/// results in the given slice
|
||||
pub trait ReadMemVolatile<T> {
|
||||
/// Read from a memory location using
|
||||
/// [pointer::read_volatile] in a loop and store the
|
||||
/// results in an array
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Refer to [pointer::read_volatile]
|
||||
unsafe fn read_mem_volatile(self, dst: &mut [T]);
|
||||
}
|
||||
|
||||
impl<T> ReadMemVolatile<T> for *const T {
|
||||
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
|
||||
for (idx, dst) in dst.iter_mut().enumerate() {
|
||||
*dst = unsafe { self.add(idx).read_volatile() };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ReadMemVolatile<T> for *mut T {
|
||||
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
|
||||
unsafe { self.cast_const().read_mem_volatile(dst) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Write to a memory location using
|
||||
/// [pointer::write_volatile] in a loop
|
||||
/// and store the resulting values in the given slice
|
||||
pub trait WriteMemVolatile<T>: ReadMemVolatile<T> {
|
||||
/// Write to a memory location using
|
||||
/// [pointer::write_volatile] in a loop
|
||||
/// and store the resulting values in the given slice
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Refer to [pointer::write_volatile]
|
||||
unsafe fn write_mem_volatile(self, src: &[T]);
|
||||
}
|
||||
|
||||
impl<T: Copy> WriteMemVolatile<T> for *mut T {
|
||||
unsafe fn write_mem_volatile(self, src: &[T]) {
|
||||
for (idx, src) in src.iter().enumerate() {
|
||||
unsafe { self.add(idx).write_volatile(*src) }
|
||||
}
|
||||
}
|
||||
}
|
||||
3
util/src/ringbuf/concurrend/mod.rs
Normal file
3
util/src/ringbuf/concurrend/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Concurrent ring buffers
|
||||
|
||||
pub mod framework;
|
||||
202
util/src/ringbuf/concurrent/framework.rs
Normal file
202
util/src/ringbuf/concurrent/framework.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
//! An implementation of a concurrent ring buffer with the
|
||||
//! core (IO/memory access) operations in a detached trait. Everything around the core is
|
||||
//! implemented but before being able to use this, callers must implement an appropriate IO core.
|
||||
|
||||
use crate::int::u64uint::U64USizeRangeExt;
|
||||
use crate::result::OkExt;
|
||||
use crate::ringbuf::sched::{
|
||||
Diff, OperationType, RingBufferFromCountersError, RingBufferScheduler, ScheduledOperations,
|
||||
};
|
||||
use crate::sync::atomic::abstract_atomic::AbstractAtomic;
|
||||
|
||||
pub trait ConcurrentPipeCore {
|
||||
type AtomicType: AbstractAtomic<u64>;
|
||||
|
||||
fn buf_len(&self) -> u64;
|
||||
fn items_read(&self) -> &Self::AtomicType;
|
||||
fn items_written(&self) -> &Self::AtomicType;
|
||||
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64);
|
||||
fn write_to_buffer(&mut self, off: u64, src: &[u8]);
|
||||
}
|
||||
|
||||
/// Raised by [ConcurrentPipeReader::read()]/[ConcurrentPipeWriter::write()] if the underlying
|
||||
/// ring buffer is in an inconsistent state.
|
||||
///
|
||||
/// When used in shared memory applications, this error may have been locally caused, but it may
|
||||
/// also be caused by **the other threads or processes** putting the ring buffer into an
|
||||
/// inconsistent state.
|
||||
///
|
||||
/// For this reason, you must treat this error as an unrecoverable breakdown of the communication channel. You
|
||||
/// must close and no longer use ring buffer after receiving this error, but you can handle it
|
||||
/// gracefully by reopening a new ring buffer in its place.
|
||||
///
|
||||
/// # API Stability
|
||||
///
|
||||
/// Treat this error type as opaque; do not rely on the enum values remaining the same
|
||||
#[derive(Debug, thiserror::Error, Clone)]
|
||||
pub enum InconsistentRingBufferStateError {
|
||||
/// Failed to update an inner counter; this probably indicates a data race (forbidden
|
||||
/// concurrent read/write)
|
||||
#[error("Concurrent access to the ring buffer {:?}", self)]
|
||||
ConcurrrentAccess {
|
||||
/// Whether this was a read or write
|
||||
operation_type: OperationType,
|
||||
/// Operations performed before updaing the ring buffer state
|
||||
scheduled_ops: ScheduledOperations,
|
||||
/// The scheduler that scheduled our operation
|
||||
scheduler_state: RingBufferScheduler,
|
||||
/// The counter value before compare and exchange
|
||||
expected_counter_value: u64,
|
||||
/// The counter value we actually found in the counter
|
||||
actual_counter_value: u64,
|
||||
/// The counter value we tried to set
|
||||
new_counter_value_tried_to_set: u64,
|
||||
},
|
||||
/// Could not construct ring buffer scheduler
|
||||
#[error("Inconsistent ring buffer state: {:?}", .0)]
|
||||
InconsistentCounterState(#[from] RingBufferFromCountersError),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ConcurrentPipeOperation<'a> {
|
||||
Read(&'a mut [u8]),
|
||||
Write(&'a [u8]),
|
||||
}
|
||||
|
||||
impl<'a> ConcurrentPipeOperation<'a> {
|
||||
pub fn inner_buf(&'a self) -> &'a [u8] {
|
||||
match self {
|
||||
ConcurrentPipeOperation::Read(items) => items,
|
||||
ConcurrentPipeOperation::Write(items) => items,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.inner_buf().len()
|
||||
}
|
||||
|
||||
pub fn scheduler_op(&self) -> OperationType {
|
||||
match self {
|
||||
ConcurrentPipeOperation::Read(_) => OperationType::Read,
|
||||
ConcurrentPipeOperation::Write(_) => OperationType::Write,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ConcurrentPipeImpl<Core: ConcurrentPipeCore> {
|
||||
core: Core,
|
||||
}
|
||||
|
||||
impl<Core: ConcurrentPipeCore> ConcurrentPipeImpl<Core> {
|
||||
fn from_core(core: Core) -> Self {
|
||||
Self { core }
|
||||
}
|
||||
|
||||
fn read_or_write(
|
||||
&mut self,
|
||||
mut op: ConcurrentPipeOperation,
|
||||
) -> Result<usize, InconsistentRingBufferStateError> {
|
||||
use std::sync::atomic::Ordering as O;
|
||||
|
||||
// Figure out which counter to store the result of the operation in and the orderings to
|
||||
// use for load/store operations
|
||||
let (ord_r, ord_w, ord_store_succ, ord_store_fail) = match op {
|
||||
ConcurrentPipeOperation::Read(_) => (O::Relaxed, O::Acquire, O::Relaxed, O::Relaxed),
|
||||
ConcurrentPipeOperation::Write(_) => (O::Relaxed, O::Relaxed, O::Release, O::Relaxed),
|
||||
};
|
||||
|
||||
// Construct a ring buffer scheduler from the current state
|
||||
let sched = RingBufferScheduler::try_from_counters(
|
||||
self.core.buf_len(),
|
||||
self.core.items_written().load(ord_w),
|
||||
self.core.items_read().load(ord_r),
|
||||
)?;
|
||||
|
||||
// Have the scheduler schedule the operations
|
||||
let ops = sched.schedule_contigous_operations(op.scheduler_op(), op.len());
|
||||
|
||||
// Actually perform the operations
|
||||
for (buf_slice, ring_op) in ops.with_outside_buffer_range() {
|
||||
match &mut op {
|
||||
ConcurrentPipeOperation::Read(dst) => self
|
||||
.core
|
||||
.read_from_buffer(&mut dst[buf_slice.usize()], ring_op.off),
|
||||
ConcurrentPipeOperation::Write(src) => self
|
||||
.core
|
||||
.write_to_buffer(ring_op.off, &src[buf_slice.usize()]),
|
||||
}
|
||||
}
|
||||
|
||||
// Take a differential between the scheduler and the scheduler after the operations where
|
||||
// applied. Make sure only one counter was updated
|
||||
let diff = sched
|
||||
.register_operation(&ops)
|
||||
.diff_old(&sched)
|
||||
.expect_op_only(op.scheduler_op())
|
||||
.unwrap();
|
||||
|
||||
let store_to_ctr = match op {
|
||||
ConcurrentPipeOperation::Read(_) => self.core.items_read(),
|
||||
ConcurrentPipeOperation::Write(_) => self.core.items_written(),
|
||||
};
|
||||
|
||||
// Update the counters, assuming there was any change at all
|
||||
if let Diff::Different(old, new) = diff {
|
||||
store_to_ctr
|
||||
.compare_exchange(old, new, ord_store_succ, ord_store_fail)
|
||||
.map_err(
|
||||
|actual| InconsistentRingBufferStateError::ConcurrrentAccess {
|
||||
operation_type: op.scheduler_op(),
|
||||
scheduled_ops: ops,
|
||||
scheduler_state: sched,
|
||||
expected_counter_value: old,
|
||||
actual_counter_value: actual,
|
||||
new_counter_value_tried_to_set: new,
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
ops.cumulative_operation_length().usize().ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConcurrentPipeReader<Core: ConcurrentPipeCore> {
|
||||
inner: ConcurrentPipeImpl<Core>,
|
||||
}
|
||||
|
||||
impl<Core: ConcurrentPipeCore> ConcurrentPipeWriter<Core> {
|
||||
pub fn from_core(core: Core) -> Self {
|
||||
Self {
|
||||
inner: ConcurrentPipeImpl::from_core(core),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buf_len(&self) -> u64 {
|
||||
self.inner.core.buf_len()
|
||||
}
|
||||
|
||||
pub fn write(&mut self, src: &[u8]) -> Result<usize, InconsistentRingBufferStateError> {
|
||||
self.inner
|
||||
.read_or_write(ConcurrentPipeOperation::Write(src))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConcurrentPipeWriter<Core: ConcurrentPipeCore> {
|
||||
inner: ConcurrentPipeImpl<Core>,
|
||||
}
|
||||
|
||||
impl<Core: ConcurrentPipeCore> ConcurrentPipeReader<Core> {
|
||||
pub fn from_core(core: Core) -> Self {
|
||||
Self {
|
||||
inner: ConcurrentPipeImpl::from_core(core),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buf_len(&self) -> u64 {
|
||||
self.inner.core.buf_len()
|
||||
}
|
||||
|
||||
pub fn read(&mut self, dst: &mut [u8]) -> Result<usize, InconsistentRingBufferStateError> {
|
||||
self.inner.read_or_write(ConcurrentPipeOperation::Read(dst))
|
||||
}
|
||||
}
|
||||
3
util/src/ringbuf/concurrent/mod.rs
Normal file
3
util/src/ringbuf/concurrent/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Concurrent ring buffers
|
||||
|
||||
pub mod framework;
|
||||
4
util/src/ringbuf/mod.rs
Normal file
4
util/src/ringbuf/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Ring buffers and utilities for working with them
|
||||
|
||||
pub mod concurrent;
|
||||
pub mod sched;
|
||||
1376
util/src/ringbuf/sched.rs
Normal file
1376
util/src/ringbuf/sched.rs
Normal file
File diff suppressed because it is too large
Load Diff
117
util/src/rustix/error.rs
Normal file
117
util/src/rustix/error.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
//! Rustix extensions for error handling
|
||||
|
||||
/// Provides access to the last system error number
|
||||
///
|
||||
/// > The integer variable errno is set by system calls and some library functions in the event of an error to indicate what went wrong.
|
||||
///
|
||||
/// -- `man 3 errno`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function panics if ther
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
///
|
||||
/// use rustix::io::Errno as E;
|
||||
/// use rosenpass_util::rustix::{errno, try_errno, last_os_result};
|
||||
///
|
||||
/// let res = unsafe { libc::mkdir(c"/tmp/baz".as_ptr(), 0) };
|
||||
/// assert_eq!(res, -1);
|
||||
/// assert_eq!(errno(), E::EXIST);
|
||||
/// assert_eq!(try_errno(), Some(E::EXIST));
|
||||
/// assert_eq!(last_os_result(), Err(E::EXIST));
|
||||
///
|
||||
/// // Deliberately clear the system error
|
||||
/// unsafe { libc::__errno_location().write(0) };
|
||||
/// // assert_eq!(errno(), _); // PANICS
|
||||
/// assert_eq!(try_errno(), None);
|
||||
/// assert_eq!(last_os_result(), Ok(()));
|
||||
/// ```
|
||||
///
|
||||
/// Calling errno() when there is no error causes a panic:
|
||||
///
|
||||
/// ```rust,should_panic
|
||||
///
|
||||
/// use rustix::io::Errno as E;
|
||||
/// use rosenpass_util::rustix::errno;
|
||||
///
|
||||
/// // Deliberately clear the system error
|
||||
/// unsafe { libc::__errno_location().write(0) };
|
||||
/// errno(); // PANICS
|
||||
/// ```
|
||||
pub fn errno() -> rustix::io::Errno {
|
||||
match try_errno() {
|
||||
None => panic!("Tried to retrieve last system error, but there was no system error (the system error number, errno = 0)"),
|
||||
Some(errno) => errno,
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides access to the last system error number.
|
||||
///
|
||||
/// Variant of [errno] that will return None if there was no system error.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [errno].
|
||||
pub fn try_errno() -> Option<rustix::io::Errno> {
|
||||
let raw = unsafe { libc::__errno_location().read() };
|
||||
match raw {
|
||||
0 => None,
|
||||
_ => Some(rustix::io::Errno::from_raw_os_error(raw)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides access to the last system error number.
|
||||
///
|
||||
/// Variant of [errno] that will return `Err(errno)` if there
|
||||
/// was a system error and `Ok(())` otherwise.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [errno].
|
||||
pub fn last_os_result() -> Result<(), rustix::io::Errno> {
|
||||
match try_errno() {
|
||||
None => Ok(()),
|
||||
Some(errno) => Err(errno),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert low level errors into std::io::Error
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::ErrorKind as EK;
|
||||
/// use rustix::io::Errno;
|
||||
/// use rosenpass_util::rustix::IntoStdioErr;
|
||||
///
|
||||
/// let e = Errno::INTR.into_stdio_err();
|
||||
/// assert!(matches!(e.kind(), EK::Interrupted));
|
||||
///
|
||||
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
|
||||
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
|
||||
/// ```
|
||||
pub trait IntoStdioErr {
|
||||
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
|
||||
type Target;
|
||||
/// Convert low level errors to
|
||||
fn into_stdio_err(self) -> Self::Target;
|
||||
}
|
||||
|
||||
impl IntoStdioErr for rustix::io::Errno {
|
||||
type Target = std::io::Error;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
std::io::Error::from_raw_os_error(self.raw_os_error())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoStdioErr for rustix::io::Result<T> {
|
||||
type Target = std::io::Result<T>;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
self.map_err(IntoStdioErr::into_stdio_err)
|
||||
}
|
||||
}
|
||||
255
util/src/rustix/fd.rs
Normal file
255
util/src/rustix/fd.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
//! Basic utilities for working with file descriptors
|
||||
|
||||
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
|
||||
|
||||
use rustix::io::fcntl_dupfd_cloexec;
|
||||
|
||||
use super::IntoStdioErr;
|
||||
|
||||
use crate::mem::Forgetting;
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
|
||||
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
|
||||
/// in case the given file descriptor is still used somewhere
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::{IntoRawFd, AsRawFd};
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::rustix::{claim_fd, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let orig = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut claimed = FdIo(claim_fd(orig)?);
|
||||
///
|
||||
/// // A different file descriptor is used
|
||||
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
|
||||
///
|
||||
/// // Write some data
|
||||
/// claimed.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
|
||||
mask_fd(fd)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid.
|
||||
///
|
||||
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::IntoRawFd;
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::rustix::{claim_fd_inplace, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let fd = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
|
||||
///
|
||||
/// // Write some data
|
||||
/// fd.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||
let tmp = clone_fd_cloexec(&new)?;
|
||||
clone_fd_to_cloexec(tmp, &mut new)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Will close the given file descriptor and overwrite
|
||||
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use std::fs::File;
|
||||
/// # use std::io::Read;
|
||||
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||
/// # use std::os::fd::IntoRawFd;
|
||||
/// # use rustix::fd::AsFd;
|
||||
/// # use rosenpass_util::rustix::mask_fd;
|
||||
///
|
||||
/// // Open a temporary file
|
||||
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
|
||||
/// assert!(fd >= 0);
|
||||
///
|
||||
/// // Mask the file descriptor
|
||||
/// mask_fd(fd).unwrap();
|
||||
///
|
||||
/// // Verify the file descriptor now points to `/dev/null`
|
||||
/// // Reading from `/dev/null` always returns 0 bytes
|
||||
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
|
||||
/// let mut buffer = [0u8; 4];
|
||||
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
|
||||
/// assert_eq!(bytes_read, 0);
|
||||
/// ```
|
||||
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
|
||||
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
|
||||
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
|
||||
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag
|
||||
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
|
||||
/// Avoid stdin, stdout, and stderr
|
||||
const MINFD: RawFd = 3;
|
||||
fcntl_dupfd_cloexec(fd, MINFD)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup3, DupFlags};
|
||||
dup3(fd, new, DupFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup2, fcntl_setfd, FdFlags};
|
||||
dup2(&fd, new)?;
|
||||
fcntl_setfd(&new, FdFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
|
||||
/// writing.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The behavior of the file descriptor when being written to or from is undefined.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
|
||||
/// use rustix::fd::FromRawFd;
|
||||
/// use rosenpass_util::rustix::open_nullfd;
|
||||
///
|
||||
/// let nullfd = open_nullfd().unwrap();
|
||||
/// ```
|
||||
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
|
||||
use rustix::fs::{open, Mode, OFlags};
|
||||
// TODO: Add tests showing that this will throw errors on use
|
||||
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
|
||||
}
|
||||
|
||||
/// Read and write directly from a file descriptor
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [claim_fd].
|
||||
pub struct FdIo<Fd: AsFd>(pub Fd);
|
||||
|
||||
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
rustix::io::read(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
rustix::io::write(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_neg() {
|
||||
let _ = claim_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_max() {
|
||||
let _ = claim_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_neg() {
|
||||
let _ = claim_fd_inplace(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_max() {
|
||||
let _ = claim_fd_inplace(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_neg() {
|
||||
let _ = mask_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_max() {
|
||||
let _ = mask_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_open_nullfd() -> anyhow::Result<()> {
|
||||
let mut file = FdIo(open_nullfd()?);
|
||||
let mut buf = [0; 10];
|
||||
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
|
||||
assert!(file.write(&buf).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nullfd_read_write() {
|
||||
let nullfd = open_nullfd().unwrap();
|
||||
let mut buf = vec![0u8; 16];
|
||||
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
|
||||
assert!(rustix::io::write(&nullfd, b"test").is_err());
|
||||
}
|
||||
}
|
||||
86
util/src/rustix/memfd.rs
Normal file
86
util/src/rustix/memfd.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
//! Utilities for working with memory based file descriptors
|
||||
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use rustix::fs::MemfdFlags;
|
||||
use rustix::io::Errno;
|
||||
use rustix::path::Arg as Path;
|
||||
|
||||
use bitflags::bitflags;
|
||||
|
||||
use crate::convert::IntoTypeExt;
|
||||
|
||||
use super::SyscallResult;
|
||||
|
||||
/// Create an anonymous file
|
||||
/// using the memfd_create(2) syscall
|
||||
///
|
||||
/// Just forwards to [rustix::fs::memfd_create]
|
||||
pub fn memfd_create<P: Path>(name: P, flags: MemfdFlags) -> rustix::io::Result<OwnedFd> {
|
||||
rustix::fs::memfd_create(name, flags)
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
/// `FD_*` constants for use with [memfd_secret].
|
||||
#[repr(transparent)]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct MemfdSecretFlags: std::ffi::c_uint {
|
||||
/// FD_CLOEXEC
|
||||
const CLOEXEC = libc::FD_CLOEXEC as std::ffi::c_uint;
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors for [create_memfd_secret()]
|
||||
#[derive(Copy, Clone, Debug, thiserror::Error)]
|
||||
pub enum MemfdSecretError {
|
||||
/// memfd_secret(2) not supported on system
|
||||
#[error("Could not create secret memory segment using memfd_secret(2): not supported on your system")]
|
||||
NotSupported,
|
||||
/// Other error
|
||||
#[error("Could not create secret memory segment using memfd_secret(2): underlying system error: {:?}", .0)]
|
||||
SystemError(Errno),
|
||||
}
|
||||
|
||||
impl From<Errno> for MemfdSecretError {
|
||||
fn from(value: Errno) -> Self {
|
||||
match value {
|
||||
Errno::NOSYS => Self::NotSupported,
|
||||
e => Self::SystemError(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an anonymous RAM-based file to access secret memory regions
|
||||
/// using the memfd_secret(2) syscall
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rustix::io::Errno;
|
||||
/// use rustix::fs::ftruncate;
|
||||
///
|
||||
/// use rosenpass_util::rustix::{memfd_secret, MemfdSecretFlags, IntoStdioErr};
|
||||
/// use rosenpass_util::io::handle_interrupted;
|
||||
///
|
||||
/// let res = memfd_secret(MemfdSecretFlags::empty());
|
||||
/// let fd = match res {
|
||||
/// Ok(fd) => fd,
|
||||
/// // The system might not have memfd_secret enabled; abort the test
|
||||
/// Err(Errno::NOSYS) => return Ok(()),
|
||||
/// Err(err) => return Err(err)?,
|
||||
/// };
|
||||
///
|
||||
/// handle_interrupted(|| { ftruncate(&fd, 8192).into_stdio_err() })?;
|
||||
///
|
||||
/// Ok::<(), anyhow::Error>(())
|
||||
/// ```
|
||||
pub fn memfd_secret(flags: MemfdSecretFlags) -> Result<rustix::fd::OwnedFd, MemfdSecretError> {
|
||||
let res = unsafe {
|
||||
use libc::{syscall, SYS_memfd_secret};
|
||||
syscall(SYS_memfd_secret, flags)
|
||||
.into_type::<SyscallResult>()
|
||||
.claim_fd()
|
||||
};
|
||||
|
||||
res.map_err(MemfdSecretError::from)
|
||||
}
|
||||
16
util/src/rustix/mod.rs
Normal file
16
util/src/rustix/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Extensions to the rustix crate for memory safe operating system interfaces
|
||||
|
||||
mod error;
|
||||
pub use error::*;
|
||||
|
||||
mod fd;
|
||||
pub use fd::*;
|
||||
|
||||
mod stat;
|
||||
pub use stat::*;
|
||||
|
||||
mod syscall;
|
||||
pub use syscall::*;
|
||||
|
||||
mod memfd;
|
||||
pub use memfd::*;
|
||||
@@ -1,235 +1,12 @@
|
||||
//! Utilities for working with file descriptors
|
||||
//! Rustix extensions for getting information about file descriptors`
|
||||
|
||||
use std::os::fd::AsFd;
|
||||
|
||||
use anyhow::bail;
|
||||
use rustix::io::fcntl_dupfd_cloexec;
|
||||
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
|
||||
|
||||
use crate::{mem::Forgetting, result::OkExt};
|
||||
use super::IntoStdioErr;
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
|
||||
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
|
||||
/// in case the given file descriptor is still used somewhere
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::{IntoRawFd, AsRawFd};
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::fd::{claim_fd, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let orig = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut claimed = FdIo(claim_fd(orig)?);
|
||||
///
|
||||
/// // A different file descriptor is used
|
||||
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
|
||||
///
|
||||
/// // Write some data
|
||||
/// claimed.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
|
||||
mask_fd(fd)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid.
|
||||
///
|
||||
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::IntoRawFd;
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::fd::{claim_fd_inplace, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let fd = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
|
||||
///
|
||||
/// // Write some data
|
||||
/// fd.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||
let tmp = clone_fd_cloexec(&new)?;
|
||||
clone_fd_to_cloexec(tmp, &mut new)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Will close the given file descriptor and overwrite
|
||||
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use std::fs::File;
|
||||
/// # use std::io::Read;
|
||||
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||
/// # use std::os::fd::IntoRawFd;
|
||||
/// # use rustix::fd::AsFd;
|
||||
/// # use rosenpass_util::fd::mask_fd;
|
||||
///
|
||||
/// // Open a temporary file
|
||||
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
|
||||
/// assert!(fd >= 0);
|
||||
///
|
||||
/// // Mask the file descriptor
|
||||
/// mask_fd(fd).unwrap();
|
||||
///
|
||||
/// // Verify the file descriptor now points to `/dev/null`
|
||||
/// // Reading from `/dev/null` always returns 0 bytes
|
||||
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
|
||||
/// let mut buffer = [0u8; 4];
|
||||
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
|
||||
/// assert_eq!(bytes_read, 0);
|
||||
/// ```
|
||||
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
|
||||
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
|
||||
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
|
||||
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag
|
||||
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
|
||||
/// Avoid stdin, stdout, and stderr
|
||||
const MINFD: RawFd = 3;
|
||||
fcntl_dupfd_cloexec(fd, MINFD)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup3, DupFlags};
|
||||
dup3(fd, new, DupFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup2, fcntl_setfd, FdFlags};
|
||||
dup2(&fd, new)?;
|
||||
fcntl_setfd(&new, FdFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
|
||||
/// writing.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The behavior of the file descriptor when being written to or from is undefined.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
|
||||
/// use rustix::fd::FromRawFd;
|
||||
/// use rosenpass_util::fd::open_nullfd;
|
||||
///
|
||||
/// let nullfd = open_nullfd().unwrap();
|
||||
/// ```
|
||||
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
|
||||
use rustix::fs::{open, Mode, OFlags};
|
||||
// TODO: Add tests showing that this will throw errors on use
|
||||
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
|
||||
}
|
||||
|
||||
/// Convert low level errors into std::io::Error
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::ErrorKind as EK;
|
||||
/// use rustix::io::Errno;
|
||||
/// use rosenpass_util::fd::IntoStdioErr;
|
||||
///
|
||||
/// let e = Errno::INTR.into_stdio_err();
|
||||
/// assert!(matches!(e.kind(), EK::Interrupted));
|
||||
///
|
||||
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
|
||||
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
|
||||
/// ```
|
||||
pub trait IntoStdioErr {
|
||||
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
|
||||
type Target;
|
||||
/// Convert low level errors to
|
||||
fn into_stdio_err(self) -> Self::Target;
|
||||
}
|
||||
|
||||
impl IntoStdioErr for rustix::io::Errno {
|
||||
type Target = std::io::Error;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
std::io::Error::from_raw_os_error(self.raw_os_error())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoStdioErr for rustix::io::Result<T> {
|
||||
type Target = std::io::Result<T>;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
self.map_err(IntoStdioErr::into_stdio_err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read and write directly from a file descriptor
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [claim_fd].
|
||||
pub struct FdIo<Fd: AsFd>(pub Fd);
|
||||
|
||||
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
rustix::io::read(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
rustix::io::write(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
use crate::result::OkExt;
|
||||
|
||||
/// Helpers for accessing stat(2) information
|
||||
pub trait StatExt {
|
||||
@@ -238,7 +15,7 @@ pub trait StatExt {
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::fd::StatExt;
|
||||
/// use rosenpass_util::rustix::StatExt;
|
||||
/// assert!(rustix::fs::stat("/")?.is_socket() == false);
|
||||
/// Ok::<(), rustix::io::Errno>(())
|
||||
/// ````
|
||||
@@ -263,7 +40,7 @@ pub trait TryStatExt {
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::fd::TryStatExt;
|
||||
/// use rosenpass_util::rustix::TryStatExt;
|
||||
/// let fd = rustix::fs::open("/", rustix::fs::OFlags::empty(), rustix::fs::Mode::empty())?;
|
||||
/// assert!(matches!(fd.is_socket(), Ok(false)));
|
||||
/// Ok::<(), rustix::io::Errno>(())
|
||||
@@ -358,7 +135,7 @@ pub trait GetUnixSocketType {
|
||||
/// # use std::os::fd::{AsFd, BorrowedFd};
|
||||
/// # use std::os::unix::net::UnixListener;
|
||||
/// # use tempfile::NamedTempFile;
|
||||
/// # use rosenpass_util::fd::GetUnixSocketType;
|
||||
/// # use rosenpass_util::rustix::GetUnixSocketType;
|
||||
/// let f = {
|
||||
/// // Generate a temp file and take its path
|
||||
/// // Remove the temp file
|
||||
@@ -378,7 +155,7 @@ pub trait GetUnixSocketType {
|
||||
/// # use std::os::fd::{AsFd, BorrowedFd};
|
||||
/// # use std::os::unix::net::{UnixDatagram, UnixListener};
|
||||
/// # use tempfile::NamedTempFile;
|
||||
/// # use rosenpass_util::fd::GetUnixSocketType;
|
||||
/// # use rosenpass_util::rustix::GetUnixSocketType;
|
||||
/// let f = {
|
||||
/// // Generate a temp file and take its path
|
||||
/// // Remove the temp file
|
||||
@@ -445,7 +222,7 @@ pub trait GetSocketProtocol {
|
||||
/// ```
|
||||
/// # use std::net::UdpSocket;
|
||||
/// # use std::os::fd::{AsFd, AsRawFd};
|
||||
/// # use rosenpass_util::fd::GetSocketProtocol;
|
||||
/// # use rosenpass_util::rustix::GetSocketProtocol;
|
||||
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
/// assert_eq!(socket.as_fd().socket_protocol().unwrap().unwrap(), rustix::net::ipproto::UDP);
|
||||
/// # Ok::<(), std::io::Error>(())
|
||||
@@ -458,7 +235,7 @@ pub trait GetSocketProtocol {
|
||||
/// # use std::net::UdpSocket;
|
||||
/// # use std::net::TcpListener;
|
||||
/// # use std::os::fd::{AsFd, AsRawFd};
|
||||
/// # use rosenpass_util::fd::GetSocketProtocol;
|
||||
/// # use rosenpass_util::rustix::GetSocketProtocol;
|
||||
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
/// assert!(socket.as_fd().is_udp_socket().unwrap());
|
||||
///
|
||||
@@ -484,7 +261,7 @@ pub trait GetSocketProtocol {
|
||||
/// # use std::net::UdpSocket;
|
||||
/// # use std::net::TcpListener;
|
||||
/// # use std::os::fd::{AsFd, AsRawFd};
|
||||
/// # use rosenpass_util::fd::GetSocketProtocol;
|
||||
/// # use rosenpass_util::rustix::GetSocketProtocol;
|
||||
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
/// assert!(matches!(socket.as_fd().demand_udp_socket(), Ok(())));
|
||||
///
|
||||
@@ -514,62 +291,3 @@ where
|
||||
rustix::net::sockopt::get_socket_protocol(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_neg() {
|
||||
let _ = claim_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_max() {
|
||||
let _ = claim_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_neg() {
|
||||
let _ = claim_fd_inplace(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_max() {
|
||||
let _ = claim_fd_inplace(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_neg() {
|
||||
let _ = mask_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_max() {
|
||||
let _ = mask_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_open_nullfd() -> anyhow::Result<()> {
|
||||
let mut file = FdIo(open_nullfd()?);
|
||||
let mut buf = [0; 10];
|
||||
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
|
||||
assert!(file.write(&buf).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nullfd_read_write() {
|
||||
let nullfd = open_nullfd().unwrap();
|
||||
let mut buf = vec![0u8; 16];
|
||||
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
|
||||
assert!(rustix::io::write(&nullfd, b"test").is_err());
|
||||
}
|
||||
}
|
||||
33
util/src/rustix/syscall.rs
Normal file
33
util/src/rustix/syscall.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! Helpers for performing system calls
|
||||
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
use super::errno;
|
||||
|
||||
#[repr(C, packed)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct SyscallResult(pub libc::c_long);
|
||||
|
||||
impl SyscallResult {
|
||||
pub fn raw_value(&self) -> libc::c_long {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// TODO…
|
||||
pub unsafe fn claim_fd(&self) -> Result<rustix::fd::OwnedFd, rustix::io::Errno> {
|
||||
let fde = self.0;
|
||||
match fde {
|
||||
e if e < 0 => Err(errno()),
|
||||
fd if fd > i32::MAX.into() => panic!("File descriptor `{fd}` is out of bounds: "),
|
||||
fd => Ok(unsafe { rustix::fd::OwnedFd::from_raw_fd(fd as i32) }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<libc::c_long> for SyscallResult {
|
||||
fn from(value: libc::c_long) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
200
util/src/secret_memory/fd.rs
Normal file
200
util/src/secret_memory/fd.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
//! Creation of secret memory file descriptors.
|
||||
//!
|
||||
//! This essentially provides a higher_level API for [memfd_secret()] and [memfd_create()]
|
||||
|
||||
// Tests: This uses nix-based integration tests
|
||||
|
||||
use std::{os::fd::OwnedFd, sync::OnceLock};
|
||||
|
||||
use rustix::{fs::MemfdFlags, io::Errno};
|
||||
|
||||
use crate::rustix::{memfd_create, memfd_secret, MemfdSecretError, MemfdSecretFlags};
|
||||
use crate::{mem::CopyExt, result::OkExt};
|
||||
|
||||
/// Cache for [memfd_secret_supported]
|
||||
static MEMFD_SECRET_SUPPORTED: OnceLock<bool> = OnceLock::new();
|
||||
|
||||
/// Check whether support for memfd_secret is available
|
||||
pub fn memfd_secret_supported() -> Result<bool, Errno> {
|
||||
match MEMFD_SECRET_SUPPORTED.get() {
|
||||
Some(v) => return Ok(*v),
|
||||
_ => {} // Continue
|
||||
};
|
||||
|
||||
use MemfdSecretError as E;
|
||||
let is_supported = match memfd_secret(MemfdSecretFlags::empty()) {
|
||||
Ok(_) => true,
|
||||
Err(E::NotSupported) => false,
|
||||
Err(E::SystemError(e)) => return Err(e),
|
||||
};
|
||||
|
||||
// We are deliberately using get_or_init here to make sure that the entire application
|
||||
// never sees different values here
|
||||
MEMFD_SECRET_SUPPORTED
|
||||
.get_or_init(|| is_supported)
|
||||
.copy()
|
||||
.ok()
|
||||
}
|
||||
|
||||
/// How secure memory file descriptors should be allocated
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
|
||||
pub enum SecretMemfdPolicy {
|
||||
/// Use memfd_secret(2) if available, otherwise fall back to less
|
||||
/// secure options
|
||||
Opportunistic,
|
||||
/// Enforce the use of memfd_secret(2)
|
||||
UseMemfdSecret,
|
||||
/// Never use memfd_secret(2)
|
||||
DisableMemfdSecret,
|
||||
}
|
||||
|
||||
impl Default for SecretMemfdPolicy {
|
||||
fn default() -> Self {
|
||||
Self::Opportunistic
|
||||
}
|
||||
}
|
||||
|
||||
impl SecretMemfdPolicy {
|
||||
/// Create a SecretMemfdPolicy with the default policy
|
||||
///
|
||||
/// Currently [Self::Opportunistic]
|
||||
pub const fn default_const() -> Self {
|
||||
Self::Opportunistic
|
||||
}
|
||||
|
||||
/// Enforce the use of the highest security configuration available
|
||||
///
|
||||
/// This might not work on some systems, which is why this is not used
|
||||
/// by default.
|
||||
///
|
||||
/// Currently [Self::UseMemfdSecret]
|
||||
pub const fn enforce_high_security() -> Self {
|
||||
Self::UseMemfdSecret
|
||||
}
|
||||
}
|
||||
|
||||
/// Which mechanism to us use when allocating secret memory file descriptors
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
|
||||
pub enum SecretMemfdMechanism {
|
||||
/// The less secure memfd_create(2) will be used
|
||||
MemfdCreate,
|
||||
/// The more secure memfd_secret(2) will be used
|
||||
MemfdSecret,
|
||||
}
|
||||
|
||||
impl SecretMemfdMechanism {
|
||||
/// Decide which mechanism to use, based on the given [SecretMemfdPolicy]
|
||||
/// and [memfd_secret_supported()].
|
||||
///
|
||||
/// If [SecretMemfdPolicy::UseMemfdSecret] is used, then [SecretMemfdMechanism::MemfdSecret]
|
||||
/// will be returned, regardless of whether it is supported.
|
||||
///
|
||||
/// Likewise, if [SecretMemfdPolicy::DisableMemfdSecret] is used, then
|
||||
/// [SecretMemfdMechanism::MemfdCreate] will be used unconditionally.
|
||||
pub fn decide_with_policy(policy: SecretMemfdPolicy) -> Result<Self, Errno> {
|
||||
use SecretMemfdMechanism as M;
|
||||
use SecretMemfdPolicy as P;
|
||||
|
||||
match policy {
|
||||
P::UseMemfdSecret => return Ok(M::MemfdSecret),
|
||||
P::DisableMemfdSecret => return Ok(M::MemfdCreate),
|
||||
P::Opportunistic => {}
|
||||
};
|
||||
|
||||
match memfd_secret_supported()? {
|
||||
true => Ok(M::MemfdSecret),
|
||||
false => Ok(M::MemfdCreate),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors for [create_memfd_secret()]
|
||||
#[derive(Copy, Clone, Debug, thiserror::Error)]
|
||||
pub enum SecretMemfdWithConfigError {
|
||||
/// Call to [memfd_secret()] failed
|
||||
#[error("{:?}", .0)]
|
||||
MemfdSecretError(#[from] MemfdSecretError),
|
||||
/// Call to [memfd()] failed
|
||||
#[error("Could not create secret memory segment using memfd_create(2) due to underlying system error: {:?}", .0)]
|
||||
MemfdCreateError(Errno),
|
||||
#[error("Failed to determine whether memfd_secret(2) is supported due to underlying system error: {:?}", .0)]
|
||||
FailedToDetectSupport(Errno),
|
||||
}
|
||||
|
||||
/// Robustly configure and create secret memory file descriptors
|
||||
///
|
||||
/// Whereas [memfd_secret] will always use memfd_secret(2), this construction allows multiple
|
||||
/// file descriptor back ends to be used to support different usage scenarios.
|
||||
///
|
||||
/// This is necessary, because older systems do not support memfd_secret(2) and using it might not
|
||||
/// always be desirable, as memfd_secret for instance also inhibits hibernation.
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
|
||||
pub struct SecretMemfdConfig {
|
||||
/// Enable the close-on-exec flag for the new file descriptor.
|
||||
pub close_on_exec: bool,
|
||||
/// Security mechanism to use
|
||||
pub policy: SecretMemfdPolicy,
|
||||
}
|
||||
|
||||
impl SecretMemfdConfig {
|
||||
/// Create a new, default [Self]
|
||||
pub const fn new() -> Self {
|
||||
let close_on_exec = false;
|
||||
let policy = SecretMemfdPolicy::default_const();
|
||||
Self {
|
||||
close_on_exec,
|
||||
policy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the [Self::close_on_exec] flag to true
|
||||
pub const fn cloexec(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.close_on_exec = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Set `self.policy = SecretMemoryFdPolicy::UseMemfdSecret`
|
||||
pub const fn enforce_high_security(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.policy = SecretMemfdPolicy::enforce_high_security();
|
||||
r
|
||||
}
|
||||
|
||||
/// Whether memfd_secret will be used by [Self::create()]
|
||||
pub fn mechanism(&self) -> Result<SecretMemfdMechanism, Errno> {
|
||||
SecretMemfdMechanism::decide_with_policy(self.policy)
|
||||
}
|
||||
|
||||
/// Allocate a secret file descriptor based on the configuration
|
||||
pub fn create(&self) -> Result<OwnedFd, SecretMemfdWithConfigError> {
|
||||
use SecretMemfdWithConfigError as E;
|
||||
let mech = self.mechanism().map_err(E::FailedToDetectSupport)?;
|
||||
|
||||
use SecretMemfdMechanism as M;
|
||||
match (mech, self.close_on_exec) {
|
||||
(M::MemfdCreate, cloexec) => {
|
||||
let flags = match cloexec {
|
||||
true => MemfdFlags::CLOEXEC,
|
||||
false => MemfdFlags::empty(),
|
||||
};
|
||||
memfd_create("rosenpass secret memory segment", flags).map_err(E::MemfdCreateError)
|
||||
}
|
||||
(M::MemfdSecret, cloexec) => {
|
||||
let flags = match cloexec {
|
||||
true => MemfdSecretFlags::CLOEXEC,
|
||||
false => MemfdSecretFlags::empty(),
|
||||
};
|
||||
memfd_secret(flags).map_err(E::MemfdSecretError)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a secret memory file descriptor using the default policy
|
||||
///
|
||||
/// Shorthand for
|
||||
/// [`SecretMemfdConfig`](SecretMemfdConfig)[`::new()`](SecretMemfdConfig::new)[`.create()`](SecretMemfdConfig::create)
|
||||
pub fn memfd_for_secrets_with_default_policy() -> Result<OwnedFd, SecretMemfdWithConfigError> {
|
||||
SecretMemfdConfig::new().create()
|
||||
}
|
||||
409
util/src/secret_memory/mmap.rs
Normal file
409
util/src/secret_memory/mmap.rs
Normal file
@@ -0,0 +1,409 @@
|
||||
//! This module takes care of allocating memory segments for
|
||||
//! file descriptors created with [super::fd::memfd_for_secrets_with_default_policy]
|
||||
//! and anonymous memory segments
|
||||
|
||||
// Tests: Nix based integration tests
|
||||
|
||||
#![deny(unsafe_op_in_unsafe_fn)]
|
||||
|
||||
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
|
||||
use std::ptr::null_mut;
|
||||
|
||||
use crate::mem::CopyExt;
|
||||
|
||||
/// Size of the memory mapping for [MappableFd]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MMapSizePolicy {
|
||||
/// Size is assumed to be this particular value; [MappableFd::mmap] will simply
|
||||
/// use this value without checking whether its matches the size of the underlying
|
||||
/// data
|
||||
Assumed(u64),
|
||||
/// Size is assumed to be this particular value; [MappableFd::mmap] will check the
|
||||
/// size of the underlying data and raise an error if the size of data and this value
|
||||
/// do not match
|
||||
Checked(u64),
|
||||
/// Size is defined to be this particular value; [MappableFd::mmap] will explicitly resize
|
||||
/// the underlying file descriptor to be this particular size when called.
|
||||
Resize(u64),
|
||||
}
|
||||
|
||||
impl MMapSizePolicy {
|
||||
/// The numeric value of the size
|
||||
pub fn size_value(&self) -> u64 {
|
||||
match *self {
|
||||
Self::Assumed(v) => v,
|
||||
Self::Checked(v) => v,
|
||||
Self::Resize(v) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for [MappableFd::mmap]
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct MapFdConfig {
|
||||
/// The memory region can not be read from
|
||||
pub unreadable: bool,
|
||||
/// The memory region can not be written to
|
||||
pub immutable: bool,
|
||||
/// The memory region can be executed
|
||||
pub executable: bool,
|
||||
/// The memory region is shared-memory; other mappings of the same
|
||||
/// region within and outside this process can see the modifications
|
||||
/// to the memory region (as long as they also set the shared flag)
|
||||
pub shared: bool,
|
||||
/// How [MappableFd::mmap] will determine the size to be used fo
|
||||
///
|
||||
/// You should usually set this value through [Self::set_size_policy], [Self::assume_size_without_checking],
|
||||
/// [Self::expected_size], or [Self::resize_on_mmap].
|
||||
pub size_policy: Option<MMapSizePolicy>,
|
||||
}
|
||||
|
||||
impl MapFdConfig {
|
||||
/// New MapFdConfig with all settings turned off
|
||||
///
|
||||
/// You still must set [Self::size_policy], otherwise [MappableFd::mmap] will raise
|
||||
/// an error when called
|
||||
pub const fn new() -> Self {
|
||||
MapFdConfig {
|
||||
unreadable: false,
|
||||
immutable: false,
|
||||
executable: false,
|
||||
shared: false,
|
||||
size_policy: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// New MappableFdConfig with shared memory turned on
|
||||
pub const fn shared_memory() -> Self {
|
||||
Self::new().set_shared()
|
||||
}
|
||||
|
||||
/// Set the [Self::unreadable()] flag
|
||||
pub const fn set_unreadable(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.unreadable = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Set the [Self::immutable()] flag
|
||||
pub const fn set_immutable(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.immutable = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Set the [Self::shared()] flag
|
||||
pub const fn set_shared(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.shared = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Create a [MappableFd] instance with this configuration
|
||||
pub fn mappable_fd<Fd: AsFd>(&self, fd: Fd) -> MappableFd<Fd> {
|
||||
MappableFd::new(fd, self.copy())
|
||||
}
|
||||
|
||||
/// Calculate [rustix::mm::ProtFlags] for this configuration
|
||||
pub const fn mmap_prot(&self) -> rustix::mm::ProtFlags {
|
||||
use rustix::mm::ProtFlags as P;
|
||||
|
||||
let p_read = match self.unreadable {
|
||||
true => P::empty(),
|
||||
false => P::READ,
|
||||
};
|
||||
|
||||
let p_write = match self.immutable {
|
||||
true => P::empty(),
|
||||
false => P::WRITE,
|
||||
};
|
||||
|
||||
let p_exec = match self.executable {
|
||||
true => P::EXEC,
|
||||
false => P::empty(),
|
||||
};
|
||||
|
||||
p_read.union(p_write).union(p_exec)
|
||||
}
|
||||
|
||||
/// Calculate [rustix::mm::MapFlags] for this configuration
|
||||
pub const fn mmap_flags(&self) -> rustix::mm::MapFlags {
|
||||
use rustix::mm::MapFlags as M;
|
||||
match self.shared {
|
||||
true => M::SHARED,
|
||||
false => M::empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to the given value
|
||||
pub const fn set_size_policy(&self, size_policy: MMapSizePolicy) -> Self {
|
||||
let mut r = *self;
|
||||
r.size_policy = Some(size_policy);
|
||||
r
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to [MMapSizePolicy::Assumed] with the given value
|
||||
pub const fn assume_size_without_checking(&self, size: u64) -> Self {
|
||||
self.set_size_policy(MMapSizePolicy::Assumed(size))
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to [MMapSizePolicy::Checked] with the given value
|
||||
pub const fn expected_size(&self, size: u64) -> Self {
|
||||
self.set_size_policy(MMapSizePolicy::Checked(size))
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to [MMapSizePolicy::Resize] with the given value
|
||||
pub const fn resize_on_mmap(&self, size: u64) -> Self {
|
||||
self.set_size_policy(MMapSizePolicy::Resize(size))
|
||||
}
|
||||
}
|
||||
|
||||
/// Error returned by MappableFd::mmap
|
||||
#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum MMapError {
|
||||
/// Error converting between usize & f64
|
||||
#[error("Requested memory map of size {requested_len} but maximum supported size is {max_supported_len}. \
|
||||
This is a low level error that usually arises on architectures where the integer type usize ({} bytes) can not \
|
||||
represent all values in u64 ({} bytes). Are you possibly on 32 bit CPU architecture requesting a buffer bigger than 4 GB?\n\
|
||||
Error: {err:?}",
|
||||
(usize::BITS as f64)/8f64, (u64::BITS as f64)/8f64
|
||||
)]
|
||||
OutOfBounds {
|
||||
/// Underlying error
|
||||
err: <u64 as TryInto<usize>>::Error,
|
||||
/// The size of the memory map requested
|
||||
requested_len: u64,
|
||||
/// Maximum supported size
|
||||
max_supported_len: usize,
|
||||
},
|
||||
/// Tried to map a file descriptor into memory, but the size policy was never set. Developer
|
||||
/// error.
|
||||
#[error("Tried to map a file descriptor into memory, but the size policy was never set. This is a developer error.")]
|
||||
MissingSizePolicy,
|
||||
/// fseek(3)/ftell(3) system call failed
|
||||
#[error("Tried to map file descriptor into memory, but failed to determine the size of the file descriptor: {:?}", .0)]
|
||||
CouldNotDetermineSize(rustix::io::Errno),
|
||||
/// Mismatch between expected and actual size of the file descriptor
|
||||
#[error("Tried to map file descriptor into memory with expected size {expected:?}, instead we found that the true size is {actual:?}")]
|
||||
IncorrectSize {
|
||||
/// Expected file descriptor size (given by caller)
|
||||
expected: u64,
|
||||
/// Actual file descriptor size
|
||||
actual: u64,
|
||||
},
|
||||
/// Negative size reported by fstat(2)
|
||||
#[error("Tried to determine the size of the underlying file descriptor, but failed to determine the size of the file descriptor because the size ({size}) is negative: {err:?}")]
|
||||
InvalidSize {
|
||||
/// Underlying error
|
||||
err: <i64 as TryInto<u64>>::Error,
|
||||
/// Reported size of the file descriptor
|
||||
size: i64,
|
||||
},
|
||||
/// ftruncate(2) system call failed
|
||||
#[error("Tried to resize the underlying file descriptor, but failed: {:?}", .0)]
|
||||
ResizeError(rustix::io::Errno),
|
||||
/// mmap(2) system call failed
|
||||
#[error("Tried to map file descriptor into memory, but mmap(2) system call failed: {:?}", .0)]
|
||||
MMapError(rustix::io::Errno),
|
||||
}
|
||||
|
||||
/// Handle mapping a file descriptor into memory
|
||||
pub struct MappableFd<Fd: AsFd> {
|
||||
/// The file descriptor this struct refers to
|
||||
fd: Fd,
|
||||
/// The configuration for mmap
|
||||
config: MapFdConfig,
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> AsFd for MappableFd<Fd> {
|
||||
fn as_fd(&self) -> BorrowedFd<'_> {
|
||||
self.fd.as_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> AsRawFd for MappableFd<Fd> {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.as_fd().as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> MappableFd<Fd> {
|
||||
/// Create a [MappableFd] using the default configuration ([MapFdConfig])
|
||||
pub fn from_fd(fd: Fd) -> Self {
|
||||
Self::new(fd, MapFdConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new [MappableFd] for an existing file descriptor
|
||||
pub fn new(fd: Fd, config: MapFdConfig) -> Self {
|
||||
Self { fd, config }
|
||||
}
|
||||
|
||||
/// Extract the underlying file descriptor
|
||||
pub fn into_fd(self) -> Fd {
|
||||
self.fd
|
||||
}
|
||||
|
||||
/// Access the [MapFdConfig] associated with Self
|
||||
pub fn config(&self) -> MapFdConfig {
|
||||
self.config
|
||||
}
|
||||
|
||||
/// Access the [MapFdConfig] associated with Self
|
||||
pub fn config_mut(&mut self) -> &mut MapFdConfig {
|
||||
&mut self.config
|
||||
}
|
||||
|
||||
/// Modify the [MapFdConfig] associated with Self, chainable
|
||||
pub fn with_config(mut self, config: MapFdConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Determine the size of the data associated with the file descriptor
|
||||
pub fn size_of_underlying_data(&self) -> Result<u64, MMapError> {
|
||||
use MMapError as E;
|
||||
let size = rustix::fs::fstat(self)
|
||||
.map_err(E::CouldNotDetermineSize)?
|
||||
.st_size;
|
||||
size.try_into().map_err(|err| E::InvalidSize { err, size })
|
||||
}
|
||||
|
||||
/// Map the file into memory
|
||||
///
|
||||
/// # Determining the size of the mapping.
|
||||
///
|
||||
/// Before calling this function, [MapFdConfig::size_policy] must be set.
|
||||
///
|
||||
/// Note that [Self::from_fd] and [Self::new] still allow you to create a [Self] with
|
||||
/// [MapFdConfig::size_policy] set to [None], so you can use [Self::size_of_underlying_data]
|
||||
/// to auto-detect the size of the mapping.
|
||||
///
|
||||
/// This functionality is not implemented by default, as its crucial to validate the size of
|
||||
/// the data being mapped into memory somehow; otherwise, the party that created the file
|
||||
/// descriptor can trigger a denial-of-service attack against our process by allocating an
|
||||
/// excessively large, sparse file. If you implement size auto-detection facilities, you should
|
||||
/// still enforce some bounds on the size.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// If there exist any Rust references referring to the memory region, or if you subsequently create a Rust reference referring to the resulting region, it is your responsibility to ensure that the Rust reference invariants are preserved, including ensuring that the memory is not mutated in a way that a Rust reference would not expect.
|
||||
pub fn mmap(&self) -> Result<MappedSegment, MMapError> {
|
||||
use rustix::mm::mmap;
|
||||
use MMapError as E;
|
||||
|
||||
let prot = self.config().mmap_prot();
|
||||
let flags = self.config().mmap_flags();
|
||||
|
||||
// Determine the size of the mapping to be used as u64
|
||||
let requested_size = match self.config().size_policy {
|
||||
None => return Err(E::MissingSizePolicy),
|
||||
Some(MMapSizePolicy::Assumed(size)) => size,
|
||||
Some(MMapSizePolicy::Resize(size)) => {
|
||||
rustix::fs::ftruncate(self, size).map_err(E::ResizeError)?;
|
||||
size
|
||||
}
|
||||
Some(MMapSizePolicy::Checked(expected)) => {
|
||||
let actual = self.size_of_underlying_data()?;
|
||||
if expected != actual {
|
||||
return Err(E::IncorrectSize { expected, actual });
|
||||
}
|
||||
expected
|
||||
}
|
||||
};
|
||||
|
||||
// Cast the size of the mapping to be used to usize, raising an error if the
|
||||
// requested_size can not be represented as usize (this should never happen in general,
|
||||
// but it could conceivably be thrown on 32 bit systems when very large mappings (>= 4GB)
|
||||
// are requested
|
||||
let len = requested_size.try_into().map_err(|err| {
|
||||
let max_supported_len = usize::MAX;
|
||||
E::OutOfBounds {
|
||||
err,
|
||||
requested_len: requested_size,
|
||||
max_supported_len,
|
||||
}
|
||||
})?;
|
||||
|
||||
let ptr = unsafe { mmap(null_mut(), len, prot, flags, self, 0) };
|
||||
let ptr = ptr.map_err(E::MMapError)?;
|
||||
let ptr = unsafe { MappedSegment::from_raw_parts(ptr.cast(), len) };
|
||||
|
||||
Ok(ptr)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents exclusive ownership of a memory segment mapped into memory
|
||||
///
|
||||
/// Automatically unmaps the memory segment as this goes out of scope
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// If the munmap(2) call fails, the destructor will panic. You can avoid this and use explicit
|
||||
/// error handling by calling [Self::unmap] instead
|
||||
#[derive(Debug)]
|
||||
pub struct MappedSegment {
|
||||
/// The location of the memory segment
|
||||
ptr: *mut u8,
|
||||
/// Length of the segment in bytes
|
||||
len: usize,
|
||||
}
|
||||
|
||||
unsafe impl Send for MappedSegment {}
|
||||
|
||||
impl MappedSegment {
|
||||
/// Construct a new [Self] from a pointer and a length
|
||||
///
|
||||
/// `ptr` is the address of the memory segment and `len` is its length in bytes
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// It is the responsibility of the caller to make sure that any pointer P such that `P = ptr.add(n)`,
|
||||
/// where `n < len` is a valid pointer. See #safety in [std::ptr].
|
||||
pub unsafe fn from_raw_parts(ptr: *mut u8, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
}
|
||||
|
||||
/// Decompose a MappedSegment into its raw components: Pointer and length
|
||||
pub fn into_raw_parts(self) -> (*mut u8, usize) {
|
||||
let r = (self.ptr(), self.len());
|
||||
std::mem::forget(self);
|
||||
r
|
||||
}
|
||||
|
||||
/// The location of the memory segment
|
||||
pub fn ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
/// Length of the segment in bytes
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Release the memory segment
|
||||
///
|
||||
/// Compared to using the destructor which panics if unmapping fails, this allows explicit error handling to be used.
|
||||
///
|
||||
/// If this returns an error, the memory segment has not been freed. The values from
|
||||
/// [Self::into_raw_parts()] are returned as part of the error value so the caller has
|
||||
/// some chance to free the memory some other way (or to leak it if they so choose)
|
||||
pub fn unmap(self) -> Result<(), (rustix::io::Errno, *mut u8, usize)> {
|
||||
let (ptr, len) = self.into_raw_parts();
|
||||
let res = unsafe { rustix::mm::munmap(ptr.cast(), len) };
|
||||
res.map_err(|errno| (errno, ptr, len))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MappedSegment {
|
||||
fn drop(&mut self) {
|
||||
let mut owned = MappedSegment {
|
||||
ptr: null_mut(),
|
||||
len: 0,
|
||||
};
|
||||
std::mem::swap(self, &mut owned);
|
||||
if let Err((errno, _ptr, _len)) = owned.unmap() {
|
||||
panic!("Failed to unmap MappedSegment: {errno:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
4
util/src/secret_memory/mod.rs
Normal file
4
util/src/secret_memory/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Utilities for allocating secret memory
|
||||
|
||||
pub mod fd;
|
||||
pub mod mmap;
|
||||
107
util/src/sync/atomic/abstract_atomic.rs
Normal file
107
util/src/sync/atomic/abstract_atomic.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! Traits for types with atomic access semantics
|
||||
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
/// A trait for types with atomic access semantics
|
||||
///
|
||||
/// Using this trait allows us to achieve two goals:
|
||||
///
|
||||
/// 1. We can implement atomic semantics for types where
|
||||
/// there is no platform support for atomic semantics
|
||||
/// (e.g. by using a Mutex)
|
||||
/// 2. We can reuse implementations of concurrent data structures
|
||||
/// efficiently for the non-concurrent case. E.g. we could
|
||||
/// build a thread-local ring buffer using
|
||||
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeWriter]/
|
||||
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeReader]
|
||||
/// by supplying some sort `Immediate<u64>` type for the atomics
|
||||
/// in [crate::ringbuf::concurrent::framework::ConcurrentPipeCore] that implements
|
||||
/// this trait by using a [std::cell::Cell]. It may seem counter
|
||||
/// intuitive, but this setup implements perfectly fine atomic-appearing
|
||||
/// semantics just as long as the cell is thread-local.
|
||||
pub trait AbstractAtomic<T> {
|
||||
/// Like [std::sync::atomic::AtomicU64::load()]
|
||||
fn load(&self, order: Ordering) -> T;
|
||||
|
||||
/// Like [std::sync::atomic::AtomicU64::compare_exchange_weak()].
|
||||
///
|
||||
/// The default implementation just calls [AbstractAtomic::compare_exchange()].
|
||||
fn compare_exchange_weak(
|
||||
&self,
|
||||
current: T,
|
||||
new: T,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<T, T> {
|
||||
self.compare_exchange(current, new, success, failure)
|
||||
}
|
||||
|
||||
/// Like [std::sync::atomic::AtomicU64::compare_exchange()].
|
||||
fn compare_exchange(
|
||||
&self,
|
||||
current: T,
|
||||
new: T,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<T, T>;
|
||||
}
|
||||
|
||||
/// Implements a default type for [AbstractAtomic]
|
||||
///
|
||||
/// This in essence is to [AbstractAtomic], as [std::ops::Deref] is to [std::borrow::Borrow];
|
||||
/// the same functionality, except with a
|
||||
pub trait AbstractAtomicType: AbstractAtomic<Self::ValueType> {
|
||||
/// The underlying atomic value
|
||||
type ValueType;
|
||||
}
|
||||
|
||||
/// Implementing [AbstractAtomic] and [AbstractAtomicType] for standard atomics
|
||||
macro_rules! impl_abstract_atomic_for_atomic {
|
||||
($($Atomic:ty : $Value:ty),*) => {
|
||||
$(
|
||||
impl AbstractAtomicType for $Atomic {
|
||||
type ValueType = $Value;
|
||||
}
|
||||
|
||||
impl AbstractAtomic<$Value> for $Atomic {
|
||||
fn load(&self, order: Ordering) -> $Value {
|
||||
<$Atomic>::load(&self, order)
|
||||
}
|
||||
|
||||
fn compare_exchange_weak(
|
||||
&self,
|
||||
current: $Value,
|
||||
new: $Value,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<$Value, $Value> {
|
||||
<$Atomic>::compare_exchange_weak(&self, current, new, success, failure)
|
||||
}
|
||||
|
||||
fn compare_exchange(
|
||||
&self,
|
||||
current: $Value,
|
||||
new: $Value,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<$Value, $Value> {
|
||||
<$Atomic>::compare_exchange(&self, current, new, success, failure)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_abstract_atomic_for_atomic! {
|
||||
std::sync::atomic::AtomicBool: bool,
|
||||
std::sync::atomic::AtomicI8: i8,
|
||||
std::sync::atomic::AtomicI16: i16,
|
||||
std::sync::atomic::AtomicI32: i32,
|
||||
std::sync::atomic::AtomicI64: i64,
|
||||
std::sync::atomic::AtomicIsize: isize,
|
||||
std::sync::atomic::AtomicU8: u8,
|
||||
std::sync::atomic::AtomicU16: u16,
|
||||
std::sync::atomic::AtomicU32: u32,
|
||||
std::sync::atomic::AtomicU64: u64,
|
||||
std::sync::atomic::AtomicUsize: usize
|
||||
}
|
||||
5
util/src/sync/atomic/mod.rs
Normal file
5
util/src/sync/atomic/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
//! Synchronization using atomic semantics for multi-threaded programs.
|
||||
//!
|
||||
//! Analogous to [std::sync::atomic].
|
||||
|
||||
pub mod abstract_atomic;
|
||||
5
util/src/sync/mod.rs
Normal file
5
util/src/sync/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
//! Synchronization for multi-threaded programs.
|
||||
//!
|
||||
//! Analogous to [std::sync].
|
||||
|
||||
pub mod atomic;
|
||||
123
util/tests/pipe.rs
Normal file
123
util/tests/pipe.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
thread,
|
||||
};
|
||||
|
||||
use rosenpass_util::{
|
||||
int::u64uint::U64USizeRangeExt,
|
||||
ipc::shm::{ringbuf::shm_pipe, SharedMemorySegment},
|
||||
ringbuf::sched::{Diff, OperationType, RingBufferScheduler},
|
||||
};
|
||||
|
||||
macro_rules! dbg_print {
|
||||
($($arg:tt)*) => {{
|
||||
use std::io::Write;
|
||||
let stderr = std::io::stderr();
|
||||
let mut stderr = stderr.lock();
|
||||
// writeln!(stderr, $($arg)*).unwrap()
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pipe_test() -> anyhow::Result<()> {
|
||||
let (fd, reg1) = SharedMemorySegment::create(1024)?;
|
||||
let reg2 = unsafe { SharedMemorySegment::from_fd(fd, 1024) }?;
|
||||
|
||||
dbg_print!("Regions {:?} {:?}", reg1.ptr(), reg2.ptr());
|
||||
|
||||
let mut buf = b"______________________________________".to_owned();
|
||||
reg1.volatile_read(&mut buf, 0);
|
||||
dbg_print!(
|
||||
"Region 1 read: `{:?}` `{:?}`",
|
||||
String::from_utf8_lossy(&buf),
|
||||
&buf
|
||||
);
|
||||
|
||||
let mut buf = b"______________________________________".to_owned();
|
||||
reg2.volatile_read(&mut buf, 0);
|
||||
dbg_print!(
|
||||
"Region 1 read: `{:?}` `{:?}`",
|
||||
String::from_utf8_lossy(&buf),
|
||||
&buf
|
||||
);
|
||||
|
||||
dbg_print!("Write to region 1");
|
||||
reg1.volatile_write(0, b"Hello World");
|
||||
|
||||
let mut buf = b"______________________________________".to_owned();
|
||||
reg1.volatile_read(&mut buf, 0);
|
||||
dbg_print!(
|
||||
"Region 1 read: `{:?}` `{:?}`",
|
||||
String::from_utf8_lossy(&buf),
|
||||
&buf
|
||||
);
|
||||
|
||||
let mut buf = b"______________________________________".to_owned();
|
||||
reg2.volatile_read(&mut buf, 0);
|
||||
dbg_print!(
|
||||
"Region 1 read: `{:?}` `{:?}`",
|
||||
String::from_utf8_lossy(&buf),
|
||||
&buf
|
||||
);
|
||||
|
||||
let (mut writer, mut reader) = shm_pipe(1024)?;
|
||||
|
||||
const MSG: &[u8] = b"Hello World\0";
|
||||
const MSG_COUNT: usize = 100000;
|
||||
|
||||
let t = thread::spawn(move || {
|
||||
for _ in 0..MSG_COUNT {
|
||||
let mut buf = MSG;
|
||||
while !buf.is_empty() {
|
||||
let n = writer.write(buf).unwrap();
|
||||
buf = &buf[n..];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut buf = [0u8; 1000];
|
||||
let mut buf_off = 0;
|
||||
let mut msg_no = 0usize;
|
||||
|
||||
'read_data: while msg_no < MSG_COUNT {
|
||||
let mut old_off = buf_off;
|
||||
|
||||
// Read the data from the shared memory buffer
|
||||
buf_off += reader.read(&mut buf[buf_off..])?;
|
||||
|
||||
'scan_again: loop {
|
||||
// Scan the available data for the zero terminator
|
||||
let msg_len = &buf[old_off..buf_off]
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.find(|(_off, c)| *c == 0x0)
|
||||
.map(|(off, _c)| off + old_off + 1);
|
||||
|
||||
// Next iteration, unless the terminator was found
|
||||
let msg_len = match *msg_len {
|
||||
Some(l) => l,
|
||||
None => continue 'read_data,
|
||||
};
|
||||
|
||||
// Register the newly read message
|
||||
msg_no += 1;
|
||||
|
||||
// Check that the message is correctly transferred
|
||||
let msg = &buf[0..msg_len];
|
||||
dbg_print!("CONT {:?}", &buf[..buf_off]);
|
||||
dbg_print!("RECV {msg:?}");
|
||||
assert_eq!(msg, MSG);
|
||||
|
||||
// Move any extra data to the beginning of the buffer and adjust the offsets accordingly
|
||||
buf.copy_within(msg_len..buf_off, 0);
|
||||
old_off = 0;
|
||||
buf_off -= msg_len;
|
||||
}
|
||||
}
|
||||
|
||||
t.join().unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -12,7 +12,7 @@ use tokio::task;
|
||||
use anyhow::{bail, ensure, Result};
|
||||
use clap::{ArgGroup, Parser};
|
||||
|
||||
use rosenpass_util::fd::claim_fd;
|
||||
use rosenpass_util::rustix::claim_fd;
|
||||
use rosenpass_wireguard_broker::api::msgs;
|
||||
|
||||
/// Command-line arguments for configuring the socket handler
|
||||
|
||||
Reference in New Issue
Block a user