Compare commits

..

9 Commits

Author SHA1 Message Date
Karolin Varner
c4eea31c5d stash 2025-11-01 21:27:51 +01:00
Karolin Varner
9fd32086ea stash 2025-11-01 21:14:08 +01:00
Karolin Varner
63511465de stash 2025-11-01 20:49:47 +01:00
Karolin Varner
0c960d57bc stasg 2025-11-01 20:49:38 +01:00
Karolin Varner
8f276f70a6 foo 2025-11-01 20:35:10 +01:00
Karolin Varner
9580961dd9 stash 2025-11-01 18:07:45 +01:00
Karolin Varner
1a51478e89 chore: Split rosenpass_util::rustix into multiple files 2025-09-20 17:26:10 +02:00
Karolin Varner
5b14ef8065 chore: Rename rosenpass_util::{fd -> rustix} 2025-09-20 17:26:10 +02:00
Karolin Varner
a796bdd2e7 stash 2025-09-20 17:26:10 +02:00
67 changed files with 4641 additions and 529 deletions

52
Cargo.lock generated
View File

@@ -212,7 +212,7 @@ version = "0.68.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "726e4313eb6ec35d2730258ad4e15b547ee75d6afaa1361a922e78e59b7d8078"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"cexpr",
"clang-sys",
"lazy_static",
@@ -237,9 +237,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.8.0"
version = "2.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
[[package]]
name = "blake2"
@@ -818,9 +818,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "errno"
version = "0.3.10"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
dependencies = [
"libc",
"windows-sys 0.59.0",
@@ -1193,7 +1193,7 @@ version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"cfg-if",
"libc",
]
@@ -1453,7 +1453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.48.5",
"windows-targets 0.52.6",
]
[[package]]
@@ -1673,7 +1673,7 @@ version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"cfg-if",
"libc",
]
@@ -1965,7 +1965,7 @@ checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
"zerocopy 0.8.24",
"zerocopy 0.8.27",
]
[[package]]
@@ -2032,7 +2032,7 @@ version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
]
[[package]]
@@ -2115,7 +2115,7 @@ dependencies = [
"thiserror 1.0.69",
"toml",
"uds",
"zerocopy 0.7.35",
"zerocopy 0.8.27",
"zeroize",
]
@@ -2257,17 +2257,23 @@ version = "0.1.0"
dependencies = [
"anyhow",
"base64ct",
"bitflags 2.9.3",
"errno",
"libc",
"libcrux-test-utils",
"log",
"mio",
"num-traits",
"rosenpass-to",
"rustix",
"static_assertions",
"tempfile",
"thiserror 1.0.69",
"tinyvec",
"tokio",
"typenum",
"uds",
"zerocopy 0.7.35",
"zerocopy 0.8.27",
"zeroize",
]
@@ -2292,7 +2298,7 @@ dependencies = [
"thiserror 1.0.69",
"tokio",
"wireguard-uapi",
"zerocopy 0.7.35",
"zerocopy 0.8.27",
]
[[package]]
@@ -2340,7 +2346,7 @@ version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
"errno",
"libc",
"linux-raw-sys",
@@ -2710,6 +2716,12 @@ dependencies = [
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
[[package]]
name = "tokio"
version = "1.47.0"
@@ -3276,7 +3288,7 @@ version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
dependencies = [
"bitflags 2.8.0",
"bitflags 2.9.3",
]
[[package]]
@@ -3303,11 +3315,11 @@ dependencies = [
[[package]]
name = "zerocopy"
version = "0.8.24"
version = "0.8.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879"
checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c"
dependencies = [
"zerocopy-derive 0.8.24",
"zerocopy-derive 0.8.27",
]
[[package]]
@@ -3323,9 +3335,9 @@ dependencies = [
[[package]]
name = "zerocopy-derive"
version = "0.8.24"
version = "0.8.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be"
checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831"
dependencies = [
"proc-macro2",
"quote",

View File

@@ -66,7 +66,7 @@ chacha20poly1305 = { version = "0.10.1", default-features = false, features = [
"std",
"heapless",
] }
zerocopy = { version = "0.7.35", features = ["derive"] }
zerocopy = { version = "0.8.27", features = ["derive"] }
home = "=0.5.9" # 5.11 requires rustc 1.81
derive_builder = "0.20.1"
tokio = { version = "1.46", features = ["macros", "rt-multi-thread"] }
@@ -80,6 +80,7 @@ hex-literal = { version = "0.4.1" }
hex = { version = "0.4.3" }
heck = { version = "0.5.0" }
libc = { version = "0.2" }
errno = { version = "0.3.13" }
uds = { git = "https://github.com/rosenpass/uds" }
lazy_static = "1.5"
@@ -95,6 +96,7 @@ criterion = "0.5.1"
allocator-api2-tests = "0.2.15"
procspawn = { version = "1.0.1", features = ["test-support"] }
serde_json = { version = "1.0.140" }
bitflags = "2.9.3"
#Broker dependencies (might need cleanup or changes)
wireguard-uapi = { version = "3.0.0", features = ["xplatform"] }

View File

@@ -97,6 +97,7 @@ rustix = { workspace = true }
serde_json = { workspace = true }
[features]
default = ["experiment_api"]
experiment_cookie_dos_mitigation = []
experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"]
experiment_libcrux_all = ["rosenpass-ciphers/experiment_libcrux_all"]

View File

@@ -6,12 +6,12 @@ use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd};
use anyhow::Context;
use rosenpass_to::{ops::copy_slice, To};
use rosenpass_util::{
fd::FdIo,
functional::{run, ApplyExt},
io::ReadExt,
mem::DiscardResultExt,
mio::UnixStreamExt,
result::OkExt,
rustix::FdIo,
};
use rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient;
@@ -243,7 +243,7 @@ where
.context("Invalid request socket missing.")?;
// TODO: We need to have this outside linux
#[cfg(target_os = "linux")]
rosenpass_util::fd::GetSocketProtocol::demand_udp_socket(&sock)?;
rosenpass_util::rustix::GetSocketProtocol::demand_udp_socket(&sock)?;
let sock = std::net::UdpSocket::from(sock);
sock.set_nonblocking(true)?;
mio::net::UdpSocket::from_std(sock).ok()
@@ -316,7 +316,7 @@ where
use crate::app_server::BrokerStorePtr;
//
use rosenpass_secret_memory::Public;
use zerocopy::AsBytes;
use zerocopy::IntoBytes;
(self.app_server().brokers.store.len() - 1)
.apply(|x| x as u64)
.apply(|x| Public::from_slice(x.as_bytes()))

View File

@@ -1,4 +1,4 @@
use zerocopy::{ByteSlice, Ref};
use zerocopy::{Ref, SplitByteSlice};
use rosenpass_util::zerocopy::{RefMaker, ZerocopySliceExt};
@@ -7,7 +7,7 @@ use super::{
ResponseMsgType, ResponseRef, SupplyKeypairRequest, SupplyKeypairResponse,
};
pub trait ByteSliceRefExt: ByteSlice {
pub trait ByteSliceRefExt: SplitByteSlice {
/// Shorthand for the typed use of [ZerocopySliceExt::zk_ref_maker].
fn msg_type_maker(self) -> RefMaker<Self, RawMsgType> {
self.zk_ref_maker()
@@ -259,4 +259,4 @@ pub trait ByteSliceRefExt: ByteSlice {
}
}
impl<B: ByteSlice> ByteSliceRefExt for B {}
impl<B: SplitByteSlice> ByteSliceRefExt for B {}

View File

@@ -1,4 +1,4 @@
use zerocopy::{ByteSliceMut, Ref};
use zerocopy::{Ref, SplitByteSliceMut};
use rosenpass_util::zerocopy::RefMaker;
@@ -35,7 +35,7 @@ pub trait Message {
/// # Examples
///
/// See [crate::api::PingRequest::setup]
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>>;
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>>;
}
/// Additional convenience functions for working with [rosenpass_util::zerocopy::RefMaker]
@@ -45,7 +45,7 @@ pub trait ZerocopyResponseMakerSetupMessageExt<B, T> {
impl<B, T> ZerocopyResponseMakerSetupMessageExt<B, T> for RefMaker<B, T>
where
B: ByteSliceMut,
B: SplitByteSliceMut,
T: Message,
{
/// Initialize the message using [Message::setup].

View File

@@ -1,6 +1,6 @@
use hex_literal::hex;
use rosenpass_util::zerocopy::RefMaker;
use zerocopy::ByteSlice;
use zerocopy::SplitByteSlice;
use crate::RosenpassError::{self, InvalidApiMessageType};
@@ -169,12 +169,12 @@ pub trait RefMakerRawMsgTypeExt {
fn parse_response_msg_type(self) -> anyhow::Result<ResponseMsgType>;
}
impl<B: ByteSlice> RefMakerRawMsgTypeExt for RefMaker<B, RawMsgType> {
impl<B: SplitByteSlice> RefMakerRawMsgTypeExt for RefMaker<B, RawMsgType> {
fn parse_request_msg_type(self) -> anyhow::Result<RequestMsgType> {
Ok(self.parse()?.read().try_into()?)
Ok(zerocopy::Ref::read(&self.parse()?).try_into()?)
}
fn parse_response_msg_type(self) -> anyhow::Result<ResponseMsgType> {
Ok(self.parse()?.read().try_into()?)
Ok(zerocopy::Ref::<B, u128>::read(&self.parse()?).try_into()?)
}
}

View File

@@ -1,5 +1,5 @@
use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
use zerocopy::{AsBytes, ByteSliceMut, FromBytes, FromZeroes, Ref};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, SplitByteSliceMut};
use super::{Message, RawMsgType, RequestMsgType, ResponseMsgType};
@@ -12,8 +12,8 @@ pub const MAX_REQUEST_FDS: usize = 2;
/// Message envelope for API messages
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
pub struct Envelope<M: AsBytes + FromBytes> {
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct Envelope<M: IntoBytes + FromBytes + Immutable + KnownLayout> {
/// Which message this is
pub msg_type: RawMsgType,
/// The actual Paylod
@@ -27,7 +27,7 @@ pub type ResponseEnvelope<M> = Envelope<M>;
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct PingRequestPayload {
/// Randomly generated connection id
pub echo: [u8; 256],
@@ -55,7 +55,7 @@ impl Message for PingRequest {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
@@ -68,7 +68,7 @@ impl Message for PingRequest {
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct PingResponsePayload {
/// Randomly generated connection id
pub echo: [u8; 256],
@@ -96,7 +96,7 @@ impl Message for PingResponse {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
@@ -109,7 +109,7 @@ impl Message for PingResponse {
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct SupplyKeypairRequestPayload {}
#[allow(missing_docs)]
@@ -140,7 +140,7 @@ impl Message for SupplyKeypairRequest {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
@@ -169,7 +169,7 @@ pub mod supply_keypair_response_status {
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct SupplyKeypairResponsePayload {
#[allow(missing_docs)]
pub status: u128,
@@ -197,7 +197,7 @@ impl Message for SupplyKeypairResponse {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
@@ -210,7 +210,7 @@ impl Message for SupplyKeypairResponse {
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct AddListenSocketRequestPayload {}
#[allow(missing_docs)]
@@ -241,7 +241,7 @@ impl Message for AddListenSocketRequest {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
@@ -264,7 +264,7 @@ pub mod add_listen_socket_response_status {
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct AddListenSocketResponsePayload {
pub status: u128,
}
@@ -291,7 +291,7 @@ impl Message for AddListenSocketResponse {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
@@ -304,7 +304,7 @@ impl Message for AddListenSocketResponse {
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct AddPskBrokerRequestPayload {}
#[allow(missing_docs)]
@@ -336,7 +336,7 @@ impl Message for AddPskBrokerRequest {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)
@@ -359,7 +359,7 @@ pub mod add_psk_broker_response_status {
#[allow(missing_docs)]
#[repr(packed)]
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
pub struct AddPskBrokerResponsePayload {
pub status: u128,
}
@@ -386,7 +386,7 @@ impl Message for AddPskBrokerResponse {
}
}
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
r.init();
Ok(r)

View File

@@ -1,6 +1,6 @@
use anyhow::ensure;
use anyhow::{anyhow, ensure};
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
use zerocopy::{IntoBytes, Ref, SplitByteSlice, SplitByteSliceMut};
use super::{ByteSliceRefExt, MessageAttributes, PingRequest, RequestMsgType};
@@ -13,14 +13,14 @@ struct RequestRefMaker<B> {
msg_type: RequestMsgType,
}
impl<B: ByteSlice> RequestRef<B> {
impl<B: SplitByteSlice> RequestRef<B> {
/// Produce a [RequestRef] from a raw message buffer,
/// reading the type from the buffer
///
/// # Examples
///
/// ```
/// use zerocopy::AsBytes;
/// use zerocopy::IntoBytes;
///
/// use rosenpass::api::{PingRequest, RequestRef, RequestMsgType};
///
@@ -95,7 +95,7 @@ impl<B> From<Ref<B, super::AddPskBrokerRequest>> for RequestRef<B> {
}
}
impl<B: ByteSlice> RequestRefMaker<B> {
impl<B: SplitByteSlice> RequestRefMaker<B> {
fn new(buf: B) -> anyhow::Result<Self> {
let msg_type = buf.deref().request_msg_type_from_prefix()?;
Ok(Self { buf, msg_type })
@@ -125,7 +125,9 @@ impl<B: ByteSlice> RequestRefMaker<B> {
self.ensure_fit()?;
let point = self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point);
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer"))?;
Ok(Self { buf, msg_type })
}
@@ -134,7 +136,9 @@ impl<B: ByteSlice> RequestRefMaker<B> {
self.ensure_fit()?;
let point = self.buf.len() - self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point);
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer"))?;
Ok(Self { buf, msg_type })
}
@@ -159,7 +163,7 @@ pub enum RequestRef<B> {
impl<B> RequestRef<B>
where
B: ByteSlice,
B: SplitByteSlice,
{
/// Access the byte data of this reference
///
@@ -168,25 +172,25 @@ where
/// See [Self::parse].
pub fn bytes(&self) -> &[u8] {
match self {
Self::Ping(r) => r.bytes(),
Self::SupplyKeypair(r) => r.bytes(),
Self::AddListenSocket(r) => r.bytes(),
Self::AddPskBroker(r) => r.bytes(),
Self::Ping(r) => r.as_bytes(),
Self::SupplyKeypair(r) => r.as_bytes(),
Self::AddListenSocket(r) => r.as_bytes(),
Self::AddPskBroker(r) => r.as_bytes(),
}
}
}
impl<B> RequestRef<B>
where
B: ByteSliceMut,
B: SplitByteSliceMut,
{
/// Access the byte data of this reference; mutably
pub fn bytes_mut(&mut self) -> &[u8] {
match self {
Self::Ping(r) => r.bytes_mut(),
Self::SupplyKeypair(r) => r.bytes_mut(),
Self::AddListenSocket(r) => r.bytes_mut(),
Self::AddPskBroker(r) => r.bytes_mut(),
Self::Ping(r) => r.as_mut_bytes(),
Self::SupplyKeypair(r) => r.as_mut_bytes(),
Self::AddListenSocket(r) => r.as_mut_bytes(),
Self::AddPskBroker(r) => r.as_mut_bytes(),
}
}
}

View File

@@ -1,7 +1,7 @@
use rosenpass_util::zerocopy::{
RefMaker, ZerocopyEmancipateExt, ZerocopyEmancipateMutExt, ZerocopySliceExt,
};
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
use super::{Message, PingRequest, PingResponse};
use super::{RequestRef, ResponseRef, ZerocopyResponseMakerSetupMessageExt};
@@ -9,27 +9,27 @@ use super::{RequestRef, ResponseRef, ZerocopyResponseMakerSetupMessageExt};
/// Extension trait for [Message]s that are requests messages
pub trait RequestMsg: Sized + Message {
/// The response message belonging to this request message
type ResponseMsg: ResponseMsg;
type ResponseMsg: ResponseMsg + Immutable + KnownLayout;
/// Construct a response make for this particular message
fn zk_response_maker<B: ByteSlice>(buf: B) -> RefMaker<B, Self::ResponseMsg> {
fn zk_response_maker<B: SplitByteSlice>(buf: B) -> RefMaker<B, Self::ResponseMsg> {
buf.zk_ref_maker()
}
/// Setup a response maker (through [Message::setup]) for this request message type
fn setup_response<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
fn setup_response<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
Self::zk_response_maker(buf).setup_msg()
}
/// Setup a response maker from a buffer prefix (through [Message::setup]) for this request message type
fn setup_response_from_prefix<B: ByteSliceMut>(
fn setup_response_from_prefix<B: SplitByteSliceMut>(
buf: B,
) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
Self::zk_response_maker(buf).from_prefix()?.setup_msg()
}
/// Setup a response maker from a buffer suffix (through [Message::setup]) for this request message type
fn setup_response_from_suffix<B: ByteSliceMut>(
fn setup_response_from_suffix<B: SplitByteSliceMut>(
buf: B,
) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
Self::zk_response_maker(buf).from_prefix()?.setup_msg()
@@ -125,8 +125,8 @@ impl<B1, B2> From<AddPskBrokerPair<B1, B2>> for RequestResponsePair<B1, B2> {
impl<B1, B2> RequestResponsePair<B1, B2>
where
B1: ByteSlice,
B2: ByteSlice,
B1: SplitByteSlice,
B2: SplitByteSlice,
{
/// Returns a tuple to both the request and the response message
pub fn both(&self) -> (RequestRef<&[u8]>, ResponseRef<&[u8]>) {
@@ -167,8 +167,8 @@ where
impl<B1, B2> RequestResponsePair<B1, B2>
where
B1: ByteSliceMut,
B2: ByteSliceMut,
B1: SplitByteSliceMut,
B2: SplitByteSliceMut,
{
/// Returns a mutable tuple to both the request and the response message
pub fn both_mut(&mut self) -> (RequestRef<&mut [u8]>, ResponseRef<&mut [u8]>) {

View File

@@ -1,7 +1,7 @@
// TODO: This is copied verbatim from ResponseRef…not pretty
use anyhow::ensure;
use anyhow::{anyhow, ensure};
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
use zerocopy::{IntoBytes, Ref, SplitByteSlice, SplitByteSliceMut};
use super::{ByteSliceRefExt, MessageAttributes, PingResponse, ResponseMsgType};
@@ -16,14 +16,14 @@ struct ResponseRefMaker<B> {
msg_type: ResponseMsgType,
}
impl<B: ByteSlice> ResponseRef<B> {
impl<B: SplitByteSlice> ResponseRef<B> {
/// Produce a [ResponseRef] from a raw message buffer,
/// reading the type from the buffer
///
/// # Examples
///
/// ```
/// use zerocopy::AsBytes;
/// use zerocopy::IntoBytes;
///
/// use rosenpass::api::{PingResponse, ResponseRef, ResponseMsgType};
/// // Produce the original PingResponse
@@ -99,7 +99,7 @@ impl<B> From<Ref<B, super::AddPskBrokerResponse>> for ResponseRef<B> {
}
}
impl<B: ByteSlice> ResponseRefMaker<B> {
impl<B: SplitByteSlice> ResponseRefMaker<B> {
fn new(buf: B) -> anyhow::Result<Self> {
let msg_type = buf.deref().response_msg_type_from_prefix()?;
Ok(Self { buf, msg_type })
@@ -129,7 +129,9 @@ impl<B: ByteSlice> ResponseRefMaker<B> {
self.ensure_fit()?;
let point = self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point);
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer!"))?;
Ok(Self { buf, msg_type })
}
@@ -138,7 +140,9 @@ impl<B: ByteSlice> ResponseRefMaker<B> {
self.ensure_fit()?;
let point = self.buf.len() - self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point);
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer!"))?;
Ok(Self { buf, msg_type })
}
@@ -163,7 +167,7 @@ pub enum ResponseRef<B> {
impl<B> ResponseRef<B>
where
B: ByteSlice,
B: SplitByteSlice,
{
/// Access the byte data of this reference
///
@@ -172,25 +176,25 @@ where
/// See [Self::parse].
pub fn bytes(&self) -> &[u8] {
match self {
Self::Ping(r) => r.bytes(),
Self::SupplyKeypair(r) => r.bytes(),
Self::AddListenSocket(r) => r.bytes(),
Self::AddPskBroker(r) => r.bytes(),
Self::Ping(r) => r.as_bytes(),
Self::SupplyKeypair(r) => r.as_bytes(),
Self::AddListenSocket(r) => r.as_bytes(),
Self::AddPskBroker(r) => r.as_bytes(),
}
}
}
impl<B> ResponseRef<B>
where
B: ByteSliceMut,
B: SplitByteSliceMut,
{
/// Access the byte data of this reference; mutably
pub fn bytes_mut(&mut self) -> &[u8] {
match self {
Self::Ping(r) => r.bytes_mut(),
Self::SupplyKeypair(r) => r.bytes_mut(),
Self::AddListenSocket(r) => r.bytes_mut(),
Self::AddPskBroker(r) => r.bytes_mut(),
Self::Ping(r) => r.as_mut_bytes(),
Self::SupplyKeypair(r) => r.as_mut_bytes(),
Self::AddListenSocket(r) => r.as_mut_bytes(),
Self::AddPskBroker(r) => r.as_mut_bytes(),
}
}
}

View File

@@ -1,6 +1,6 @@
use super::{ByteSliceRefExt, Message, PingRequest, PingResponse, RequestRef, RequestResponsePair};
use std::{collections::VecDeque, os::fd::OwnedFd};
use zerocopy::{ByteSlice, ByteSliceMut};
use zerocopy::{SplitByteSlice, SplitByteSliceMut};
/// The rosenpass API implementation functions.
///
@@ -152,8 +152,8 @@ pub trait Server {
req_fds: &mut VecDeque<OwnedFd>,
) -> anyhow::Result<()>
where
ReqBuf: ByteSlice,
ResBuf: ByteSliceMut,
ReqBuf: SplitByteSlice,
ResBuf: SplitByteSliceMut,
{
match p {
RequestResponsePair::Ping((req, res)) => self.ping(req, req_fds, res),
@@ -182,8 +182,8 @@ pub trait Server {
res: ResBuf,
) -> anyhow::Result<usize>
where
ReqBuf: ByteSlice,
ResBuf: ByteSliceMut,
ReqBuf: SplitByteSlice,
ResBuf: SplitByteSliceMut,
{
let req = req.parse_request_from_prefix()?;
// TODO: This is not pretty; This match should be moved into RequestRef

View File

@@ -13,7 +13,7 @@ use signal_hook_mio::v1_0 as signal_hook_mio;
use anyhow::{bail, Context, Result};
use derive_builder::Builder;
use log::{error, info, warn};
use zerocopy::AsBytes;
use zerocopy::IntoBytes;
use rosenpass_util::attempt;
use rosenpass_util::fmt::debug::NullDebug;

View File

@@ -26,7 +26,7 @@ use {
command_fds::{CommandFdExt, FdMapping},
log::{error, info},
mio::net::UnixStream,
rosenpass_util::fd::claim_fd,
rosenpass_util::rustix::claim_fd,
rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient,
rosenpass_wireguard_broker::WireguardBrokerMio,
rustix::net::{socketpair, AddressFamily, SocketFlags, SocketType},

View File

@@ -9,7 +9,7 @@
//! To achieve this we utilize the zerocopy library.
//!
use std::mem::size_of;
use zerocopy::{AsBytes, FromBytes, FromZeroes};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
use super::RosenpassError;
use rosenpass_cipher_traits::primitives::{Aead as _, Kem};
@@ -51,7 +51,7 @@ pub type MsgEnvelopeCookie = [u8; COOKIE_SIZE];
///
/// ```
/// use rosenpass::msgs::{Envelope, InitHello};
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
/// use memoffset::offset_of;
///
/// // Zero-initialization
@@ -76,8 +76,8 @@ pub type MsgEnvelopeCookie = [u8; COOKIE_SIZE];
/// assert_eq!(ih3.msg_type, 42);
/// ```
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone)]
pub struct Envelope<M: AsBytes + FromBytes> {
#[derive(IntoBytes, FromBytes, Clone, Immutable, KnownLayout)]
pub struct Envelope<M: IntoBytes + FromBytes> {
/// [MsgType] of this message
pub msg_type: u8,
/// Reserved for future use
@@ -106,7 +106,7 @@ pub struct Envelope<M: AsBytes + FromBytes> {
///
/// ```
/// use rosenpass::msgs::{Envelope, InitHello};
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
/// use memoffset::span_of;
///
/// // Zero initialization
@@ -126,7 +126,7 @@ pub struct Envelope<M: AsBytes + FromBytes> {
/// assert_eq!(ih.payload.sidi, [1,2,3,4]);
/// ```
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct InitHello {
/// Randomly generated connection id
pub sidi: [u8; 4],
@@ -155,7 +155,7 @@ pub struct InitHello {
///
/// ```
/// use rosenpass::msgs::{Envelope, RespHello};
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
/// use memoffset::span_of;
///
/// // Zero initialization
@@ -175,7 +175,7 @@ pub struct InitHello {
/// assert_eq!(ih.payload.sidi, [1,2,3,4]);
/// ```
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct RespHello {
/// Randomly generated connection id
pub sidr: [u8; 4],
@@ -206,7 +206,7 @@ pub struct RespHello {
///
/// ```
/// use rosenpass::msgs::{Envelope, InitConf};
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
/// use zerocopy::{IntoBytes, FromBytes, FromZeros, Ref};
/// use memoffset::span_of;
///
/// // Zero initialization
@@ -226,7 +226,7 @@ pub struct RespHello {
/// assert_eq!(ih.payload.sidi, [1,2,3,4]);
/// ```
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Debug)]
#[derive(IntoBytes, FromBytes, Debug, Immutable, KnownLayout)]
pub struct InitConf {
/// Copied from InitHello
pub sidi: [u8; 4],
@@ -264,7 +264,7 @@ pub struct InitConf {
///
/// ```
/// use rosenpass::msgs::{Envelope, EmptyData};
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
/// use memoffset::span_of;
///
/// // Zero initialization
@@ -284,7 +284,7 @@ pub struct InitConf {
/// assert_eq!(ih.payload.sid, [1,2,3,4]);
/// ```
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Copy)]
#[derive(IntoBytes, FromBytes, Clone, Copy, Immutable, KnownLayout)]
pub struct EmptyData {
/// Copied from RespHello
pub sid: [u8; 4],
@@ -311,7 +311,7 @@ pub struct EmptyData {
///
/// [Envelope] and [InitHello] contain some extra examples on how to use structures from the [::zerocopy] crate.
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct Biscuit {
/// H(spki) Ident ifies the initiator
pub pidi: [u8; KEY_LEN],
@@ -336,7 +336,7 @@ pub struct Biscuit {
///
/// [Envelope] and [InitHello] contain some extra examples on how to use structures from the [::zerocopy] crate.
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct CookieReplyInner {
/// [MsgType] of this message
pub msg_type: u8,
@@ -363,7 +363,7 @@ pub struct CookieReplyInner {
///
/// [Envelope] and [InitHello] contain some extra examples on how to use structures from the [::zerocopy] crate.
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct CookieReply {
pub inner: CookieReplyInner,
pub padding: [u8; size_of::<Envelope<InitHello>>() - size_of::<CookieReplyInner>()],

View File

@@ -17,7 +17,7 @@ use std::{
use anyhow::{bail, ensure, Context, Result};
use assert_tv::{TestVector, TestVectorNOP};
use memoffset::span_of;
use zerocopy::{AsBytes, FromBytes, Ref};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref};
use rosenpass_cipher_traits::primitives::{
Aead as _, AeadWithNonceInCiphertext, Kem, KeyedHashInstance,
@@ -413,7 +413,7 @@ pub struct InitiatorHandshake {
///
/// Used as [KnownInitConfResponse] for now cache [EmptyData] (responder confirmation)
/// responses to [InitConf]
pub struct KnownResponse<ResponseType: AsBytes + FromBytes> {
pub struct KnownResponse<ResponseType: IntoBytes + FromBytes> {
/// When the response was initially computed
pub received_at: Timing,
/// Hash of the message that triggered the response; created using
@@ -423,7 +423,7 @@ pub struct KnownResponse<ResponseType: AsBytes + FromBytes> {
pub response: Envelope<ResponseType>,
}
impl<ResponseType: AsBytes + FromBytes> Debug for KnownResponse<ResponseType> {
impl<ResponseType: IntoBytes + FromBytes> Debug for KnownResponse<ResponseType> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KnownResponse")
.field("received_at", &self.received_at)
@@ -435,7 +435,7 @@ impl<ResponseType: AsBytes + FromBytes> Debug for KnownResponse<ResponseType> {
#[test]
fn known_response_format() {
use zerocopy::FromZeroes;
use zerocopy::FromZeros;
let v = KnownResponse::<[u8; 32]> {
received_at: 42.0,
@@ -464,7 +464,7 @@ pub type KnownResponseHash = Public<16>;
/// # Examples
///
/// ```
/// use zerocopy::FromZeroes;
/// use zerocopy::FromZeros;
/// use rosenpass::protocol::KnownResponseHasher;
/// use rosenpass::msgs::{Envelope, InitConf};
///
@@ -508,7 +508,10 @@ impl KnownResponseHasher {
/// # Panic & Safety
///
/// Panics in case of a problem with this underlying hash function
pub fn hash<Msg: AsBytes + FromBytes>(&self, msg: &Envelope<Msg>) -> KnownResponseHash {
pub fn hash<Msg: IntoBytes + FromBytes + Immutable>(
&self,
msg: &Envelope<Msg>,
) -> KnownResponseHash {
let data = &msg.as_bytes()[span_of!(Envelope<Msg>, msg_type..cookie)];
// This function is only used internally and results are not propagated
// to outside the peer. Thus, it uses SHAKE256 exclusively.
@@ -2042,7 +2045,8 @@ impl CryptoServer {
let mut expected = [0u8; COOKIE_SIZE];
let msg_in = Ref::<&[u8], Envelope<InitHello>>::new(rx_buf)
let msg_in = Ref::<&[u8], Envelope<InitHello>>::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
expected.copy_from_slice(
&hash_domains::cookie(KeyedHash::keyed_shake256())?
@@ -2187,8 +2191,9 @@ impl CryptoServer {
let peer = match msg_type {
Ok(MsgType::InitHello) => {
let msg_in: Ref<&[u8], Envelope<InitHello>> =
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<InitHello>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
// At this point, we do not know the hash functon used by the peer, thus we try both,
// with a preference for SHAKE256.
@@ -2221,8 +2226,9 @@ impl CryptoServer {
peer
}
Ok(MsgType::RespHello) => {
let msg_in: Ref<&[u8], Envelope<RespHello>> =
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<RespHello>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
let mut msg_out = truncating_cast_into::<Envelope<InitConf>>(tx_buf)?;
let peer = self.handle_resp_hello(&msg_in.payload, &mut msg_out.payload)?;
@@ -2238,8 +2244,9 @@ impl CryptoServer {
peer
}
Ok(MsgType::InitConf) => {
let msg_in: Ref<&[u8], Envelope<InitConf>> =
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<InitConf>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
let mut msg_out = truncating_cast_into::<Envelope<EmptyData>>(tx_buf)?;
@@ -2258,7 +2265,7 @@ impl CryptoServer {
.map(|v| v.response.borrow())
// Invalid! Found peer no with cache in index but the cache does not exist
.unwrap();
copy_slice(cached.as_bytes()).to(msg_out.as_bytes_mut());
copy_slice(cached.as_bytes()).to(msg_out.as_mut_bytes());
peer
}
@@ -2306,14 +2313,16 @@ impl CryptoServer {
peer
}
Ok(MsgType::EmptyData) => {
let msg_in: Ref<&[u8], Envelope<EmptyData>> =
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<EmptyData>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
self.handle_resp_conf(&msg_in, seal_broken.to_string())?
}
Ok(MsgType::CookieReply) => {
let msg_in: Ref<&[u8], CookieReply> =
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], CookieReply> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
let peer = self.handle_cookie_reply(&msg_in)?;
len = 0;
peer
@@ -2354,8 +2363,8 @@ impl CryptoServer {
///
/// To save some code, the function returns the size of the message,
/// but the same could be easily achieved by calling [size_of] with the
/// message type or by calling [AsBytes::as_bytes] on the message reference.
pub fn seal_and_commit_msg<M: AsBytes + FromBytes>(
/// message type or by calling [IntoBytes::as_bytes] on the message reference.
pub fn seal_and_commit_msg<M: IntoBytes + FromBytes + Immutable + KnownLayout>(
&mut self,
peer: PeerPtr,
msg_type: MsgType,
@@ -3078,7 +3087,7 @@ impl IniHsPtr {
impl<M> Envelope<M>
where
M: AsBytes + FromBytes,
M: IntoBytes + FromBytes + Immutable + KnownLayout,
{
/// Internal business logic: Calculate the message authentication code (`mac`) and also append cookie value
pub fn seal(&mut self, peer: PeerPtr, srv: &CryptoServer) -> Result<()> {
@@ -3107,7 +3116,7 @@ where
impl<M> Envelope<M>
where
M: AsBytes + FromBytes,
M: IntoBytes + FromBytes + Immutable + KnownLayout,
{
/// Internal business logic: Check the message authentication code produced by [Self::seal]
pub fn check_seal(&self, srv: &CryptoServer, shake_or_blake: KeyedHash) -> Result<bool> {
@@ -3240,7 +3249,7 @@ impl HandshakeState {
/// out the const generics.
/// - By adding a value parameter of type `PhantomData<TV>`, you can choose
/// `TV` at the call site while allowing the compiler to infer `KEM_*`
/// const generics from `ct` and `pk`.
/// const generics from `ct` and `pk`.
/// - Call like: `encaps_and_mix_with_test_vector(&StaticKem, &mut ct, pk,
/// PhantomData::<TestVectorActive>)?;`
pub fn encaps_and_mix_with_test_vector<
@@ -3322,7 +3331,7 @@ impl HandshakeState {
let test_values: StoreBiscuitTestValues = TV::initialize_values();
let mut biscuit = Secret::<BISCUIT_PT_LEN>::zero(); // pt buffer
let mut biscuit: Ref<&mut [u8], Biscuit> =
Ref::new(biscuit.secret_mut().as_mut_slice()).unwrap();
Ref::from_bytes(biscuit.secret_mut().as_mut_slice()).unwrap();
// calculate pt contents
biscuit
@@ -3384,9 +3393,9 @@ impl HandshakeState {
// Allocate and decrypt the biscuit data
let mut biscuit = Secret::<BISCUIT_PT_LEN>::zero(); // pt buf
let mut biscuit: Ref<&mut [u8], Biscuit> =
Ref::new(biscuit.secret_mut().as_mut_slice()).unwrap();
Ref::from_bytes(biscuit.secret_mut().as_mut_slice()).unwrap();
XAead.decrypt_with_nonce_in_ctxt(
biscuit.as_bytes_mut(),
biscuit.as_mut_bytes(),
bk.get(srv).value.secret(),
&ad,
biscuit_ct,

View File

@@ -2,7 +2,7 @@ use std::{borrow::BorrowMut, fmt::Display, net::SocketAddrV4, ops::DerefMut};
use anyhow::{Context, Result};
use serial_test::serial;
use zerocopy::{AsBytes, FromBytes, FromZeroes};
use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout};
use rosenpass_cipher_traits::primitives::Kem;
use rosenpass_ciphers::StaticKem;
@@ -538,10 +538,10 @@ fn init_conf_retransmission(protocol_version: ProtocolVersion) -> anyhow::Result
srv.initiate_handshake(peer, buf.as_mut_slice())?
.discard_result();
let msg = truncating_cast_into::<Envelope<InitHello>>(buf.borrow_mut())?;
Ok(msg.read())
Ok(zerocopy::Ref::read(&msg))
}
fn proc_msg<Rx: AsBytes + FromBytes, Tx: AsBytes + FromBytes>(
fn proc_msg<Rx: IntoBytes + FromBytes + Immutable, Tx: IntoBytes + FromBytes + Immutable>(
srv: &mut CryptoServer,
rx: &Envelope<Rx>,
) -> anyhow::Result<Envelope<Tx>> {
@@ -551,7 +551,7 @@ fn init_conf_retransmission(protocol_version: ProtocolVersion) -> anyhow::Result
.context("Failed to produce RespHello message")?
.discard_result();
let msg = truncating_cast_into::<Envelope<Tx>>(buf.borrow_mut())?;
Ok(msg.read())
Ok(zerocopy::Ref::read(&msg))
}
fn proc_init_hello(
@@ -582,17 +582,21 @@ fn init_conf_retransmission(protocol_version: ProtocolVersion) -> anyhow::Result
}
// TODO: Implement Clone on our message types
fn clone_msg<Msg: AsBytes + FromBytes>(msg: &Msg) -> anyhow::Result<Msg> {
Ok(truncating_cast_into_nomut::<Msg>(msg.as_bytes())?.read())
fn clone_msg<Msg: IntoBytes + FromBytes + Immutable + KnownLayout>(
msg: &Msg,
) -> anyhow::Result<Msg> {
Ok(zerocopy::Ref::read(&truncating_cast_into_nomut::<Msg>(
msg.as_bytes(),
)?))
}
fn break_payload<Msg: AsBytes + FromBytes>(
fn break_payload<Msg: IntoBytes + FromBytes + Immutable + KnownLayout>(
srv: &mut CryptoServer,
peer: PeerPtr,
msg: &Envelope<Msg>,
) -> anyhow::Result<Envelope<Msg>> {
let mut msg = clone_msg(msg)?;
msg.as_bytes_mut()[memoffset::offset_of!(Envelope<Msg>, payload)] ^= 0x01;
msg.as_mut_bytes()[memoffset::offset_of!(Envelope<Msg>, payload)] ^= 0x01;
msg.seal(peer, srv)?; // Recalculate seal; we do not want to focus on "seal broken" errs
Ok(msg)
}

View File

@@ -17,7 +17,7 @@ use assert_tv::TestVectorSet;
use base64::Engine;
use rosenpass_cipher_traits::primitives::{Aead, Kem};
use rosenpass_ciphers::{EphemeralKem, XAead, KEY_LEN};
use rosenpass_secret_memory::{Public, PublicBox, Secret};
use rosenpass_secret_memory::{Public, Secret};
use serde_json::Value;
#[derive(TestVectorSet)]

View File

@@ -2,20 +2,24 @@
use std::mem::size_of;
use zerocopy::{FromBytes, Ref};
use zerocopy::{FromBytes, Immutable, KnownLayout, Ref};
use crate::RosenpassError;
/// Used to parse a network message using [zerocopy]
pub fn truncating_cast_into<T: FromBytes>(
pub fn truncating_cast_into<T: FromBytes + KnownLayout + Immutable>(
buf: &mut [u8],
) -> Result<Ref<&mut [u8], T>, RosenpassError> {
Ref::new(&mut buf[..size_of::<T>()]).ok_or(RosenpassError::BufferSizeMismatch)
Ref::from_bytes(&mut buf[..size_of::<T>()])
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)
}
/// Used to parse a network message using [zerocopy], mutably
pub fn truncating_cast_into_nomut<T: FromBytes>(
pub fn truncating_cast_into_nomut<T: FromBytes + KnownLayout + Immutable>(
buf: &[u8],
) -> Result<Ref<&[u8], T>, RosenpassError> {
Ref::new(&buf[..size_of::<T>()]).ok_or(RosenpassError::BufferSizeMismatch)
Ref::from_bytes(&buf[..size_of::<T>()])
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)
}

View File

@@ -27,7 +27,7 @@ use rosenpass_util::{
};
use std::os::fd::{AsFd, AsRawFd};
use tempfile::TempDir;
use zerocopy::AsBytes;
use zerocopy::IntoBytes;
struct KillChild(std::process::Child);

View File

@@ -14,7 +14,7 @@ use rosenpass_util::{
};
use rosenpass_util::{mem::DiscardResultExt, zerocopy::ZerocopySliceExt};
use tempfile::TempDir;
use zerocopy::AsBytes;
use zerocopy::IntoBytes;
use rosenpass::config::ProtocolVersion;
use rosenpass::protocol::basic_types::SymKey;

View File

@@ -13,11 +13,14 @@ rust-version = "1.77.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
rosenpass-to = { workspace = true }
base64ct = { workspace = true }
anyhow = { workspace = true }
typenum = { workspace = true }
static_assertions = { workspace = true }
rustix = { workspace = true }
rustix = { workspace = true, features = ["net", "fs", "process", "mm"] }
libc = { workspace = true }
errno = { workspace = true }
zeroize = { workspace = true }
zerocopy = { workspace = true }
thiserror = { workspace = true }
@@ -32,6 +35,9 @@ tokio = { workspace = true, optional = true, features = [
"time",
] }
log = { workspace = true }
bitflags = { workspace = true }
tinyvec = "1.10.0"
num-traits = "0.2.19"
[features]
experiment_file_descriptor_passing = ["uds"]

85
util/src/convert.rs Normal file
View File

@@ -0,0 +1,85 @@
//! Additional helpers to [std::convert]: Traits for conversions between types.
/// Variant of [std::convert::Into] with an explicitly specified type.
///
/// This facilitates method chaining.
///
/// # Examples
///
/// We can create a nicer implementation of the following function using this extension trait
///
/// ```
/// use rosenpass_util::convert::IntoTypeExt;
///
/// fn encode_char_u32_be_1(c: char) -> [u8; 4] {
/// let i : u32 = c.into();
/// i.to_be_bytes()
/// }
///
/// fn encode_char_u32_be_2(c: char) -> [u8; 4] {
/// c.into_type::<u32>().to_be_bytes()
/// }
///
/// assert_eq!(encode_char_u32_be_1('X'), [0x00, 0x00, 0x00, 0x58]);
/// assert_eq!(encode_char_u32_be_1('X'), encode_char_u32_be_2('X'));
///
/// ```
pub trait IntoTypeExt {
/// Variant of [std::convert::Into] with explicitly specified type.
///
/// # Examples
///
/// See [IntoType].
fn into_type<T>(self) -> T
where
Self: Into<T>,
{
self.into()
}
}
impl<T> IntoTypeExt for T {}
/// Variant of [std::convert::TryInto] with an explicitly specified type.
///
/// This facilitates method chaining.
///
/// # Examples
///
/// We can create a nicer implementation of the following function using this extension trait
///
/// ```
/// use rosenpass_util::convert::TryIntoTypeExt;
/// use rosenpass_util::result::OkExt;
///
/// fn encode_char_u16_be_1(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
/// let i : u16 = c.try_into()?;
/// Ok(i.to_be_bytes())
/// }
///
/// fn encode_char_u16_be_2(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
/// c.try_into_type::<u16>()?.to_be_bytes().ok()
/// }
///
/// assert_eq!(encode_char_u16_be_1('X'), Ok([0x00, 0x58]));
/// assert_eq!(encode_char_u16_be_1('X'), encode_char_u16_be_2('X'));
///
/// const HEART_CAT : char = '😻'; // 1F63B
/// assert!(encode_char_u16_be_1(HEART_CAT).is_err());
/// assert_eq!(encode_char_u16_be_1(HEART_CAT), encode_char_u16_be_2(HEART_CAT));
/// ```
pub trait TryIntoTypeExt {
/// Variant of [std::convert::TryInto] with explicitly specified type.
///
/// # Examples
///
/// See [TryIntoType].
fn try_into_type<T>(self) -> Result<T, <Self as TryInto<T>>::Error>
where
Self: TryInto<T>,
{
self.try_into()
}
}
impl<T> TryIntoTypeExt for T {}

4
util/src/int/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Helpers for working with integer types
pub mod modular;
pub mod u64uint;

276
util/src/int/modular.rs Normal file
View File

@@ -0,0 +1,276 @@
//! Numeric types for modular arithmetic
use num_traits::{
ops::overflowing::OverflowingAdd, CheckedMul, Euclid, Num, Unsigned, WrappingNeg, Zero,
};
/// Summary-trait for numeric types that can serve as the basis for ModuleBase
pub trait ModuleBase: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
impl<T> ModuleBase for T where T: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
/// Represents a modulus; i.e. the range of values some number type is allowed to use.
///
/// This is based on some inner representation.
///
/// This is not just a value of the underlying representation, because it also supports the modulus
/// [Self::new_full_range()], which indicates that the full range of the underlying type is to be
/// supported.
///
/// Note that zero is not a valid modulus
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct Modulus<T: ModuleBase> {
/// Inner representation of the type
modulus: T,
}
impl<T: ModuleBase> Modulus<T> {
/// Create a new [Self] without any checks
fn raw_new(modulus: T) -> Self {
Self { modulus }
}
/// Create a new [Self] that indicates that the full range of the underlying type is to be used
pub fn new_full_range() -> Self {
Self::raw_new(T::zero())
}
/// Try to create a new [Self]. Will return None only if `modulus == 0`
pub fn try_new(modulus: T) -> Option<Self> {
match modulus == T::zero() {
true => None,
false => Some(Self::raw_new(modulus)),
}
}
/// Like [Self::try_new] but will panic if `modulus == 0`
///
/// # Panic
///
/// Will panic if `modulus == 0`
pub fn new_or_panic(modulus: T) -> Self {
match Self::try_new(modulus) {
None => panic!("Can not create Module with modulus zero!"),
Some(me) => me,
}
}
/// Check if this [Self] represents the full range of the underlying type
pub fn is_full_range(&self) -> bool {
self.modulus == T::zero()
}
/// Get the raw modulus. I.e. the value of the underlying type that can be given to
/// a modulo operation to implement modular arithmetic.
///
/// Will return None if [Self::is_full_range].
pub fn modulus(&self) -> Option<T> {
match self.is_full_range() {
true => None,
false => Some(self.modulus),
}
}
/// Check if the given type is contained in the range represented by this [Self]
pub fn contains(&self, v: T) -> bool {
match self.is_full_range() {
true => true,
false => v < self.modulus,
}
}
/// Double the modulus.
///
/// Correctly handles the case that `v.double().is_full_range()`.
///
/// # Examples
///
/// ```rust
/// use rosenpass_util::int::modular::Modulus;
///
/// fn m(v: u8) -> Modulus<u8> {
/// Modulus::new_or_panic(v)
/// }
///
/// assert_eq!(m(100).double(), Some(m(200)));
/// assert_eq!(m(128).double(), Some(Modulus::new_full_range()));
/// assert_eq!(m(129).double(), None);
/// ```
pub fn double(&self) -> Option<Self> {
let s = self.modulus()?;
match s.overflowing_add(&s) {
(d, true) if d > T::zero() => None,
(d, _) => Some(Self::raw_new(d)),
}
}
/// Create a new [ModularArithmetic] by taking the value modulo the modulus
pub fn new_number<U: Into<T>>(self, value: U) -> ModularArithmetic<T>
where
T: ModularArithmeticBase,
{
ModularArithmetic::modular_new(value.into(), self)
}
/// Apply [Self::new_number] to each of the parameters, return whatever result the closure
/// produces
pub fn with_converted<U, const N: usize, R, F>(&self, params: [U; N], f: F) -> R
where
Self: Copy,
T: std::fmt::Debug + ModularArithmeticBase,
U: Into<T>,
F: FnOnce([ModularArithmetic<T>; N]) -> R,
{
let params = params.map(|v| self.new_number(v));
f(params)
}
/// Apply [Self::new_number] to each of the parameters, converting the result to the underlying
/// representation
pub fn formula<U, const N: usize, F>(&self, params: [U; N], f: F) -> T
where
Self: Copy,
T: std::fmt::Debug + ModularArithmeticBase,
U: Into<T>,
F: FnOnce([ModularArithmetic<T>; N]) -> ModularArithmetic<T>,
{
self.with_converted(params, f).value()
}
}
/// Summary trait for types that can serve as the basis for [ModularArithmetic]
pub trait ModularArithmeticBase:
ModuleBase
+ std::fmt::Debug
+ Num
+ PartialOrd
+ Ord
+ Copy
+ Unsigned
+ CheckedMul
+ Euclid
+ WrappingNeg
{
}
impl<T> ModularArithmeticBase for T where
T: ModuleBase
+ std::fmt::Debug
+ Num
+ PartialOrd
+ Ord
+ Copy
+ Unsigned
+ CheckedMul
+ Euclid
+ WrappingNeg
{
}
/// Modular arithmetic with an arbitrary modulus
#[derive(Debug, Copy, Clone)]
pub struct ModularArithmetic<T: ModularArithmeticBase> {
/// The modulus
modulus: Modulus<T>,
/// The value inside the modulus
///
/// Note that `self.modulus.contains(self.value)` must always hold.
value: T,
}
impl<T: ModularArithmeticBase> ModularArithmetic<T> {
/// Construct a new [Self].
///
/// Will return none unless `module.`[contains](Modulus::contains)`(value)`.
pub fn try_new(value: T, module: Modulus<T>) -> Option<Self> {
module.contains(value).then_some(Self {
value,
modulus: module,
})
}
/// Construct a new [Self].
///
/// # Panic
///
/// Will punic unless `module.`[contains](Modulus::contains)`(value)`.
pub fn modular_new(value: T, module: Modulus<T>) -> Self {
let value = match module.modulus() {
Some(m) => value.rem_euclid(&m),
None => value,
};
Self {
modulus: module,
value,
}
}
/// Return the modulus
pub fn modulus(&self) -> &Modulus<T> {
&self.modulus
}
/// The inner value
pub fn value(&self) -> T {
self.value
}
}
impl<T: ModularArithmeticBase> PartialEq for ModularArithmetic<T> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl<T: ModularArithmeticBase> Eq for ModularArithmetic<T> {}
impl<T: ModularArithmeticBase> PartialOrd for ModularArithmetic<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T: ModularArithmeticBase> Ord for ModularArithmetic<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.value.cmp(&other.value)
}
}
impl<T: ModularArithmeticBase> std::ops::Neg for ModularArithmetic<T> {
type Output = Self;
fn neg(self) -> Self::Output {
let ret = match self.modulus().modulus() {
None => self.value().wrapping_neg(),
Some(modulus) => modulus - self.value(),
};
Self {
value: ret,
modulus: self.modulus,
}
}
}
impl<T: ModularArithmeticBase> std::ops::Sub for ModularArithmetic<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.modulus(), rhs.modulus());
if self < rhs {
return -(rhs - self);
}
Self {
value: self.value() - rhs.value(),
modulus: self.modulus,
}
}
}
impl<T: ModularArithmeticBase> std::ops::Add for ModularArithmetic<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
self - (-rhs)
}
}

View File

@@ -0,0 +1,12 @@
//! Constants for u64 <-> usize conversion
/// Largest numeric value that can be safely represented both as
/// a u64 and usize
pub const MAX_USIZE_IN_U64: usize = match u64::BITS >= usize::BITS {
true => usize::MAX,
false => u64::MAX as usize,
};
/// Largest numeric value that can be safely represented both as
/// a u64 and usize
pub const MAX_U64_IN_USIZE: u64 = MAX_USIZE_IN_U64 as u64;

View File

@@ -0,0 +1,21 @@
//! Making sure size representations fit both into [usize] and [u64]
use static_assertions::const_assert;
mod constants;
pub use constants::*;
#[allow(clippy::module_inception)]
mod u64uint;
pub use u64uint::*;
mod range;
pub use range::*;
/// Safe conversion from usize to u64
///
/// TODO: Deprecate this
pub const fn usize_to_u64(v: usize) -> u64 {
const_assert!(u64::BITS >= usize::BITS);
v as u64
}

View File

@@ -0,0 +1,29 @@
//! Working with ranges of [super::U64Uint]
use std::ops::Range;
use super::U64USize;
/// Extensions for working with [std::ops::Range] of [U64USize]
pub trait U64USizeRangeExt {
/// Convert to a usize based range
fn usize(self) -> Range<usize>;
/// Convert to a u64 based range
fn u64(self) -> Range<u64>;
}
impl U64USizeRangeExt for Range<U64USize> {
fn usize(self) -> Range<usize> {
Range {
start: self.start.usize(),
end: self.end.usize(),
}
}
fn u64(self) -> Range<u64> {
Range {
start: self.start.u64(),
end: self.end.u64(),
}
}
}

View File

@@ -0,0 +1,192 @@
//! The [U64UInt type]
use super::MAX_U64_IN_USIZE;
/// Error produced by [U64USize::try_new]
#[derive(Debug, thiserror::Error)]
pub enum U64USizeConversionError<T: std::fmt::Debug> {
/// Value can not be represented as u64
#[error("Value can not be represented as a u64 (max = {}) value: {:?}", u64::MAX, .0)]
NoU64Repr(T),
/// Value can not be represented as usize
#[error("Value can not be represented as a usize (max = {}) value: {:?}", usize::MAX, .0)]
NoUSizeRepr(T),
/// Value can not be represented as usize or u64
#[error("Value can not be represented as a u64 (max = {}) or usize (max = {}) value: {:?}", u64::MAX, usize::MAX, .0)]
NoU64OrUSizeRepr(T),
}
/// A number that can be represented as both a usize and a u64.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct U64USize {
/// Enclosed data
storage: u64,
}
impl U64USize {
/// Cast another number to a number that can be represented as usize and u64;
///
/// This is the internal version which does not use [TryInto]; we use this to implement
/// [TryInto].
fn try_new_internal<T>(v: T) -> Result<Self, U64USizeConversionError<T>>
where
T: Copy + TryInto<usize> + TryInto<u64> + std::fmt::Debug,
{
use U64USizeConversionError as E;
let v_u64: Result<u64, _> = v.try_into();
let v_usize: Result<usize, _> = v.try_into();
match (v_u64, v_usize) {
(Ok(storage), Ok(_)) => Ok(Self { storage }),
(Err(_), Ok(_)) => Err(E::NoU64Repr(v)),
(Ok(_), Err(_)) => Err(E::NoUSizeRepr(v)),
(Err(_), Err(_)) => Err(E::NoU64OrUSizeRepr(v)),
}
}
/// Cast another number to a number that can be represented as usize and u64
pub fn try_new<T>(v: T) -> Result<U64USize, <T as TryInto<Self>>::Error>
where
T: TryInto<Self>,
{
v.try_into()
}
/// Like [Self::try_new], but panics
pub fn new_or_panic<T>(v: T) -> Self
where
T: TryInto<Self>,
<T as TryInto<Self>>::Error: std::fmt::Debug,
{
match Self::try_new(v) {
Ok(v) => v,
Err(e) => panic!(
"Could not construct {}: {e:?}",
std::any::type_name::<Self>()
),
}
}
/// Return this value as a usize
pub fn usize(&self) -> usize {
self.storage as usize
}
/// Return this value as a u64
pub fn u64(&self) -> u64 {
self.storage
}
}
impl std::ops::Sub for U64USize {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self::new_or_panic(self.u64() - rhs.u64())
}
}
impl std::ops::Add for U64USize {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self::new_or_panic(self.u64() + rhs.u64())
}
}
/// Facilitates creation of [U64USize] in cases where truncation to the largest representable value is permissible
pub trait TruncateIntoU64USize {
/// Check whether calling [TruncateIntoU64Usize::truncate_to_u64usize] would truncate or return
/// the value as-is
fn fits_into_u64usize(&self) -> bool;
/// Turn [Self] into a [U64USize]. If the value is representable as a usize and a u64, then
/// the value will be returned as is. Otherwise, the maximum representable value [MAX_U64_IN_USIZE]
/// will be returned.
fn truncate_to_u64usize(&self) -> U64USize;
}
/// Create instances of From for SafeUSize
macro_rules! derive_TruncateIntoU64Usize {
($($T:ty),*) => {
$(
impl TruncateIntoU64USize for $T {
fn fits_into_u64usize(&self) -> bool {
U64USize::try_new(*self).is_ok()
}
fn truncate_to_u64usize(&self) -> U64USize {
U64USize::try_new(*self).unwrap_or(U64USize::new_or_panic(MAX_U64_IN_USIZE))
}
}
)*
}
}
derive_TruncateIntoU64Usize!(
U64USize, usize, isize, bool, u8, u16, u32, u64, u128, i8, i16, i32, i64, i128
);
/// Create instances of From for SafeUSize
macro_rules! U64USize_derive_from {
($($T:ty),*) => {
$(
impl From<$T> for U64USize {
fn from(value: $T) -> Self {
U64USize::try_new_internal::<$T>(value).unwrap()
}
}
)*
}
}
U64USize_derive_from!(bool, u8, u16);
/// Create instances of TryFrom for SafeUSize
macro_rules! U64USize_derive_try_from {
($($T:ty),*) => {
$(
impl TryFrom<$T> for U64USize {
type Error = U64USizeConversionError<$T>;
fn try_from(value: $T) -> Result<Self, Self::Error> {
U64USize::try_new_internal::<$T>(value)
}
}
)*
}
}
U64USize_derive_try_from!(usize, isize, u32, u64, u128, i8, i16, i32, i64, i128);
/// Create instances of Into for SafeUSize
macro_rules! U64USize_derive_into {
($($T:ty),*) => {
$(
impl From<U64USize> for $T {
fn from(val: U64USize) -> Self {
val.u64() as $T
}
}
)*
}
}
U64USize_derive_into!(usize, u64, u128, i128);
/// Create instances of TryInto for SafeUSize
macro_rules! U64USize_derive_try_into {
($($T:ty),*) => {
$(
impl TryFrom<U64USize> for $T {
type Error = <$T as TryFrom<u64>>::Error;
fn try_from(val: U64USize) -> Result<Self, Self::Error> {
val.u64().try_into()
}
}
)*
}
}
U64USize_derive_try_into!(isize, u8, u16, u32, i8, i16, i32, i64);

3
util/src/ipc/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
//! Inter-process communication related resources
pub mod shm;

6
util/src/ipc/shm/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
//! Resources for working with shared-memory
mod shared_memory_segment;
pub use shared_memory_segment::*;
pub mod ringbuf;

View File

@@ -0,0 +1,101 @@
//! Local shared memory ring buffers mostly for testing
use std::sync::Arc;
use crate::ipc::shm::SharedMemorySegment;
use crate::ringbuf::concurrent::framework::{ConcurrentPipeReader, ConcurrentPipeWriter};
use super::{ShmPipeCore, ShmPipeVariables};
/// A process-local shared memory pipe reader
///
/// See [shm_pipe()].
pub type LocalShmPipeReader = ConcurrentPipeReader<ShmPipeCore<Arc<ShmPipeVariables>>>;
/// A process-local shared memory pipe writer
///
/// See [shm_pipe()].
pub type LocalShmPipeWriter = ConcurrentPipeWriter<ShmPipeCore<Arc<ShmPipeVariables>>>;
/// Creates a process-local shared-memory pipe.
///
/// See [ConcurrentPipeWriter]/[ConcurrentPipeReader]. The types [LocalShmPipeWriter]/[LocalShmPipeReader] are just aliases for these.
///
/// # Safety
///
/// Mind the comments in the safety section of [super::super::SharedMemorySegment]; the issues
/// described in there *exactly* affect this implementation.
pub fn shm_pipe(len: usize) -> anyhow::Result<(LocalShmPipeWriter, LocalShmPipeReader)> {
let shared = Arc::new(ShmPipeVariables::new());
let (seg_fd, seg_buf_1) = SharedMemorySegment::create(len)?;
let seg_buf_2 = unsafe { SharedMemorySegment::from_fd(seg_fd, len) }?;
let writer = ShmPipeCore::new(shared.clone(), seg_buf_1);
let reader = ShmPipeCore::new(shared, seg_buf_2);
Ok((
ConcurrentPipeWriter::from_core(writer),
ConcurrentPipeReader::from_core(reader),
))
}
#[test]
fn test_shm_pipe() -> anyhow::Result<()> {
let (mut writer, mut reader) = shm_pipe(1024)?;
const MSG: &[u8] = b"Hello World\0";
const MSG_COUNT: usize = 100000;
let t = std::thread::spawn(move || {
for _ in 0..MSG_COUNT {
let mut buf = MSG;
while !buf.is_empty() {
let n = writer.write(buf).unwrap();
buf = &buf[n..];
}
}
});
let mut buf = [0u8; 1000];
let mut buf_off = 0;
let mut msg_no = 0usize;
'read_data: while msg_no < MSG_COUNT {
let mut old_off = buf_off;
// Read the data from the shared memory buffer
buf_off += reader.read(&mut buf[buf_off..])?;
loop {
// Scan the available data for the zero terminator
let msg_len = &buf[old_off..buf_off]
.iter()
.copied()
.enumerate()
.find(|(_off, c)| *c == 0x0)
.map(|(off, _c)| off + old_off + 1);
// Next iteration, unless the terminator was found
let msg_len = match *msg_len {
Some(l) => l,
None => continue 'read_data,
};
// Register the newly read message
msg_no += 1;
// Check that the message is correctly transferred
let msg = &buf[0..msg_len];
assert_eq!(msg, MSG);
// Move any extra data to the beginning of the buffer and adjust the offsets accordingly
buf.copy_within(msg_len..buf_off, 0);
old_off = 0;
buf_off -= msg_len;
}
}
t.join().unwrap();
Ok(())
}

View File

@@ -0,0 +1,81 @@
//! Shared-memory ring buffer implementations (main part of the enclosing module)
use std::{borrow::Borrow, sync::atomic::AtomicU64};
use zerocopy::{FromBytes, IntoBytes};
use crate::{
ipc::shm::SharedMemorySegment,
ringbuf::concurrent::framework::{
ConcurrentPipeCore, ConcurrentPipeReader, ConcurrentPipeWriter,
},
};
/// Synchronization variables for a [ShmPipeWriter]/[ShmPipeReader].
///
/// These values must be shared between the reader/writer in such a way that access to the inner
/// variables is synchronized and atomic between the two parties.
#[repr(C)]
#[derive(Debug, Default, IntoBytes, FromBytes)]
pub struct ShmPipeVariables {
/// See [crate::ringbuf::sched::RingBufferScheduler::items_read()]
pub items_read: AtomicU64,
/// See [crate::ringbuf::sched::RingBufferScheduler::items_written()]
pub items_written: AtomicU64,
}
impl ShmPipeVariables {
/// Constructor
pub fn new() -> Self {
Self {
items_read: 0.into(),
items_written: 0.into(),
}
}
}
/// The [ConcurrentPipeCore] for a shared memory pipe.
#[derive(Debug)]
pub struct ShmPipeCore<Variables: Borrow<ShmPipeVariables>> {
/// The synchronization variables
variables: Variables,
/// The memory buffer
buf: SharedMemorySegment,
}
impl<Variables: Borrow<ShmPipeVariables>> ShmPipeCore<Variables> {
/// Constructor
pub fn new(variables: Variables, buf: SharedMemorySegment) -> Self {
Self { variables, buf }
}
}
/// A shared memory pipe reader
pub type ShmPipeReader<Variables> = ConcurrentPipeReader<ShmPipeCore<Variables>>;
/// A shared memory pipe reader
pub type ShmPipeWriter<Variables> = ConcurrentPipeWriter<ShmPipeCore<Variables>>;
impl<Variables: Borrow<ShmPipeVariables>> ConcurrentPipeCore for ShmPipeCore<Variables> {
type AtomicType = AtomicU64;
fn buf_len(&self) -> u64 {
self.buf.len() as u64
}
fn items_read(&self) -> &AtomicU64 {
&self.variables.borrow().items_read
}
fn items_written(&self) -> &AtomicU64 {
&self.variables.borrow().items_written
}
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64) {
self.buf.volatile_read(dst, off as usize)
}
fn write_to_buffer(&mut self, off: u64, src: &[u8]) {
self.buf.volatile_write(off as usize, src);
}
}

View File

@@ -0,0 +1,7 @@
//! Shared-memory ring buffers
mod main;
pub use main::*;
mod local;
pub use local::*;

View File

@@ -0,0 +1,399 @@
//! Accessing data in a shared memory segment
use std::{
borrow::Borrow,
os::fd::{AsFd, OwnedFd},
};
use crate::{
int::u64uint::usize_to_u64,
ptr::{ReadMemVolatile, WriteMemVolatile},
result::OkExt,
secret_memory::{
fd::SecretMemfdConfig,
mmap::{MMapError, MapFdConfig, MappedSegment},
},
};
/// Safe creation of shared memory segments
///
/// This is a slightly more convenient API than using [MapFdConfig] directly.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SharedMemorySegmentBuilder {
/// Configuration for allocating memory file descriptors and their secrecy level
pub secret_memfd_cfg: SecretMemfdConfig,
/// Configuration for mapping file descriptors into memory
pub map_fd_cfg: MapFdConfig,
}
impl SharedMemorySegmentBuilder {
/// Create a new secret memory segment builder.
///
/// Note that this always sets [MapFdConfig::set_shared()], but you can overwrite this behavior
/// by un-setting the flag again.
///
/// # Safety & panic
///
/// This function can panic when [u64] is unable to represent the given size
/// value. This might be the case in future computers impressing me with their excessive size.
pub const fn new(len: usize) -> Self {
let secret_memfd_cfg = SecretMemfdConfig::new();
let map_fd_cfg = MapFdConfig::new()
.set_shared()
.resize_on_mmap(usize_to_u64(len));
Self {
secret_memfd_cfg,
map_fd_cfg,
}
}
/// Create a secret memory segment using the configuration stored here
pub fn create_segment(&self) -> anyhow::Result<(OwnedFd, SharedMemorySegment)> {
let fd = self.secret_memfd_cfg.create()?;
let seg = self.map_fd_cfg.mappable_fd(&fd).mmap()?;
let seg = unsafe { SharedMemorySegment::from_mapped_segment(seg) };
Ok((fd, seg))
}
}
/// Safe creation of and access to shared memory segments
///
/// # Safety
///
/// Any means to create a [Self] must guarantee, that further calls to [Self::volatile_write] and
/// [Self::volatile_read] are also safe.
///
/// The API of this struct is specifically designed so creating multiple memory mappings of the same
/// shared memory segment is impossible without unsafe code.
///
/// We recognize that the sole purpose of shared memory segments is that multiple mappings of them
/// are created, there just is no way to do so using safe rust.
///
/// The reason for this is that by the Rust documentation any concurrent memory access leads to
/// undefined behavior, even volatile accesses caused solely by an adversarial application on the
/// other end of a shared memory communication channel.
///
/// For this reason, we force users to use unsafe code to create shared memory mappings from an
/// existing file descriptor, as this could potentially lead to adversarial data races (and thus to
/// undefined behavior).
///
/// In practice, we believe using concurrent memory access using volatile operations is going to
/// lead to nothing worse than garbled data being transferred. The user should treat any data
/// received through a shared-memory ring buffer as untrusted and validate this data any way, so
/// garbled data should be caught.
///
/// This means that [Self::from_fd()] is unsafe in theory, but most likely safe in practice.
///
/// ## Concurrent, untrusted shared memory is technically undefined behavior
///
/// Its even worse than having to use unsafe: technically speaking, it may be impossible to
/// use shared memory soundly in rust unless all parties with access to the segment are
/// *trusted*. If these parties are not trusted (or buggy) they can always cause undefined
/// behavior:
///
/// From the [std::sync::atomic] documentation:
///
/// > **The most important aspect of this model is that data races are undefined behavior.** A data race
/// > is defined as conflicting non-synchronized accesses where at least one of the accesses is non-atomic.
/// > Here, accesses are conflicting if they affect overlapping regions of memory and at least one of them
/// > is a write. (A compare_exchange or compare_exchange_weak that does not succeed is not considered a
/// > write.) They are non-synchronized if neither of them happens-before the other, according to the
/// > happens-before order of the memory model.
///
/// The fact that this API uses volatile [reads](Self::volatile_read) and [writes](Self::volatile_write),
/// and the the fact that we use mmap(2) for allocation does not mitigate this issue; from the
/// documentation of [std::ptr::write_volatile()]:
///
/// > When a volatile operation is used for memory inside an allocation, it behaves exactly like write,
/// > except for the additional guarantee that it wont be elided or reordered (see above). This implies
/// > that the operation will actually access memory and not e.g. be lowered to a register access. Other
/// > than that, all the usual rules for memory accesses apply (including provenance). In particular, just
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
/// > access from multiple threads. Volatile accesses behave exactly like non-atomic accesses in that regard.
///
/// An allocation is defined as follows (taken from [std::ptr]):
///
/// > An allocation is a subset of program memory which is addressable from Rust, and within which pointer
/// > arithmetic is possible. Examples of allocations include heap allocations, stack-allocated variables,
/// > statics, and consts. The safety preconditions of some Rust operations - such as offset and field
/// > projections (expr.field) - are defined in terms of the allocations on which they operate.
///
/// This definition clearly applies to mmap(2) allocated regions.
///
/// What might mitigate this issue is mapping the region just once per process:
///
/// > In particular, just
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
/// > access from **multiple threads**.
///
/// We could argue that a process is not a thread, and thus concurrent access from two processes is
/// fine, but concurrent access from two threads is not (unless guarded by an atomic value or a
/// mutex or some primitive actually designed for synchronization).
///
/// There is no wording in the spec explicitly allowing raceful, concurrent access from multiple processes.
///
/// The problem with basing our safety-argument on the claim that "processes are not threads" is
/// that the line between processes and threads is drawn in the sand. For linux, read the man page
/// of clone(2):
///
/// > By contrast with fork(2), these [clone, __clone2, clone3] system calls provide more precise control over what pieces of execution
/// > context are shared between the calling process and the child process. For example, using these system
/// > calls, the caller can control whether or not the two processes share the virtual address space, the ta
/// > ble of file descriptors, and the table of signal handlers. These system calls also allow the new child
/// > process to be placed in separate namespaces(7).
/// >
/// > […]
/// >
/// > ## CLONE_THREAD (since Linux 2.4.0)
/// >
/// > If CLONE_THREAD is set, the child is placed in the same thread group as the calling process. To
/// > make the remainder of the discussion of CLONE_THREAD more readable, the term "thread" is used to
/// > refer to the processes within a thread group.
///
/// According to the man page, "the term 'thread' is used to refer to the processes within a thread group.".
///
/// The Rust (transitively, from the C++11 Atomic) specification tells us that there must be no
/// concurrent memory access between threads whether this access is volatile or not. The linux man
/// pages tell us that "thread" is just a special type of "process".
///
/// **The most robust interpretation of these specifications is that shared memory must not be used for
/// communication with an untrusted party across thread or process boundaries, or else the other
/// process/thread can cause undefined behavior in our process.**
///
/// ## In practice
///
/// Realistically, using volatile reads/writes on valid, mapped memory might cause garbled values in
/// case of a data race, but it should crash the program or do anything worse than create garbled
/// values.
///
/// Mind that we do not mind garbled values here; we are implementing a shared memory communication
/// interface, so our application must always assume, that the data it receives may be garbled. It
/// has to be validated. We just don't want the other application to be able to do anything worse
/// that garble the data it is sending (or receiving), so lets estimate what can *realistically*
/// happen here if the other application maliciously causes a race.
///
/// The worst any of the assembly sequences below should do is cause tearing in case of a data
/// race.
///
/// This leads me to the conclusion that what what we are dealing here with is not an
/// implementation that is faulty/insecure, instead it is a definition-gap in the compiler
/// semantics for volatile memory access for use in security-critical applications.
///
/// Godbolt link: <https://rust.godbolt.org/z/GGjsGsc33>
/// Compiler: `rustc 1.90.0`
///
/// Rust code:
///
/// ```rust
/// #[unsafe(no_mangle)]
/// pub fn read_volatile(num: &[u128]) -> u128 {
/// let ptr = num.as_ptr();
/// unsafe { ptr.read_volatile() }
/// }
///
/// #[unsafe(no_mangle)]
/// pub fn write_volatile(num: &mut [u128]) {
/// let ptr = num.as_mut_ptr();
/// unsafe { ptr.write_volatile(42u128) };
/// }
/// ```
///
/// x86_64 (`--target=x86_64-unknown-linux-gnu -O`):
///
/// ```asm
/// read_volatile:
/// mov rax, qword ptr [rdi]
/// mov rdx, qword ptr [rdi + 8]
/// ret
///
/// write_volatile:
/// mov qword ptr [rdi + 8], 0
/// mov qword ptr [rdi], 42
/// ret
/// ```
///
/// arm64 (`--target=aarch64-unknown-linux-gnu -O`):
///
/// ```asm
/// read_volatile:
/// ldp x0, x1, [x0]
/// ret
///
/// write_volatile:
/// mov w8, #42
/// stp x8, xzr, [x0]
/// ret
/// ```
///
/// armv7 (`--target=armv7-unknown-linux-gnueabihf -O`)
///
/// ```asm
/// read_volatile:
/// push {r4, r5, r11, lr}
/// ldrd r2, r3, [r1]
/// ldrd r4, r5, [r1, #8]
/// stm r0, {r2, r3, r4, r5}
/// pop {r4, r5, r11, pc}
///
/// write_volatile:
/// push {r4, r5, r11, lr}
/// mov r2, #0
/// mov r4, #42
/// mov r3, r2
/// mov r5, r2
/// strd r2, r3, [r0, #8]
/// strd r4, r5, [r0]
/// pop {r4, r5, r11, pc}
/// ```
///
/// risc64 (`--target=riscv64gc-unknown-linux-gnu`):
///
/// ```asm
/// read_volatile:
/// ld a1, 8(a0)
/// ld a0, 0(a0)
/// ret
///
/// write_volatile:
/// sd zero, 8(a0)
/// li a1, 42
/// sd a1, 0(a0)
/// ret
/// ```
///
#[derive(Debug)]
pub struct SharedMemorySegment {
/// The underlying mapped segment
inner: MappedSegment,
}
impl SharedMemorySegment {
/// Create a new shared memory segment.
pub fn create(len: usize) -> anyhow::Result<(OwnedFd, Self)> {
SharedMemorySegmentBuilder::new(len).create_segment()
}
/// Create a shared memory segment from a file descriptor
///
/// # Safety
///
/// See the comments in [Self].
pub unsafe fn from_fd<Fd: AsFd>(fd: Fd, size: usize) -> Result<Self, MMapError> {
let cfg = MapFdConfig::new()
.set_shared()
.expected_size(usize_to_u64(size));
unsafe { Self::from_fd_with_config(fd, cfg) }
}
/// Create a shared memory segment from a file descriptor
///
/// # Safety
///
/// See the comments in [Self].
pub unsafe fn from_fd_with_config<Fd: AsFd>(
fd: Fd,
cfg: MapFdConfig,
) -> Result<Self, MMapError> {
let segment = cfg.mappable_fd(&fd).mmap()?;
unsafe { Self::from_mapped_segment(segment).ok() }
}
/// Create a shared memory segment from an existing mapped segment
///
/// # Safety
///
/// See the comments in [Self].
pub unsafe fn from_mapped_segment(inner: MappedSegment) -> Self {
Self { inner }
}
/// The underlying mapped segment
pub fn mapped_segment(&self) -> &MappedSegment {
self.inner.borrow()
}
/// A pointer to the underlying mapped segment
pub fn ptr(&self) -> *mut u8 {
self.mapped_segment().ptr()
}
/// The length of the underlying mapped segment
pub fn len(&self) -> usize {
self.mapped_segment().len()
}
/// Whether `self.`[len()](Self::len)` == 0`
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Read data from the ring buffer
///
/// # Safety
///
/// See comments in [Self].
pub fn volatile_read(&self, dst: &mut [u8], off: usize) {
let end = off + dst.len();
assert!(end <= self.len());
unsafe { self.ptr().add(off).read_mem_volatile(dst) }
}
/// Read data from the ring buffer
///
/// # Safety
///
/// See comments in [Self].
pub fn volatile_write(&self, off: usize, src: &[u8]) {
let end = off + src.len();
assert!(end <= self.len());
unsafe { self.ptr().add(off).write_mem_volatile(src) }
}
}
impl From<SharedMemorySegment> for MappedSegment {
fn from(val: SharedMemorySegment) -> Self {
val.inner
}
}
impl Borrow<MappedSegment> for SharedMemorySegment {
fn borrow(&self) -> &MappedSegment {
self.mapped_segment()
}
}
#[test]
fn test_shared_memory_segment() -> anyhow::Result<()> {
let underscore = [b'_'; 38];
let zero = [0u8; 38];
let test_string = b"Hello World";
let mut after_write = zero.to_owned();
crate::mem::cpy_min(test_string, &mut after_write);
let (fd, reg1) = SharedMemorySegment::create(1024)?;
let reg2 = unsafe { SharedMemorySegment::from_fd(fd, 1024) }?;
let mut buf = underscore.to_owned();
reg1.volatile_read(&mut buf, 0);
assert_eq!(&buf, &zero);
let mut buf = underscore.to_owned();
reg2.volatile_read(&mut buf, 0);
assert_eq!(&buf, &zero);
reg1.volatile_write(0, test_string);
let mut buf = underscore.to_owned();
reg1.volatile_read(&mut buf, 0);
assert_eq!(&buf, &after_write);
let mut buf = underscore.to_owned();
reg2.volatile_read(&mut buf, 0);
assert_eq!(&buf, &after_write);
Ok(())
}

View File

@@ -10,15 +10,16 @@ pub mod b64;
pub mod build;
/// Control flow abstractions and utilities.
pub mod controlflow;
/// File descriptor utilities.
pub mod fd;
pub mod convert;
/// File system operations and handling.
pub mod file;
pub mod fmt;
/// Functional programming utilities.
pub mod functional;
pub mod int;
/// Input/output operations.
pub mod io;
pub mod ipc;
/// Length prefix encoding schemes implementation.
pub mod length_prefix_encoding;
/// Memory manipulation and allocation utilities.
@@ -27,8 +28,13 @@ pub mod mem;
pub mod mio;
/// Extended Option type functionality.
pub mod option;
pub mod ptr;
/// Extended Result type functionality.
pub mod result;
pub mod ringbuf;
pub mod rustix;
pub mod secret_memory;
pub mod sync;
/// Time and duration utilities.
pub mod time;
#[cfg(feature = "tokio")]

View File

@@ -280,6 +280,110 @@ impl<T: Sized> MoveExt for T {
}
}
/// Explicitly copy a value
///
/// This is sometimes useful as an alternative to the deref operator
/// (`x.copy()` instead of `*x`) to achieve neater method chains.
///
/// # Examples
///
/// ```
/// use rosenpass_util::mem::CopyExt;
/// use rosenpass_util::result::OkExt;
///
/// // Without CopyExt
/// fn boilerplate_1<T: Copy>(v: &T) -> Result<T, ()> {
/// (*v).ok()
/// }
///
/// // With CopyExt
/// fn boilerplate_2<T: Copy>(v: &T) -> Result<T, ()> {
/// v.copy().ok()
/// }
///
/// assert_eq!(boilerplate_1(&32), Ok(32));
/// assert_eq!(boilerplate_1(&32), boilerplate_2(&32));
/// ```
pub trait CopyExt: Copy {
/// Copy a value
fn copy(&self) -> Self {
*self
}
}
impl<T: Copy> CopyExt for T {}
/// Helper trait for applying a mutating closure/function to a mutable reference
///
/// # Examples
///
/// ```rust
/// use rosenpass_util::mem::MutateRefExt;
///
/// fn plus_two_mul_three(v: u64) -> u64 {
/// (v + 2) * 3
/// }
///
/// fn inconvenient_example(v: &mut u64) {
/// *v = plus_two_mul_three(*v);
/// }
///
/// fn convenient_example(v: &mut u64) {
/// v.mutate(plus_two_mul_three);
/// }
///
/// let mut x = 0;
/// inconvenient_example(&mut x);
/// assert_eq!(x, 6);
/// convenient_example(&mut x);
/// assert_eq!(x, 24);
///
/// x = 0;
///
/// x.mutate(plus_two_mul_three);
/// assert_eq!(x, 6);
/// x.mutate(|v| (v + 2) * 3);
/// assert_eq!(x, 24);
///
/// x = 0;
///
/// let y = &mut x;
/// y.mutate(plus_two_mul_three);
/// assert_eq!(x, 6);
///
/// struct Cont {
/// x: u64,
/// }
///
/// impl Cont {
/// fn x_mut(&mut self) -> &mut u64 {
/// &mut self.x
/// }
/// }
///
/// let mut s = Cont { x: 0 };
/// s.x_mut().mutate(|v| (v + 2) * 3);
/// ```
pub trait MutateRefExt: DerefMut
where
Self::Target: Sized,
{
/// Directly mutate a reference
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(self, f: F) -> Self;
}
impl<T> MutateRefExt for T
where
T: DerefMut,
<T as Deref>::Target: Copy,
{
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(mut self, f: F) -> Self {
let v = self.deref_mut();
*v = f(*v);
self
}
}
#[cfg(test)]
mod test_forgetting {
use crate::mem::Forgetting;

View File

@@ -2,8 +2,8 @@ use mio::net::{UnixListener, UnixStream};
use std::os::fd::{OwnedFd, RawFd};
use crate::{
fd::{claim_fd, claim_fd_inplace},
result::OkExt,
rustix::{claim_fd, claim_fd_inplace},
};
/// Module containing I/O interest flags for Unix operations (see also: [mio::Interest])
@@ -93,7 +93,7 @@ impl UnixStreamExt for UnixStream {
fn from_fd(fd: OwnedFd) -> anyhow::Result<Self> {
use std::os::unix::net::UnixStream as StdUnixStream;
#[cfg(target_os = "linux")] // TODO: We should support this on other plattforms
crate::fd::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
crate::rustix::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
let sock = StdUnixStream::from(fd);
sock.set_nonblocking(true)?;
UnixStream::from_std(sock).ok()

View File

@@ -7,7 +7,7 @@ use std::{
};
use uds::UnixStreamExt as FdPassingExt;
use crate::fd::{claim_fd_inplace, IntoStdioErr};
use crate::rustix::{claim_fd_inplace, IntoStdioErr};
/// A wrapper around a socket that combines reading from the socket with tracking
/// received file descriptors. Limits the maximum number of file descriptors that

4
util/src/ptr/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Utilities for working with pointers
mod volatile;
pub use volatile::*;

51
util/src/ptr/volatile.rs Normal file
View File

@@ -0,0 +1,51 @@
//! Utilities relating to volatile reads/writes on pointers
/// Read from a memory location using
/// [pointer::read_volatile] in a loop and store the
/// results in the given slice
pub trait ReadMemVolatile<T> {
/// Read from a memory location using
/// [pointer::read_volatile] in a loop and store the
/// results in an array
///
/// # Safety
///
/// Refer to [pointer::read_volatile]
unsafe fn read_mem_volatile(self, dst: &mut [T]);
}
impl<T> ReadMemVolatile<T> for *const T {
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
for (idx, dst) in dst.iter_mut().enumerate() {
*dst = unsafe { self.add(idx).read_volatile() };
}
}
}
impl<T> ReadMemVolatile<T> for *mut T {
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
unsafe { self.cast_const().read_mem_volatile(dst) }
}
}
/// Write to a memory location using
/// [pointer::write_volatile] in a loop
/// and store the resulting values in the given slice
pub trait WriteMemVolatile<T>: ReadMemVolatile<T> {
/// Write to a memory location using
/// [pointer::write_volatile] in a loop
/// and store the resulting values in the given slice
///
/// # Safety
///
/// Refer to [pointer::write_volatile]
unsafe fn write_mem_volatile(self, src: &[T]);
}
impl<T: Copy> WriteMemVolatile<T> for *mut T {
unsafe fn write_mem_volatile(self, src: &[T]) {
for (idx, src) in src.iter().enumerate() {
unsafe { self.add(idx).write_volatile(*src) }
}
}
}

View File

@@ -0,0 +1,3 @@
//! Concurrent ring buffers
pub mod framework;

View File

@@ -0,0 +1,243 @@
//! An implementation of a concurrent ring buffer with the
//! core (IO/memory access) operations in a detached trait. Everything around the core is
//! implemented but before being able to use this, callers must implement an appropriate IO core.
use crate::int::u64uint::U64USizeRangeExt;
use crate::result::OkExt;
use crate::ringbuf::sched::{
Diff, OperationType, RingBufferFromCountersError, RingBufferScheduler, ScheduledOperations,
};
use crate::sync::atomic::abstract_atomic::AbstractAtomic;
/// Core trait used by [ConcurrentPipeReader] and [ConcurrentPipeWriter] to implement
/// a concurrent pipe based on a ring buffer
pub trait ConcurrentPipeCore {
/// The type used for the atomics; i.e. the values returned by
/// [Self::items_read] and [Self.:items_written].
type AtomicType: AbstractAtomic<u64>;
/// Length of the underlying memory buffer
fn buf_len(&self) -> u64;
/// Number of items read
fn items_read(&self) -> &Self::AtomicType;
/// Number of items written
fn items_written(&self) -> &Self::AtomicType;
/// Copy data from the underlying memory buffer into the destination
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64);
/// Write data from src into the underlying memory buffer
fn write_to_buffer(&mut self, off: u64, src: &[u8]);
}
/// Raised by [ConcurrentPipeReader::read()]/[ConcurrentPipeWriter::write()] if the underlying
/// ring buffer is in an inconsistent state.
///
/// When used in shared memory applications, this error may have been locally caused, but it may
/// also be caused by **the other threads or processes** putting the ring buffer into an
/// inconsistent state.
///
/// For this reason, you must treat this error as an unrecoverable breakdown of the communication channel. You
/// must close and no longer use ring buffer after receiving this error, but you can handle it
/// gracefully by reopening a new ring buffer in its place.
///
/// # API Stability
///
/// Treat this error type as opaque; do not rely on the enum values remaining the same
#[derive(Debug, thiserror::Error, Clone)]
pub enum InconsistentRingBufferStateError {
/// Failed to update an inner counter; this probably indicates a data race (forbidden
/// concurrent read/write)
#[error("Concurrent access to the ring buffer {:?}", self)]
ConcurrrentAccess {
/// Whether this was a read or write
operation_type: OperationType,
/// Operations performed before updaing the ring buffer state
scheduled_ops: ScheduledOperations,
/// The scheduler that scheduled our operation
scheduler_state: RingBufferScheduler,
/// The counter value before compare and exchange
expected_counter_value: u64,
/// The counter value we actually found in the counter
actual_counter_value: u64,
/// The counter value we tried to set
new_counter_value_tried_to_set: u64,
},
/// Could not construct ring buffer scheduler
#[error("Inconsistent ring buffer state: {:?}", .0)]
InconsistentCounterState(#[from] RingBufferFromCountersError),
}
/// Indicator for [ConcurrentPipeImpl::read_or_write] about which operation
/// is being executed
#[derive(Debug)]
enum ConcurrentPipeOperation<'a> {
/// This call implements [ConcurrentPipeReader::read]
Read(&'a mut [u8]),
/// This call implements [ConcurrentPipeReader::write]
Write(&'a [u8]),
}
impl<'a> ConcurrentPipeOperation<'a> {
/// Read-only access to the operation buffer
pub fn inner_buf(&'a self) -> &'a [u8] {
match self {
ConcurrentPipeOperation::Read(items) => items,
ConcurrentPipeOperation::Write(items) => items,
}
}
/// Length of [Self::inner_buf]
pub fn len(&self) -> usize {
self.inner_buf().len()
}
/// Decides which type of operation, in terms of [OperationType] [Self] represents
pub fn scheduler_op(&self) -> OperationType {
match self {
ConcurrentPipeOperation::Read(_) => OperationType::Read,
ConcurrentPipeOperation::Write(_) => OperationType::Write,
}
}
}
/// The implementations of [ConcurrentPipeReader] and [ConcurrentPipeWriter]
/// happen to be extremely similar. This struct forms the basis of both.
struct ConcurrentPipeImpl<Core: ConcurrentPipeCore> {
/// Core trait
core: Core,
}
impl<Core: ConcurrentPipeCore> ConcurrentPipeImpl<Core> {
/// Like [ConcurrentPipeReader::from_core] and [ConcurrentPipeWriter::from_core]
fn from_core(core: Core) -> Self {
Self { core }
}
/// The implementations of [ConcurrentPipeReader::read] and [ConcurrentPipeWriter::write]
/// happen to be extremely similar. This function implements both.
fn read_or_write(
&mut self,
mut op: ConcurrentPipeOperation,
) -> Result<usize, InconsistentRingBufferStateError> {
use std::sync::atomic::Ordering as O;
// Figure out which counter to store the result of the operation in and the orderings to
// use for load/store operations
let (ord_r, ord_w, ord_store_succ, ord_store_fail) = match op {
ConcurrentPipeOperation::Read(_) => (O::Relaxed, O::Acquire, O::Relaxed, O::Relaxed),
ConcurrentPipeOperation::Write(_) => (O::Relaxed, O::Relaxed, O::Release, O::Relaxed),
};
// Construct a ring buffer scheduler from the current state
let sched = RingBufferScheduler::try_from_counters(
self.core.buf_len(),
self.core.items_written().load(ord_w),
self.core.items_read().load(ord_r),
)?;
// Have the scheduler schedule the operations
let ops = sched.schedule_contigous_operations(op.scheduler_op(), op.len());
// Actually perform the operations
for (buf_slice, ring_op) in ops.with_outside_buffer_range() {
match &mut op {
ConcurrentPipeOperation::Read(dst) => self
.core
.read_from_buffer(&mut dst[buf_slice.usize()], ring_op.off),
ConcurrentPipeOperation::Write(src) => self
.core
.write_to_buffer(ring_op.off, &src[buf_slice.usize()]),
}
}
// Take a differential between the scheduler and the scheduler after the operations where
// applied. Make sure only one counter was updated
let diff = sched
.register_operation(&ops)
.diff_old(&sched)
.expect_op_only(op.scheduler_op())
.unwrap();
let store_to_ctr = match op {
ConcurrentPipeOperation::Read(_) => self.core.items_read(),
ConcurrentPipeOperation::Write(_) => self.core.items_written(),
};
// Update the counters, assuming there was any change at all
if let Diff::Different(old, new) = diff {
store_to_ctr
.compare_exchange(old, new, ord_store_succ, ord_store_fail)
.map_err(
|actual| InconsistentRingBufferStateError::ConcurrrentAccess {
operation_type: op.scheduler_op(),
scheduled_ops: ops,
scheduler_state: sched,
expected_counter_value: old,
actual_counter_value: actual,
new_counter_value_tried_to_set: new,
},
)?;
}
ops.cumulative_operation_length().usize().ok()
}
}
/// Provides the necessary boilerplate around [ConcurrentPipeCore] to implement
/// reading from the pipe
pub struct ConcurrentPipeReader<Core: ConcurrentPipeCore> {
/// The implementations of [ConcurrentPipeReader::read] and [ConcurrentPipeWriter::write]
/// happen to be extremely similar. We use [ConcurrentPipeImpl] to implement both.
inner: ConcurrentPipeImpl<Core>,
}
impl<Core: ConcurrentPipeCore> ConcurrentPipeWriter<Core> {
/// Create a [Self] from a [ConcurrentPipeCore]
pub fn from_core(core: Core) -> Self {
Self {
inner: ConcurrentPipeImpl::from_core(core),
}
}
/// Determine the length of the underlying ring buffer
pub fn buf_len(&self) -> u64 {
self.inner.core.buf_len()
}
/// Write data into the concurrent pipe.
///
/// Returns the number of bytes actually written.
pub fn write(&mut self, src: &[u8]) -> Result<usize, InconsistentRingBufferStateError> {
self.inner
.read_or_write(ConcurrentPipeOperation::Write(src))
}
}
/// Provides the necessary boilerplate around [ConcurrentPipeCore] to implement
/// writing to the pipe
pub struct ConcurrentPipeWriter<Core: ConcurrentPipeCore> {
/// The implementations of [ConcurrentPipeReader::read] and [ConcurrentPipeWriter::write]
/// happen to be extremely similar. We use [ConcurrentPipeImpl] to implement both.
inner: ConcurrentPipeImpl<Core>,
}
impl<Core: ConcurrentPipeCore> ConcurrentPipeReader<Core> {
/// Create a [Self] from a [ConcurrentPipeCore]
pub fn from_core(core: Core) -> Self {
Self {
inner: ConcurrentPipeImpl::from_core(core),
}
}
/// Determine the length of the underlying ring buffer
pub fn buf_len(&self) -> u64 {
self.inner.core.buf_len()
}
/// Read data from the concurrent pipe.
///
/// Returns the number of bytes read.
pub fn read(&mut self, dst: &mut [u8]) -> Result<usize, InconsistentRingBufferStateError> {
self.inner.read_or_write(ConcurrentPipeOperation::Read(dst))
}
}

View File

@@ -0,0 +1,3 @@
//! Concurrent ring buffers
pub mod framework;

4
util/src/ringbuf/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
//! Ring buffers and utilities for working with them
pub mod concurrent;
pub mod sched;

1410
util/src/ringbuf/sched.rs Normal file

File diff suppressed because it is too large Load Diff

117
util/src/rustix/error.rs Normal file
View File

@@ -0,0 +1,117 @@
//! Rustix extensions for error handling
/// Provides access to the last system error number
///
/// > The integer variable errno is set by system calls and some library functions in the event of an error to indicate what went wrong.
///
/// -- `man 3 errno`
///
/// # Panics
///
/// This function panics if ther
///
/// # Examples
///
/// ```rust
///
/// use rustix::io::Errno as E;
/// use rosenpass_util::rustix::{errno, try_errno, last_os_result};
///
/// let res = unsafe { libc::mkdir(c"/tmp/baz".as_ptr(), 0) };
/// assert_eq!(res, -1);
/// assert_eq!(errno(), E::EXIST);
/// assert_eq!(try_errno(), Some(E::EXIST));
/// assert_eq!(last_os_result(), Err(E::EXIST));
///
/// // Deliberately clear the system error
/// unsafe { libc::__errno_location().write(0) };
/// // assert_eq!(errno(), _); // PANICS
/// assert_eq!(try_errno(), None);
/// assert_eq!(last_os_result(), Ok(()));
/// ```
///
/// Calling errno() when there is no error causes a panic:
///
/// ```rust,should_panic
///
/// use rustix::io::Errno as E;
/// use rosenpass_util::rustix::errno;
///
/// // Deliberately clear the system error
/// unsafe { libc::__errno_location().write(0) };
/// errno(); // PANICS
/// ```
pub fn errno() -> rustix::io::Errno {
match try_errno() {
None => panic!("Tried to retrieve last system error, but there was no system error (the system error number, errno = 0)"),
Some(errno) => errno,
}
}
/// Provides access to the last system error number.
///
/// Variant of [errno()] that will return None if there was no system error.
///
/// # Examples
///
/// See [errno()].
pub fn try_errno() -> Option<rustix::io::Errno> {
let raw = unsafe { libc::__errno_location().read() };
match raw {
0 => None,
_ => Some(rustix::io::Errno::from_raw_os_error(raw)),
}
}
/// Provides access to the last system error number.
///
/// Variant of [errno()] that will return `Err(errno)` if there
/// was a system error and `Ok(())` otherwise.
///
/// # Examples
///
/// See [errno()].
pub fn last_os_result() -> Result<(), rustix::io::Errno> {
match try_errno() {
None => Ok(()),
Some(errno) => Err(errno),
}
}
/// Convert low level errors into std::io::Error
///
/// # Examples
///
/// ```
/// use std::io::ErrorKind as EK;
/// use rustix::io::Errno;
/// use rosenpass_util::rustix::IntoStdioErr;
///
/// let e = Errno::INTR.into_stdio_err();
/// assert!(matches!(e.kind(), EK::Interrupted));
///
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
/// ```
pub trait IntoStdioErr {
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
type Target;
/// Convert low level errors to
fn into_stdio_err(self) -> Self::Target;
}
impl IntoStdioErr for rustix::io::Errno {
type Target = std::io::Error;
fn into_stdio_err(self) -> Self::Target {
std::io::Error::from_raw_os_error(self.raw_os_error())
}
}
impl<T> IntoStdioErr for rustix::io::Result<T> {
type Target = std::io::Result<T>;
fn into_stdio_err(self) -> Self::Target {
self.map_err(IntoStdioErr::into_stdio_err)
}
}

255
util/src/rustix/fd.rs Normal file
View File

@@ -0,0 +1,255 @@
//! Basic utilities for working with file descriptors
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
use rustix::io::fcntl_dupfd_cloexec;
use super::IntoStdioErr;
use crate::mem::Forgetting;
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
/// in case the given file descriptor is still used somewhere
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::{IntoRawFd, AsRawFd};
/// use tempfile::tempdir;
/// use rosenpass_util::rustix::{claim_fd, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let orig = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut claimed = FdIo(claim_fd(orig)?);
///
/// // A different file descriptor is used
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
///
/// // Write some data
/// claimed.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
mask_fd(fd)?;
Ok(new)
}
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid.
///
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::IntoRawFd;
/// use tempfile::tempdir;
/// use rosenpass_util::rustix::{claim_fd_inplace, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let fd = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
///
/// // Write some data
/// fd.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
let tmp = clone_fd_cloexec(&new)?;
clone_fd_to_cloexec(tmp, &mut new)?;
Ok(new)
}
/// Will close the given file descriptor and overwrite
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Example
/// ```
/// # use std::fs::File;
/// # use std::io::Read;
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
/// # use std::os::fd::IntoRawFd;
/// # use rustix::fd::AsFd;
/// # use rosenpass_util::rustix::mask_fd;
///
/// // Open a temporary file
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
/// assert!(fd >= 0);
///
/// // Mask the file descriptor
/// mask_fd(fd).unwrap();
///
/// // Verify the file descriptor now points to `/dev/null`
/// // Reading from `/dev/null` always returns 0 bytes
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
/// let mut buffer = [0u8; 4];
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
/// assert_eq!(bytes_read, 0);
/// ```
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
}
/// Duplicate a file descriptor, setting the close on exec flag
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
/// Avoid stdin, stdout, and stderr
const MINFD: RawFd = 3;
fcntl_dupfd_cloexec(fd, MINFD)
}
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
#[cfg(target_os = "linux")]
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup3, DupFlags};
dup3(fd, new, DupFlags::CLOEXEC)
}
#[cfg(not(target_os = "linux"))]
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup2, fcntl_setfd, FdFlags};
dup2(&fd, new)?;
fcntl_setfd(&new, FdFlags::CLOEXEC)
}
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
/// writing.
///
/// # Safety
///
/// The behavior of the file descriptor when being written to or from is undefined.
///
/// # Examples
///
/// ```
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
/// use rustix::fd::FromRawFd;
/// use rosenpass_util::rustix::open_nullfd;
///
/// let nullfd = open_nullfd().unwrap();
/// ```
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
use rustix::fs::{open, Mode, OFlags};
// TODO: Add tests showing that this will throw errors on use
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
}
/// Read and write directly from a file descriptor
///
/// # Examples
///
/// See [claim_fd].
pub struct FdIo<Fd: AsFd>(pub Fd);
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
rustix::io::read(&self.0, buf).into_stdio_err()
}
}
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
rustix::io::write(&self.0, buf).into_stdio_err()
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
#[test]
#[should_panic]
fn test_claim_fd_invalid_neg() {
let _ = claim_fd(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_invalid_max() {
let _ = claim_fd(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_neg() {
let _ = claim_fd_inplace(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_max() {
let _ = claim_fd_inplace(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_neg() {
let _ = mask_fd(-1);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_max() {
let _ = mask_fd(i64::MAX as RawFd);
}
#[test]
fn test_open_nullfd() -> anyhow::Result<()> {
let mut file = FdIo(open_nullfd()?);
let mut buf = [0; 10];
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
assert!(file.write(&buf).is_err());
Ok(())
}
#[test]
fn test_nullfd_read_write() {
let nullfd = open_nullfd().unwrap();
let mut buf = vec![0u8; 16];
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
assert!(rustix::io::write(&nullfd, b"test").is_err());
}
}

88
util/src/rustix/memfd.rs Normal file
View File

@@ -0,0 +1,88 @@
//! Utilities for working with memory based file descriptors
use std::os::fd::OwnedFd;
use rustix::fs::MemfdFlags;
use rustix::io::Errno;
use rustix::path::Arg as Path;
use bitflags::bitflags;
use crate::convert::IntoTypeExt;
use super::SyscallResult;
/// Create an anonymous file
/// using the memfd_create(2) syscall
///
/// Just forwards to [rustix::fs::memfd_create]
pub fn memfd_create<P: Path>(name: P, flags: MemfdFlags) -> rustix::io::Result<OwnedFd> {
rustix::fs::memfd_create(name, flags)
}
bitflags! {
/// `FD_*` constants for use with [memfd_secret].
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct MemfdSecretFlags: std::ffi::c_uint {
/// FD_CLOEXEC
const CLOEXEC = libc::FD_CLOEXEC as std::ffi::c_uint;
}
}
/// Errors for [create_memfd_secret()]
#[derive(Copy, Clone, Debug, thiserror::Error)]
pub enum MemfdSecretError {
/// memfd_secret(2) not supported on system
#[error("Could not create secret memory segment using memfd_secret(2): not supported on your system")]
NotSupported,
/// Other error
#[error("Could not create secret memory segment using memfd_secret(2): underlying system error: {:?}", .0)]
SystemError(Errno),
}
impl From<Errno> for MemfdSecretError {
fn from(value: Errno) -> Self {
match value {
Errno::NOSYS => Self::NotSupported,
e => Self::SystemError(e),
}
}
}
/// Create an anonymous RAM-based file to access secret memory regions
/// using the memfd_secret(2) syscall
///
/// # Examples
///
/// ```
/// use rustix::io::Errno;
/// use rustix::fs::ftruncate;
///
/// use rosenpass_util::rustix::{memfd_secret, MemfdSecretFlags, IntoStdioErr, MemfdSecretError};
/// use rosenpass_util::io::handle_interrupted;
///
/// let res = memfd_secret(MemfdSecretFlags::empty());
///
/// use MemfdSecretError as E;
/// let fd = match res {
/// Ok(fd) => fd,
/// // The system might not have memfd_secret enabled; abort the test
/// Err(E::NotSupported) => return Ok(()),
/// Err(E::SystemError(err)) => return Err(err)?,
/// };
///
/// handle_interrupted(|| { ftruncate(&fd, 8192).into_stdio_err() })?;
///
/// Ok::<(), anyhow::Error>(())
/// ```
pub fn memfd_secret(flags: MemfdSecretFlags) -> Result<rustix::fd::OwnedFd, MemfdSecretError> {
let res = unsafe {
use libc::{syscall, SYS_memfd_secret};
syscall(SYS_memfd_secret, flags)
.into_type::<SyscallResult>()
.claim_fd()
};
res.map_err(MemfdSecretError::from)
}

16
util/src/rustix/mod.rs Normal file
View File

@@ -0,0 +1,16 @@
//! Extensions to the rustix crate for memory safe operating system interfaces
mod error;
pub use error::*;
mod fd;
pub use fd::*;
mod stat;
pub use stat::*;
mod syscall;
pub use syscall::*;
mod memfd;
pub use memfd::*;

View File

@@ -1,235 +1,12 @@
//! Utilities for working with file descriptors
//! Rustix extensions for getting information about file descriptors`
use std::os::fd::AsFd;
use anyhow::bail;
use rustix::io::fcntl_dupfd_cloexec;
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
use crate::{mem::Forgetting, result::OkExt};
use super::IntoStdioErr;
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
/// in case the given file descriptor is still used somewhere
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::{IntoRawFd, AsRawFd};
/// use tempfile::tempdir;
/// use rosenpass_util::fd::{claim_fd, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let orig = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut claimed = FdIo(claim_fd(orig)?);
///
/// // A different file descriptor is used
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
///
/// // Write some data
/// claimed.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
mask_fd(fd)?;
Ok(new)
}
/// Prepare a file descriptor for use in Rust code.
///
/// Checks if the file descriptor is valid.
///
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Examples
///
/// ```
/// use std::io::Write;
/// use std::os::fd::IntoRawFd;
/// use tempfile::tempdir;
/// use rosenpass_util::fd::{claim_fd_inplace, FdIo};
///
/// // Open a file and turn it into a raw file descriptor
/// let fd = tempfile::tempfile()?.into_raw_fd();
///
/// // Reclaim that file and ready it for reading
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
///
/// // Write some data
/// fd.write_all(b"Hello, World!")?;
///
/// Ok::<(), std::io::Error>(())
/// ```
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
let tmp = clone_fd_cloexec(&new)?;
clone_fd_to_cloexec(tmp, &mut new)?;
Ok(new)
}
/// Will close the given file descriptor and overwrite
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
///
/// # Panic and safety
///
/// Will panic if the given file descriptor is negative of or larger than
/// the file descriptor numbers permitted by the operating system.
///
/// # Example
/// ```
/// # use std::fs::File;
/// # use std::io::Read;
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
/// # use std::os::fd::IntoRawFd;
/// # use rustix::fd::AsFd;
/// # use rosenpass_util::fd::mask_fd;
///
/// // Open a temporary file
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
/// assert!(fd >= 0);
///
/// // Mask the file descriptor
/// mask_fd(fd).unwrap();
///
/// // Verify the file descriptor now points to `/dev/null`
/// // Reading from `/dev/null` always returns 0 bytes
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
/// let mut buffer = [0u8; 4];
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
/// assert_eq!(bytes_read, 0);
/// ```
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
}
/// Duplicate a file descriptor, setting the close on exec flag
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
/// Avoid stdin, stdout, and stderr
const MINFD: RawFd = 3;
fcntl_dupfd_cloexec(fd, MINFD)
}
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
#[cfg(target_os = "linux")]
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup3, DupFlags};
dup3(fd, new, DupFlags::CLOEXEC)
}
#[cfg(not(target_os = "linux"))]
/// Duplicate a file descriptor, setting the close on exec flag.
///
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
/// explicit destination file descriptor.
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
use rustix::io::{dup2, fcntl_setfd, FdFlags};
dup2(&fd, new)?;
fcntl_setfd(&new, FdFlags::CLOEXEC)
}
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
/// writing.
///
/// # Safety
///
/// The behavior of the file descriptor when being written to or from is undefined.
///
/// # Examples
///
/// ```
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
/// use rustix::fd::FromRawFd;
/// use rosenpass_util::fd::open_nullfd;
///
/// let nullfd = open_nullfd().unwrap();
/// ```
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
use rustix::fs::{open, Mode, OFlags};
// TODO: Add tests showing that this will throw errors on use
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
}
/// Convert low level errors into std::io::Error
///
/// # Examples
///
/// ```
/// use std::io::ErrorKind as EK;
/// use rustix::io::Errno;
/// use rosenpass_util::fd::IntoStdioErr;
///
/// let e = Errno::INTR.into_stdio_err();
/// assert!(matches!(e.kind(), EK::Interrupted));
///
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
/// ```
pub trait IntoStdioErr {
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
type Target;
/// Convert low level errors to
fn into_stdio_err(self) -> Self::Target;
}
impl IntoStdioErr for rustix::io::Errno {
type Target = std::io::Error;
fn into_stdio_err(self) -> Self::Target {
std::io::Error::from_raw_os_error(self.raw_os_error())
}
}
impl<T> IntoStdioErr for rustix::io::Result<T> {
type Target = std::io::Result<T>;
fn into_stdio_err(self) -> Self::Target {
self.map_err(IntoStdioErr::into_stdio_err)
}
}
/// Read and write directly from a file descriptor
///
/// # Examples
///
/// See [claim_fd].
pub struct FdIo<Fd: AsFd>(pub Fd);
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
rustix::io::read(&self.0, buf).into_stdio_err()
}
}
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
rustix::io::write(&self.0, buf).into_stdio_err()
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
use crate::result::OkExt;
/// Helpers for accessing stat(2) information
pub trait StatExt {
@@ -238,7 +15,7 @@ pub trait StatExt {
/// # Examples
///
/// ```
/// use rosenpass_util::fd::StatExt;
/// use rosenpass_util::rustix::StatExt;
/// assert!(rustix::fs::stat("/")?.is_socket() == false);
/// Ok::<(), rustix::io::Errno>(())
/// ````
@@ -263,7 +40,7 @@ pub trait TryStatExt {
/// # Examples
///
/// ```
/// use rosenpass_util::fd::TryStatExt;
/// use rosenpass_util::rustix::TryStatExt;
/// let fd = rustix::fs::open("/", rustix::fs::OFlags::empty(), rustix::fs::Mode::empty())?;
/// assert!(matches!(fd.is_socket(), Ok(false)));
/// Ok::<(), rustix::io::Errno>(())
@@ -358,7 +135,7 @@ pub trait GetUnixSocketType {
/// # use std::os::fd::{AsFd, BorrowedFd};
/// # use std::os::unix::net::UnixListener;
/// # use tempfile::NamedTempFile;
/// # use rosenpass_util::fd::GetUnixSocketType;
/// # use rosenpass_util::rustix::GetUnixSocketType;
/// let f = {
/// // Generate a temp file and take its path
/// // Remove the temp file
@@ -378,7 +155,7 @@ pub trait GetUnixSocketType {
/// # use std::os::fd::{AsFd, BorrowedFd};
/// # use std::os::unix::net::{UnixDatagram, UnixListener};
/// # use tempfile::NamedTempFile;
/// # use rosenpass_util::fd::GetUnixSocketType;
/// # use rosenpass_util::rustix::GetUnixSocketType;
/// let f = {
/// // Generate a temp file and take its path
/// // Remove the temp file
@@ -445,7 +222,7 @@ pub trait GetSocketProtocol {
/// ```
/// # use std::net::UdpSocket;
/// # use std::os::fd::{AsFd, AsRawFd};
/// # use rosenpass_util::fd::GetSocketProtocol;
/// # use rosenpass_util::rustix::GetSocketProtocol;
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
/// assert_eq!(socket.as_fd().socket_protocol().unwrap().unwrap(), rustix::net::ipproto::UDP);
/// # Ok::<(), std::io::Error>(())
@@ -458,7 +235,7 @@ pub trait GetSocketProtocol {
/// # use std::net::UdpSocket;
/// # use std::net::TcpListener;
/// # use std::os::fd::{AsFd, AsRawFd};
/// # use rosenpass_util::fd::GetSocketProtocol;
/// # use rosenpass_util::rustix::GetSocketProtocol;
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
/// assert!(socket.as_fd().is_udp_socket().unwrap());
///
@@ -484,7 +261,7 @@ pub trait GetSocketProtocol {
/// # use std::net::UdpSocket;
/// # use std::net::TcpListener;
/// # use std::os::fd::{AsFd, AsRawFd};
/// # use rosenpass_util::fd::GetSocketProtocol;
/// # use rosenpass_util::rustix::GetSocketProtocol;
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
/// assert!(matches!(socket.as_fd().demand_udp_socket(), Ok(())));
///
@@ -514,62 +291,3 @@ where
rustix::net::sockopt::get_socket_protocol(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
#[test]
#[should_panic]
fn test_claim_fd_invalid_neg() {
let _ = claim_fd(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_invalid_max() {
let _ = claim_fd(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_neg() {
let _ = claim_fd_inplace(-1);
}
#[test]
#[should_panic]
fn test_claim_fd_inplace_invalid_max() {
let _ = claim_fd_inplace(i64::MAX as RawFd);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_neg() {
let _ = mask_fd(-1);
}
#[test]
#[should_panic]
fn test_mask_fd_invalid_max() {
let _ = mask_fd(i64::MAX as RawFd);
}
#[test]
fn test_open_nullfd() -> anyhow::Result<()> {
let mut file = FdIo(open_nullfd()?);
let mut buf = [0; 10];
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
assert!(file.write(&buf).is_err());
Ok(())
}
#[test]
fn test_nullfd_read_write() {
let nullfd = open_nullfd().unwrap();
let mut buf = vec![0u8; 16];
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
assert!(rustix::io::write(&nullfd, b"test").is_err());
}
}

View File

@@ -0,0 +1,46 @@
//! Helpers for performing system calls
use std::os::fd::FromRawFd;
use super::errno;
/// Wrapper type around [libc::c_long] that indicates that this value represents
/// the result of a system call
#[repr(transparent)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct SyscallResult(pub libc::c_long);
impl SyscallResult {
/// Access to [Self::0]
pub fn raw_value(&self) -> libc::c_long {
self.0
}
/// Claim the system call result as a file descriptor
///
/// - If [Self::raw_value] < 0, then [errno()] is called to retrieve the error type
/// - If [Self::raw_value] > [i32::MAX], panics
/// - Otherwise, this just forwards to [rustix::fd::OwnedFd::from_raw_fd]
///
/// # Panic
///
/// Panics if [Self::raw_value] > [i32::MAX].
///
/// # Safety
///
/// Refer to [rustix::fd::OwnedFd::from_raw_fd].
pub unsafe fn claim_fd(&self) -> Result<rustix::fd::OwnedFd, rustix::io::Errno> {
let fde = self.0;
match fde {
e if e < 0 => Err(errno()),
fd if fd > i32::MAX.into() => panic!("File descriptor `{fd}` is out of bounds!"),
fd => Ok(unsafe { rustix::fd::OwnedFd::from_raw_fd(fd as i32) }),
}
}
}
impl From<libc::c_long> for SyscallResult {
fn from(value: libc::c_long) -> Self {
Self(value)
}
}

View File

@@ -0,0 +1,202 @@
//! Creation of secret memory file descriptors.
//!
//! This essentially provides a higher_level API for [memfd_secret()] and [memfd_create()]
// Tests: This uses nix-based integration tests
use std::{os::fd::OwnedFd, sync::OnceLock};
use rustix::{fs::MemfdFlags, io::Errno};
use crate::rustix::{memfd_create, memfd_secret, MemfdSecretError, MemfdSecretFlags};
use crate::{mem::CopyExt, result::OkExt};
/// Cache for [memfd_secret_supported]
static MEMFD_SECRET_SUPPORTED: OnceLock<bool> = OnceLock::new();
/// Check whether support for memfd_secret is available
pub fn memfd_secret_supported() -> Result<bool, Errno> {
match MEMFD_SECRET_SUPPORTED.get() {
Some(v) => return Ok(*v),
_ => {} // Continue
};
use MemfdSecretError as E;
let is_supported = match memfd_secret(MemfdSecretFlags::empty()) {
Ok(_) => true,
Err(E::NotSupported) => false,
Err(E::SystemError(e)) => return Err(e),
};
// We are deliberately using get_or_init here to make sure that the entire application
// never sees different values here
MEMFD_SECRET_SUPPORTED
.get_or_init(|| is_supported)
.copy()
.ok()
}
/// How secure memory file descriptors should be allocated
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
pub enum SecretMemfdPolicy {
/// Use memfd_secret(2) if available, otherwise fall back to less
/// secure options
Opportunistic,
/// Enforce the use of memfd_secret(2)
UseMemfdSecret,
/// Never use memfd_secret(2)
DisableMemfdSecret,
}
impl Default for SecretMemfdPolicy {
fn default() -> Self {
Self::Opportunistic
}
}
impl SecretMemfdPolicy {
/// Create a SecretMemfdPolicy with the default policy
///
/// Currently [Self::Opportunistic]
pub const fn default_const() -> Self {
Self::Opportunistic
}
/// Enforce the use of the highest security configuration available
///
/// This might not work on some systems, which is why this is not used
/// by default.
///
/// Currently [Self::UseMemfdSecret]
pub const fn enforce_high_security() -> Self {
Self::UseMemfdSecret
}
}
/// Which mechanism to us use when allocating secret memory file descriptors
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
pub enum SecretMemfdMechanism {
/// The less secure memfd_create(2) will be used
MemfdCreate,
/// The more secure memfd_secret(2) will be used
MemfdSecret,
}
impl SecretMemfdMechanism {
/// Decide which mechanism to use, based on the given [SecretMemfdPolicy]
/// and [memfd_secret_supported()].
///
/// If [SecretMemfdPolicy::UseMemfdSecret] is used, then [SecretMemfdMechanism::MemfdSecret]
/// will be returned, regardless of whether it is supported.
///
/// Likewise, if [SecretMemfdPolicy::DisableMemfdSecret] is used, then
/// [SecretMemfdMechanism::MemfdCreate] will be used unconditionally.
pub fn decide_with_policy(policy: SecretMemfdPolicy) -> Result<Self, Errno> {
use SecretMemfdMechanism as M;
use SecretMemfdPolicy as P;
match policy {
P::UseMemfdSecret => return Ok(M::MemfdSecret),
P::DisableMemfdSecret => return Ok(M::MemfdCreate),
P::Opportunistic => {}
};
match memfd_secret_supported()? {
true => Ok(M::MemfdSecret),
false => Ok(M::MemfdCreate),
}
}
}
/// Errors for [create_memfd_secret()]
#[derive(Copy, Clone, Debug, thiserror::Error)]
pub enum SecretMemfdWithConfigError {
/// Call to [memfd_secret()] failed
#[error("{:?}", .0)]
MemfdSecretError(#[from] MemfdSecretError),
/// Call to [memfd()] failed
#[error("Could not create secret memory segment using memfd_create(2) due to underlying system error: {:?}", .0)]
MemfdCreateError(Errno),
/// Some other (system) error occurred that prevented us from determining whether
/// memfd_secret(2) is supported.
#[error("Failed to determine whether memfd_secret(2) is supported due to underlying system error: {:?}", .0)]
FailedToDetectSupport(Errno),
}
/// Robustly configure and create secret memory file descriptors
///
/// Whereas [memfd_secret] will always use memfd_secret(2), this construction allows multiple
/// file descriptor back ends to be used to support different usage scenarios.
///
/// This is necessary, because older systems do not support memfd_secret(2) and using it might not
/// always be desirable, as memfd_secret for instance also inhibits hibernation.
#[derive(Default, Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
pub struct SecretMemfdConfig {
/// Enable the close-on-exec flag for the new file descriptor.
pub close_on_exec: bool,
/// Security mechanism to use
pub policy: SecretMemfdPolicy,
}
impl SecretMemfdConfig {
/// Create a new, default [Self]
pub const fn new() -> Self {
let close_on_exec = false;
let policy = SecretMemfdPolicy::default_const();
Self {
close_on_exec,
policy,
}
}
/// Set the [Self::close_on_exec] flag to true
pub const fn cloexec(&self) -> Self {
let mut r = *self;
r.close_on_exec = true;
r
}
/// Set `self.policy = SecretMemoryFdPolicy::UseMemfdSecret`
pub const fn enforce_high_security(&self) -> Self {
let mut r = *self;
r.policy = SecretMemfdPolicy::enforce_high_security();
r
}
/// Whether memfd_secret will be used by [Self::create()]
pub fn mechanism(&self) -> Result<SecretMemfdMechanism, Errno> {
SecretMemfdMechanism::decide_with_policy(self.policy)
}
/// Allocate a secret file descriptor based on the configuration
pub fn create(&self) -> Result<OwnedFd, SecretMemfdWithConfigError> {
use SecretMemfdWithConfigError as E;
let mech = self.mechanism().map_err(E::FailedToDetectSupport)?;
use SecretMemfdMechanism as M;
match (mech, self.close_on_exec) {
(M::MemfdCreate, cloexec) => {
let flags = match cloexec {
true => MemfdFlags::CLOEXEC,
false => MemfdFlags::empty(),
};
memfd_create("rosenpass secret memory segment", flags).map_err(E::MemfdCreateError)
}
(M::MemfdSecret, cloexec) => {
let flags = match cloexec {
true => MemfdSecretFlags::CLOEXEC,
false => MemfdSecretFlags::empty(),
};
memfd_secret(flags).map_err(E::MemfdSecretError)
}
}
}
}
/// Create a secret memory file descriptor using the default policy
///
/// Shorthand for
/// [`SecretMemfdConfig`][`::new()`](SecretMemfdConfig::new)[`.create()`](SecretMemfdConfig::create)
pub fn memfd_for_secrets_with_default_policy() -> Result<OwnedFd, SecretMemfdWithConfigError> {
SecretMemfdConfig::new().create()
}

View File

@@ -0,0 +1,409 @@
//! This module takes care of allocating memory segments for
//! file descriptors created with [super::fd::memfd_for_secrets_with_default_policy]
//! and anonymous memory segments
// Tests: Nix based integration tests
#![deny(unsafe_op_in_unsafe_fn)]
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::ptr::null_mut;
use crate::mem::CopyExt;
/// Size of the memory mapping for [MappableFd]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum MMapSizePolicy {
/// Size is assumed to be this particular value; [MappableFd::mmap] will simply
/// use this value without checking whether its matches the size of the underlying
/// data
Assumed(u64),
/// Size is assumed to be this particular value; [MappableFd::mmap] will check the
/// size of the underlying data and raise an error if the size of data and this value
/// do not match
Checked(u64),
/// Size is defined to be this particular value; [MappableFd::mmap] will explicitly resize
/// the underlying file descriptor to be this particular size when called.
Resize(u64),
}
impl MMapSizePolicy {
/// The numeric value of the size
pub fn size_value(&self) -> u64 {
match *self {
Self::Assumed(v) => v,
Self::Checked(v) => v,
Self::Resize(v) => v,
}
}
}
/// Configuration for [MappableFd::mmap]
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct MapFdConfig {
/// The memory region can not be read from
pub unreadable: bool,
/// The memory region can not be written to
pub immutable: bool,
/// The memory region can be executed
pub executable: bool,
/// The memory region is shared-memory; other mappings of the same
/// region within and outside this process can see the modifications
/// to the memory region (as long as they also set the shared flag)
pub shared: bool,
/// How [MappableFd::mmap] will determine the size to be used fo
///
/// You should usually set this value through [Self::set_size_policy], [Self::assume_size_without_checking],
/// [Self::expected_size], or [Self::resize_on_mmap].
pub size_policy: Option<MMapSizePolicy>,
}
impl MapFdConfig {
/// New MapFdConfig with all settings turned off
///
/// You still must set [Self::size_policy], otherwise [MappableFd::mmap] will raise
/// an error when called
pub const fn new() -> Self {
MapFdConfig {
unreadable: false,
immutable: false,
executable: false,
shared: false,
size_policy: None,
}
}
/// New MappableFdConfig with shared memory turned on
pub const fn shared_memory() -> Self {
Self::new().set_shared()
}
/// Set the [Self::unreadable] flag
pub const fn set_unreadable(&self) -> Self {
let mut r = *self;
r.unreadable = true;
r
}
/// Set the [Self::immutable] flag
pub const fn set_immutable(&self) -> Self {
let mut r = *self;
r.immutable = true;
r
}
/// Set the [Self::shared] flag
pub const fn set_shared(&self) -> Self {
let mut r = *self;
r.shared = true;
r
}
/// Create a [MappableFd] instance with this configuration
pub fn mappable_fd<Fd: AsFd>(&self, fd: Fd) -> MappableFd<Fd> {
MappableFd::new(fd, self.copy())
}
/// Calculate [rustix::mm::ProtFlags] for this configuration
pub const fn mmap_prot(&self) -> rustix::mm::ProtFlags {
use rustix::mm::ProtFlags as P;
let p_read = match self.unreadable {
true => P::empty(),
false => P::READ,
};
let p_write = match self.immutable {
true => P::empty(),
false => P::WRITE,
};
let p_exec = match self.executable {
true => P::EXEC,
false => P::empty(),
};
p_read.union(p_write).union(p_exec)
}
/// Calculate [rustix::mm::MapFlags] for this configuration
pub const fn mmap_flags(&self) -> rustix::mm::MapFlags {
use rustix::mm::MapFlags as M;
match self.shared {
true => M::SHARED,
false => M::empty(),
}
}
/// Set [Self::size_policy] to the given value
pub const fn set_size_policy(&self, size_policy: MMapSizePolicy) -> Self {
let mut r = *self;
r.size_policy = Some(size_policy);
r
}
/// Set [Self::size_policy] to [MMapSizePolicy::Assumed] with the given value
pub const fn assume_size_without_checking(&self, size: u64) -> Self {
self.set_size_policy(MMapSizePolicy::Assumed(size))
}
/// Set [Self::size_policy] to [MMapSizePolicy::Checked] with the given value
pub const fn expected_size(&self, size: u64) -> Self {
self.set_size_policy(MMapSizePolicy::Checked(size))
}
/// Set [Self::size_policy] to [MMapSizePolicy::Resize] with the given value
pub const fn resize_on_mmap(&self, size: u64) -> Self {
self.set_size_policy(MMapSizePolicy::Resize(size))
}
}
/// Error returned by MappableFd::mmap
#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
pub enum MMapError {
/// Error converting between usize & f64
#[error("Requested memory map of size {requested_len} but maximum supported size is {max_supported_len}. \
This is a low level error that usually arises on architectures where the integer type usize ({} bytes) can not \
represent all values in u64 ({} bytes). Are you possibly on 32 bit CPU architecture requesting a buffer bigger than 4 GB?\n\
Error: {err:?}",
(usize::BITS as f64)/8f64, (u64::BITS as f64)/8f64
)]
OutOfBounds {
/// Underlying error
err: <u64 as TryInto<usize>>::Error,
/// The size of the memory map requested
requested_len: u64,
/// Maximum supported size
max_supported_len: usize,
},
/// Tried to map a file descriptor into memory, but the size policy was never set. Developer
/// error.
#[error("Tried to map a file descriptor into memory, but the size policy was never set. This is a developer error.")]
MissingSizePolicy,
/// fseek(3)/ftell(3) system call failed
#[error("Tried to map file descriptor into memory, but failed to determine the size of the file descriptor: {:?}", .0)]
CouldNotDetermineSize(rustix::io::Errno),
/// Mismatch between expected and actual size of the file descriptor
#[error("Tried to map file descriptor into memory with expected size {expected:?}, instead we found that the true size is {actual:?}")]
IncorrectSize {
/// Expected file descriptor size (given by caller)
expected: u64,
/// Actual file descriptor size
actual: u64,
},
/// Negative size reported by fstat(2)
#[error("Tried to determine the size of the underlying file descriptor, but failed to determine the size of the file descriptor because the size ({size}) is negative: {err:?}")]
InvalidSize {
/// Underlying error
err: <i64 as TryInto<u64>>::Error,
/// Reported size of the file descriptor
size: i64,
},
/// ftruncate(2) system call failed
#[error("Tried to resize the underlying file descriptor, but failed: {:?}", .0)]
ResizeError(rustix::io::Errno),
/// mmap(2) system call failed
#[error("Tried to map file descriptor into memory, but mmap(2) system call failed: {:?}", .0)]
MMapError(rustix::io::Errno),
}
/// Handle mapping a file descriptor into memory
pub struct MappableFd<Fd: AsFd> {
/// The file descriptor this struct refers to
fd: Fd,
/// The configuration for mmap
config: MapFdConfig,
}
impl<Fd: AsFd> AsFd for MappableFd<Fd> {
fn as_fd(&self) -> BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl<Fd: AsFd> AsRawFd for MappableFd<Fd> {
fn as_raw_fd(&self) -> RawFd {
self.as_fd().as_raw_fd()
}
}
impl<Fd: AsFd> MappableFd<Fd> {
/// Create a [MappableFd] using the default configuration ([MapFdConfig])
pub fn from_fd(fd: Fd) -> Self {
Self::new(fd, MapFdConfig::default())
}
/// Create a new [MappableFd] for an existing file descriptor
pub fn new(fd: Fd, config: MapFdConfig) -> Self {
Self { fd, config }
}
/// Extract the underlying file descriptor
pub fn into_fd(self) -> Fd {
self.fd
}
/// Access the [MapFdConfig] associated with Self
pub fn config(&self) -> MapFdConfig {
self.config
}
/// Access the [MapFdConfig] associated with Self
pub fn config_mut(&mut self) -> &mut MapFdConfig {
&mut self.config
}
/// Modify the [MapFdConfig] associated with Self, chainable
pub fn with_config(mut self, config: MapFdConfig) -> Self {
self.config = config;
self
}
/// Determine the size of the data associated with the file descriptor
pub fn size_of_underlying_data(&self) -> Result<u64, MMapError> {
use MMapError as E;
let size = rustix::fs::fstat(self)
.map_err(E::CouldNotDetermineSize)?
.st_size;
size.try_into().map_err(|err| E::InvalidSize { err, size })
}
/// Map the file into memory
///
/// # Determining the size of the mapping.
///
/// Before calling this function, [MapFdConfig::size_policy] must be set.
///
/// Note that [Self::from_fd] and [Self::new] still allow you to create a [Self] with
/// [MapFdConfig::size_policy] set to [None], so you can use [Self::size_of_underlying_data]
/// to auto-detect the size of the mapping.
///
/// This functionality is not implemented by default, as its crucial to validate the size of
/// the data being mapped into memory somehow; otherwise, the party that created the file
/// descriptor can trigger a denial-of-service attack against our process by allocating an
/// excessively large, sparse file. If you implement size auto-detection facilities, you should
/// still enforce some bounds on the size.
///
/// # Safety
///
/// If there exist any Rust references referring to the memory region, or if you subsequently create a Rust reference referring to the resulting region, it is your responsibility to ensure that the Rust reference invariants are preserved, including ensuring that the memory is not mutated in a way that a Rust reference would not expect.
pub fn mmap(&self) -> Result<MappedSegment, MMapError> {
use rustix::mm::mmap;
use MMapError as E;
let prot = self.config().mmap_prot();
let flags = self.config().mmap_flags();
// Determine the size of the mapping to be used as u64
let requested_size = match self.config().size_policy {
None => return Err(E::MissingSizePolicy),
Some(MMapSizePolicy::Assumed(size)) => size,
Some(MMapSizePolicy::Resize(size)) => {
rustix::fs::ftruncate(self, size).map_err(E::ResizeError)?;
size
}
Some(MMapSizePolicy::Checked(expected)) => {
let actual = self.size_of_underlying_data()?;
if expected != actual {
return Err(E::IncorrectSize { expected, actual });
}
expected
}
};
// Cast the size of the mapping to be used to usize, raising an error if the
// requested_size can not be represented as usize (this should never happen in general,
// but it could conceivably be thrown on 32 bit systems when very large mappings (>= 4GB)
// are requested
let len = requested_size.try_into().map_err(|err| {
let max_supported_len = usize::MAX;
E::OutOfBounds {
err,
requested_len: requested_size,
max_supported_len,
}
})?;
let ptr = unsafe { mmap(null_mut(), len, prot, flags, self, 0) };
let ptr = ptr.map_err(E::MMapError)?;
let ptr = unsafe { MappedSegment::from_raw_parts(ptr.cast(), len) };
Ok(ptr)
}
}
/// Represents exclusive ownership of a memory segment mapped into memory
///
/// Automatically unmaps the memory segment as this goes out of scope
///
/// # Panic
///
/// If the munmap(2) call fails, the destructor will panic. You can avoid this and use explicit
/// error handling by calling [Self::unmap] instead
#[derive(Debug)]
pub struct MappedSegment {
/// The location of the memory segment
ptr: *mut u8,
/// Length of the segment in bytes
len: usize,
}
unsafe impl Send for MappedSegment {}
impl MappedSegment {
/// Construct a new [Self] from a pointer and a length
///
/// `ptr` is the address of the memory segment and `len` is its length in bytes
///
/// # Safety
///
/// It is the responsibility of the caller to make sure that any pointer P such that `P = ptr.add(n)`,
/// where `n < len` is a valid pointer. See #safety in [std::ptr].
pub unsafe fn from_raw_parts(ptr: *mut u8, len: usize) -> Self {
Self { ptr, len }
}
/// Decompose a MappedSegment into its raw components: Pointer and length
pub fn into_raw_parts(self) -> (*mut u8, usize) {
let r = (self.ptr(), self.len());
std::mem::forget(self);
r
}
/// The location of the memory segment
pub fn ptr(&self) -> *mut u8 {
self.ptr
}
/// Length of the segment in bytes
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.len
}
/// Release the memory segment
///
/// Compared to using the destructor which panics if unmapping fails, this allows explicit error handling to be used.
///
/// If this returns an error, the memory segment has not been freed. The values from
/// [Self::into_raw_parts()] are returned as part of the error value so the caller has
/// some chance to free the memory some other way (or to leak it if they so choose)
pub fn unmap(self) -> Result<(), (rustix::io::Errno, *mut u8, usize)> {
let (ptr, len) = self.into_raw_parts();
let res = unsafe { rustix::mm::munmap(ptr.cast(), len) };
res.map_err(|errno| (errno, ptr, len))
}
}
impl Drop for MappedSegment {
fn drop(&mut self) {
let mut owned = MappedSegment {
ptr: null_mut(),
len: 0,
};
std::mem::swap(self, &mut owned);
if let Err((errno, _ptr, _len)) = owned.unmap() {
panic!("Failed to unmap MappedSegment: {errno:?}")
}
}
}

View File

@@ -0,0 +1,4 @@
//! Utilities for allocating secret memory
pub mod fd;
pub mod mmap;

View File

@@ -0,0 +1,107 @@
//! Traits for types with atomic access semantics
use std::sync::atomic::Ordering;
/// A trait for types with atomic access semantics
///
/// Using this trait allows us to achieve two goals:
///
/// 1. We can implement atomic semantics for types where
/// there is no platform support for atomic semantics
/// (e.g. by using a Mutex)
/// 2. We can reuse implementations of concurrent data structures
/// efficiently for the non-concurrent case. E.g. we could
/// build a thread-local ring buffer using
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeWriter]/
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeReader]
/// by supplying some sort `Immediate<u64>` type for the atomics
/// in [crate::ringbuf::concurrent::framework::ConcurrentPipeCore] that implements
/// this trait by using a [std::cell::Cell]. It may seem counter
/// intuitive, but this setup implements perfectly fine atomic-appearing
/// semantics just as long as the cell is thread-local.
pub trait AbstractAtomic<T> {
/// Like [std::sync::atomic::AtomicU64::load()]
fn load(&self, order: Ordering) -> T;
/// Like [std::sync::atomic::AtomicU64::compare_exchange_weak()].
///
/// The default implementation just calls [AbstractAtomic::compare_exchange()].
fn compare_exchange_weak(
&self,
current: T,
new: T,
success: Ordering,
failure: Ordering,
) -> Result<T, T> {
self.compare_exchange(current, new, success, failure)
}
/// Like [std::sync::atomic::AtomicU64::compare_exchange()].
fn compare_exchange(
&self,
current: T,
new: T,
success: Ordering,
failure: Ordering,
) -> Result<T, T>;
}
/// Implements a default type for [AbstractAtomic]
///
/// This in essence is to [AbstractAtomic], as [std::ops::Deref] is to [std::borrow::Borrow];
/// the same functionality, except with a
pub trait AbstractAtomicType: AbstractAtomic<Self::ValueType> {
/// The underlying atomic value
type ValueType;
}
/// Implementing [AbstractAtomic] and [AbstractAtomicType] for standard atomics
macro_rules! impl_abstract_atomic_for_atomic {
($($Atomic:ty : $Value:ty),*) => {
$(
impl AbstractAtomicType for $Atomic {
type ValueType = $Value;
}
impl AbstractAtomic<$Value> for $Atomic {
fn load(&self, order: Ordering) -> $Value {
<$Atomic>::load(&self, order)
}
fn compare_exchange_weak(
&self,
current: $Value,
new: $Value,
success: Ordering,
failure: Ordering,
) -> Result<$Value, $Value> {
<$Atomic>::compare_exchange_weak(&self, current, new, success, failure)
}
fn compare_exchange(
&self,
current: $Value,
new: $Value,
success: Ordering,
failure: Ordering,
) -> Result<$Value, $Value> {
<$Atomic>::compare_exchange(&self, current, new, success, failure)
}
}
)*
};
}
impl_abstract_atomic_for_atomic! {
std::sync::atomic::AtomicBool: bool,
std::sync::atomic::AtomicI8: i8,
std::sync::atomic::AtomicI16: i16,
std::sync::atomic::AtomicI32: i32,
std::sync::atomic::AtomicI64: i64,
std::sync::atomic::AtomicIsize: isize,
std::sync::atomic::AtomicU8: u8,
std::sync::atomic::AtomicU16: u16,
std::sync::atomic::AtomicU32: u32,
std::sync::atomic::AtomicU64: u64,
std::sync::atomic::AtomicUsize: usize
}

View File

@@ -0,0 +1,5 @@
//! Synchronization using atomic semantics for multi-threaded programs.
//!
//! Analogous to [std::sync::atomic].
pub mod abstract_atomic;

5
util/src/sync/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
//! Synchronization for multi-threaded programs.
//!
//! Analogous to [std::sync].
pub mod atomic;

View File

@@ -1,9 +1,9 @@
//! A module providing the [`RefMaker`] type and its associated methods for constructing
//! [`zerocopy::Ref`] references from byte buffers.
use anyhow::{ensure, Context};
use anyhow::ensure;
use std::marker::PhantomData;
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
use zeroize::Zeroize;
use crate::zeroize::ZeroizedExt;
@@ -19,10 +19,10 @@ use crate::zeroize::ZeroizedExt;
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref};///
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::RefMaker;
///
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Header {
/// field1: u32,
@@ -82,7 +82,10 @@ impl<B, T> RefMaker<B, T> {
}
}
impl<B: ByteSlice, T> RefMaker<B, T> {
impl<B: SplitByteSlice, T> RefMaker<B, T>
where
T: KnownLayout + Immutable,
{
/// Parses the buffer into a [`zerocopy::Ref<B, T>`].
///
/// This will fail if the buffer is smaller than `size_of::<T>`.
@@ -94,10 +97,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref};
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::RefMaker;
///
/// #[derive(FromBytes, FromZeroes, AsBytes, Debug)]
/// #[derive(FromBytes, IntoBytes, Debug, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data(u32);
///
@@ -116,11 +119,11 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
/// let bytes = [1u8, 2, 3, 4, 5, 6, 7, 8];
/// let parse_error = RefMaker::<_, Data>::new(&bytes[1..5]).parse()
/// .expect_err("Should error");
/// assert_eq!(parse_error.to_string(), "Parser error!");
/// assert_eq!(parse_error.to_string(), "Parser error: Alignment(AlignmentError)");
/// ```
pub fn parse(self) -> anyhow::Result<Ref<B, T>> {
self.ensure_fit()?;
Ref::<B, T>::new(self.buf).context("Parser error!")
Ref::<B, T>::from_bytes(self.buf).map_err(|e| anyhow::anyhow!("Parser error: {e:?}"))
}
/// Splits the internal buffer into a `RefMaker` containing a buffer with
@@ -142,7 +145,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
/// ```
pub fn from_prefix_with_tail(self) -> anyhow::Result<(Self, B)> {
self.ensure_fit()?;
let (head, tail) = self.buf.split_at(Self::target_size());
let (head, tail) = self
.buf
.split_at(Self::target_size())
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
Ok((Self::new(head), tail))
}
@@ -165,7 +171,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
/// ```
pub fn split_prefix(self) -> anyhow::Result<(Self, Self)> {
self.ensure_fit()?;
let (head, tail) = self.buf.split_at(Self::target_size());
let (head, tail) = self
.buf
.split_at(Self::target_size())
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
Ok((Self::new(head), Self::new(tail)))
}
@@ -204,7 +213,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
pub fn from_suffix_with_head(self) -> anyhow::Result<(Self, B)> {
self.ensure_fit()?;
let point = self.bytes().len() - Self::target_size();
let (head, tail) = self.buf.split_at(point);
let (head, tail) = self
.buf
.split_at(point)
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
Ok((Self::new(tail), head))
}
@@ -227,7 +239,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
pub fn split_suffix(self) -> anyhow::Result<(Self, Self)> {
self.ensure_fit()?;
let point = self.bytes().len() - Self::target_size();
let (head, tail) = self.buf.split_at(point);
let (head, tail) = self
.buf
.split_at(point)
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
Ok((Self::new(head), Self::new(tail)))
}
@@ -282,7 +297,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
}
}
impl<B: ByteSliceMut, T> RefMaker<B, T> {
impl<B: SplitByteSliceMut, T> RefMaker<B, T>
where
T: KnownLayout + Immutable,
{
/// Creates a zeroized reference of type `T` from the buffer.
///
/// # Errors
@@ -292,9 +310,9 @@ impl<B: ByteSliceMut, T> RefMaker<B, T> {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref}; ///
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::RefMaker;
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data([u8; 4]);
///
@@ -312,7 +330,10 @@ impl<B: ByteSliceMut, T> RefMaker<B, T> {
}
}
impl<B: ByteSliceMut, T> Zeroize for RefMaker<B, T> {
impl<B: SplitByteSliceMut, T> Zeroize for RefMaker<B, T>
where
T: KnownLayout + Immutable,
{
fn zeroize(&mut self) {
self.bytes_mut().zeroize()
}

View File

@@ -1,7 +1,7 @@
//! Extension traits for converting `Ref<B, T>` into references backed by
//! standard slices.
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
/// A trait for converting a `Ref<B, T>` into a `Ref<&[u8], T>`.
///
@@ -16,9 +16,9 @@ pub trait ZerocopyEmancipateExt<B, T> {
///
/// ```
/// # use std::ops::Deref;
/// # use zerocopy::{AsBytes, ByteSlice, FromBytes, FromZeroes, Ref};
/// # use zerocopy::{IntoBytes, SplitByteSlice, FromBytes, FromZeros, Ref, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::ZerocopyEmancipateExt;
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data(u32);
/// #[repr(align(4))]
@@ -44,9 +44,9 @@ pub trait ZerocopyEmancipateMutExt<B, T> {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref};
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::{ZerocopyEmancipateMutExt};
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data(u32);
/// #[repr(align(4))]
@@ -64,18 +64,20 @@ pub trait ZerocopyEmancipateMutExt<B, T> {
impl<B, T> ZerocopyEmancipateExt<B, T> for Ref<B, T>
where
B: ByteSlice,
B: SplitByteSlice,
T: KnownLayout + Immutable,
{
fn emancipate(&self) -> Ref<&[u8], T> {
Ref::new(self.bytes()).unwrap()
Ref::from_bytes(zerocopy::Ref::bytes(self)).unwrap()
}
}
impl<B, T> ZerocopyEmancipateMutExt<B, T> for Ref<B, T>
where
B: ByteSliceMut,
B: SplitByteSliceMut,
T: KnownLayout + Immutable,
{
fn emancipate_mut(&mut self) -> Ref<&mut [u8], T> {
Ref::new(self.bytes_mut()).unwrap()
Ref::from_bytes(zerocopy::Ref::bytes_mut(self)).unwrap()
}
}

View File

@@ -1,7 +1,7 @@
//! Extension traits for parsing slices into [`zerocopy::Ref`] values using the
//! [`RefMaker`] abstraction.
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
use super::RefMaker;
@@ -9,16 +9,16 @@ use super::RefMaker;
///
/// This trait adds methods for creating [`Ref`] references from
/// slices by using the [`RefMaker`] type internally.
pub trait ZerocopySliceExt: Sized + ByteSlice {
pub trait ZerocopySliceExt: Sized + SplitByteSlice {
/// Creates a new `RefMaker` for the given slice.
///
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::{RefMaker, ZerocopySliceExt};
///
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data(u32);
///
@@ -39,10 +39,10 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::ZerocopySliceExt;
///
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data(u16, u16);
/// #[repr(align(4))]
@@ -52,7 +52,10 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
/// assert_eq!(data_ref.0, 0x0201);
/// assert_eq!(data_ref.1, 0x0403);
/// ```
fn zk_parse<T>(self) -> anyhow::Result<Ref<Self, T>> {
fn zk_parse<T>(self) -> anyhow::Result<Ref<Self, T>>
where
T: Immutable + KnownLayout,
{
self.zk_ref_maker().parse()
}
@@ -67,9 +70,9 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::ZerocopySliceExt;
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Header(u32);
/// #[repr(align(4))]
@@ -80,7 +83,10 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
/// let header_ref = bytes.0.zk_parse_prefix::<Header>().unwrap();
/// assert_eq!(header_ref.0, 0xDDCCBBAA);
/// ```
fn zk_parse_prefix<T>(self) -> anyhow::Result<Ref<Self, T>> {
fn zk_parse_prefix<T>(self) -> anyhow::Result<Ref<Self, T>>
where
T: Immutable + KnownLayout,
{
self.zk_ref_maker().from_prefix()?.parse()
}
@@ -95,9 +101,9 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::ZerocopySliceExt;
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Header(u32);
/// #[repr(align(4))]
@@ -108,18 +114,21 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
/// let header_ref = bytes.0.zk_parse_suffix::<Header>().unwrap();
/// assert_eq!(header_ref.0, 0x30201000);
/// ```
fn zk_parse_suffix<T>(self) -> anyhow::Result<Ref<Self, T>> {
fn zk_parse_suffix<T>(self) -> anyhow::Result<Ref<Self, T>>
where
T: Immutable + KnownLayout,
{
self.zk_ref_maker().from_suffix()?.parse()
}
}
impl<B: ByteSlice> ZerocopySliceExt for B {}
impl<B: SplitByteSlice> ZerocopySliceExt for B {}
/// Extension trait for zero-copy parsing of mutable slices with zeroization
/// capabilities.
///
/// Provides convenience methods to create zero-initialized references.
pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + SplitByteSliceMut {
/// Creates a new zeroized reference from the entire slice.
///
/// This zeroizes the slice first, then provides a `Ref`.
@@ -131,9 +140,9 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data([u8; 4]);
/// #[repr(align(4))]
@@ -143,7 +152,10 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
/// assert_eq!(data_ref.0, [0,0,0,0]);
/// assert_eq!(bytes.0, [0, 0, 0, 0]);
/// ```
fn zk_zeroized<T>(self) -> anyhow::Result<Ref<Self, T>> {
fn zk_zeroized<T>(self) -> anyhow::Result<Ref<Self, T>>
where
T: Immutable + KnownLayout,
{
self.zk_ref_maker().make_zeroized()
}
@@ -159,9 +171,9 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data([u8; 4]);
/// #[repr(align(4))]
@@ -171,7 +183,10 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
/// assert_eq!(data_ref.0, [0,0,0,0]);
/// assert_eq!(bytes.0, [0, 0, 0, 0, 0xFF, 0xFF]);
/// ```
fn zk_zeroized_from_prefix<T>(self) -> anyhow::Result<Ref<Self, T>> {
fn zk_zeroized_from_prefix<T>(self) -> anyhow::Result<Ref<Self, T>>
where
T: Immutable + KnownLayout,
{
self.zk_ref_maker().from_prefix()?.make_zeroized()
}
@@ -187,9 +202,9 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
/// # Example
///
/// ```
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
/// # use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
/// #[derive(FromBytes, FromZeroes, AsBytes)]
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
/// #[repr(C)]
/// struct Data([u8; 4]);
/// #[repr(align(4))]
@@ -199,9 +214,12 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
/// assert_eq!(data_ref.0, [0,0,0,0]);
/// assert_eq!(bytes.0, [0xFF, 0xFF, 0, 0, 0, 0]);
/// ```
fn zk_zeroized_from_suffix<T>(self) -> anyhow::Result<Ref<Self, T>> {
fn zk_zeroized_from_suffix<T>(self) -> anyhow::Result<Ref<Self, T>>
where
T: Immutable + KnownLayout,
{
self.zk_ref_maker().from_suffix()?.make_zeroized()
}
}
impl<B: ByteSliceMut> ZerocopyMutSliceExt for B {}
impl<B: SplitByteSliceMut> ZerocopyMutSliceExt for B {}

View File

@@ -47,7 +47,7 @@ async fn janitor_demo() -> anyhow::Result<()> {
anyhow::Ok(())
})
}
.await;
.await?;
// At this point, all background jobs have finished, now we can check the result of all our
// additions

View File

@@ -41,6 +41,7 @@ rand = { workspace = true }
procspawn = { workspace = true }
[features]
default = ["experiment_api"]
experiment_api = ["rustix", "libc"]
experiment_memfd_secret = []

View File

@@ -38,6 +38,8 @@
use std::{borrow::BorrowMut, fmt::Debug};
use zerocopy::IntoBytes;
use crate::{
api::{
config::NetworkBrokerConfig,
@@ -170,7 +172,8 @@ where
let typ = msgs::MsgType::try_from(*typ)?;
let msgs::MsgType::SetPsk = typ; // Assert type
let res = zerocopy::Ref::<&[u8], Envelope<SetPskResponse>>::new(res)
let res = zerocopy::Ref::<&[u8], Envelope<SetPskResponse>>::from_bytes(res)
.ok()
.ok_or(invalid_msg_poller())?;
let res: &msgs::SetPskResponse = &res.payload;
let res: msgs::SetPskResponseReturnCode = res
@@ -200,8 +203,10 @@ where
let mut req = [0u8; BUF_SIZE];
// Construct message view
let mut req = zerocopy::Ref::<&mut [u8], Envelope<msgs::SetPskRequest>>::new(&mut req)
.ok_or(MsgError)?;
let mut req =
zerocopy::Ref::<&mut [u8], Envelope<msgs::SetPskRequest>>::from_bytes(&mut req)
.ok()
.ok_or(MsgError)?;
// Populate envelope
req.msg_type = msgs::MsgType::SetPsk as u8;
@@ -219,7 +224,7 @@ where
// Send message
self.io
.borrow_mut()
.send_msg(req.bytes())
.send_msg(req.as_bytes())
.map_err(IoError)?;
Ok(())

View File

@@ -3,7 +3,7 @@
use std::str::{from_utf8, Utf8Error};
use zerocopy::{AsBytes, FromBytes, FromZeroes};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
/// The number of bytes reserved for overhead when packaging data.
pub const ENVELOPE_OVERHEAD: usize = 1 + 3;
@@ -15,8 +15,8 @@ pub const RESPONSE_MSG_BUFFER_SIZE: usize = ENVELOPE_OVERHEAD + 1;
/// Envelope for messages being passed around.
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
pub struct Envelope<M: AsBytes + FromBytes> {
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct Envelope<M: IntoBytes + FromBytes> {
/// [MsgType] of this message
pub msg_type: u8,
/// Reserved for future use
@@ -29,7 +29,7 @@ pub struct Envelope<M: AsBytes + FromBytes> {
/// # Example
///
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct SetPskRequest {
/// The pre-shared key.
pub psk: [u8; 32],
@@ -85,7 +85,7 @@ impl SetPskRequest {
/// Message format for response to the set pre-shared key operation.
#[repr(packed)]
#[derive(AsBytes, FromBytes, FromZeroes)]
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
pub struct SetPskResponse {
pub return_code: u8,
}

View File

@@ -8,6 +8,7 @@
use std::borrow::BorrowMut;
use rosenpass_secret_memory::{Public, Secret};
use zerocopy::IntoBytes;
use crate::api::msgs::{self, Envelope, SetPskRequest, SetPskResponse};
use crate::WireGuardBroker;
@@ -78,14 +79,16 @@ where
let typ = msgs::MsgType::try_from(*typ)?;
let msgs::MsgType::SetPsk = typ; // Assert type
let req =
zerocopy::Ref::<&[u8], Envelope<SetPskRequest>>::new(req).ok_or(InvalidMessage)?;
let mut res =
zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::new(res).ok_or(InvalidMessage)?;
let req = zerocopy::Ref::<&[u8], Envelope<SetPskRequest>>::from_bytes(req)
.ok()
.ok_or(InvalidMessage)?;
let mut res = zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::from_bytes(res)
.ok()
.ok_or(InvalidMessage)?;
res.msg_type = msgs::MsgType::SetPsk as u8;
self.handle_set_psk(&req.payload, &mut res.payload)?;
Ok(res.bytes().len())
Ok(res.as_bytes().len())
}
/// Sets the pre-shared key for the interface identified in `req` to the pre-shared key
@@ -138,7 +141,7 @@ mod tests {
use crate::brokers::netlink::SetPskError;
use crate::{SerializedBrokerConfig, WireGuardBroker};
use rosenpass_secret_memory::{secret_policy_use_only_malloc_secrets, Secret};
use zerocopy::AsBytes;
use zerocopy::IntoBytes;
#[derive(Debug, Clone)]
struct MockWireGuardBroker {

View File

@@ -12,7 +12,7 @@ use tokio::task;
use anyhow::{bail, ensure, Result};
use clap::{ArgGroup, Parser};
use rosenpass_util::fd::claim_fd;
use rosenpass_util::rustix::claim_fd;
use rosenpass_wireguard_broker::api::msgs;
/// Command-line arguments for configuring the socket handler