mirror of
https://github.com/rosenpass/rosenpass.git
synced 2025-12-09 06:10:30 -08:00
Compare commits
9 Commits
main
...
dev/karo/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4eea31c5d | ||
|
|
9fd32086ea | ||
|
|
63511465de | ||
|
|
0c960d57bc | ||
|
|
8f276f70a6 | ||
|
|
9580961dd9 | ||
|
|
1a51478e89 | ||
|
|
5b14ef8065 | ||
|
|
a796bdd2e7 |
52
Cargo.lock
generated
52
Cargo.lock
generated
@@ -212,7 +212,7 @@ version = "0.68.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "726e4313eb6ec35d2730258ad4e15b547ee75d6afaa1361a922e78e59b7d8078"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"lazy_static",
|
||||
@@ -237,9 +237,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.8.0"
|
||||
version = "2.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
|
||||
checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d"
|
||||
|
||||
[[package]]
|
||||
name = "blake2"
|
||||
@@ -818,9 +818,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.10"
|
||||
version = "0.3.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
|
||||
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
@@ -1193,7 +1193,7 @@ version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
@@ -1453,7 +1453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.48.5",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1673,7 +1673,7 @@ version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
@@ -1965,7 +1965,7 @@ checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94"
|
||||
dependencies = [
|
||||
"rand_chacha 0.9.0",
|
||||
"rand_core 0.9.3",
|
||||
"zerocopy 0.8.24",
|
||||
"zerocopy 0.8.27",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2032,7 +2032,7 @@ version = "0.5.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2115,7 +2115,7 @@ dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
"toml",
|
||||
"uds",
|
||||
"zerocopy 0.7.35",
|
||||
"zerocopy 0.8.27",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
@@ -2257,17 +2257,23 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64ct",
|
||||
"bitflags 2.9.3",
|
||||
"errno",
|
||||
"libc",
|
||||
"libcrux-test-utils",
|
||||
"log",
|
||||
"mio",
|
||||
"num-traits",
|
||||
"rosenpass-to",
|
||||
"rustix",
|
||||
"static_assertions",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
"tinyvec",
|
||||
"tokio",
|
||||
"typenum",
|
||||
"uds",
|
||||
"zerocopy 0.7.35",
|
||||
"zerocopy 0.8.27",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
@@ -2292,7 +2298,7 @@ dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"wireguard-uapi",
|
||||
"zerocopy 0.7.35",
|
||||
"zerocopy 0.8.27",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2340,7 +2346,7 @@ version = "0.38.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
@@ -2710,6 +2716,12 @@ dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.47.0"
|
||||
@@ -3276,7 +3288,7 @@ version = "0.33.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
"bitflags 2.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3303,11 +3315,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.24"
|
||||
version = "0.8.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879"
|
||||
checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c"
|
||||
dependencies = [
|
||||
"zerocopy-derive 0.8.24",
|
||||
"zerocopy-derive 0.8.27",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3323,9 +3335,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.24"
|
||||
version = "0.8.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be"
|
||||
checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
||||
@@ -66,7 +66,7 @@ chacha20poly1305 = { version = "0.10.1", default-features = false, features = [
|
||||
"std",
|
||||
"heapless",
|
||||
] }
|
||||
zerocopy = { version = "0.7.35", features = ["derive"] }
|
||||
zerocopy = { version = "0.8.27", features = ["derive"] }
|
||||
home = "=0.5.9" # 5.11 requires rustc 1.81
|
||||
derive_builder = "0.20.1"
|
||||
tokio = { version = "1.46", features = ["macros", "rt-multi-thread"] }
|
||||
@@ -80,6 +80,7 @@ hex-literal = { version = "0.4.1" }
|
||||
hex = { version = "0.4.3" }
|
||||
heck = { version = "0.5.0" }
|
||||
libc = { version = "0.2" }
|
||||
errno = { version = "0.3.13" }
|
||||
uds = { git = "https://github.com/rosenpass/uds" }
|
||||
lazy_static = "1.5"
|
||||
|
||||
@@ -95,6 +96,7 @@ criterion = "0.5.1"
|
||||
allocator-api2-tests = "0.2.15"
|
||||
procspawn = { version = "1.0.1", features = ["test-support"] }
|
||||
serde_json = { version = "1.0.140" }
|
||||
bitflags = "2.9.3"
|
||||
|
||||
#Broker dependencies (might need cleanup or changes)
|
||||
wireguard-uapi = { version = "3.0.0", features = ["xplatform"] }
|
||||
|
||||
@@ -97,6 +97,7 @@ rustix = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["experiment_api"]
|
||||
experiment_cookie_dos_mitigation = []
|
||||
experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"]
|
||||
experiment_libcrux_all = ["rosenpass-ciphers/experiment_libcrux_all"]
|
||||
|
||||
@@ -6,12 +6,12 @@ use std::{borrow::BorrowMut, collections::VecDeque, os::fd::OwnedFd};
|
||||
use anyhow::Context;
|
||||
use rosenpass_to::{ops::copy_slice, To};
|
||||
use rosenpass_util::{
|
||||
fd::FdIo,
|
||||
functional::{run, ApplyExt},
|
||||
io::ReadExt,
|
||||
mem::DiscardResultExt,
|
||||
mio::UnixStreamExt,
|
||||
result::OkExt,
|
||||
rustix::FdIo,
|
||||
};
|
||||
use rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient;
|
||||
|
||||
@@ -243,7 +243,7 @@ where
|
||||
.context("Invalid request – socket missing.")?;
|
||||
// TODO: We need to have this outside linux
|
||||
#[cfg(target_os = "linux")]
|
||||
rosenpass_util::fd::GetSocketProtocol::demand_udp_socket(&sock)?;
|
||||
rosenpass_util::rustix::GetSocketProtocol::demand_udp_socket(&sock)?;
|
||||
let sock = std::net::UdpSocket::from(sock);
|
||||
sock.set_nonblocking(true)?;
|
||||
mio::net::UdpSocket::from_std(sock).ok()
|
||||
@@ -316,7 +316,7 @@ where
|
||||
use crate::app_server::BrokerStorePtr;
|
||||
//
|
||||
use rosenpass_secret_memory::Public;
|
||||
use zerocopy::AsBytes;
|
||||
use zerocopy::IntoBytes;
|
||||
(self.app_server().brokers.store.len() - 1)
|
||||
.apply(|x| x as u64)
|
||||
.apply(|x| Public::from_slice(x.as_bytes()))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use zerocopy::{ByteSlice, Ref};
|
||||
use zerocopy::{Ref, SplitByteSlice};
|
||||
|
||||
use rosenpass_util::zerocopy::{RefMaker, ZerocopySliceExt};
|
||||
|
||||
@@ -7,7 +7,7 @@ use super::{
|
||||
ResponseMsgType, ResponseRef, SupplyKeypairRequest, SupplyKeypairResponse,
|
||||
};
|
||||
|
||||
pub trait ByteSliceRefExt: ByteSlice {
|
||||
pub trait ByteSliceRefExt: SplitByteSlice {
|
||||
/// Shorthand for the typed use of [ZerocopySliceExt::zk_ref_maker].
|
||||
fn msg_type_maker(self) -> RefMaker<Self, RawMsgType> {
|
||||
self.zk_ref_maker()
|
||||
@@ -259,4 +259,4 @@ pub trait ByteSliceRefExt: ByteSlice {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSlice> ByteSliceRefExt for B {}
|
||||
impl<B: SplitByteSlice> ByteSliceRefExt for B {}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use zerocopy::{ByteSliceMut, Ref};
|
||||
use zerocopy::{Ref, SplitByteSliceMut};
|
||||
|
||||
use rosenpass_util::zerocopy::RefMaker;
|
||||
|
||||
@@ -35,7 +35,7 @@ pub trait Message {
|
||||
/// # Examples
|
||||
///
|
||||
/// See [crate::api::PingRequest::setup]
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>>;
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>>;
|
||||
}
|
||||
|
||||
/// Additional convenience functions for working with [rosenpass_util::zerocopy::RefMaker]
|
||||
@@ -45,7 +45,7 @@ pub trait ZerocopyResponseMakerSetupMessageExt<B, T> {
|
||||
|
||||
impl<B, T> ZerocopyResponseMakerSetupMessageExt<B, T> for RefMaker<B, T>
|
||||
where
|
||||
B: ByteSliceMut,
|
||||
B: SplitByteSliceMut,
|
||||
T: Message,
|
||||
{
|
||||
/// Initialize the message using [Message::setup].
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use hex_literal::hex;
|
||||
use rosenpass_util::zerocopy::RefMaker;
|
||||
use zerocopy::ByteSlice;
|
||||
use zerocopy::SplitByteSlice;
|
||||
|
||||
use crate::RosenpassError::{self, InvalidApiMessageType};
|
||||
|
||||
@@ -169,12 +169,12 @@ pub trait RefMakerRawMsgTypeExt {
|
||||
fn parse_response_msg_type(self) -> anyhow::Result<ResponseMsgType>;
|
||||
}
|
||||
|
||||
impl<B: ByteSlice> RefMakerRawMsgTypeExt for RefMaker<B, RawMsgType> {
|
||||
impl<B: SplitByteSlice> RefMakerRawMsgTypeExt for RefMaker<B, RawMsgType> {
|
||||
fn parse_request_msg_type(self) -> anyhow::Result<RequestMsgType> {
|
||||
Ok(self.parse()?.read().try_into()?)
|
||||
Ok(zerocopy::Ref::read(&self.parse()?).try_into()?)
|
||||
}
|
||||
|
||||
fn parse_response_msg_type(self) -> anyhow::Result<ResponseMsgType> {
|
||||
Ok(self.parse()?.read().try_into()?)
|
||||
Ok(zerocopy::Ref::<B, u128>::read(&self.parse()?).try_into()?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
|
||||
use zerocopy::{AsBytes, ByteSliceMut, FromBytes, FromZeroes, Ref};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, SplitByteSliceMut};
|
||||
|
||||
use super::{Message, RawMsgType, RequestMsgType, ResponseMsgType};
|
||||
|
||||
@@ -12,8 +12,8 @@ pub const MAX_REQUEST_FDS: usize = 2;
|
||||
|
||||
/// Message envelope for API messages
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
pub struct Envelope<M: AsBytes + FromBytes> {
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct Envelope<M: IntoBytes + FromBytes + Immutable + KnownLayout> {
|
||||
/// Which message this is
|
||||
pub msg_type: RawMsgType,
|
||||
/// The actual Paylod
|
||||
@@ -27,7 +27,7 @@ pub type ResponseEnvelope<M> = Envelope<M>;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct PingRequestPayload {
|
||||
/// Randomly generated connection id
|
||||
pub echo: [u8; 256],
|
||||
@@ -55,7 +55,7 @@ impl Message for PingRequest {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
@@ -68,7 +68,7 @@ impl Message for PingRequest {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct PingResponsePayload {
|
||||
/// Randomly generated connection id
|
||||
pub echo: [u8; 256],
|
||||
@@ -96,7 +96,7 @@ impl Message for PingResponse {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
@@ -109,7 +109,7 @@ impl Message for PingResponse {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct SupplyKeypairRequestPayload {}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
@@ -140,7 +140,7 @@ impl Message for SupplyKeypairRequest {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
@@ -169,7 +169,7 @@ pub mod supply_keypair_response_status {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct SupplyKeypairResponsePayload {
|
||||
#[allow(missing_docs)]
|
||||
pub status: u128,
|
||||
@@ -197,7 +197,7 @@ impl Message for SupplyKeypairResponse {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
@@ -210,7 +210,7 @@ impl Message for SupplyKeypairResponse {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct AddListenSocketRequestPayload {}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
@@ -241,7 +241,7 @@ impl Message for AddListenSocketRequest {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
@@ -264,7 +264,7 @@ pub mod add_listen_socket_response_status {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct AddListenSocketResponsePayload {
|
||||
pub status: u128,
|
||||
}
|
||||
@@ -291,7 +291,7 @@ impl Message for AddListenSocketResponse {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
@@ -304,7 +304,7 @@ impl Message for AddListenSocketResponse {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct AddPskBrokerRequestPayload {}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
@@ -336,7 +336,7 @@ impl Message for AddPskBrokerRequest {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
@@ -359,7 +359,7 @@ pub mod add_psk_broker_response_status {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[repr(packed)]
|
||||
#[derive(Debug, Copy, Clone, Hash, AsBytes, FromBytes, FromZeroes, PartialEq, Eq)]
|
||||
#[derive(Debug, Copy, Clone, Hash, IntoBytes, FromBytes, PartialEq, Eq, Immutable, KnownLayout)]
|
||||
pub struct AddPskBrokerResponsePayload {
|
||||
pub status: u128,
|
||||
}
|
||||
@@ -386,7 +386,7 @@ impl Message for AddPskBrokerResponse {
|
||||
}
|
||||
}
|
||||
|
||||
fn setup<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
fn setup<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self>> {
|
||||
let mut r: Ref<B, Self> = buf.zk_zeroized()?;
|
||||
r.init();
|
||||
Ok(r)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::ensure;
|
||||
use anyhow::{anyhow, ensure};
|
||||
|
||||
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
|
||||
use zerocopy::{IntoBytes, Ref, SplitByteSlice, SplitByteSliceMut};
|
||||
|
||||
use super::{ByteSliceRefExt, MessageAttributes, PingRequest, RequestMsgType};
|
||||
|
||||
@@ -13,14 +13,14 @@ struct RequestRefMaker<B> {
|
||||
msg_type: RequestMsgType,
|
||||
}
|
||||
|
||||
impl<B: ByteSlice> RequestRef<B> {
|
||||
impl<B: SplitByteSlice> RequestRef<B> {
|
||||
/// Produce a [RequestRef] from a raw message buffer,
|
||||
/// reading the type from the buffer
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use zerocopy::AsBytes;
|
||||
/// use zerocopy::IntoBytes;
|
||||
///
|
||||
/// use rosenpass::api::{PingRequest, RequestRef, RequestMsgType};
|
||||
///
|
||||
@@ -95,7 +95,7 @@ impl<B> From<Ref<B, super::AddPskBrokerRequest>> for RequestRef<B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSlice> RequestRefMaker<B> {
|
||||
impl<B: SplitByteSlice> RequestRefMaker<B> {
|
||||
fn new(buf: B) -> anyhow::Result<Self> {
|
||||
let msg_type = buf.deref().request_msg_type_from_prefix()?;
|
||||
Ok(Self { buf, msg_type })
|
||||
@@ -125,7 +125,9 @@ impl<B: ByteSlice> RequestRefMaker<B> {
|
||||
self.ensure_fit()?;
|
||||
let point = self.target_size();
|
||||
let Self { buf, msg_type } = self;
|
||||
let (buf, _) = buf.split_at(point);
|
||||
let (buf, _) = buf
|
||||
.split_at(point)
|
||||
.map_err(|_| anyhow!("Failed to split buffer"))?;
|
||||
Ok(Self { buf, msg_type })
|
||||
}
|
||||
|
||||
@@ -134,7 +136,9 @@ impl<B: ByteSlice> RequestRefMaker<B> {
|
||||
self.ensure_fit()?;
|
||||
let point = self.buf.len() - self.target_size();
|
||||
let Self { buf, msg_type } = self;
|
||||
let (buf, _) = buf.split_at(point);
|
||||
let (buf, _) = buf
|
||||
.split_at(point)
|
||||
.map_err(|_| anyhow!("Failed to split buffer"))?;
|
||||
Ok(Self { buf, msg_type })
|
||||
}
|
||||
|
||||
@@ -159,7 +163,7 @@ pub enum RequestRef<B> {
|
||||
|
||||
impl<B> RequestRef<B>
|
||||
where
|
||||
B: ByteSlice,
|
||||
B: SplitByteSlice,
|
||||
{
|
||||
/// Access the byte data of this reference
|
||||
///
|
||||
@@ -168,25 +172,25 @@ where
|
||||
/// See [Self::parse].
|
||||
pub fn bytes(&self) -> &[u8] {
|
||||
match self {
|
||||
Self::Ping(r) => r.bytes(),
|
||||
Self::SupplyKeypair(r) => r.bytes(),
|
||||
Self::AddListenSocket(r) => r.bytes(),
|
||||
Self::AddPskBroker(r) => r.bytes(),
|
||||
Self::Ping(r) => r.as_bytes(),
|
||||
Self::SupplyKeypair(r) => r.as_bytes(),
|
||||
Self::AddListenSocket(r) => r.as_bytes(),
|
||||
Self::AddPskBroker(r) => r.as_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> RequestRef<B>
|
||||
where
|
||||
B: ByteSliceMut,
|
||||
B: SplitByteSliceMut,
|
||||
{
|
||||
/// Access the byte data of this reference; mutably
|
||||
pub fn bytes_mut(&mut self) -> &[u8] {
|
||||
match self {
|
||||
Self::Ping(r) => r.bytes_mut(),
|
||||
Self::SupplyKeypair(r) => r.bytes_mut(),
|
||||
Self::AddListenSocket(r) => r.bytes_mut(),
|
||||
Self::AddPskBroker(r) => r.bytes_mut(),
|
||||
Self::Ping(r) => r.as_mut_bytes(),
|
||||
Self::SupplyKeypair(r) => r.as_mut_bytes(),
|
||||
Self::AddListenSocket(r) => r.as_mut_bytes(),
|
||||
Self::AddPskBroker(r) => r.as_mut_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use rosenpass_util::zerocopy::{
|
||||
RefMaker, ZerocopyEmancipateExt, ZerocopyEmancipateMutExt, ZerocopySliceExt,
|
||||
};
|
||||
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
|
||||
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
|
||||
|
||||
use super::{Message, PingRequest, PingResponse};
|
||||
use super::{RequestRef, ResponseRef, ZerocopyResponseMakerSetupMessageExt};
|
||||
@@ -9,27 +9,27 @@ use super::{RequestRef, ResponseRef, ZerocopyResponseMakerSetupMessageExt};
|
||||
/// Extension trait for [Message]s that are requests messages
|
||||
pub trait RequestMsg: Sized + Message {
|
||||
/// The response message belonging to this request message
|
||||
type ResponseMsg: ResponseMsg;
|
||||
type ResponseMsg: ResponseMsg + Immutable + KnownLayout;
|
||||
|
||||
/// Construct a response make for this particular message
|
||||
fn zk_response_maker<B: ByteSlice>(buf: B) -> RefMaker<B, Self::ResponseMsg> {
|
||||
fn zk_response_maker<B: SplitByteSlice>(buf: B) -> RefMaker<B, Self::ResponseMsg> {
|
||||
buf.zk_ref_maker()
|
||||
}
|
||||
|
||||
/// Setup a response maker (through [Message::setup]) for this request message type
|
||||
fn setup_response<B: ByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
|
||||
fn setup_response<B: SplitByteSliceMut>(buf: B) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
|
||||
Self::zk_response_maker(buf).setup_msg()
|
||||
}
|
||||
|
||||
/// Setup a response maker from a buffer prefix (through [Message::setup]) for this request message type
|
||||
fn setup_response_from_prefix<B: ByteSliceMut>(
|
||||
fn setup_response_from_prefix<B: SplitByteSliceMut>(
|
||||
buf: B,
|
||||
) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
|
||||
Self::zk_response_maker(buf).from_prefix()?.setup_msg()
|
||||
}
|
||||
|
||||
/// Setup a response maker from a buffer suffix (through [Message::setup]) for this request message type
|
||||
fn setup_response_from_suffix<B: ByteSliceMut>(
|
||||
fn setup_response_from_suffix<B: SplitByteSliceMut>(
|
||||
buf: B,
|
||||
) -> anyhow::Result<Ref<B, Self::ResponseMsg>> {
|
||||
Self::zk_response_maker(buf).from_prefix()?.setup_msg()
|
||||
@@ -125,8 +125,8 @@ impl<B1, B2> From<AddPskBrokerPair<B1, B2>> for RequestResponsePair<B1, B2> {
|
||||
|
||||
impl<B1, B2> RequestResponsePair<B1, B2>
|
||||
where
|
||||
B1: ByteSlice,
|
||||
B2: ByteSlice,
|
||||
B1: SplitByteSlice,
|
||||
B2: SplitByteSlice,
|
||||
{
|
||||
/// Returns a tuple to both the request and the response message
|
||||
pub fn both(&self) -> (RequestRef<&[u8]>, ResponseRef<&[u8]>) {
|
||||
@@ -167,8 +167,8 @@ where
|
||||
|
||||
impl<B1, B2> RequestResponsePair<B1, B2>
|
||||
where
|
||||
B1: ByteSliceMut,
|
||||
B2: ByteSliceMut,
|
||||
B1: SplitByteSliceMut,
|
||||
B2: SplitByteSliceMut,
|
||||
{
|
||||
/// Returns a mutable tuple to both the request and the response message
|
||||
pub fn both_mut(&mut self) -> (RequestRef<&mut [u8]>, ResponseRef<&mut [u8]>) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// TODO: This is copied verbatim from ResponseRef…not pretty
|
||||
use anyhow::ensure;
|
||||
use anyhow::{anyhow, ensure};
|
||||
|
||||
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
|
||||
use zerocopy::{IntoBytes, Ref, SplitByteSlice, SplitByteSliceMut};
|
||||
|
||||
use super::{ByteSliceRefExt, MessageAttributes, PingResponse, ResponseMsgType};
|
||||
|
||||
@@ -16,14 +16,14 @@ struct ResponseRefMaker<B> {
|
||||
msg_type: ResponseMsgType,
|
||||
}
|
||||
|
||||
impl<B: ByteSlice> ResponseRef<B> {
|
||||
impl<B: SplitByteSlice> ResponseRef<B> {
|
||||
/// Produce a [ResponseRef] from a raw message buffer,
|
||||
/// reading the type from the buffer
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use zerocopy::AsBytes;
|
||||
/// use zerocopy::IntoBytes;
|
||||
///
|
||||
/// use rosenpass::api::{PingResponse, ResponseRef, ResponseMsgType};
|
||||
/// // Produce the original PingResponse
|
||||
@@ -99,7 +99,7 @@ impl<B> From<Ref<B, super::AddPskBrokerResponse>> for ResponseRef<B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSlice> ResponseRefMaker<B> {
|
||||
impl<B: SplitByteSlice> ResponseRefMaker<B> {
|
||||
fn new(buf: B) -> anyhow::Result<Self> {
|
||||
let msg_type = buf.deref().response_msg_type_from_prefix()?;
|
||||
Ok(Self { buf, msg_type })
|
||||
@@ -129,7 +129,9 @@ impl<B: ByteSlice> ResponseRefMaker<B> {
|
||||
self.ensure_fit()?;
|
||||
let point = self.target_size();
|
||||
let Self { buf, msg_type } = self;
|
||||
let (buf, _) = buf.split_at(point);
|
||||
let (buf, _) = buf
|
||||
.split_at(point)
|
||||
.map_err(|_| anyhow!("Failed to split buffer!"))?;
|
||||
Ok(Self { buf, msg_type })
|
||||
}
|
||||
|
||||
@@ -138,7 +140,9 @@ impl<B: ByteSlice> ResponseRefMaker<B> {
|
||||
self.ensure_fit()?;
|
||||
let point = self.buf.len() - self.target_size();
|
||||
let Self { buf, msg_type } = self;
|
||||
let (buf, _) = buf.split_at(point);
|
||||
let (buf, _) = buf
|
||||
.split_at(point)
|
||||
.map_err(|_| anyhow!("Failed to split buffer!"))?;
|
||||
Ok(Self { buf, msg_type })
|
||||
}
|
||||
|
||||
@@ -163,7 +167,7 @@ pub enum ResponseRef<B> {
|
||||
|
||||
impl<B> ResponseRef<B>
|
||||
where
|
||||
B: ByteSlice,
|
||||
B: SplitByteSlice,
|
||||
{
|
||||
/// Access the byte data of this reference
|
||||
///
|
||||
@@ -172,25 +176,25 @@ where
|
||||
/// See [Self::parse].
|
||||
pub fn bytes(&self) -> &[u8] {
|
||||
match self {
|
||||
Self::Ping(r) => r.bytes(),
|
||||
Self::SupplyKeypair(r) => r.bytes(),
|
||||
Self::AddListenSocket(r) => r.bytes(),
|
||||
Self::AddPskBroker(r) => r.bytes(),
|
||||
Self::Ping(r) => r.as_bytes(),
|
||||
Self::SupplyKeypair(r) => r.as_bytes(),
|
||||
Self::AddListenSocket(r) => r.as_bytes(),
|
||||
Self::AddPskBroker(r) => r.as_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> ResponseRef<B>
|
||||
where
|
||||
B: ByteSliceMut,
|
||||
B: SplitByteSliceMut,
|
||||
{
|
||||
/// Access the byte data of this reference; mutably
|
||||
pub fn bytes_mut(&mut self) -> &[u8] {
|
||||
match self {
|
||||
Self::Ping(r) => r.bytes_mut(),
|
||||
Self::SupplyKeypair(r) => r.bytes_mut(),
|
||||
Self::AddListenSocket(r) => r.bytes_mut(),
|
||||
Self::AddPskBroker(r) => r.bytes_mut(),
|
||||
Self::Ping(r) => r.as_mut_bytes(),
|
||||
Self::SupplyKeypair(r) => r.as_mut_bytes(),
|
||||
Self::AddListenSocket(r) => r.as_mut_bytes(),
|
||||
Self::AddPskBroker(r) => r.as_mut_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{ByteSliceRefExt, Message, PingRequest, PingResponse, RequestRef, RequestResponsePair};
|
||||
use std::{collections::VecDeque, os::fd::OwnedFd};
|
||||
use zerocopy::{ByteSlice, ByteSliceMut};
|
||||
use zerocopy::{SplitByteSlice, SplitByteSliceMut};
|
||||
|
||||
/// The rosenpass API implementation functions.
|
||||
///
|
||||
@@ -152,8 +152,8 @@ pub trait Server {
|
||||
req_fds: &mut VecDeque<OwnedFd>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
ReqBuf: ByteSlice,
|
||||
ResBuf: ByteSliceMut,
|
||||
ReqBuf: SplitByteSlice,
|
||||
ResBuf: SplitByteSliceMut,
|
||||
{
|
||||
match p {
|
||||
RequestResponsePair::Ping((req, res)) => self.ping(req, req_fds, res),
|
||||
@@ -182,8 +182,8 @@ pub trait Server {
|
||||
res: ResBuf,
|
||||
) -> anyhow::Result<usize>
|
||||
where
|
||||
ReqBuf: ByteSlice,
|
||||
ResBuf: ByteSliceMut,
|
||||
ReqBuf: SplitByteSlice,
|
||||
ResBuf: SplitByteSliceMut,
|
||||
{
|
||||
let req = req.parse_request_from_prefix()?;
|
||||
// TODO: This is not pretty; This match should be moved into RequestRef
|
||||
|
||||
@@ -13,7 +13,7 @@ use signal_hook_mio::v1_0 as signal_hook_mio;
|
||||
use anyhow::{bail, Context, Result};
|
||||
use derive_builder::Builder;
|
||||
use log::{error, info, warn};
|
||||
use zerocopy::AsBytes;
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use rosenpass_util::attempt;
|
||||
use rosenpass_util::fmt::debug::NullDebug;
|
||||
|
||||
@@ -26,7 +26,7 @@ use {
|
||||
command_fds::{CommandFdExt, FdMapping},
|
||||
log::{error, info},
|
||||
mio::net::UnixStream,
|
||||
rosenpass_util::fd::claim_fd,
|
||||
rosenpass_util::rustix::claim_fd,
|
||||
rosenpass_wireguard_broker::brokers::mio_client::MioBrokerClient,
|
||||
rosenpass_wireguard_broker::WireguardBrokerMio,
|
||||
rustix::net::{socketpair, AddressFamily, SocketFlags, SocketType},
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
//! To achieve this we utilize the zerocopy library.
|
||||
//!
|
||||
use std::mem::size_of;
|
||||
use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
|
||||
|
||||
use super::RosenpassError;
|
||||
use rosenpass_cipher_traits::primitives::{Aead as _, Kem};
|
||||
@@ -51,7 +51,7 @@ pub type MsgEnvelopeCookie = [u8; COOKIE_SIZE];
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass::msgs::{Envelope, InitHello};
|
||||
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
|
||||
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
|
||||
/// use memoffset::offset_of;
|
||||
///
|
||||
/// // Zero-initialization
|
||||
@@ -76,8 +76,8 @@ pub type MsgEnvelopeCookie = [u8; COOKIE_SIZE];
|
||||
/// assert_eq!(ih3.msg_type, 42);
|
||||
/// ```
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone)]
|
||||
pub struct Envelope<M: AsBytes + FromBytes> {
|
||||
#[derive(IntoBytes, FromBytes, Clone, Immutable, KnownLayout)]
|
||||
pub struct Envelope<M: IntoBytes + FromBytes> {
|
||||
/// [MsgType] of this message
|
||||
pub msg_type: u8,
|
||||
/// Reserved for future use
|
||||
@@ -106,7 +106,7 @@ pub struct Envelope<M: AsBytes + FromBytes> {
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass::msgs::{Envelope, InitHello};
|
||||
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
|
||||
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
|
||||
/// use memoffset::span_of;
|
||||
///
|
||||
/// // Zero initialization
|
||||
@@ -126,7 +126,7 @@ pub struct Envelope<M: AsBytes + FromBytes> {
|
||||
/// assert_eq!(ih.payload.sidi, [1,2,3,4]);
|
||||
/// ```
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct InitHello {
|
||||
/// Randomly generated connection id
|
||||
pub sidi: [u8; 4],
|
||||
@@ -155,7 +155,7 @@ pub struct InitHello {
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass::msgs::{Envelope, RespHello};
|
||||
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
|
||||
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
|
||||
/// use memoffset::span_of;
|
||||
///
|
||||
/// // Zero initialization
|
||||
@@ -175,7 +175,7 @@ pub struct InitHello {
|
||||
/// assert_eq!(ih.payload.sidi, [1,2,3,4]);
|
||||
/// ```
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct RespHello {
|
||||
/// Randomly generated connection id
|
||||
pub sidr: [u8; 4],
|
||||
@@ -206,7 +206,7 @@ pub struct RespHello {
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass::msgs::{Envelope, InitConf};
|
||||
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
|
||||
/// use zerocopy::{IntoBytes, FromBytes, FromZeros, Ref};
|
||||
/// use memoffset::span_of;
|
||||
///
|
||||
/// // Zero initialization
|
||||
@@ -226,7 +226,7 @@ pub struct RespHello {
|
||||
/// assert_eq!(ih.payload.sidi, [1,2,3,4]);
|
||||
/// ```
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Debug)]
|
||||
#[derive(IntoBytes, FromBytes, Debug, Immutable, KnownLayout)]
|
||||
pub struct InitConf {
|
||||
/// Copied from InitHello
|
||||
pub sidi: [u8; 4],
|
||||
@@ -264,7 +264,7 @@ pub struct InitConf {
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass::msgs::{Envelope, EmptyData};
|
||||
/// use zerocopy::{AsBytes, FromBytes, Ref, FromZeroes};
|
||||
/// use zerocopy::{FromZeros, IntoBytes, FromBytes, Ref};
|
||||
/// use memoffset::span_of;
|
||||
///
|
||||
/// // Zero initialization
|
||||
@@ -284,7 +284,7 @@ pub struct InitConf {
|
||||
/// assert_eq!(ih.payload.sid, [1,2,3,4]);
|
||||
/// ```
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Copy)]
|
||||
#[derive(IntoBytes, FromBytes, Clone, Copy, Immutable, KnownLayout)]
|
||||
pub struct EmptyData {
|
||||
/// Copied from RespHello
|
||||
pub sid: [u8; 4],
|
||||
@@ -311,7 +311,7 @@ pub struct EmptyData {
|
||||
///
|
||||
/// [Envelope] and [InitHello] contain some extra examples on how to use structures from the [::zerocopy] crate.
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct Biscuit {
|
||||
/// H(spki) – Ident ifies the initiator
|
||||
pub pidi: [u8; KEY_LEN],
|
||||
@@ -336,7 +336,7 @@ pub struct Biscuit {
|
||||
///
|
||||
/// [Envelope] and [InitHello] contain some extra examples on how to use structures from the [::zerocopy] crate.
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct CookieReplyInner {
|
||||
/// [MsgType] of this message
|
||||
pub msg_type: u8,
|
||||
@@ -363,7 +363,7 @@ pub struct CookieReplyInner {
|
||||
///
|
||||
/// [Envelope] and [InitHello] contain some extra examples on how to use structures from the [::zerocopy] crate.
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct CookieReply {
|
||||
pub inner: CookieReplyInner,
|
||||
pub padding: [u8; size_of::<Envelope<InitHello>>() - size_of::<CookieReplyInner>()],
|
||||
|
||||
@@ -17,7 +17,7 @@ use std::{
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use assert_tv::{TestVector, TestVectorNOP};
|
||||
use memoffset::span_of;
|
||||
use zerocopy::{AsBytes, FromBytes, Ref};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref};
|
||||
|
||||
use rosenpass_cipher_traits::primitives::{
|
||||
Aead as _, AeadWithNonceInCiphertext, Kem, KeyedHashInstance,
|
||||
@@ -413,7 +413,7 @@ pub struct InitiatorHandshake {
|
||||
///
|
||||
/// Used as [KnownInitConfResponse] for now cache [EmptyData] (responder confirmation)
|
||||
/// responses to [InitConf]
|
||||
pub struct KnownResponse<ResponseType: AsBytes + FromBytes> {
|
||||
pub struct KnownResponse<ResponseType: IntoBytes + FromBytes> {
|
||||
/// When the response was initially computed
|
||||
pub received_at: Timing,
|
||||
/// Hash of the message that triggered the response; created using
|
||||
@@ -423,7 +423,7 @@ pub struct KnownResponse<ResponseType: AsBytes + FromBytes> {
|
||||
pub response: Envelope<ResponseType>,
|
||||
}
|
||||
|
||||
impl<ResponseType: AsBytes + FromBytes> Debug for KnownResponse<ResponseType> {
|
||||
impl<ResponseType: IntoBytes + FromBytes> Debug for KnownResponse<ResponseType> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("KnownResponse")
|
||||
.field("received_at", &self.received_at)
|
||||
@@ -435,7 +435,7 @@ impl<ResponseType: AsBytes + FromBytes> Debug for KnownResponse<ResponseType> {
|
||||
|
||||
#[test]
|
||||
fn known_response_format() {
|
||||
use zerocopy::FromZeroes;
|
||||
use zerocopy::FromZeros;
|
||||
|
||||
let v = KnownResponse::<[u8; 32]> {
|
||||
received_at: 42.0,
|
||||
@@ -464,7 +464,7 @@ pub type KnownResponseHash = Public<16>;
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use zerocopy::FromZeroes;
|
||||
/// use zerocopy::FromZeros;
|
||||
/// use rosenpass::protocol::KnownResponseHasher;
|
||||
/// use rosenpass::msgs::{Envelope, InitConf};
|
||||
///
|
||||
@@ -508,7 +508,10 @@ impl KnownResponseHasher {
|
||||
/// # Panic & Safety
|
||||
///
|
||||
/// Panics in case of a problem with this underlying hash function
|
||||
pub fn hash<Msg: AsBytes + FromBytes>(&self, msg: &Envelope<Msg>) -> KnownResponseHash {
|
||||
pub fn hash<Msg: IntoBytes + FromBytes + Immutable>(
|
||||
&self,
|
||||
msg: &Envelope<Msg>,
|
||||
) -> KnownResponseHash {
|
||||
let data = &msg.as_bytes()[span_of!(Envelope<Msg>, msg_type..cookie)];
|
||||
// This function is only used internally and results are not propagated
|
||||
// to outside the peer. Thus, it uses SHAKE256 exclusively.
|
||||
@@ -2042,7 +2045,8 @@ impl CryptoServer {
|
||||
|
||||
let mut expected = [0u8; COOKIE_SIZE];
|
||||
|
||||
let msg_in = Ref::<&[u8], Envelope<InitHello>>::new(rx_buf)
|
||||
let msg_in = Ref::<&[u8], Envelope<InitHello>>::from_bytes(rx_buf)
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
expected.copy_from_slice(
|
||||
&hash_domains::cookie(KeyedHash::keyed_shake256())?
|
||||
@@ -2187,8 +2191,9 @@ impl CryptoServer {
|
||||
|
||||
let peer = match msg_type {
|
||||
Ok(MsgType::InitHello) => {
|
||||
let msg_in: Ref<&[u8], Envelope<InitHello>> =
|
||||
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
let msg_in: Ref<&[u8], Envelope<InitHello>> = Ref::from_bytes(rx_buf)
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
|
||||
// At this point, we do not know the hash functon used by the peer, thus we try both,
|
||||
// with a preference for SHAKE256.
|
||||
@@ -2221,8 +2226,9 @@ impl CryptoServer {
|
||||
peer
|
||||
}
|
||||
Ok(MsgType::RespHello) => {
|
||||
let msg_in: Ref<&[u8], Envelope<RespHello>> =
|
||||
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
let msg_in: Ref<&[u8], Envelope<RespHello>> = Ref::from_bytes(rx_buf)
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
|
||||
let mut msg_out = truncating_cast_into::<Envelope<InitConf>>(tx_buf)?;
|
||||
let peer = self.handle_resp_hello(&msg_in.payload, &mut msg_out.payload)?;
|
||||
@@ -2238,8 +2244,9 @@ impl CryptoServer {
|
||||
peer
|
||||
}
|
||||
Ok(MsgType::InitConf) => {
|
||||
let msg_in: Ref<&[u8], Envelope<InitConf>> =
|
||||
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
let msg_in: Ref<&[u8], Envelope<InitConf>> = Ref::from_bytes(rx_buf)
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
|
||||
let mut msg_out = truncating_cast_into::<Envelope<EmptyData>>(tx_buf)?;
|
||||
|
||||
@@ -2258,7 +2265,7 @@ impl CryptoServer {
|
||||
.map(|v| v.response.borrow())
|
||||
// Invalid! Found peer no with cache in index but the cache does not exist
|
||||
.unwrap();
|
||||
copy_slice(cached.as_bytes()).to(msg_out.as_bytes_mut());
|
||||
copy_slice(cached.as_bytes()).to(msg_out.as_mut_bytes());
|
||||
peer
|
||||
}
|
||||
|
||||
@@ -2306,14 +2313,16 @@ impl CryptoServer {
|
||||
peer
|
||||
}
|
||||
Ok(MsgType::EmptyData) => {
|
||||
let msg_in: Ref<&[u8], Envelope<EmptyData>> =
|
||||
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
let msg_in: Ref<&[u8], Envelope<EmptyData>> = Ref::from_bytes(rx_buf)
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
|
||||
self.handle_resp_conf(&msg_in, seal_broken.to_string())?
|
||||
}
|
||||
Ok(MsgType::CookieReply) => {
|
||||
let msg_in: Ref<&[u8], CookieReply> =
|
||||
Ref::new(rx_buf).ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
let msg_in: Ref<&[u8], CookieReply> = Ref::from_bytes(rx_buf)
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)?;
|
||||
let peer = self.handle_cookie_reply(&msg_in)?;
|
||||
len = 0;
|
||||
peer
|
||||
@@ -2354,8 +2363,8 @@ impl CryptoServer {
|
||||
///
|
||||
/// To save some code, the function returns the size of the message,
|
||||
/// but the same could be easily achieved by calling [size_of] with the
|
||||
/// message type or by calling [AsBytes::as_bytes] on the message reference.
|
||||
pub fn seal_and_commit_msg<M: AsBytes + FromBytes>(
|
||||
/// message type or by calling [IntoBytes::as_bytes] on the message reference.
|
||||
pub fn seal_and_commit_msg<M: IntoBytes + FromBytes + Immutable + KnownLayout>(
|
||||
&mut self,
|
||||
peer: PeerPtr,
|
||||
msg_type: MsgType,
|
||||
@@ -3078,7 +3087,7 @@ impl IniHsPtr {
|
||||
|
||||
impl<M> Envelope<M>
|
||||
where
|
||||
M: AsBytes + FromBytes,
|
||||
M: IntoBytes + FromBytes + Immutable + KnownLayout,
|
||||
{
|
||||
/// Internal business logic: Calculate the message authentication code (`mac`) and also append cookie value
|
||||
pub fn seal(&mut self, peer: PeerPtr, srv: &CryptoServer) -> Result<()> {
|
||||
@@ -3107,7 +3116,7 @@ where
|
||||
|
||||
impl<M> Envelope<M>
|
||||
where
|
||||
M: AsBytes + FromBytes,
|
||||
M: IntoBytes + FromBytes + Immutable + KnownLayout,
|
||||
{
|
||||
/// Internal business logic: Check the message authentication code produced by [Self::seal]
|
||||
pub fn check_seal(&self, srv: &CryptoServer, shake_or_blake: KeyedHash) -> Result<bool> {
|
||||
@@ -3240,7 +3249,7 @@ impl HandshakeState {
|
||||
/// out the const generics.
|
||||
/// - By adding a value parameter of type `PhantomData<TV>`, you can choose
|
||||
/// `TV` at the call site while allowing the compiler to infer `KEM_*`
|
||||
/// const generics from `ct` and `pk`.
|
||||
/// const generics from `ct` and `pk`.
|
||||
/// - Call like: `encaps_and_mix_with_test_vector(&StaticKem, &mut ct, pk,
|
||||
/// PhantomData::<TestVectorActive>)?;`
|
||||
pub fn encaps_and_mix_with_test_vector<
|
||||
@@ -3322,7 +3331,7 @@ impl HandshakeState {
|
||||
let test_values: StoreBiscuitTestValues = TV::initialize_values();
|
||||
let mut biscuit = Secret::<BISCUIT_PT_LEN>::zero(); // pt buffer
|
||||
let mut biscuit: Ref<&mut [u8], Biscuit> =
|
||||
Ref::new(biscuit.secret_mut().as_mut_slice()).unwrap();
|
||||
Ref::from_bytes(biscuit.secret_mut().as_mut_slice()).unwrap();
|
||||
|
||||
// calculate pt contents
|
||||
biscuit
|
||||
@@ -3384,9 +3393,9 @@ impl HandshakeState {
|
||||
// Allocate and decrypt the biscuit data
|
||||
let mut biscuit = Secret::<BISCUIT_PT_LEN>::zero(); // pt buf
|
||||
let mut biscuit: Ref<&mut [u8], Biscuit> =
|
||||
Ref::new(biscuit.secret_mut().as_mut_slice()).unwrap();
|
||||
Ref::from_bytes(biscuit.secret_mut().as_mut_slice()).unwrap();
|
||||
XAead.decrypt_with_nonce_in_ctxt(
|
||||
biscuit.as_bytes_mut(),
|
||||
biscuit.as_mut_bytes(),
|
||||
bk.get(srv).value.secret(),
|
||||
&ad,
|
||||
biscuit_ct,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{borrow::BorrowMut, fmt::Display, net::SocketAddrV4, ops::DerefMut};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serial_test::serial;
|
||||
use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout};
|
||||
|
||||
use rosenpass_cipher_traits::primitives::Kem;
|
||||
use rosenpass_ciphers::StaticKem;
|
||||
@@ -538,10 +538,10 @@ fn init_conf_retransmission(protocol_version: ProtocolVersion) -> anyhow::Result
|
||||
srv.initiate_handshake(peer, buf.as_mut_slice())?
|
||||
.discard_result();
|
||||
let msg = truncating_cast_into::<Envelope<InitHello>>(buf.borrow_mut())?;
|
||||
Ok(msg.read())
|
||||
Ok(zerocopy::Ref::read(&msg))
|
||||
}
|
||||
|
||||
fn proc_msg<Rx: AsBytes + FromBytes, Tx: AsBytes + FromBytes>(
|
||||
fn proc_msg<Rx: IntoBytes + FromBytes + Immutable, Tx: IntoBytes + FromBytes + Immutable>(
|
||||
srv: &mut CryptoServer,
|
||||
rx: &Envelope<Rx>,
|
||||
) -> anyhow::Result<Envelope<Tx>> {
|
||||
@@ -551,7 +551,7 @@ fn init_conf_retransmission(protocol_version: ProtocolVersion) -> anyhow::Result
|
||||
.context("Failed to produce RespHello message")?
|
||||
.discard_result();
|
||||
let msg = truncating_cast_into::<Envelope<Tx>>(buf.borrow_mut())?;
|
||||
Ok(msg.read())
|
||||
Ok(zerocopy::Ref::read(&msg))
|
||||
}
|
||||
|
||||
fn proc_init_hello(
|
||||
@@ -582,17 +582,21 @@ fn init_conf_retransmission(protocol_version: ProtocolVersion) -> anyhow::Result
|
||||
}
|
||||
|
||||
// TODO: Implement Clone on our message types
|
||||
fn clone_msg<Msg: AsBytes + FromBytes>(msg: &Msg) -> anyhow::Result<Msg> {
|
||||
Ok(truncating_cast_into_nomut::<Msg>(msg.as_bytes())?.read())
|
||||
fn clone_msg<Msg: IntoBytes + FromBytes + Immutable + KnownLayout>(
|
||||
msg: &Msg,
|
||||
) -> anyhow::Result<Msg> {
|
||||
Ok(zerocopy::Ref::read(&truncating_cast_into_nomut::<Msg>(
|
||||
msg.as_bytes(),
|
||||
)?))
|
||||
}
|
||||
|
||||
fn break_payload<Msg: AsBytes + FromBytes>(
|
||||
fn break_payload<Msg: IntoBytes + FromBytes + Immutable + KnownLayout>(
|
||||
srv: &mut CryptoServer,
|
||||
peer: PeerPtr,
|
||||
msg: &Envelope<Msg>,
|
||||
) -> anyhow::Result<Envelope<Msg>> {
|
||||
let mut msg = clone_msg(msg)?;
|
||||
msg.as_bytes_mut()[memoffset::offset_of!(Envelope<Msg>, payload)] ^= 0x01;
|
||||
msg.as_mut_bytes()[memoffset::offset_of!(Envelope<Msg>, payload)] ^= 0x01;
|
||||
msg.seal(peer, srv)?; // Recalculate seal; we do not want to focus on "seal broken" errs
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use assert_tv::TestVectorSet;
|
||||
use base64::Engine;
|
||||
use rosenpass_cipher_traits::primitives::{Aead, Kem};
|
||||
use rosenpass_ciphers::{EphemeralKem, XAead, KEY_LEN};
|
||||
use rosenpass_secret_memory::{Public, PublicBox, Secret};
|
||||
use rosenpass_secret_memory::{Public, Secret};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(TestVectorSet)]
|
||||
|
||||
@@ -2,20 +2,24 @@
|
||||
|
||||
use std::mem::size_of;
|
||||
|
||||
use zerocopy::{FromBytes, Ref};
|
||||
use zerocopy::{FromBytes, Immutable, KnownLayout, Ref};
|
||||
|
||||
use crate::RosenpassError;
|
||||
|
||||
/// Used to parse a network message using [zerocopy]
|
||||
pub fn truncating_cast_into<T: FromBytes>(
|
||||
pub fn truncating_cast_into<T: FromBytes + KnownLayout + Immutable>(
|
||||
buf: &mut [u8],
|
||||
) -> Result<Ref<&mut [u8], T>, RosenpassError> {
|
||||
Ref::new(&mut buf[..size_of::<T>()]).ok_or(RosenpassError::BufferSizeMismatch)
|
||||
Ref::from_bytes(&mut buf[..size_of::<T>()])
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)
|
||||
}
|
||||
|
||||
/// Used to parse a network message using [zerocopy], mutably
|
||||
pub fn truncating_cast_into_nomut<T: FromBytes>(
|
||||
pub fn truncating_cast_into_nomut<T: FromBytes + KnownLayout + Immutable>(
|
||||
buf: &[u8],
|
||||
) -> Result<Ref<&[u8], T>, RosenpassError> {
|
||||
Ref::new(&buf[..size_of::<T>()]).ok_or(RosenpassError::BufferSizeMismatch)
|
||||
Ref::from_bytes(&buf[..size_of::<T>()])
|
||||
.ok()
|
||||
.ok_or(RosenpassError::BufferSizeMismatch)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ use rosenpass_util::{
|
||||
};
|
||||
use std::os::fd::{AsFd, AsRawFd};
|
||||
use tempfile::TempDir;
|
||||
use zerocopy::AsBytes;
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
struct KillChild(std::process::Child);
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ use rosenpass_util::{
|
||||
};
|
||||
use rosenpass_util::{mem::DiscardResultExt, zerocopy::ZerocopySliceExt};
|
||||
use tempfile::TempDir;
|
||||
use zerocopy::AsBytes;
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use rosenpass::config::ProtocolVersion;
|
||||
use rosenpass::protocol::basic_types::SymKey;
|
||||
|
||||
@@ -13,11 +13,14 @@ rust-version = "1.77.0"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
rosenpass-to = { workspace = true }
|
||||
base64ct = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
typenum = { workspace = true }
|
||||
static_assertions = { workspace = true }
|
||||
rustix = { workspace = true }
|
||||
rustix = { workspace = true, features = ["net", "fs", "process", "mm"] }
|
||||
libc = { workspace = true }
|
||||
errno = { workspace = true }
|
||||
zeroize = { workspace = true }
|
||||
zerocopy = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
@@ -32,6 +35,9 @@ tokio = { workspace = true, optional = true, features = [
|
||||
"time",
|
||||
] }
|
||||
log = { workspace = true }
|
||||
bitflags = { workspace = true }
|
||||
tinyvec = "1.10.0"
|
||||
num-traits = "0.2.19"
|
||||
|
||||
[features]
|
||||
experiment_file_descriptor_passing = ["uds"]
|
||||
|
||||
85
util/src/convert.rs
Normal file
85
util/src/convert.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Additional helpers to [std::convert]: Traits for conversions between types.
|
||||
|
||||
/// Variant of [std::convert::Into] with an explicitly specified type.
|
||||
///
|
||||
/// This facilitates method chaining.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// We can create a nicer implementation of the following function using this extension trait
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::convert::IntoTypeExt;
|
||||
///
|
||||
/// fn encode_char_u32_be_1(c: char) -> [u8; 4] {
|
||||
/// let i : u32 = c.into();
|
||||
/// i.to_be_bytes()
|
||||
/// }
|
||||
///
|
||||
/// fn encode_char_u32_be_2(c: char) -> [u8; 4] {
|
||||
/// c.into_type::<u32>().to_be_bytes()
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(encode_char_u32_be_1('X'), [0x00, 0x00, 0x00, 0x58]);
|
||||
/// assert_eq!(encode_char_u32_be_1('X'), encode_char_u32_be_2('X'));
|
||||
///
|
||||
/// ```
|
||||
pub trait IntoTypeExt {
|
||||
/// Variant of [std::convert::Into] with explicitly specified type.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [IntoType].
|
||||
fn into_type<T>(self) -> T
|
||||
where
|
||||
Self: Into<T>,
|
||||
{
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoTypeExt for T {}
|
||||
|
||||
/// Variant of [std::convert::TryInto] with an explicitly specified type.
|
||||
///
|
||||
/// This facilitates method chaining.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// We can create a nicer implementation of the following function using this extension trait
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::convert::TryIntoTypeExt;
|
||||
/// use rosenpass_util::result::OkExt;
|
||||
///
|
||||
/// fn encode_char_u16_be_1(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
|
||||
/// let i : u16 = c.try_into()?;
|
||||
/// Ok(i.to_be_bytes())
|
||||
/// }
|
||||
///
|
||||
/// fn encode_char_u16_be_2(c: char) -> Result<[u8; 2], <char as TryInto<u16>>::Error> {
|
||||
/// c.try_into_type::<u16>()?.to_be_bytes().ok()
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(encode_char_u16_be_1('X'), Ok([0x00, 0x58]));
|
||||
/// assert_eq!(encode_char_u16_be_1('X'), encode_char_u16_be_2('X'));
|
||||
///
|
||||
/// const HEART_CAT : char = '😻'; // 1F63B
|
||||
/// assert!(encode_char_u16_be_1(HEART_CAT).is_err());
|
||||
/// assert_eq!(encode_char_u16_be_1(HEART_CAT), encode_char_u16_be_2(HEART_CAT));
|
||||
/// ```
|
||||
pub trait TryIntoTypeExt {
|
||||
/// Variant of [std::convert::TryInto] with explicitly specified type.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [TryIntoType].
|
||||
fn try_into_type<T>(self) -> Result<T, <Self as TryInto<T>>::Error>
|
||||
where
|
||||
Self: TryInto<T>,
|
||||
{
|
||||
self.try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TryIntoTypeExt for T {}
|
||||
4
util/src/int/mod.rs
Normal file
4
util/src/int/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Helpers for working with integer types
|
||||
|
||||
pub mod modular;
|
||||
pub mod u64uint;
|
||||
276
util/src/int/modular.rs
Normal file
276
util/src/int/modular.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
//! Numeric types for modular arithmetic
|
||||
|
||||
use num_traits::{
|
||||
ops::overflowing::OverflowingAdd, CheckedMul, Euclid, Num, Unsigned, WrappingNeg, Zero,
|
||||
};
|
||||
|
||||
/// Summary-trait for numeric types that can serve as the basis for ModuleBase
|
||||
pub trait ModuleBase: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
|
||||
impl<T> ModuleBase for T where T: Num + Ord + Copy + Unsigned + OverflowingAdd + Zero {}
|
||||
|
||||
/// Represents a modulus; i.e. the range of values some number type is allowed to use.
|
||||
///
|
||||
/// This is based on some inner representation.
|
||||
///
|
||||
/// This is not just a value of the underlying representation, because it also supports the modulus
|
||||
/// [Self::new_full_range()], which indicates that the full range of the underlying type is to be
|
||||
/// supported.
|
||||
///
|
||||
/// Note that zero is not a valid modulus
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub struct Modulus<T: ModuleBase> {
|
||||
/// Inner representation of the type
|
||||
modulus: T,
|
||||
}
|
||||
|
||||
impl<T: ModuleBase> Modulus<T> {
|
||||
/// Create a new [Self] without any checks
|
||||
fn raw_new(modulus: T) -> Self {
|
||||
Self { modulus }
|
||||
}
|
||||
|
||||
/// Create a new [Self] that indicates that the full range of the underlying type is to be used
|
||||
pub fn new_full_range() -> Self {
|
||||
Self::raw_new(T::zero())
|
||||
}
|
||||
|
||||
/// Try to create a new [Self]. Will return None only if `modulus == 0`
|
||||
pub fn try_new(modulus: T) -> Option<Self> {
|
||||
match modulus == T::zero() {
|
||||
true => None,
|
||||
false => Some(Self::raw_new(modulus)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Like [Self::try_new] but will panic if `modulus == 0`
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// Will panic if `modulus == 0`
|
||||
pub fn new_or_panic(modulus: T) -> Self {
|
||||
match Self::try_new(modulus) {
|
||||
None => panic!("Can not create Module with modulus zero!"),
|
||||
Some(me) => me,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this [Self] represents the full range of the underlying type
|
||||
pub fn is_full_range(&self) -> bool {
|
||||
self.modulus == T::zero()
|
||||
}
|
||||
|
||||
/// Get the raw modulus. I.e. the value of the underlying type that can be given to
|
||||
/// a modulo operation to implement modular arithmetic.
|
||||
///
|
||||
/// Will return None if [Self::is_full_range].
|
||||
pub fn modulus(&self) -> Option<T> {
|
||||
match self.is_full_range() {
|
||||
true => None,
|
||||
false => Some(self.modulus),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the given type is contained in the range represented by this [Self]
|
||||
pub fn contains(&self, v: T) -> bool {
|
||||
match self.is_full_range() {
|
||||
true => true,
|
||||
false => v < self.modulus,
|
||||
}
|
||||
}
|
||||
/// Double the modulus.
|
||||
///
|
||||
/// Correctly handles the case that `v.double().is_full_range()`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use rosenpass_util::int::modular::Modulus;
|
||||
///
|
||||
/// fn m(v: u8) -> Modulus<u8> {
|
||||
/// Modulus::new_or_panic(v)
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(m(100).double(), Some(m(200)));
|
||||
/// assert_eq!(m(128).double(), Some(Modulus::new_full_range()));
|
||||
/// assert_eq!(m(129).double(), None);
|
||||
/// ```
|
||||
pub fn double(&self) -> Option<Self> {
|
||||
let s = self.modulus()?;
|
||||
match s.overflowing_add(&s) {
|
||||
(d, true) if d > T::zero() => None,
|
||||
(d, _) => Some(Self::raw_new(d)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new [ModularArithmetic] by taking the value modulo the modulus
|
||||
pub fn new_number<U: Into<T>>(self, value: U) -> ModularArithmetic<T>
|
||||
where
|
||||
T: ModularArithmeticBase,
|
||||
{
|
||||
ModularArithmetic::modular_new(value.into(), self)
|
||||
}
|
||||
|
||||
/// Apply [Self::new_number] to each of the parameters, return whatever result the closure
|
||||
/// produces
|
||||
pub fn with_converted<U, const N: usize, R, F>(&self, params: [U; N], f: F) -> R
|
||||
where
|
||||
Self: Copy,
|
||||
T: std::fmt::Debug + ModularArithmeticBase,
|
||||
U: Into<T>,
|
||||
F: FnOnce([ModularArithmetic<T>; N]) -> R,
|
||||
{
|
||||
let params = params.map(|v| self.new_number(v));
|
||||
f(params)
|
||||
}
|
||||
|
||||
/// Apply [Self::new_number] to each of the parameters, converting the result to the underlying
|
||||
/// representation
|
||||
pub fn formula<U, const N: usize, F>(&self, params: [U; N], f: F) -> T
|
||||
where
|
||||
Self: Copy,
|
||||
T: std::fmt::Debug + ModularArithmeticBase,
|
||||
U: Into<T>,
|
||||
F: FnOnce([ModularArithmetic<T>; N]) -> ModularArithmetic<T>,
|
||||
{
|
||||
self.with_converted(params, f).value()
|
||||
}
|
||||
}
|
||||
|
||||
/// Summary trait for types that can serve as the basis for [ModularArithmetic]
|
||||
pub trait ModularArithmeticBase:
|
||||
ModuleBase
|
||||
+ std::fmt::Debug
|
||||
+ Num
|
||||
+ PartialOrd
|
||||
+ Ord
|
||||
+ Copy
|
||||
+ Unsigned
|
||||
+ CheckedMul
|
||||
+ Euclid
|
||||
+ WrappingNeg
|
||||
{
|
||||
}
|
||||
impl<T> ModularArithmeticBase for T where
|
||||
T: ModuleBase
|
||||
+ std::fmt::Debug
|
||||
+ Num
|
||||
+ PartialOrd
|
||||
+ Ord
|
||||
+ Copy
|
||||
+ Unsigned
|
||||
+ CheckedMul
|
||||
+ Euclid
|
||||
+ WrappingNeg
|
||||
{
|
||||
}
|
||||
|
||||
/// Modular arithmetic with an arbitrary modulus
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct ModularArithmetic<T: ModularArithmeticBase> {
|
||||
/// The modulus
|
||||
modulus: Modulus<T>,
|
||||
/// The value inside the modulus
|
||||
///
|
||||
/// Note that `self.modulus.contains(self.value)` must always hold.
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> ModularArithmetic<T> {
|
||||
/// Construct a new [Self].
|
||||
///
|
||||
/// Will return none unless `module.`[contains](Modulus::contains)`(value)`.
|
||||
pub fn try_new(value: T, module: Modulus<T>) -> Option<Self> {
|
||||
module.contains(value).then_some(Self {
|
||||
value,
|
||||
modulus: module,
|
||||
})
|
||||
}
|
||||
|
||||
/// Construct a new [Self].
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// Will punic unless `module.`[contains](Modulus::contains)`(value)`.
|
||||
pub fn modular_new(value: T, module: Modulus<T>) -> Self {
|
||||
let value = match module.modulus() {
|
||||
Some(m) => value.rem_euclid(&m),
|
||||
None => value,
|
||||
};
|
||||
|
||||
Self {
|
||||
modulus: module,
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the modulus
|
||||
pub fn modulus(&self) -> &Modulus<T> {
|
||||
&self.modulus
|
||||
}
|
||||
|
||||
/// The inner value
|
||||
pub fn value(&self) -> T {
|
||||
self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> PartialEq for ModularArithmetic<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.value == other.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> Eq for ModularArithmetic<T> {}
|
||||
|
||||
impl<T: ModularArithmeticBase> PartialOrd for ModularArithmetic<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> Ord for ModularArithmetic<T> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.value.cmp(&other.value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> std::ops::Neg for ModularArithmetic<T> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let ret = match self.modulus().modulus() {
|
||||
None => self.value().wrapping_neg(),
|
||||
Some(modulus) => modulus - self.value(),
|
||||
};
|
||||
|
||||
Self {
|
||||
value: ret,
|
||||
modulus: self.modulus,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> std::ops::Sub for ModularArithmetic<T> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
assert_eq!(self.modulus(), rhs.modulus());
|
||||
|
||||
if self < rhs {
|
||||
return -(rhs - self);
|
||||
}
|
||||
|
||||
Self {
|
||||
value: self.value() - rhs.value(),
|
||||
modulus: self.modulus,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ModularArithmeticBase> std::ops::Add for ModularArithmetic<T> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
self - (-rhs)
|
||||
}
|
||||
}
|
||||
12
util/src/int/u64uint/constants.rs
Normal file
12
util/src/int/u64uint/constants.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Constants for u64 <-> usize conversion
|
||||
|
||||
/// Largest numeric value that can be safely represented both as
|
||||
/// a u64 and usize
|
||||
pub const MAX_USIZE_IN_U64: usize = match u64::BITS >= usize::BITS {
|
||||
true => usize::MAX,
|
||||
false => u64::MAX as usize,
|
||||
};
|
||||
|
||||
/// Largest numeric value that can be safely represented both as
|
||||
/// a u64 and usize
|
||||
pub const MAX_U64_IN_USIZE: u64 = MAX_USIZE_IN_U64 as u64;
|
||||
21
util/src/int/u64uint/mod.rs
Normal file
21
util/src/int/u64uint/mod.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
//! Making sure size representations fit both into [usize] and [u64]
|
||||
|
||||
use static_assertions::const_assert;
|
||||
|
||||
mod constants;
|
||||
pub use constants::*;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
mod u64uint;
|
||||
pub use u64uint::*;
|
||||
|
||||
mod range;
|
||||
pub use range::*;
|
||||
|
||||
/// Safe conversion from usize to u64
|
||||
///
|
||||
/// TODO: Deprecate this
|
||||
pub const fn usize_to_u64(v: usize) -> u64 {
|
||||
const_assert!(u64::BITS >= usize::BITS);
|
||||
v as u64
|
||||
}
|
||||
29
util/src/int/u64uint/range.rs
Normal file
29
util/src/int/u64uint/range.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
//! Working with ranges of [super::U64Uint]
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
use super::U64USize;
|
||||
|
||||
/// Extensions for working with [std::ops::Range] of [U64USize]
|
||||
pub trait U64USizeRangeExt {
|
||||
/// Convert to a usize based range
|
||||
fn usize(self) -> Range<usize>;
|
||||
/// Convert to a u64 based range
|
||||
fn u64(self) -> Range<u64>;
|
||||
}
|
||||
|
||||
impl U64USizeRangeExt for Range<U64USize> {
|
||||
fn usize(self) -> Range<usize> {
|
||||
Range {
|
||||
start: self.start.usize(),
|
||||
end: self.end.usize(),
|
||||
}
|
||||
}
|
||||
|
||||
fn u64(self) -> Range<u64> {
|
||||
Range {
|
||||
start: self.start.u64(),
|
||||
end: self.end.u64(),
|
||||
}
|
||||
}
|
||||
}
|
||||
192
util/src/int/u64uint/u64uint.rs
Normal file
192
util/src/int/u64uint/u64uint.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! The [U64UInt type]
|
||||
|
||||
use super::MAX_U64_IN_USIZE;
|
||||
|
||||
/// Error produced by [U64USize::try_new]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum U64USizeConversionError<T: std::fmt::Debug> {
|
||||
/// Value can not be represented as u64
|
||||
#[error("Value can not be represented as a u64 (max = {}) value: {:?}", u64::MAX, .0)]
|
||||
NoU64Repr(T),
|
||||
/// Value can not be represented as usize
|
||||
#[error("Value can not be represented as a usize (max = {}) value: {:?}", usize::MAX, .0)]
|
||||
NoUSizeRepr(T),
|
||||
/// Value can not be represented as usize or u64
|
||||
#[error("Value can not be represented as a u64 (max = {}) or usize (max = {}) value: {:?}", u64::MAX, usize::MAX, .0)]
|
||||
NoU64OrUSizeRepr(T),
|
||||
}
|
||||
|
||||
/// A number that can be represented as both a usize and a u64.
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct U64USize {
|
||||
/// Enclosed data
|
||||
storage: u64,
|
||||
}
|
||||
|
||||
impl U64USize {
|
||||
/// Cast another number to a number that can be represented as usize and u64;
|
||||
///
|
||||
/// This is the internal version which does not use [TryInto]; we use this to implement
|
||||
/// [TryInto].
|
||||
fn try_new_internal<T>(v: T) -> Result<Self, U64USizeConversionError<T>>
|
||||
where
|
||||
T: Copy + TryInto<usize> + TryInto<u64> + std::fmt::Debug,
|
||||
{
|
||||
use U64USizeConversionError as E;
|
||||
|
||||
let v_u64: Result<u64, _> = v.try_into();
|
||||
let v_usize: Result<usize, _> = v.try_into();
|
||||
match (v_u64, v_usize) {
|
||||
(Ok(storage), Ok(_)) => Ok(Self { storage }),
|
||||
(Err(_), Ok(_)) => Err(E::NoU64Repr(v)),
|
||||
(Ok(_), Err(_)) => Err(E::NoUSizeRepr(v)),
|
||||
(Err(_), Err(_)) => Err(E::NoU64OrUSizeRepr(v)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast another number to a number that can be represented as usize and u64
|
||||
pub fn try_new<T>(v: T) -> Result<U64USize, <T as TryInto<Self>>::Error>
|
||||
where
|
||||
T: TryInto<Self>,
|
||||
{
|
||||
v.try_into()
|
||||
}
|
||||
|
||||
/// Like [Self::try_new], but panics
|
||||
pub fn new_or_panic<T>(v: T) -> Self
|
||||
where
|
||||
T: TryInto<Self>,
|
||||
<T as TryInto<Self>>::Error: std::fmt::Debug,
|
||||
{
|
||||
match Self::try_new(v) {
|
||||
Ok(v) => v,
|
||||
Err(e) => panic!(
|
||||
"Could not construct {}: {e:?}",
|
||||
std::any::type_name::<Self>()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return this value as a usize
|
||||
pub fn usize(&self) -> usize {
|
||||
self.storage as usize
|
||||
}
|
||||
|
||||
/// Return this value as a u64
|
||||
pub fn u64(&self) -> u64 {
|
||||
self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Sub for U64USize {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Self::new_or_panic(self.u64() - rhs.u64())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add for U64USize {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Self::new_or_panic(self.u64() + rhs.u64())
|
||||
}
|
||||
}
|
||||
|
||||
/// Facilitates creation of [U64USize] in cases where truncation to the largest representable value is permissible
|
||||
pub trait TruncateIntoU64USize {
|
||||
/// Check whether calling [TruncateIntoU64Usize::truncate_to_u64usize] would truncate or return
|
||||
/// the value as-is
|
||||
fn fits_into_u64usize(&self) -> bool;
|
||||
|
||||
/// Turn [Self] into a [U64USize]. If the value is representable as a usize and a u64, then
|
||||
/// the value will be returned as is. Otherwise, the maximum representable value [MAX_U64_IN_USIZE]
|
||||
/// will be returned.
|
||||
fn truncate_to_u64usize(&self) -> U64USize;
|
||||
}
|
||||
|
||||
/// Create instances of From for SafeUSize
|
||||
macro_rules! derive_TruncateIntoU64Usize {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl TruncateIntoU64USize for $T {
|
||||
fn fits_into_u64usize(&self) -> bool {
|
||||
U64USize::try_new(*self).is_ok()
|
||||
}
|
||||
|
||||
fn truncate_to_u64usize(&self) -> U64USize {
|
||||
U64USize::try_new(*self).unwrap_or(U64USize::new_or_panic(MAX_U64_IN_USIZE))
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
derive_TruncateIntoU64Usize!(
|
||||
U64USize, usize, isize, bool, u8, u16, u32, u64, u128, i8, i16, i32, i64, i128
|
||||
);
|
||||
|
||||
/// Create instances of From for SafeUSize
|
||||
macro_rules! U64USize_derive_from {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl From<$T> for U64USize {
|
||||
fn from(value: $T) -> Self {
|
||||
U64USize::try_new_internal::<$T>(value).unwrap()
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_from!(bool, u8, u16);
|
||||
|
||||
/// Create instances of TryFrom for SafeUSize
|
||||
macro_rules! U64USize_derive_try_from {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl TryFrom<$T> for U64USize {
|
||||
type Error = U64USizeConversionError<$T>;
|
||||
|
||||
fn try_from(value: $T) -> Result<Self, Self::Error> {
|
||||
U64USize::try_new_internal::<$T>(value)
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_try_from!(usize, isize, u32, u64, u128, i8, i16, i32, i64, i128);
|
||||
|
||||
/// Create instances of Into for SafeUSize
|
||||
macro_rules! U64USize_derive_into {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl From<U64USize> for $T {
|
||||
fn from(val: U64USize) -> Self {
|
||||
val.u64() as $T
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_into!(usize, u64, u128, i128);
|
||||
|
||||
/// Create instances of TryInto for SafeUSize
|
||||
macro_rules! U64USize_derive_try_into {
|
||||
($($T:ty),*) => {
|
||||
$(
|
||||
impl TryFrom<U64USize> for $T {
|
||||
type Error = <$T as TryFrom<u64>>::Error;
|
||||
|
||||
fn try_from(val: U64USize) -> Result<Self, Self::Error> {
|
||||
val.u64().try_into()
|
||||
}
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
U64USize_derive_try_into!(isize, u8, u16, u32, i8, i16, i32, i64);
|
||||
3
util/src/ipc/mod.rs
Normal file
3
util/src/ipc/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Inter-process communication related resources
|
||||
|
||||
pub mod shm;
|
||||
6
util/src/ipc/shm/mod.rs
Normal file
6
util/src/ipc/shm/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
//! Resources for working with shared-memory
|
||||
|
||||
mod shared_memory_segment;
|
||||
pub use shared_memory_segment::*;
|
||||
|
||||
pub mod ringbuf;
|
||||
101
util/src/ipc/shm/ringbuf/local.rs
Normal file
101
util/src/ipc/shm/ringbuf/local.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
//! Local shared memory ring buffers – mostly for testing
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::ipc::shm::SharedMemorySegment;
|
||||
use crate::ringbuf::concurrent::framework::{ConcurrentPipeReader, ConcurrentPipeWriter};
|
||||
|
||||
use super::{ShmPipeCore, ShmPipeVariables};
|
||||
|
||||
/// A process-local shared memory pipe reader
|
||||
///
|
||||
/// See [shm_pipe()].
|
||||
pub type LocalShmPipeReader = ConcurrentPipeReader<ShmPipeCore<Arc<ShmPipeVariables>>>;
|
||||
|
||||
/// A process-local shared memory pipe writer
|
||||
///
|
||||
/// See [shm_pipe()].
|
||||
pub type LocalShmPipeWriter = ConcurrentPipeWriter<ShmPipeCore<Arc<ShmPipeVariables>>>;
|
||||
|
||||
/// Creates a process-local shared-memory pipe.
|
||||
///
|
||||
/// See [ConcurrentPipeWriter]/[ConcurrentPipeReader]. The types [LocalShmPipeWriter]/[LocalShmPipeReader] are just aliases for these.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Mind the comments in the safety section of [super::super::SharedMemorySegment]; the issues
|
||||
/// described in there *exactly* affect this implementation.
|
||||
pub fn shm_pipe(len: usize) -> anyhow::Result<(LocalShmPipeWriter, LocalShmPipeReader)> {
|
||||
let shared = Arc::new(ShmPipeVariables::new());
|
||||
let (seg_fd, seg_buf_1) = SharedMemorySegment::create(len)?;
|
||||
let seg_buf_2 = unsafe { SharedMemorySegment::from_fd(seg_fd, len) }?;
|
||||
|
||||
let writer = ShmPipeCore::new(shared.clone(), seg_buf_1);
|
||||
let reader = ShmPipeCore::new(shared, seg_buf_2);
|
||||
|
||||
Ok((
|
||||
ConcurrentPipeWriter::from_core(writer),
|
||||
ConcurrentPipeReader::from_core(reader),
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shm_pipe() -> anyhow::Result<()> {
|
||||
let (mut writer, mut reader) = shm_pipe(1024)?;
|
||||
|
||||
const MSG: &[u8] = b"Hello World\0";
|
||||
const MSG_COUNT: usize = 100000;
|
||||
|
||||
let t = std::thread::spawn(move || {
|
||||
for _ in 0..MSG_COUNT {
|
||||
let mut buf = MSG;
|
||||
while !buf.is_empty() {
|
||||
let n = writer.write(buf).unwrap();
|
||||
buf = &buf[n..];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut buf = [0u8; 1000];
|
||||
let mut buf_off = 0;
|
||||
let mut msg_no = 0usize;
|
||||
|
||||
'read_data: while msg_no < MSG_COUNT {
|
||||
let mut old_off = buf_off;
|
||||
|
||||
// Read the data from the shared memory buffer
|
||||
buf_off += reader.read(&mut buf[buf_off..])?;
|
||||
|
||||
loop {
|
||||
// Scan the available data for the zero terminator
|
||||
let msg_len = &buf[old_off..buf_off]
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.find(|(_off, c)| *c == 0x0)
|
||||
.map(|(off, _c)| off + old_off + 1);
|
||||
|
||||
// Next iteration, unless the terminator was found
|
||||
let msg_len = match *msg_len {
|
||||
Some(l) => l,
|
||||
None => continue 'read_data,
|
||||
};
|
||||
|
||||
// Register the newly read message
|
||||
msg_no += 1;
|
||||
|
||||
// Check that the message is correctly transferred
|
||||
let msg = &buf[0..msg_len];
|
||||
assert_eq!(msg, MSG);
|
||||
|
||||
// Move any extra data to the beginning of the buffer and adjust the offsets accordingly
|
||||
buf.copy_within(msg_len..buf_off, 0);
|
||||
old_off = 0;
|
||||
buf_off -= msg_len;
|
||||
}
|
||||
}
|
||||
|
||||
t.join().unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
81
util/src/ipc/shm/ringbuf/main.rs
Normal file
81
util/src/ipc/shm/ringbuf/main.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
//! Shared-memory ring buffer implementations (main part of the enclosing module)
|
||||
|
||||
use std::{borrow::Borrow, sync::atomic::AtomicU64};
|
||||
|
||||
use zerocopy::{FromBytes, IntoBytes};
|
||||
|
||||
use crate::{
|
||||
ipc::shm::SharedMemorySegment,
|
||||
ringbuf::concurrent::framework::{
|
||||
ConcurrentPipeCore, ConcurrentPipeReader, ConcurrentPipeWriter,
|
||||
},
|
||||
};
|
||||
|
||||
/// Synchronization variables for a [ShmPipeWriter]/[ShmPipeReader].
|
||||
///
|
||||
/// These values must be shared between the reader/writer in such a way that access to the inner
|
||||
/// variables is synchronized and atomic between the two parties.
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Default, IntoBytes, FromBytes)]
|
||||
pub struct ShmPipeVariables {
|
||||
/// See [crate::ringbuf::sched::RingBufferScheduler::items_read()]
|
||||
pub items_read: AtomicU64,
|
||||
/// See [crate::ringbuf::sched::RingBufferScheduler::items_written()]
|
||||
pub items_written: AtomicU64,
|
||||
}
|
||||
|
||||
impl ShmPipeVariables {
|
||||
/// Constructor
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items_read: 0.into(),
|
||||
items_written: 0.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The [ConcurrentPipeCore] for a shared memory pipe.
|
||||
#[derive(Debug)]
|
||||
pub struct ShmPipeCore<Variables: Borrow<ShmPipeVariables>> {
|
||||
/// The synchronization variables
|
||||
variables: Variables,
|
||||
/// The memory buffer
|
||||
buf: SharedMemorySegment,
|
||||
}
|
||||
|
||||
impl<Variables: Borrow<ShmPipeVariables>> ShmPipeCore<Variables> {
|
||||
/// Constructor
|
||||
pub fn new(variables: Variables, buf: SharedMemorySegment) -> Self {
|
||||
Self { variables, buf }
|
||||
}
|
||||
}
|
||||
|
||||
/// A shared memory pipe reader
|
||||
pub type ShmPipeReader<Variables> = ConcurrentPipeReader<ShmPipeCore<Variables>>;
|
||||
|
||||
/// A shared memory pipe reader
|
||||
pub type ShmPipeWriter<Variables> = ConcurrentPipeWriter<ShmPipeCore<Variables>>;
|
||||
|
||||
impl<Variables: Borrow<ShmPipeVariables>> ConcurrentPipeCore for ShmPipeCore<Variables> {
|
||||
type AtomicType = AtomicU64;
|
||||
|
||||
fn buf_len(&self) -> u64 {
|
||||
self.buf.len() as u64
|
||||
}
|
||||
|
||||
fn items_read(&self) -> &AtomicU64 {
|
||||
&self.variables.borrow().items_read
|
||||
}
|
||||
|
||||
fn items_written(&self) -> &AtomicU64 {
|
||||
&self.variables.borrow().items_written
|
||||
}
|
||||
|
||||
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64) {
|
||||
self.buf.volatile_read(dst, off as usize)
|
||||
}
|
||||
|
||||
fn write_to_buffer(&mut self, off: u64, src: &[u8]) {
|
||||
self.buf.volatile_write(off as usize, src);
|
||||
}
|
||||
}
|
||||
7
util/src/ipc/shm/ringbuf/mod.rs
Normal file
7
util/src/ipc/shm/ringbuf/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Shared-memory ring buffers
|
||||
|
||||
mod main;
|
||||
pub use main::*;
|
||||
|
||||
mod local;
|
||||
pub use local::*;
|
||||
399
util/src/ipc/shm/shared_memory_segment.rs
Normal file
399
util/src/ipc/shm/shared_memory_segment.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
//! Accessing data in a shared memory segment
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
os::fd::{AsFd, OwnedFd},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
int::u64uint::usize_to_u64,
|
||||
ptr::{ReadMemVolatile, WriteMemVolatile},
|
||||
result::OkExt,
|
||||
secret_memory::{
|
||||
fd::SecretMemfdConfig,
|
||||
mmap::{MMapError, MapFdConfig, MappedSegment},
|
||||
},
|
||||
};
|
||||
|
||||
/// Safe creation of shared memory segments
|
||||
///
|
||||
/// This is a slightly more convenient API than using [MapFdConfig] directly.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SharedMemorySegmentBuilder {
|
||||
/// Configuration for allocating memory file descriptors and their secrecy level
|
||||
pub secret_memfd_cfg: SecretMemfdConfig,
|
||||
/// Configuration for mapping file descriptors into memory
|
||||
pub map_fd_cfg: MapFdConfig,
|
||||
}
|
||||
|
||||
impl SharedMemorySegmentBuilder {
|
||||
/// Create a new secret memory segment builder.
|
||||
///
|
||||
/// Note that this always sets [MapFdConfig::set_shared()], but you can overwrite this behavior
|
||||
/// by un-setting the flag again.
|
||||
///
|
||||
/// # Safety & panic
|
||||
///
|
||||
/// This function can panic when [u64] is unable to represent the given size
|
||||
/// value. This might be the case in future computers impressing me with their excessive size.
|
||||
pub const fn new(len: usize) -> Self {
|
||||
let secret_memfd_cfg = SecretMemfdConfig::new();
|
||||
let map_fd_cfg = MapFdConfig::new()
|
||||
.set_shared()
|
||||
.resize_on_mmap(usize_to_u64(len));
|
||||
Self {
|
||||
secret_memfd_cfg,
|
||||
map_fd_cfg,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a secret memory segment using the configuration stored here
|
||||
pub fn create_segment(&self) -> anyhow::Result<(OwnedFd, SharedMemorySegment)> {
|
||||
let fd = self.secret_memfd_cfg.create()?;
|
||||
let seg = self.map_fd_cfg.mappable_fd(&fd).mmap()?;
|
||||
let seg = unsafe { SharedMemorySegment::from_mapped_segment(seg) };
|
||||
|
||||
Ok((fd, seg))
|
||||
}
|
||||
}
|
||||
|
||||
/// Safe creation of and access to shared memory segments
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Any means to create a [Self] must guarantee, that further calls to [Self::volatile_write] and
|
||||
/// [Self::volatile_read] are also safe.
|
||||
///
|
||||
/// The API of this struct is specifically designed so creating multiple memory mappings of the same
|
||||
/// shared memory segment is impossible without unsafe code.
|
||||
///
|
||||
/// We recognize that the sole purpose of shared memory segments is that multiple mappings of them
|
||||
/// are created, there just is no way to do so using safe rust.
|
||||
///
|
||||
/// The reason for this is – that by the Rust documentation – any concurrent memory access leads to
|
||||
/// undefined behavior, even volatile accesses caused solely by an adversarial application on the
|
||||
/// other end of a shared memory communication channel.
|
||||
///
|
||||
/// For this reason, we force users to use unsafe code to create shared memory mappings from an
|
||||
/// existing file descriptor, as this could potentially lead to adversarial data races (and thus to
|
||||
/// undefined behavior).
|
||||
///
|
||||
/// In practice, we believe using concurrent memory access using volatile operations is going to
|
||||
/// lead to nothing worse than garbled data being transferred. The user should treat any data
|
||||
/// received through a shared-memory ring buffer as untrusted and validate this data any way, so
|
||||
/// garbled data should be caught.
|
||||
///
|
||||
/// This means that [Self::from_fd()] is unsafe in theory, but most likely safe in practice.
|
||||
///
|
||||
/// ## Concurrent, untrusted shared memory is technically undefined behavior
|
||||
///
|
||||
/// Its even worse than having to use unsafe: technically speaking, it may be impossible to
|
||||
/// use shared memory soundly in rust unless all parties with access to the segment are
|
||||
/// *trusted*. If these parties are not trusted (or buggy) they can always cause undefined
|
||||
/// behavior:
|
||||
///
|
||||
/// From the [std::sync::atomic] documentation:
|
||||
///
|
||||
/// > **The most important aspect of this model is that data races are undefined behavior.** A data race
|
||||
/// > is defined as conflicting non-synchronized accesses where at least one of the accesses is non-atomic.
|
||||
/// > Here, accesses are conflicting if they affect overlapping regions of memory and at least one of them
|
||||
/// > is a write. (A compare_exchange or compare_exchange_weak that does not succeed is not considered a
|
||||
/// > write.) They are non-synchronized if neither of them happens-before the other, according to the
|
||||
/// > happens-before order of the memory model.
|
||||
///
|
||||
/// The fact that this API uses volatile [reads](Self::volatile_read) and [writes](Self::volatile_write),
|
||||
/// and the the fact that we use mmap(2) for allocation does not mitigate this issue; from the
|
||||
/// documentation of [std::ptr::write_volatile()]:
|
||||
///
|
||||
/// > When a volatile operation is used for memory inside an allocation, it behaves exactly like write,
|
||||
/// > except for the additional guarantee that it won’t be elided or reordered (see above). This implies
|
||||
/// > that the operation will actually access memory and not e.g. be lowered to a register access. Other
|
||||
/// > than that, all the usual rules for memory accesses apply (including provenance). In particular, just
|
||||
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
|
||||
/// > access from multiple threads. Volatile accesses behave exactly like non-atomic accesses in that regard.
|
||||
///
|
||||
/// An allocation is defined as follows (taken from [std::ptr]):
|
||||
///
|
||||
/// > An allocation is a subset of program memory which is addressable from Rust, and within which pointer
|
||||
/// > arithmetic is possible. Examples of allocations include heap allocations, stack-allocated variables,
|
||||
/// > statics, and consts. The safety preconditions of some Rust operations - such as offset and field
|
||||
/// > projections (expr.field) - are defined in terms of the allocations on which they operate.
|
||||
///
|
||||
/// This definition clearly applies to mmap(2) allocated regions.
|
||||
///
|
||||
/// What might mitigate this issue is mapping the region just once per process:
|
||||
///
|
||||
/// > In particular, just
|
||||
/// > like in C, whether an operation is volatile has no bearing whatsoever on questions involving concurrent
|
||||
/// > access from **multiple threads**.
|
||||
///
|
||||
/// We could argue that a process is not a thread, and thus concurrent access from two processes is
|
||||
/// fine, but concurrent access from two threads is not (unless guarded by an atomic value or a
|
||||
/// mutex or some primitive actually designed for synchronization).
|
||||
///
|
||||
/// There is no wording in the spec explicitly allowing raceful, concurrent access from multiple processes.
|
||||
///
|
||||
/// The problem with basing our safety-argument on the claim that "processes are not threads" is
|
||||
/// that the line between processes and threads is drawn in the sand. For linux, read the man page
|
||||
/// of clone(2):
|
||||
///
|
||||
/// > By contrast with fork(2), these [clone, __clone2, clone3] system calls provide more precise control over what pieces of execution
|
||||
/// > context are shared between the calling process and the child process. For example, using these system
|
||||
/// > calls, the caller can control whether or not the two processes share the virtual address space, the ta‐
|
||||
/// > ble of file descriptors, and the table of signal handlers. These system calls also allow the new child
|
||||
/// > process to be placed in separate namespaces(7).
|
||||
/// >
|
||||
/// > […]
|
||||
/// >
|
||||
/// > ## CLONE_THREAD (since Linux 2.4.0)
|
||||
/// >
|
||||
/// > If CLONE_THREAD is set, the child is placed in the same thread group as the calling process. To
|
||||
/// > make the remainder of the discussion of CLONE_THREAD more readable, the term "thread" is used to
|
||||
/// > refer to the processes within a thread group.
|
||||
///
|
||||
/// According to the man page, "the term 'thread' is used to refer to the processes within a thread group.".
|
||||
///
|
||||
/// The Rust (transitively, from the C++11 Atomic) specification tells us that there must be no
|
||||
/// concurrent memory access between threads whether this access is volatile or not. The linux man
|
||||
/// pages tell us that "thread" is just a special type of "process".
|
||||
///
|
||||
/// **The most robust interpretation of these specifications is that shared memory must not be used for
|
||||
/// communication with an untrusted party across thread or process boundaries, or else the other
|
||||
/// process/thread can cause undefined behavior in our process.**
|
||||
///
|
||||
/// ## In practice
|
||||
///
|
||||
/// Realistically, using volatile reads/writes on valid, mapped memory might cause garbled values in
|
||||
/// case of a data race, but it should crash the program or do anything worse than create garbled
|
||||
/// values.
|
||||
///
|
||||
/// Mind that we do not mind garbled values here; we are implementing a shared memory communication
|
||||
/// interface, so our application must always assume, that the data it receives may be garbled. It
|
||||
/// has to be validated. We just don't want the other application to be able to do anything worse
|
||||
/// that garble the data it is sending (or receiving), so lets estimate what can *realistically*
|
||||
/// happen here if the other application maliciously causes a race.
|
||||
///
|
||||
/// The worst any of the assembly sequences below should do is cause tearing in case of a data
|
||||
/// race.
|
||||
///
|
||||
/// This leads me to the conclusion that what what we are dealing here with is not an
|
||||
/// implementation that is faulty/insecure, instead it is a definition-gap in the compiler
|
||||
/// semantics for volatile memory access for use in security-critical applications.
|
||||
///
|
||||
/// Godbolt link: <https://rust.godbolt.org/z/GGjsGsc33>
|
||||
/// Compiler: `rustc 1.90.0`
|
||||
///
|
||||
/// Rust code:
|
||||
///
|
||||
/// ```rust
|
||||
/// #[unsafe(no_mangle)]
|
||||
/// pub fn read_volatile(num: &[u128]) -> u128 {
|
||||
/// let ptr = num.as_ptr();
|
||||
/// unsafe { ptr.read_volatile() }
|
||||
/// }
|
||||
///
|
||||
/// #[unsafe(no_mangle)]
|
||||
/// pub fn write_volatile(num: &mut [u128]) {
|
||||
/// let ptr = num.as_mut_ptr();
|
||||
/// unsafe { ptr.write_volatile(42u128) };
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// x86_64 (`--target=x86_64-unknown-linux-gnu -O`):
|
||||
///
|
||||
/// ```asm
|
||||
/// read_volatile:
|
||||
/// mov rax, qword ptr [rdi]
|
||||
/// mov rdx, qword ptr [rdi + 8]
|
||||
/// ret
|
||||
///
|
||||
/// write_volatile:
|
||||
/// mov qword ptr [rdi + 8], 0
|
||||
/// mov qword ptr [rdi], 42
|
||||
/// ret
|
||||
/// ```
|
||||
///
|
||||
/// arm64 (`--target=aarch64-unknown-linux-gnu -O`):
|
||||
///
|
||||
/// ```asm
|
||||
/// read_volatile:
|
||||
/// ldp x0, x1, [x0]
|
||||
/// ret
|
||||
///
|
||||
/// write_volatile:
|
||||
/// mov w8, #42
|
||||
/// stp x8, xzr, [x0]
|
||||
/// ret
|
||||
/// ```
|
||||
///
|
||||
/// armv7 (`--target=armv7-unknown-linux-gnueabihf -O`)
|
||||
///
|
||||
/// ```asm
|
||||
/// read_volatile:
|
||||
/// push {r4, r5, r11, lr}
|
||||
/// ldrd r2, r3, [r1]
|
||||
/// ldrd r4, r5, [r1, #8]
|
||||
/// stm r0, {r2, r3, r4, r5}
|
||||
/// pop {r4, r5, r11, pc}
|
||||
///
|
||||
/// write_volatile:
|
||||
/// push {r4, r5, r11, lr}
|
||||
/// mov r2, #0
|
||||
/// mov r4, #42
|
||||
/// mov r3, r2
|
||||
/// mov r5, r2
|
||||
/// strd r2, r3, [r0, #8]
|
||||
/// strd r4, r5, [r0]
|
||||
/// pop {r4, r5, r11, pc}
|
||||
/// ```
|
||||
///
|
||||
/// risc64 (`--target=riscv64gc-unknown-linux-gnu`):
|
||||
///
|
||||
/// ```asm
|
||||
/// read_volatile:
|
||||
/// ld a1, 8(a0)
|
||||
/// ld a0, 0(a0)
|
||||
/// ret
|
||||
///
|
||||
/// write_volatile:
|
||||
/// sd zero, 8(a0)
|
||||
/// li a1, 42
|
||||
/// sd a1, 0(a0)
|
||||
/// ret
|
||||
/// ```
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub struct SharedMemorySegment {
|
||||
/// The underlying mapped segment
|
||||
inner: MappedSegment,
|
||||
}
|
||||
|
||||
impl SharedMemorySegment {
|
||||
/// Create a new shared memory segment.
|
||||
pub fn create(len: usize) -> anyhow::Result<(OwnedFd, Self)> {
|
||||
SharedMemorySegmentBuilder::new(len).create_segment()
|
||||
}
|
||||
|
||||
/// Create a shared memory segment from a file descriptor
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See the comments in [Self].
|
||||
pub unsafe fn from_fd<Fd: AsFd>(fd: Fd, size: usize) -> Result<Self, MMapError> {
|
||||
let cfg = MapFdConfig::new()
|
||||
.set_shared()
|
||||
.expected_size(usize_to_u64(size));
|
||||
unsafe { Self::from_fd_with_config(fd, cfg) }
|
||||
}
|
||||
|
||||
/// Create a shared memory segment from a file descriptor
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See the comments in [Self].
|
||||
pub unsafe fn from_fd_with_config<Fd: AsFd>(
|
||||
fd: Fd,
|
||||
cfg: MapFdConfig,
|
||||
) -> Result<Self, MMapError> {
|
||||
let segment = cfg.mappable_fd(&fd).mmap()?;
|
||||
unsafe { Self::from_mapped_segment(segment).ok() }
|
||||
}
|
||||
|
||||
/// Create a shared memory segment from an existing mapped segment
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See the comments in [Self].
|
||||
pub unsafe fn from_mapped_segment(inner: MappedSegment) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
/// The underlying mapped segment
|
||||
pub fn mapped_segment(&self) -> &MappedSegment {
|
||||
self.inner.borrow()
|
||||
}
|
||||
|
||||
/// A pointer to the underlying mapped segment
|
||||
pub fn ptr(&self) -> *mut u8 {
|
||||
self.mapped_segment().ptr()
|
||||
}
|
||||
|
||||
/// The length of the underlying mapped segment
|
||||
pub fn len(&self) -> usize {
|
||||
self.mapped_segment().len()
|
||||
}
|
||||
|
||||
/// Whether `self.`[len()](Self::len)` == 0`
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Read data from the ring buffer
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See comments in [Self].
|
||||
pub fn volatile_read(&self, dst: &mut [u8], off: usize) {
|
||||
let end = off + dst.len();
|
||||
assert!(end <= self.len());
|
||||
|
||||
unsafe { self.ptr().add(off).read_mem_volatile(dst) }
|
||||
}
|
||||
|
||||
/// Read data from the ring buffer
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// See comments in [Self].
|
||||
pub fn volatile_write(&self, off: usize, src: &[u8]) {
|
||||
let end = off + src.len();
|
||||
assert!(end <= self.len());
|
||||
|
||||
unsafe { self.ptr().add(off).write_mem_volatile(src) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SharedMemorySegment> for MappedSegment {
|
||||
fn from(val: SharedMemorySegment) -> Self {
|
||||
val.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl Borrow<MappedSegment> for SharedMemorySegment {
|
||||
fn borrow(&self) -> &MappedSegment {
|
||||
self.mapped_segment()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shared_memory_segment() -> anyhow::Result<()> {
|
||||
let underscore = [b'_'; 38];
|
||||
let zero = [0u8; 38];
|
||||
let test_string = b"Hello World";
|
||||
|
||||
let mut after_write = zero.to_owned();
|
||||
crate::mem::cpy_min(test_string, &mut after_write);
|
||||
|
||||
let (fd, reg1) = SharedMemorySegment::create(1024)?;
|
||||
let reg2 = unsafe { SharedMemorySegment::from_fd(fd, 1024) }?;
|
||||
|
||||
let mut buf = underscore.to_owned();
|
||||
reg1.volatile_read(&mut buf, 0);
|
||||
assert_eq!(&buf, &zero);
|
||||
|
||||
let mut buf = underscore.to_owned();
|
||||
reg2.volatile_read(&mut buf, 0);
|
||||
assert_eq!(&buf, &zero);
|
||||
|
||||
reg1.volatile_write(0, test_string);
|
||||
|
||||
let mut buf = underscore.to_owned();
|
||||
reg1.volatile_read(&mut buf, 0);
|
||||
assert_eq!(&buf, &after_write);
|
||||
|
||||
let mut buf = underscore.to_owned();
|
||||
reg2.volatile_read(&mut buf, 0);
|
||||
assert_eq!(&buf, &after_write);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -10,15 +10,16 @@ pub mod b64;
|
||||
pub mod build;
|
||||
/// Control flow abstractions and utilities.
|
||||
pub mod controlflow;
|
||||
/// File descriptor utilities.
|
||||
pub mod fd;
|
||||
pub mod convert;
|
||||
/// File system operations and handling.
|
||||
pub mod file;
|
||||
pub mod fmt;
|
||||
/// Functional programming utilities.
|
||||
pub mod functional;
|
||||
pub mod int;
|
||||
/// Input/output operations.
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
/// Length prefix encoding schemes implementation.
|
||||
pub mod length_prefix_encoding;
|
||||
/// Memory manipulation and allocation utilities.
|
||||
@@ -27,8 +28,13 @@ pub mod mem;
|
||||
pub mod mio;
|
||||
/// Extended Option type functionality.
|
||||
pub mod option;
|
||||
pub mod ptr;
|
||||
/// Extended Result type functionality.
|
||||
pub mod result;
|
||||
pub mod ringbuf;
|
||||
pub mod rustix;
|
||||
pub mod secret_memory;
|
||||
pub mod sync;
|
||||
/// Time and duration utilities.
|
||||
pub mod time;
|
||||
#[cfg(feature = "tokio")]
|
||||
|
||||
104
util/src/mem.rs
104
util/src/mem.rs
@@ -280,6 +280,110 @@ impl<T: Sized> MoveExt for T {
|
||||
}
|
||||
}
|
||||
|
||||
/// Explicitly copy a value
|
||||
///
|
||||
/// This is sometimes useful as an alternative to the deref operator
|
||||
/// (`x.copy()` instead of `*x`) to achieve neater method chains.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::mem::CopyExt;
|
||||
/// use rosenpass_util::result::OkExt;
|
||||
///
|
||||
/// // Without CopyExt
|
||||
/// fn boilerplate_1<T: Copy>(v: &T) -> Result<T, ()> {
|
||||
/// (*v).ok()
|
||||
/// }
|
||||
///
|
||||
/// // With CopyExt
|
||||
/// fn boilerplate_2<T: Copy>(v: &T) -> Result<T, ()> {
|
||||
/// v.copy().ok()
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(boilerplate_1(&32), Ok(32));
|
||||
/// assert_eq!(boilerplate_1(&32), boilerplate_2(&32));
|
||||
/// ```
|
||||
pub trait CopyExt: Copy {
|
||||
/// Copy a value
|
||||
fn copy(&self) -> Self {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy> CopyExt for T {}
|
||||
|
||||
/// Helper trait for applying a mutating closure/function to a mutable reference
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use rosenpass_util::mem::MutateRefExt;
|
||||
///
|
||||
/// fn plus_two_mul_three(v: u64) -> u64 {
|
||||
/// (v + 2) * 3
|
||||
/// }
|
||||
///
|
||||
/// fn inconvenient_example(v: &mut u64) {
|
||||
/// *v = plus_two_mul_three(*v);
|
||||
/// }
|
||||
///
|
||||
/// fn convenient_example(v: &mut u64) {
|
||||
/// v.mutate(plus_two_mul_three);
|
||||
/// }
|
||||
///
|
||||
/// let mut x = 0;
|
||||
/// inconvenient_example(&mut x);
|
||||
/// assert_eq!(x, 6);
|
||||
/// convenient_example(&mut x);
|
||||
/// assert_eq!(x, 24);
|
||||
///
|
||||
/// x = 0;
|
||||
///
|
||||
/// x.mutate(plus_two_mul_three);
|
||||
/// assert_eq!(x, 6);
|
||||
/// x.mutate(|v| (v + 2) * 3);
|
||||
/// assert_eq!(x, 24);
|
||||
///
|
||||
/// x = 0;
|
||||
///
|
||||
/// let y = &mut x;
|
||||
/// y.mutate(plus_two_mul_three);
|
||||
/// assert_eq!(x, 6);
|
||||
///
|
||||
/// struct Cont {
|
||||
/// x: u64,
|
||||
/// }
|
||||
///
|
||||
/// impl Cont {
|
||||
/// fn x_mut(&mut self) -> &mut u64 {
|
||||
/// &mut self.x
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let mut s = Cont { x: 0 };
|
||||
/// s.x_mut().mutate(|v| (v + 2) * 3);
|
||||
/// ```
|
||||
pub trait MutateRefExt: DerefMut
|
||||
where
|
||||
Self::Target: Sized,
|
||||
{
|
||||
/// Directly mutate a reference
|
||||
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(self, f: F) -> Self;
|
||||
}
|
||||
|
||||
impl<T> MutateRefExt for T
|
||||
where
|
||||
T: DerefMut,
|
||||
<T as Deref>::Target: Copy,
|
||||
{
|
||||
fn mutate<F: FnOnce(Self::Target) -> Self::Target>(mut self, f: F) -> Self {
|
||||
let v = self.deref_mut();
|
||||
*v = f(*v);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_forgetting {
|
||||
use crate::mem::Forgetting;
|
||||
|
||||
@@ -2,8 +2,8 @@ use mio::net::{UnixListener, UnixStream};
|
||||
use std::os::fd::{OwnedFd, RawFd};
|
||||
|
||||
use crate::{
|
||||
fd::{claim_fd, claim_fd_inplace},
|
||||
result::OkExt,
|
||||
rustix::{claim_fd, claim_fd_inplace},
|
||||
};
|
||||
|
||||
/// Module containing I/O interest flags for Unix operations (see also: [mio::Interest])
|
||||
@@ -93,7 +93,7 @@ impl UnixStreamExt for UnixStream {
|
||||
fn from_fd(fd: OwnedFd) -> anyhow::Result<Self> {
|
||||
use std::os::unix::net::UnixStream as StdUnixStream;
|
||||
#[cfg(target_os = "linux")] // TODO: We should support this on other plattforms
|
||||
crate::fd::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
|
||||
crate::rustix::GetUnixSocketType::demand_unix_stream_socket(&fd)?;
|
||||
let sock = StdUnixStream::from(fd);
|
||||
sock.set_nonblocking(true)?;
|
||||
UnixStream::from_std(sock).ok()
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::{
|
||||
};
|
||||
use uds::UnixStreamExt as FdPassingExt;
|
||||
|
||||
use crate::fd::{claim_fd_inplace, IntoStdioErr};
|
||||
use crate::rustix::{claim_fd_inplace, IntoStdioErr};
|
||||
|
||||
/// A wrapper around a socket that combines reading from the socket with tracking
|
||||
/// received file descriptors. Limits the maximum number of file descriptors that
|
||||
|
||||
4
util/src/ptr/mod.rs
Normal file
4
util/src/ptr/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Utilities for working with pointers
|
||||
|
||||
mod volatile;
|
||||
pub use volatile::*;
|
||||
51
util/src/ptr/volatile.rs
Normal file
51
util/src/ptr/volatile.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
//! Utilities relating to volatile reads/writes on pointers
|
||||
|
||||
/// Read from a memory location using
|
||||
/// [pointer::read_volatile] in a loop and store the
|
||||
/// results in the given slice
|
||||
pub trait ReadMemVolatile<T> {
|
||||
/// Read from a memory location using
|
||||
/// [pointer::read_volatile] in a loop and store the
|
||||
/// results in an array
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Refer to [pointer::read_volatile]
|
||||
unsafe fn read_mem_volatile(self, dst: &mut [T]);
|
||||
}
|
||||
|
||||
impl<T> ReadMemVolatile<T> for *const T {
|
||||
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
|
||||
for (idx, dst) in dst.iter_mut().enumerate() {
|
||||
*dst = unsafe { self.add(idx).read_volatile() };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ReadMemVolatile<T> for *mut T {
|
||||
unsafe fn read_mem_volatile(self, dst: &mut [T]) {
|
||||
unsafe { self.cast_const().read_mem_volatile(dst) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Write to a memory location using
|
||||
/// [pointer::write_volatile] in a loop
|
||||
/// and store the resulting values in the given slice
|
||||
pub trait WriteMemVolatile<T>: ReadMemVolatile<T> {
|
||||
/// Write to a memory location using
|
||||
/// [pointer::write_volatile] in a loop
|
||||
/// and store the resulting values in the given slice
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Refer to [pointer::write_volatile]
|
||||
unsafe fn write_mem_volatile(self, src: &[T]);
|
||||
}
|
||||
|
||||
impl<T: Copy> WriteMemVolatile<T> for *mut T {
|
||||
unsafe fn write_mem_volatile(self, src: &[T]) {
|
||||
for (idx, src) in src.iter().enumerate() {
|
||||
unsafe { self.add(idx).write_volatile(*src) }
|
||||
}
|
||||
}
|
||||
}
|
||||
3
util/src/ringbuf/concurrend/mod.rs
Normal file
3
util/src/ringbuf/concurrend/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Concurrent ring buffers
|
||||
|
||||
pub mod framework;
|
||||
243
util/src/ringbuf/concurrent/framework.rs
Normal file
243
util/src/ringbuf/concurrent/framework.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
//! An implementation of a concurrent ring buffer with the
|
||||
//! core (IO/memory access) operations in a detached trait. Everything around the core is
|
||||
//! implemented but before being able to use this, callers must implement an appropriate IO core.
|
||||
|
||||
use crate::int::u64uint::U64USizeRangeExt;
|
||||
use crate::result::OkExt;
|
||||
use crate::ringbuf::sched::{
|
||||
Diff, OperationType, RingBufferFromCountersError, RingBufferScheduler, ScheduledOperations,
|
||||
};
|
||||
use crate::sync::atomic::abstract_atomic::AbstractAtomic;
|
||||
|
||||
/// Core trait used by [ConcurrentPipeReader] and [ConcurrentPipeWriter] to implement
|
||||
/// a concurrent pipe based on a ring buffer
|
||||
pub trait ConcurrentPipeCore {
|
||||
/// The type used for the atomics; i.e. the values returned by
|
||||
/// [Self::items_read] and [Self.:items_written].
|
||||
type AtomicType: AbstractAtomic<u64>;
|
||||
|
||||
/// Length of the underlying memory buffer
|
||||
fn buf_len(&self) -> u64;
|
||||
|
||||
/// Number of items read
|
||||
fn items_read(&self) -> &Self::AtomicType;
|
||||
/// Number of items written
|
||||
fn items_written(&self) -> &Self::AtomicType;
|
||||
/// Copy data from the underlying memory buffer into the destination
|
||||
fn read_from_buffer(&mut self, dst: &mut [u8], off: u64);
|
||||
/// Write data from src into the underlying memory buffer
|
||||
fn write_to_buffer(&mut self, off: u64, src: &[u8]);
|
||||
}
|
||||
|
||||
/// Raised by [ConcurrentPipeReader::read()]/[ConcurrentPipeWriter::write()] if the underlying
|
||||
/// ring buffer is in an inconsistent state.
|
||||
///
|
||||
/// When used in shared memory applications, this error may have been locally caused, but it may
|
||||
/// also be caused by **the other threads or processes** putting the ring buffer into an
|
||||
/// inconsistent state.
|
||||
///
|
||||
/// For this reason, you must treat this error as an unrecoverable breakdown of the communication channel. You
|
||||
/// must close and no longer use ring buffer after receiving this error, but you can handle it
|
||||
/// gracefully by reopening a new ring buffer in its place.
|
||||
///
|
||||
/// # API Stability
|
||||
///
|
||||
/// Treat this error type as opaque; do not rely on the enum values remaining the same
|
||||
#[derive(Debug, thiserror::Error, Clone)]
|
||||
pub enum InconsistentRingBufferStateError {
|
||||
/// Failed to update an inner counter; this probably indicates a data race (forbidden
|
||||
/// concurrent read/write)
|
||||
#[error("Concurrent access to the ring buffer {:?}", self)]
|
||||
ConcurrrentAccess {
|
||||
/// Whether this was a read or write
|
||||
operation_type: OperationType,
|
||||
/// Operations performed before updaing the ring buffer state
|
||||
scheduled_ops: ScheduledOperations,
|
||||
/// The scheduler that scheduled our operation
|
||||
scheduler_state: RingBufferScheduler,
|
||||
/// The counter value before compare and exchange
|
||||
expected_counter_value: u64,
|
||||
/// The counter value we actually found in the counter
|
||||
actual_counter_value: u64,
|
||||
/// The counter value we tried to set
|
||||
new_counter_value_tried_to_set: u64,
|
||||
},
|
||||
/// Could not construct ring buffer scheduler
|
||||
#[error("Inconsistent ring buffer state: {:?}", .0)]
|
||||
InconsistentCounterState(#[from] RingBufferFromCountersError),
|
||||
}
|
||||
|
||||
/// Indicator for [ConcurrentPipeImpl::read_or_write] about which operation
|
||||
/// is being executed
|
||||
#[derive(Debug)]
|
||||
enum ConcurrentPipeOperation<'a> {
|
||||
/// This call implements [ConcurrentPipeReader::read]
|
||||
Read(&'a mut [u8]),
|
||||
/// This call implements [ConcurrentPipeReader::write]
|
||||
Write(&'a [u8]),
|
||||
}
|
||||
|
||||
impl<'a> ConcurrentPipeOperation<'a> {
|
||||
/// Read-only access to the operation buffer
|
||||
pub fn inner_buf(&'a self) -> &'a [u8] {
|
||||
match self {
|
||||
ConcurrentPipeOperation::Read(items) => items,
|
||||
ConcurrentPipeOperation::Write(items) => items,
|
||||
}
|
||||
}
|
||||
|
||||
/// Length of [Self::inner_buf]
|
||||
pub fn len(&self) -> usize {
|
||||
self.inner_buf().len()
|
||||
}
|
||||
|
||||
/// Decides which type of operation, in terms of [OperationType] [Self] represents
|
||||
pub fn scheduler_op(&self) -> OperationType {
|
||||
match self {
|
||||
ConcurrentPipeOperation::Read(_) => OperationType::Read,
|
||||
ConcurrentPipeOperation::Write(_) => OperationType::Write,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The implementations of [ConcurrentPipeReader] and [ConcurrentPipeWriter]
|
||||
/// happen to be extremely similar. This struct forms the basis of both.
|
||||
struct ConcurrentPipeImpl<Core: ConcurrentPipeCore> {
|
||||
/// Core trait
|
||||
core: Core,
|
||||
}
|
||||
|
||||
impl<Core: ConcurrentPipeCore> ConcurrentPipeImpl<Core> {
|
||||
/// Like [ConcurrentPipeReader::from_core] and [ConcurrentPipeWriter::from_core]
|
||||
fn from_core(core: Core) -> Self {
|
||||
Self { core }
|
||||
}
|
||||
|
||||
/// The implementations of [ConcurrentPipeReader::read] and [ConcurrentPipeWriter::write]
|
||||
/// happen to be extremely similar. This function implements both.
|
||||
fn read_or_write(
|
||||
&mut self,
|
||||
mut op: ConcurrentPipeOperation,
|
||||
) -> Result<usize, InconsistentRingBufferStateError> {
|
||||
use std::sync::atomic::Ordering as O;
|
||||
|
||||
// Figure out which counter to store the result of the operation in and the orderings to
|
||||
// use for load/store operations
|
||||
let (ord_r, ord_w, ord_store_succ, ord_store_fail) = match op {
|
||||
ConcurrentPipeOperation::Read(_) => (O::Relaxed, O::Acquire, O::Relaxed, O::Relaxed),
|
||||
ConcurrentPipeOperation::Write(_) => (O::Relaxed, O::Relaxed, O::Release, O::Relaxed),
|
||||
};
|
||||
|
||||
// Construct a ring buffer scheduler from the current state
|
||||
let sched = RingBufferScheduler::try_from_counters(
|
||||
self.core.buf_len(),
|
||||
self.core.items_written().load(ord_w),
|
||||
self.core.items_read().load(ord_r),
|
||||
)?;
|
||||
|
||||
// Have the scheduler schedule the operations
|
||||
let ops = sched.schedule_contigous_operations(op.scheduler_op(), op.len());
|
||||
|
||||
// Actually perform the operations
|
||||
for (buf_slice, ring_op) in ops.with_outside_buffer_range() {
|
||||
match &mut op {
|
||||
ConcurrentPipeOperation::Read(dst) => self
|
||||
.core
|
||||
.read_from_buffer(&mut dst[buf_slice.usize()], ring_op.off),
|
||||
ConcurrentPipeOperation::Write(src) => self
|
||||
.core
|
||||
.write_to_buffer(ring_op.off, &src[buf_slice.usize()]),
|
||||
}
|
||||
}
|
||||
|
||||
// Take a differential between the scheduler and the scheduler after the operations where
|
||||
// applied. Make sure only one counter was updated
|
||||
let diff = sched
|
||||
.register_operation(&ops)
|
||||
.diff_old(&sched)
|
||||
.expect_op_only(op.scheduler_op())
|
||||
.unwrap();
|
||||
|
||||
let store_to_ctr = match op {
|
||||
ConcurrentPipeOperation::Read(_) => self.core.items_read(),
|
||||
ConcurrentPipeOperation::Write(_) => self.core.items_written(),
|
||||
};
|
||||
|
||||
// Update the counters, assuming there was any change at all
|
||||
if let Diff::Different(old, new) = diff {
|
||||
store_to_ctr
|
||||
.compare_exchange(old, new, ord_store_succ, ord_store_fail)
|
||||
.map_err(
|
||||
|actual| InconsistentRingBufferStateError::ConcurrrentAccess {
|
||||
operation_type: op.scheduler_op(),
|
||||
scheduled_ops: ops,
|
||||
scheduler_state: sched,
|
||||
expected_counter_value: old,
|
||||
actual_counter_value: actual,
|
||||
new_counter_value_tried_to_set: new,
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
ops.cumulative_operation_length().usize().ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides the necessary boilerplate around [ConcurrentPipeCore] to implement
|
||||
/// reading from the pipe
|
||||
pub struct ConcurrentPipeReader<Core: ConcurrentPipeCore> {
|
||||
/// The implementations of [ConcurrentPipeReader::read] and [ConcurrentPipeWriter::write]
|
||||
/// happen to be extremely similar. We use [ConcurrentPipeImpl] to implement both.
|
||||
inner: ConcurrentPipeImpl<Core>,
|
||||
}
|
||||
|
||||
impl<Core: ConcurrentPipeCore> ConcurrentPipeWriter<Core> {
|
||||
/// Create a [Self] from a [ConcurrentPipeCore]
|
||||
pub fn from_core(core: Core) -> Self {
|
||||
Self {
|
||||
inner: ConcurrentPipeImpl::from_core(core),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the length of the underlying ring buffer
|
||||
pub fn buf_len(&self) -> u64 {
|
||||
self.inner.core.buf_len()
|
||||
}
|
||||
|
||||
/// Write data into the concurrent pipe.
|
||||
///
|
||||
/// Returns the number of bytes actually written.
|
||||
pub fn write(&mut self, src: &[u8]) -> Result<usize, InconsistentRingBufferStateError> {
|
||||
self.inner
|
||||
.read_or_write(ConcurrentPipeOperation::Write(src))
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides the necessary boilerplate around [ConcurrentPipeCore] to implement
|
||||
/// writing to the pipe
|
||||
pub struct ConcurrentPipeWriter<Core: ConcurrentPipeCore> {
|
||||
/// The implementations of [ConcurrentPipeReader::read] and [ConcurrentPipeWriter::write]
|
||||
/// happen to be extremely similar. We use [ConcurrentPipeImpl] to implement both.
|
||||
inner: ConcurrentPipeImpl<Core>,
|
||||
}
|
||||
|
||||
impl<Core: ConcurrentPipeCore> ConcurrentPipeReader<Core> {
|
||||
/// Create a [Self] from a [ConcurrentPipeCore]
|
||||
pub fn from_core(core: Core) -> Self {
|
||||
Self {
|
||||
inner: ConcurrentPipeImpl::from_core(core),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the length of the underlying ring buffer
|
||||
pub fn buf_len(&self) -> u64 {
|
||||
self.inner.core.buf_len()
|
||||
}
|
||||
|
||||
/// Read data from the concurrent pipe.
|
||||
///
|
||||
/// Returns the number of bytes read.
|
||||
pub fn read(&mut self, dst: &mut [u8]) -> Result<usize, InconsistentRingBufferStateError> {
|
||||
self.inner.read_or_write(ConcurrentPipeOperation::Read(dst))
|
||||
}
|
||||
}
|
||||
3
util/src/ringbuf/concurrent/mod.rs
Normal file
3
util/src/ringbuf/concurrent/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Concurrent ring buffers
|
||||
|
||||
pub mod framework;
|
||||
4
util/src/ringbuf/mod.rs
Normal file
4
util/src/ringbuf/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Ring buffers and utilities for working with them
|
||||
|
||||
pub mod concurrent;
|
||||
pub mod sched;
|
||||
1410
util/src/ringbuf/sched.rs
Normal file
1410
util/src/ringbuf/sched.rs
Normal file
File diff suppressed because it is too large
Load Diff
117
util/src/rustix/error.rs
Normal file
117
util/src/rustix/error.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
//! Rustix extensions for error handling
|
||||
|
||||
/// Provides access to the last system error number
|
||||
///
|
||||
/// > The integer variable errno is set by system calls and some library functions in the event of an error to indicate what went wrong.
|
||||
///
|
||||
/// -- `man 3 errno`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function panics if ther
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
///
|
||||
/// use rustix::io::Errno as E;
|
||||
/// use rosenpass_util::rustix::{errno, try_errno, last_os_result};
|
||||
///
|
||||
/// let res = unsafe { libc::mkdir(c"/tmp/baz".as_ptr(), 0) };
|
||||
/// assert_eq!(res, -1);
|
||||
/// assert_eq!(errno(), E::EXIST);
|
||||
/// assert_eq!(try_errno(), Some(E::EXIST));
|
||||
/// assert_eq!(last_os_result(), Err(E::EXIST));
|
||||
///
|
||||
/// // Deliberately clear the system error
|
||||
/// unsafe { libc::__errno_location().write(0) };
|
||||
/// // assert_eq!(errno(), _); // PANICS
|
||||
/// assert_eq!(try_errno(), None);
|
||||
/// assert_eq!(last_os_result(), Ok(()));
|
||||
/// ```
|
||||
///
|
||||
/// Calling errno() when there is no error causes a panic:
|
||||
///
|
||||
/// ```rust,should_panic
|
||||
///
|
||||
/// use rustix::io::Errno as E;
|
||||
/// use rosenpass_util::rustix::errno;
|
||||
///
|
||||
/// // Deliberately clear the system error
|
||||
/// unsafe { libc::__errno_location().write(0) };
|
||||
/// errno(); // PANICS
|
||||
/// ```
|
||||
pub fn errno() -> rustix::io::Errno {
|
||||
match try_errno() {
|
||||
None => panic!("Tried to retrieve last system error, but there was no system error (the system error number, errno = 0)"),
|
||||
Some(errno) => errno,
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides access to the last system error number.
|
||||
///
|
||||
/// Variant of [errno()] that will return None if there was no system error.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [errno()].
|
||||
pub fn try_errno() -> Option<rustix::io::Errno> {
|
||||
let raw = unsafe { libc::__errno_location().read() };
|
||||
match raw {
|
||||
0 => None,
|
||||
_ => Some(rustix::io::Errno::from_raw_os_error(raw)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides access to the last system error number.
|
||||
///
|
||||
/// Variant of [errno()] that will return `Err(errno)` if there
|
||||
/// was a system error and `Ok(())` otherwise.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [errno()].
|
||||
pub fn last_os_result() -> Result<(), rustix::io::Errno> {
|
||||
match try_errno() {
|
||||
None => Ok(()),
|
||||
Some(errno) => Err(errno),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert low level errors into std::io::Error
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::ErrorKind as EK;
|
||||
/// use rustix::io::Errno;
|
||||
/// use rosenpass_util::rustix::IntoStdioErr;
|
||||
///
|
||||
/// let e = Errno::INTR.into_stdio_err();
|
||||
/// assert!(matches!(e.kind(), EK::Interrupted));
|
||||
///
|
||||
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
|
||||
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
|
||||
/// ```
|
||||
pub trait IntoStdioErr {
|
||||
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
|
||||
type Target;
|
||||
/// Convert low level errors to
|
||||
fn into_stdio_err(self) -> Self::Target;
|
||||
}
|
||||
|
||||
impl IntoStdioErr for rustix::io::Errno {
|
||||
type Target = std::io::Error;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
std::io::Error::from_raw_os_error(self.raw_os_error())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoStdioErr for rustix::io::Result<T> {
|
||||
type Target = std::io::Result<T>;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
self.map_err(IntoStdioErr::into_stdio_err)
|
||||
}
|
||||
}
|
||||
255
util/src/rustix/fd.rs
Normal file
255
util/src/rustix/fd.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
//! Basic utilities for working with file descriptors
|
||||
|
||||
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
|
||||
|
||||
use rustix::io::fcntl_dupfd_cloexec;
|
||||
|
||||
use super::IntoStdioErr;
|
||||
|
||||
use crate::mem::Forgetting;
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
|
||||
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
|
||||
/// in case the given file descriptor is still used somewhere
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::{IntoRawFd, AsRawFd};
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::rustix::{claim_fd, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let orig = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut claimed = FdIo(claim_fd(orig)?);
|
||||
///
|
||||
/// // A different file descriptor is used
|
||||
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
|
||||
///
|
||||
/// // Write some data
|
||||
/// claimed.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
|
||||
mask_fd(fd)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid.
|
||||
///
|
||||
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::IntoRawFd;
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::rustix::{claim_fd_inplace, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let fd = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
|
||||
///
|
||||
/// // Write some data
|
||||
/// fd.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||
let tmp = clone_fd_cloexec(&new)?;
|
||||
clone_fd_to_cloexec(tmp, &mut new)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Will close the given file descriptor and overwrite
|
||||
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use std::fs::File;
|
||||
/// # use std::io::Read;
|
||||
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||
/// # use std::os::fd::IntoRawFd;
|
||||
/// # use rustix::fd::AsFd;
|
||||
/// # use rosenpass_util::rustix::mask_fd;
|
||||
///
|
||||
/// // Open a temporary file
|
||||
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
|
||||
/// assert!(fd >= 0);
|
||||
///
|
||||
/// // Mask the file descriptor
|
||||
/// mask_fd(fd).unwrap();
|
||||
///
|
||||
/// // Verify the file descriptor now points to `/dev/null`
|
||||
/// // Reading from `/dev/null` always returns 0 bytes
|
||||
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
|
||||
/// let mut buffer = [0u8; 4];
|
||||
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
|
||||
/// assert_eq!(bytes_read, 0);
|
||||
/// ```
|
||||
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
|
||||
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
|
||||
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
|
||||
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag
|
||||
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
|
||||
/// Avoid stdin, stdout, and stderr
|
||||
const MINFD: RawFd = 3;
|
||||
fcntl_dupfd_cloexec(fd, MINFD)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup3, DupFlags};
|
||||
dup3(fd, new, DupFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup2, fcntl_setfd, FdFlags};
|
||||
dup2(&fd, new)?;
|
||||
fcntl_setfd(&new, FdFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
|
||||
/// writing.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The behavior of the file descriptor when being written to or from is undefined.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
|
||||
/// use rustix::fd::FromRawFd;
|
||||
/// use rosenpass_util::rustix::open_nullfd;
|
||||
///
|
||||
/// let nullfd = open_nullfd().unwrap();
|
||||
/// ```
|
||||
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
|
||||
use rustix::fs::{open, Mode, OFlags};
|
||||
// TODO: Add tests showing that this will throw errors on use
|
||||
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
|
||||
}
|
||||
|
||||
/// Read and write directly from a file descriptor
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [claim_fd].
|
||||
pub struct FdIo<Fd: AsFd>(pub Fd);
|
||||
|
||||
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
rustix::io::read(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
rustix::io::write(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_neg() {
|
||||
let _ = claim_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_max() {
|
||||
let _ = claim_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_neg() {
|
||||
let _ = claim_fd_inplace(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_max() {
|
||||
let _ = claim_fd_inplace(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_neg() {
|
||||
let _ = mask_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_max() {
|
||||
let _ = mask_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_open_nullfd() -> anyhow::Result<()> {
|
||||
let mut file = FdIo(open_nullfd()?);
|
||||
let mut buf = [0; 10];
|
||||
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
|
||||
assert!(file.write(&buf).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nullfd_read_write() {
|
||||
let nullfd = open_nullfd().unwrap();
|
||||
let mut buf = vec![0u8; 16];
|
||||
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
|
||||
assert!(rustix::io::write(&nullfd, b"test").is_err());
|
||||
}
|
||||
}
|
||||
88
util/src/rustix/memfd.rs
Normal file
88
util/src/rustix/memfd.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
//! Utilities for working with memory based file descriptors
|
||||
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use rustix::fs::MemfdFlags;
|
||||
use rustix::io::Errno;
|
||||
use rustix::path::Arg as Path;
|
||||
|
||||
use bitflags::bitflags;
|
||||
|
||||
use crate::convert::IntoTypeExt;
|
||||
|
||||
use super::SyscallResult;
|
||||
|
||||
/// Create an anonymous file
|
||||
/// using the memfd_create(2) syscall
|
||||
///
|
||||
/// Just forwards to [rustix::fs::memfd_create]
|
||||
pub fn memfd_create<P: Path>(name: P, flags: MemfdFlags) -> rustix::io::Result<OwnedFd> {
|
||||
rustix::fs::memfd_create(name, flags)
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
/// `FD_*` constants for use with [memfd_secret].
|
||||
#[repr(transparent)]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub struct MemfdSecretFlags: std::ffi::c_uint {
|
||||
/// FD_CLOEXEC
|
||||
const CLOEXEC = libc::FD_CLOEXEC as std::ffi::c_uint;
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors for [create_memfd_secret()]
|
||||
#[derive(Copy, Clone, Debug, thiserror::Error)]
|
||||
pub enum MemfdSecretError {
|
||||
/// memfd_secret(2) not supported on system
|
||||
#[error("Could not create secret memory segment using memfd_secret(2): not supported on your system")]
|
||||
NotSupported,
|
||||
/// Other error
|
||||
#[error("Could not create secret memory segment using memfd_secret(2): underlying system error: {:?}", .0)]
|
||||
SystemError(Errno),
|
||||
}
|
||||
|
||||
impl From<Errno> for MemfdSecretError {
|
||||
fn from(value: Errno) -> Self {
|
||||
match value {
|
||||
Errno::NOSYS => Self::NotSupported,
|
||||
e => Self::SystemError(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an anonymous RAM-based file to access secret memory regions
|
||||
/// using the memfd_secret(2) syscall
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rustix::io::Errno;
|
||||
/// use rustix::fs::ftruncate;
|
||||
///
|
||||
/// use rosenpass_util::rustix::{memfd_secret, MemfdSecretFlags, IntoStdioErr, MemfdSecretError};
|
||||
/// use rosenpass_util::io::handle_interrupted;
|
||||
///
|
||||
/// let res = memfd_secret(MemfdSecretFlags::empty());
|
||||
///
|
||||
/// use MemfdSecretError as E;
|
||||
/// let fd = match res {
|
||||
/// Ok(fd) => fd,
|
||||
/// // The system might not have memfd_secret enabled; abort the test
|
||||
/// Err(E::NotSupported) => return Ok(()),
|
||||
/// Err(E::SystemError(err)) => return Err(err)?,
|
||||
/// };
|
||||
///
|
||||
/// handle_interrupted(|| { ftruncate(&fd, 8192).into_stdio_err() })?;
|
||||
///
|
||||
/// Ok::<(), anyhow::Error>(())
|
||||
/// ```
|
||||
pub fn memfd_secret(flags: MemfdSecretFlags) -> Result<rustix::fd::OwnedFd, MemfdSecretError> {
|
||||
let res = unsafe {
|
||||
use libc::{syscall, SYS_memfd_secret};
|
||||
syscall(SYS_memfd_secret, flags)
|
||||
.into_type::<SyscallResult>()
|
||||
.claim_fd()
|
||||
};
|
||||
|
||||
res.map_err(MemfdSecretError::from)
|
||||
}
|
||||
16
util/src/rustix/mod.rs
Normal file
16
util/src/rustix/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Extensions to the rustix crate for memory safe operating system interfaces
|
||||
|
||||
mod error;
|
||||
pub use error::*;
|
||||
|
||||
mod fd;
|
||||
pub use fd::*;
|
||||
|
||||
mod stat;
|
||||
pub use stat::*;
|
||||
|
||||
mod syscall;
|
||||
pub use syscall::*;
|
||||
|
||||
mod memfd;
|
||||
pub use memfd::*;
|
||||
@@ -1,235 +1,12 @@
|
||||
//! Utilities for working with file descriptors
|
||||
//! Rustix extensions for getting information about file descriptors`
|
||||
|
||||
use std::os::fd::AsFd;
|
||||
|
||||
use anyhow::bail;
|
||||
use rustix::io::fcntl_dupfd_cloexec;
|
||||
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
|
||||
|
||||
use crate::{mem::Forgetting, result::OkExt};
|
||||
use super::IntoStdioErr;
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid and duplicates it to a new file descriptor.
|
||||
/// The old file descriptor is masked to avoid potential use after free (on file descriptor)
|
||||
/// in case the given file descriptor is still used somewhere
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::{IntoRawFd, AsRawFd};
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::fd::{claim_fd, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let orig = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut claimed = FdIo(claim_fd(orig)?);
|
||||
///
|
||||
/// // A different file descriptor is used
|
||||
/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd());
|
||||
///
|
||||
/// // Write some data
|
||||
/// claimed.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?;
|
||||
mask_fd(fd)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Prepare a file descriptor for use in Rust code.
|
||||
///
|
||||
/// Checks if the file descriptor is valid.
|
||||
///
|
||||
/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::Write;
|
||||
/// use std::os::fd::IntoRawFd;
|
||||
/// use tempfile::tempdir;
|
||||
/// use rosenpass_util::fd::{claim_fd_inplace, FdIo};
|
||||
///
|
||||
/// // Open a file and turn it into a raw file descriptor
|
||||
/// let fd = tempfile::tempfile()?.into_raw_fd();
|
||||
///
|
||||
/// // Reclaim that file and ready it for reading
|
||||
/// let mut fd = FdIo(claim_fd_inplace(fd)?);
|
||||
///
|
||||
/// // Write some data
|
||||
/// fd.write_all(b"Hello, World!")?;
|
||||
///
|
||||
/// Ok::<(), std::io::Error>(())
|
||||
/// ```
|
||||
pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result<OwnedFd> {
|
||||
let mut new = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||
let tmp = clone_fd_cloexec(&new)?;
|
||||
clone_fd_to_cloexec(tmp, &mut new)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
/// Will close the given file descriptor and overwrite
|
||||
/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse.
|
||||
///
|
||||
/// # Panic and safety
|
||||
///
|
||||
/// Will panic if the given file descriptor is negative of or larger than
|
||||
/// the file descriptor numbers permitted by the operating system.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use std::fs::File;
|
||||
/// # use std::io::Read;
|
||||
/// # use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||
/// # use std::os::fd::IntoRawFd;
|
||||
/// # use rustix::fd::AsFd;
|
||||
/// # use rosenpass_util::fd::mask_fd;
|
||||
///
|
||||
/// // Open a temporary file
|
||||
/// let fd = tempfile::tempfile().unwrap().into_raw_fd();
|
||||
/// assert!(fd >= 0);
|
||||
///
|
||||
/// // Mask the file descriptor
|
||||
/// mask_fd(fd).unwrap();
|
||||
///
|
||||
/// // Verify the file descriptor now points to `/dev/null`
|
||||
/// // Reading from `/dev/null` always returns 0 bytes
|
||||
/// let mut replaced_file = unsafe { File::from_raw_fd(fd) };
|
||||
/// let mut buffer = [0u8; 4];
|
||||
/// let bytes_read = replaced_file.read(&mut buffer).unwrap();
|
||||
/// assert_eq!(bytes_read, 0);
|
||||
/// ```
|
||||
pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> {
|
||||
// Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting,
|
||||
// it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd
|
||||
let mut owned = Forgetting::new(unsafe { OwnedFd::from_raw_fd(fd) });
|
||||
clone_fd_to_cloexec(open_nullfd()?, &mut owned)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag
|
||||
pub fn clone_fd_cloexec<Fd: AsFd>(fd: Fd) -> rustix::io::Result<OwnedFd> {
|
||||
/// Avoid stdin, stdout, and stderr
|
||||
const MINFD: RawFd = 3;
|
||||
fcntl_dupfd_cloexec(fd, MINFD)
|
||||
}
|
||||
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup3, DupFlags};
|
||||
dup3(fd, new, DupFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
/// Duplicate a file descriptor, setting the close on exec flag.
|
||||
///
|
||||
/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an
|
||||
/// explicit destination file descriptor.
|
||||
pub fn clone_fd_to_cloexec<Fd: AsFd>(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> {
|
||||
use rustix::io::{dup2, fcntl_setfd, FdFlags};
|
||||
dup2(&fd, new)?;
|
||||
fcntl_setfd(&new, FdFlags::CLOEXEC)
|
||||
}
|
||||
|
||||
/// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor
|
||||
/// writing.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The behavior of the file descriptor when being written to or from is undefined.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::{fs::File, io::Write, os::fd::IntoRawFd};
|
||||
/// use rustix::fd::FromRawFd;
|
||||
/// use rosenpass_util::fd::open_nullfd;
|
||||
///
|
||||
/// let nullfd = open_nullfd().unwrap();
|
||||
/// ```
|
||||
pub fn open_nullfd() -> rustix::io::Result<OwnedFd> {
|
||||
use rustix::fs::{open, Mode, OFlags};
|
||||
// TODO: Add tests showing that this will throw errors on use
|
||||
open("/dev/null", OFlags::CLOEXEC, Mode::empty())
|
||||
}
|
||||
|
||||
/// Convert low level errors into std::io::Error
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::io::ErrorKind as EK;
|
||||
/// use rustix::io::Errno;
|
||||
/// use rosenpass_util::fd::IntoStdioErr;
|
||||
///
|
||||
/// let e = Errno::INTR.into_stdio_err();
|
||||
/// assert!(matches!(e.kind(), EK::Interrupted));
|
||||
///
|
||||
/// let r : rustix::io::Result<()> = Err(Errno::INTR);
|
||||
/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted));
|
||||
/// ```
|
||||
pub trait IntoStdioErr {
|
||||
/// Target type produced (e.g. std::io:Error or std::io::Result depending on context
|
||||
type Target;
|
||||
/// Convert low level errors to
|
||||
fn into_stdio_err(self) -> Self::Target;
|
||||
}
|
||||
|
||||
impl IntoStdioErr for rustix::io::Errno {
|
||||
type Target = std::io::Error;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
std::io::Error::from_raw_os_error(self.raw_os_error())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoStdioErr for rustix::io::Result<T> {
|
||||
type Target = std::io::Result<T>;
|
||||
|
||||
fn into_stdio_err(self) -> Self::Target {
|
||||
self.map_err(IntoStdioErr::into_stdio_err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read and write directly from a file descriptor
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// See [claim_fd].
|
||||
pub struct FdIo<Fd: AsFd>(pub Fd);
|
||||
|
||||
impl<Fd: AsFd> std::io::Read for FdIo<Fd> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
rustix::io::read(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> std::io::Write for FdIo<Fd> {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
rustix::io::write(&self.0, buf).into_stdio_err()
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
use crate::result::OkExt;
|
||||
|
||||
/// Helpers for accessing stat(2) information
|
||||
pub trait StatExt {
|
||||
@@ -238,7 +15,7 @@ pub trait StatExt {
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::fd::StatExt;
|
||||
/// use rosenpass_util::rustix::StatExt;
|
||||
/// assert!(rustix::fs::stat("/")?.is_socket() == false);
|
||||
/// Ok::<(), rustix::io::Errno>(())
|
||||
/// ````
|
||||
@@ -263,7 +40,7 @@ pub trait TryStatExt {
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use rosenpass_util::fd::TryStatExt;
|
||||
/// use rosenpass_util::rustix::TryStatExt;
|
||||
/// let fd = rustix::fs::open("/", rustix::fs::OFlags::empty(), rustix::fs::Mode::empty())?;
|
||||
/// assert!(matches!(fd.is_socket(), Ok(false)));
|
||||
/// Ok::<(), rustix::io::Errno>(())
|
||||
@@ -358,7 +135,7 @@ pub trait GetUnixSocketType {
|
||||
/// # use std::os::fd::{AsFd, BorrowedFd};
|
||||
/// # use std::os::unix::net::UnixListener;
|
||||
/// # use tempfile::NamedTempFile;
|
||||
/// # use rosenpass_util::fd::GetUnixSocketType;
|
||||
/// # use rosenpass_util::rustix::GetUnixSocketType;
|
||||
/// let f = {
|
||||
/// // Generate a temp file and take its path
|
||||
/// // Remove the temp file
|
||||
@@ -378,7 +155,7 @@ pub trait GetUnixSocketType {
|
||||
/// # use std::os::fd::{AsFd, BorrowedFd};
|
||||
/// # use std::os::unix::net::{UnixDatagram, UnixListener};
|
||||
/// # use tempfile::NamedTempFile;
|
||||
/// # use rosenpass_util::fd::GetUnixSocketType;
|
||||
/// # use rosenpass_util::rustix::GetUnixSocketType;
|
||||
/// let f = {
|
||||
/// // Generate a temp file and take its path
|
||||
/// // Remove the temp file
|
||||
@@ -445,7 +222,7 @@ pub trait GetSocketProtocol {
|
||||
/// ```
|
||||
/// # use std::net::UdpSocket;
|
||||
/// # use std::os::fd::{AsFd, AsRawFd};
|
||||
/// # use rosenpass_util::fd::GetSocketProtocol;
|
||||
/// # use rosenpass_util::rustix::GetSocketProtocol;
|
||||
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
/// assert_eq!(socket.as_fd().socket_protocol().unwrap().unwrap(), rustix::net::ipproto::UDP);
|
||||
/// # Ok::<(), std::io::Error>(())
|
||||
@@ -458,7 +235,7 @@ pub trait GetSocketProtocol {
|
||||
/// # use std::net::UdpSocket;
|
||||
/// # use std::net::TcpListener;
|
||||
/// # use std::os::fd::{AsFd, AsRawFd};
|
||||
/// # use rosenpass_util::fd::GetSocketProtocol;
|
||||
/// # use rosenpass_util::rustix::GetSocketProtocol;
|
||||
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
/// assert!(socket.as_fd().is_udp_socket().unwrap());
|
||||
///
|
||||
@@ -484,7 +261,7 @@ pub trait GetSocketProtocol {
|
||||
/// # use std::net::UdpSocket;
|
||||
/// # use std::net::TcpListener;
|
||||
/// # use std::os::fd::{AsFd, AsRawFd};
|
||||
/// # use rosenpass_util::fd::GetSocketProtocol;
|
||||
/// # use rosenpass_util::rustix::GetSocketProtocol;
|
||||
/// let socket = UdpSocket::bind("127.0.0.1:0")?;
|
||||
/// assert!(matches!(socket.as_fd().demand_udp_socket(), Ok(())));
|
||||
///
|
||||
@@ -514,62 +291,3 @@ where
|
||||
rustix::net::sockopt::get_socket_protocol(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::{Read, Write};
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_neg() {
|
||||
let _ = claim_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_invalid_max() {
|
||||
let _ = claim_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_neg() {
|
||||
let _ = claim_fd_inplace(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_claim_fd_inplace_invalid_max() {
|
||||
let _ = claim_fd_inplace(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_neg() {
|
||||
let _ = mask_fd(-1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mask_fd_invalid_max() {
|
||||
let _ = mask_fd(i64::MAX as RawFd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_open_nullfd() -> anyhow::Result<()> {
|
||||
let mut file = FdIo(open_nullfd()?);
|
||||
let mut buf = [0; 10];
|
||||
assert!(matches!(file.read(&mut buf), Ok(0) | Err(_)));
|
||||
assert!(file.write(&buf).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nullfd_read_write() {
|
||||
let nullfd = open_nullfd().unwrap();
|
||||
let mut buf = vec![0u8; 16];
|
||||
assert_eq!(rustix::io::read(&nullfd, &mut buf).unwrap(), 0);
|
||||
assert!(rustix::io::write(&nullfd, b"test").is_err());
|
||||
}
|
||||
}
|
||||
46
util/src/rustix/syscall.rs
Normal file
46
util/src/rustix/syscall.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! Helpers for performing system calls
|
||||
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
use super::errno;
|
||||
|
||||
/// Wrapper type around [libc::c_long] that indicates that this value represents
|
||||
/// the result of a system call
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct SyscallResult(pub libc::c_long);
|
||||
|
||||
impl SyscallResult {
|
||||
/// Access to [Self::0]
|
||||
pub fn raw_value(&self) -> libc::c_long {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Claim the system call result as a file descriptor
|
||||
///
|
||||
/// - If [Self::raw_value] < 0, then [errno()] is called to retrieve the error type
|
||||
/// - If [Self::raw_value] > [i32::MAX], panics
|
||||
/// - Otherwise, this just forwards to [rustix::fd::OwnedFd::from_raw_fd]
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// Panics if [Self::raw_value] > [i32::MAX].
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// Refer to [rustix::fd::OwnedFd::from_raw_fd].
|
||||
pub unsafe fn claim_fd(&self) -> Result<rustix::fd::OwnedFd, rustix::io::Errno> {
|
||||
let fde = self.0;
|
||||
match fde {
|
||||
e if e < 0 => Err(errno()),
|
||||
fd if fd > i32::MAX.into() => panic!("File descriptor `{fd}` is out of bounds!"),
|
||||
fd => Ok(unsafe { rustix::fd::OwnedFd::from_raw_fd(fd as i32) }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<libc::c_long> for SyscallResult {
|
||||
fn from(value: libc::c_long) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
202
util/src/secret_memory/fd.rs
Normal file
202
util/src/secret_memory/fd.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
//! Creation of secret memory file descriptors.
|
||||
//!
|
||||
//! This essentially provides a higher_level API for [memfd_secret()] and [memfd_create()]
|
||||
|
||||
// Tests: This uses nix-based integration tests
|
||||
|
||||
use std::{os::fd::OwnedFd, sync::OnceLock};
|
||||
|
||||
use rustix::{fs::MemfdFlags, io::Errno};
|
||||
|
||||
use crate::rustix::{memfd_create, memfd_secret, MemfdSecretError, MemfdSecretFlags};
|
||||
use crate::{mem::CopyExt, result::OkExt};
|
||||
|
||||
/// Cache for [memfd_secret_supported]
|
||||
static MEMFD_SECRET_SUPPORTED: OnceLock<bool> = OnceLock::new();
|
||||
|
||||
/// Check whether support for memfd_secret is available
|
||||
pub fn memfd_secret_supported() -> Result<bool, Errno> {
|
||||
match MEMFD_SECRET_SUPPORTED.get() {
|
||||
Some(v) => return Ok(*v),
|
||||
_ => {} // Continue
|
||||
};
|
||||
|
||||
use MemfdSecretError as E;
|
||||
let is_supported = match memfd_secret(MemfdSecretFlags::empty()) {
|
||||
Ok(_) => true,
|
||||
Err(E::NotSupported) => false,
|
||||
Err(E::SystemError(e)) => return Err(e),
|
||||
};
|
||||
|
||||
// We are deliberately using get_or_init here to make sure that the entire application
|
||||
// never sees different values here
|
||||
MEMFD_SECRET_SUPPORTED
|
||||
.get_or_init(|| is_supported)
|
||||
.copy()
|
||||
.ok()
|
||||
}
|
||||
|
||||
/// How secure memory file descriptors should be allocated
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
|
||||
pub enum SecretMemfdPolicy {
|
||||
/// Use memfd_secret(2) if available, otherwise fall back to less
|
||||
/// secure options
|
||||
Opportunistic,
|
||||
/// Enforce the use of memfd_secret(2)
|
||||
UseMemfdSecret,
|
||||
/// Never use memfd_secret(2)
|
||||
DisableMemfdSecret,
|
||||
}
|
||||
|
||||
impl Default for SecretMemfdPolicy {
|
||||
fn default() -> Self {
|
||||
Self::Opportunistic
|
||||
}
|
||||
}
|
||||
|
||||
impl SecretMemfdPolicy {
|
||||
/// Create a SecretMemfdPolicy with the default policy
|
||||
///
|
||||
/// Currently [Self::Opportunistic]
|
||||
pub const fn default_const() -> Self {
|
||||
Self::Opportunistic
|
||||
}
|
||||
|
||||
/// Enforce the use of the highest security configuration available
|
||||
///
|
||||
/// This might not work on some systems, which is why this is not used
|
||||
/// by default.
|
||||
///
|
||||
/// Currently [Self::UseMemfdSecret]
|
||||
pub const fn enforce_high_security() -> Self {
|
||||
Self::UseMemfdSecret
|
||||
}
|
||||
}
|
||||
|
||||
/// Which mechanism to us use when allocating secret memory file descriptors
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
|
||||
pub enum SecretMemfdMechanism {
|
||||
/// The less secure memfd_create(2) will be used
|
||||
MemfdCreate,
|
||||
/// The more secure memfd_secret(2) will be used
|
||||
MemfdSecret,
|
||||
}
|
||||
|
||||
impl SecretMemfdMechanism {
|
||||
/// Decide which mechanism to use, based on the given [SecretMemfdPolicy]
|
||||
/// and [memfd_secret_supported()].
|
||||
///
|
||||
/// If [SecretMemfdPolicy::UseMemfdSecret] is used, then [SecretMemfdMechanism::MemfdSecret]
|
||||
/// will be returned, regardless of whether it is supported.
|
||||
///
|
||||
/// Likewise, if [SecretMemfdPolicy::DisableMemfdSecret] is used, then
|
||||
/// [SecretMemfdMechanism::MemfdCreate] will be used unconditionally.
|
||||
pub fn decide_with_policy(policy: SecretMemfdPolicy) -> Result<Self, Errno> {
|
||||
use SecretMemfdMechanism as M;
|
||||
use SecretMemfdPolicy as P;
|
||||
|
||||
match policy {
|
||||
P::UseMemfdSecret => return Ok(M::MemfdSecret),
|
||||
P::DisableMemfdSecret => return Ok(M::MemfdCreate),
|
||||
P::Opportunistic => {}
|
||||
};
|
||||
|
||||
match memfd_secret_supported()? {
|
||||
true => Ok(M::MemfdSecret),
|
||||
false => Ok(M::MemfdCreate),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors for [create_memfd_secret()]
|
||||
#[derive(Copy, Clone, Debug, thiserror::Error)]
|
||||
pub enum SecretMemfdWithConfigError {
|
||||
/// Call to [memfd_secret()] failed
|
||||
#[error("{:?}", .0)]
|
||||
MemfdSecretError(#[from] MemfdSecretError),
|
||||
/// Call to [memfd()] failed
|
||||
#[error("Could not create secret memory segment using memfd_create(2) due to underlying system error: {:?}", .0)]
|
||||
MemfdCreateError(Errno),
|
||||
/// Some other (system) error occurred that prevented us from determining whether
|
||||
/// memfd_secret(2) is supported.
|
||||
#[error("Failed to determine whether memfd_secret(2) is supported due to underlying system error: {:?}", .0)]
|
||||
FailedToDetectSupport(Errno),
|
||||
}
|
||||
|
||||
/// Robustly configure and create secret memory file descriptors
|
||||
///
|
||||
/// Whereas [memfd_secret] will always use memfd_secret(2), this construction allows multiple
|
||||
/// file descriptor back ends to be used to support different usage scenarios.
|
||||
///
|
||||
/// This is necessary, because older systems do not support memfd_secret(2) and using it might not
|
||||
/// always be desirable, as memfd_secret for instance also inhibits hibernation.
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq)]
|
||||
pub struct SecretMemfdConfig {
|
||||
/// Enable the close-on-exec flag for the new file descriptor.
|
||||
pub close_on_exec: bool,
|
||||
/// Security mechanism to use
|
||||
pub policy: SecretMemfdPolicy,
|
||||
}
|
||||
|
||||
impl SecretMemfdConfig {
|
||||
/// Create a new, default [Self]
|
||||
pub const fn new() -> Self {
|
||||
let close_on_exec = false;
|
||||
let policy = SecretMemfdPolicy::default_const();
|
||||
Self {
|
||||
close_on_exec,
|
||||
policy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the [Self::close_on_exec] flag to true
|
||||
pub const fn cloexec(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.close_on_exec = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Set `self.policy = SecretMemoryFdPolicy::UseMemfdSecret`
|
||||
pub const fn enforce_high_security(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.policy = SecretMemfdPolicy::enforce_high_security();
|
||||
r
|
||||
}
|
||||
|
||||
/// Whether memfd_secret will be used by [Self::create()]
|
||||
pub fn mechanism(&self) -> Result<SecretMemfdMechanism, Errno> {
|
||||
SecretMemfdMechanism::decide_with_policy(self.policy)
|
||||
}
|
||||
|
||||
/// Allocate a secret file descriptor based on the configuration
|
||||
pub fn create(&self) -> Result<OwnedFd, SecretMemfdWithConfigError> {
|
||||
use SecretMemfdWithConfigError as E;
|
||||
let mech = self.mechanism().map_err(E::FailedToDetectSupport)?;
|
||||
|
||||
use SecretMemfdMechanism as M;
|
||||
match (mech, self.close_on_exec) {
|
||||
(M::MemfdCreate, cloexec) => {
|
||||
let flags = match cloexec {
|
||||
true => MemfdFlags::CLOEXEC,
|
||||
false => MemfdFlags::empty(),
|
||||
};
|
||||
memfd_create("rosenpass secret memory segment", flags).map_err(E::MemfdCreateError)
|
||||
}
|
||||
(M::MemfdSecret, cloexec) => {
|
||||
let flags = match cloexec {
|
||||
true => MemfdSecretFlags::CLOEXEC,
|
||||
false => MemfdSecretFlags::empty(),
|
||||
};
|
||||
memfd_secret(flags).map_err(E::MemfdSecretError)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a secret memory file descriptor using the default policy
|
||||
///
|
||||
/// Shorthand for
|
||||
/// [`SecretMemfdConfig`][`::new()`](SecretMemfdConfig::new)[`.create()`](SecretMemfdConfig::create)
|
||||
pub fn memfd_for_secrets_with_default_policy() -> Result<OwnedFd, SecretMemfdWithConfigError> {
|
||||
SecretMemfdConfig::new().create()
|
||||
}
|
||||
409
util/src/secret_memory/mmap.rs
Normal file
409
util/src/secret_memory/mmap.rs
Normal file
@@ -0,0 +1,409 @@
|
||||
//! This module takes care of allocating memory segments for
|
||||
//! file descriptors created with [super::fd::memfd_for_secrets_with_default_policy]
|
||||
//! and anonymous memory segments
|
||||
|
||||
// Tests: Nix based integration tests
|
||||
|
||||
#![deny(unsafe_op_in_unsafe_fn)]
|
||||
|
||||
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
|
||||
use std::ptr::null_mut;
|
||||
|
||||
use crate::mem::CopyExt;
|
||||
|
||||
/// Size of the memory mapping for [MappableFd]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MMapSizePolicy {
|
||||
/// Size is assumed to be this particular value; [MappableFd::mmap] will simply
|
||||
/// use this value without checking whether its matches the size of the underlying
|
||||
/// data
|
||||
Assumed(u64),
|
||||
/// Size is assumed to be this particular value; [MappableFd::mmap] will check the
|
||||
/// size of the underlying data and raise an error if the size of data and this value
|
||||
/// do not match
|
||||
Checked(u64),
|
||||
/// Size is defined to be this particular value; [MappableFd::mmap] will explicitly resize
|
||||
/// the underlying file descriptor to be this particular size when called.
|
||||
Resize(u64),
|
||||
}
|
||||
|
||||
impl MMapSizePolicy {
|
||||
/// The numeric value of the size
|
||||
pub fn size_value(&self) -> u64 {
|
||||
match *self {
|
||||
Self::Assumed(v) => v,
|
||||
Self::Checked(v) => v,
|
||||
Self::Resize(v) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for [MappableFd::mmap]
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct MapFdConfig {
|
||||
/// The memory region can not be read from
|
||||
pub unreadable: bool,
|
||||
/// The memory region can not be written to
|
||||
pub immutable: bool,
|
||||
/// The memory region can be executed
|
||||
pub executable: bool,
|
||||
/// The memory region is shared-memory; other mappings of the same
|
||||
/// region within and outside this process can see the modifications
|
||||
/// to the memory region (as long as they also set the shared flag)
|
||||
pub shared: bool,
|
||||
/// How [MappableFd::mmap] will determine the size to be used fo
|
||||
///
|
||||
/// You should usually set this value through [Self::set_size_policy], [Self::assume_size_without_checking],
|
||||
/// [Self::expected_size], or [Self::resize_on_mmap].
|
||||
pub size_policy: Option<MMapSizePolicy>,
|
||||
}
|
||||
|
||||
impl MapFdConfig {
|
||||
/// New MapFdConfig with all settings turned off
|
||||
///
|
||||
/// You still must set [Self::size_policy], otherwise [MappableFd::mmap] will raise
|
||||
/// an error when called
|
||||
pub const fn new() -> Self {
|
||||
MapFdConfig {
|
||||
unreadable: false,
|
||||
immutable: false,
|
||||
executable: false,
|
||||
shared: false,
|
||||
size_policy: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// New MappableFdConfig with shared memory turned on
|
||||
pub const fn shared_memory() -> Self {
|
||||
Self::new().set_shared()
|
||||
}
|
||||
|
||||
/// Set the [Self::unreadable] flag
|
||||
pub const fn set_unreadable(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.unreadable = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Set the [Self::immutable] flag
|
||||
pub const fn set_immutable(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.immutable = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Set the [Self::shared] flag
|
||||
pub const fn set_shared(&self) -> Self {
|
||||
let mut r = *self;
|
||||
r.shared = true;
|
||||
r
|
||||
}
|
||||
|
||||
/// Create a [MappableFd] instance with this configuration
|
||||
pub fn mappable_fd<Fd: AsFd>(&self, fd: Fd) -> MappableFd<Fd> {
|
||||
MappableFd::new(fd, self.copy())
|
||||
}
|
||||
|
||||
/// Calculate [rustix::mm::ProtFlags] for this configuration
|
||||
pub const fn mmap_prot(&self) -> rustix::mm::ProtFlags {
|
||||
use rustix::mm::ProtFlags as P;
|
||||
|
||||
let p_read = match self.unreadable {
|
||||
true => P::empty(),
|
||||
false => P::READ,
|
||||
};
|
||||
|
||||
let p_write = match self.immutable {
|
||||
true => P::empty(),
|
||||
false => P::WRITE,
|
||||
};
|
||||
|
||||
let p_exec = match self.executable {
|
||||
true => P::EXEC,
|
||||
false => P::empty(),
|
||||
};
|
||||
|
||||
p_read.union(p_write).union(p_exec)
|
||||
}
|
||||
|
||||
/// Calculate [rustix::mm::MapFlags] for this configuration
|
||||
pub const fn mmap_flags(&self) -> rustix::mm::MapFlags {
|
||||
use rustix::mm::MapFlags as M;
|
||||
match self.shared {
|
||||
true => M::SHARED,
|
||||
false => M::empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to the given value
|
||||
pub const fn set_size_policy(&self, size_policy: MMapSizePolicy) -> Self {
|
||||
let mut r = *self;
|
||||
r.size_policy = Some(size_policy);
|
||||
r
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to [MMapSizePolicy::Assumed] with the given value
|
||||
pub const fn assume_size_without_checking(&self, size: u64) -> Self {
|
||||
self.set_size_policy(MMapSizePolicy::Assumed(size))
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to [MMapSizePolicy::Checked] with the given value
|
||||
pub const fn expected_size(&self, size: u64) -> Self {
|
||||
self.set_size_policy(MMapSizePolicy::Checked(size))
|
||||
}
|
||||
|
||||
/// Set [Self::size_policy] to [MMapSizePolicy::Resize] with the given value
|
||||
pub const fn resize_on_mmap(&self, size: u64) -> Self {
|
||||
self.set_size_policy(MMapSizePolicy::Resize(size))
|
||||
}
|
||||
}
|
||||
|
||||
/// Error returned by MappableFd::mmap
|
||||
#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum MMapError {
|
||||
/// Error converting between usize & f64
|
||||
#[error("Requested memory map of size {requested_len} but maximum supported size is {max_supported_len}. \
|
||||
This is a low level error that usually arises on architectures where the integer type usize ({} bytes) can not \
|
||||
represent all values in u64 ({} bytes). Are you possibly on 32 bit CPU architecture requesting a buffer bigger than 4 GB?\n\
|
||||
Error: {err:?}",
|
||||
(usize::BITS as f64)/8f64, (u64::BITS as f64)/8f64
|
||||
)]
|
||||
OutOfBounds {
|
||||
/// Underlying error
|
||||
err: <u64 as TryInto<usize>>::Error,
|
||||
/// The size of the memory map requested
|
||||
requested_len: u64,
|
||||
/// Maximum supported size
|
||||
max_supported_len: usize,
|
||||
},
|
||||
/// Tried to map a file descriptor into memory, but the size policy was never set. Developer
|
||||
/// error.
|
||||
#[error("Tried to map a file descriptor into memory, but the size policy was never set. This is a developer error.")]
|
||||
MissingSizePolicy,
|
||||
/// fseek(3)/ftell(3) system call failed
|
||||
#[error("Tried to map file descriptor into memory, but failed to determine the size of the file descriptor: {:?}", .0)]
|
||||
CouldNotDetermineSize(rustix::io::Errno),
|
||||
/// Mismatch between expected and actual size of the file descriptor
|
||||
#[error("Tried to map file descriptor into memory with expected size {expected:?}, instead we found that the true size is {actual:?}")]
|
||||
IncorrectSize {
|
||||
/// Expected file descriptor size (given by caller)
|
||||
expected: u64,
|
||||
/// Actual file descriptor size
|
||||
actual: u64,
|
||||
},
|
||||
/// Negative size reported by fstat(2)
|
||||
#[error("Tried to determine the size of the underlying file descriptor, but failed to determine the size of the file descriptor because the size ({size}) is negative: {err:?}")]
|
||||
InvalidSize {
|
||||
/// Underlying error
|
||||
err: <i64 as TryInto<u64>>::Error,
|
||||
/// Reported size of the file descriptor
|
||||
size: i64,
|
||||
},
|
||||
/// ftruncate(2) system call failed
|
||||
#[error("Tried to resize the underlying file descriptor, but failed: {:?}", .0)]
|
||||
ResizeError(rustix::io::Errno),
|
||||
/// mmap(2) system call failed
|
||||
#[error("Tried to map file descriptor into memory, but mmap(2) system call failed: {:?}", .0)]
|
||||
MMapError(rustix::io::Errno),
|
||||
}
|
||||
|
||||
/// Handle mapping a file descriptor into memory
|
||||
pub struct MappableFd<Fd: AsFd> {
|
||||
/// The file descriptor this struct refers to
|
||||
fd: Fd,
|
||||
/// The configuration for mmap
|
||||
config: MapFdConfig,
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> AsFd for MappableFd<Fd> {
|
||||
fn as_fd(&self) -> BorrowedFd<'_> {
|
||||
self.fd.as_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> AsRawFd for MappableFd<Fd> {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.as_fd().as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fd: AsFd> MappableFd<Fd> {
|
||||
/// Create a [MappableFd] using the default configuration ([MapFdConfig])
|
||||
pub fn from_fd(fd: Fd) -> Self {
|
||||
Self::new(fd, MapFdConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new [MappableFd] for an existing file descriptor
|
||||
pub fn new(fd: Fd, config: MapFdConfig) -> Self {
|
||||
Self { fd, config }
|
||||
}
|
||||
|
||||
/// Extract the underlying file descriptor
|
||||
pub fn into_fd(self) -> Fd {
|
||||
self.fd
|
||||
}
|
||||
|
||||
/// Access the [MapFdConfig] associated with Self
|
||||
pub fn config(&self) -> MapFdConfig {
|
||||
self.config
|
||||
}
|
||||
|
||||
/// Access the [MapFdConfig] associated with Self
|
||||
pub fn config_mut(&mut self) -> &mut MapFdConfig {
|
||||
&mut self.config
|
||||
}
|
||||
|
||||
/// Modify the [MapFdConfig] associated with Self, chainable
|
||||
pub fn with_config(mut self, config: MapFdConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Determine the size of the data associated with the file descriptor
|
||||
pub fn size_of_underlying_data(&self) -> Result<u64, MMapError> {
|
||||
use MMapError as E;
|
||||
let size = rustix::fs::fstat(self)
|
||||
.map_err(E::CouldNotDetermineSize)?
|
||||
.st_size;
|
||||
size.try_into().map_err(|err| E::InvalidSize { err, size })
|
||||
}
|
||||
|
||||
/// Map the file into memory
|
||||
///
|
||||
/// # Determining the size of the mapping.
|
||||
///
|
||||
/// Before calling this function, [MapFdConfig::size_policy] must be set.
|
||||
///
|
||||
/// Note that [Self::from_fd] and [Self::new] still allow you to create a [Self] with
|
||||
/// [MapFdConfig::size_policy] set to [None], so you can use [Self::size_of_underlying_data]
|
||||
/// to auto-detect the size of the mapping.
|
||||
///
|
||||
/// This functionality is not implemented by default, as its crucial to validate the size of
|
||||
/// the data being mapped into memory somehow; otherwise, the party that created the file
|
||||
/// descriptor can trigger a denial-of-service attack against our process by allocating an
|
||||
/// excessively large, sparse file. If you implement size auto-detection facilities, you should
|
||||
/// still enforce some bounds on the size.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// If there exist any Rust references referring to the memory region, or if you subsequently create a Rust reference referring to the resulting region, it is your responsibility to ensure that the Rust reference invariants are preserved, including ensuring that the memory is not mutated in a way that a Rust reference would not expect.
|
||||
pub fn mmap(&self) -> Result<MappedSegment, MMapError> {
|
||||
use rustix::mm::mmap;
|
||||
use MMapError as E;
|
||||
|
||||
let prot = self.config().mmap_prot();
|
||||
let flags = self.config().mmap_flags();
|
||||
|
||||
// Determine the size of the mapping to be used as u64
|
||||
let requested_size = match self.config().size_policy {
|
||||
None => return Err(E::MissingSizePolicy),
|
||||
Some(MMapSizePolicy::Assumed(size)) => size,
|
||||
Some(MMapSizePolicy::Resize(size)) => {
|
||||
rustix::fs::ftruncate(self, size).map_err(E::ResizeError)?;
|
||||
size
|
||||
}
|
||||
Some(MMapSizePolicy::Checked(expected)) => {
|
||||
let actual = self.size_of_underlying_data()?;
|
||||
if expected != actual {
|
||||
return Err(E::IncorrectSize { expected, actual });
|
||||
}
|
||||
expected
|
||||
}
|
||||
};
|
||||
|
||||
// Cast the size of the mapping to be used to usize, raising an error if the
|
||||
// requested_size can not be represented as usize (this should never happen in general,
|
||||
// but it could conceivably be thrown on 32 bit systems when very large mappings (>= 4GB)
|
||||
// are requested
|
||||
let len = requested_size.try_into().map_err(|err| {
|
||||
let max_supported_len = usize::MAX;
|
||||
E::OutOfBounds {
|
||||
err,
|
||||
requested_len: requested_size,
|
||||
max_supported_len,
|
||||
}
|
||||
})?;
|
||||
|
||||
let ptr = unsafe { mmap(null_mut(), len, prot, flags, self, 0) };
|
||||
let ptr = ptr.map_err(E::MMapError)?;
|
||||
let ptr = unsafe { MappedSegment::from_raw_parts(ptr.cast(), len) };
|
||||
|
||||
Ok(ptr)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents exclusive ownership of a memory segment mapped into memory
|
||||
///
|
||||
/// Automatically unmaps the memory segment as this goes out of scope
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// If the munmap(2) call fails, the destructor will panic. You can avoid this and use explicit
|
||||
/// error handling by calling [Self::unmap] instead
|
||||
#[derive(Debug)]
|
||||
pub struct MappedSegment {
|
||||
/// The location of the memory segment
|
||||
ptr: *mut u8,
|
||||
/// Length of the segment in bytes
|
||||
len: usize,
|
||||
}
|
||||
|
||||
unsafe impl Send for MappedSegment {}
|
||||
|
||||
impl MappedSegment {
|
||||
/// Construct a new [Self] from a pointer and a length
|
||||
///
|
||||
/// `ptr` is the address of the memory segment and `len` is its length in bytes
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// It is the responsibility of the caller to make sure that any pointer P such that `P = ptr.add(n)`,
|
||||
/// where `n < len` is a valid pointer. See #safety in [std::ptr].
|
||||
pub unsafe fn from_raw_parts(ptr: *mut u8, len: usize) -> Self {
|
||||
Self { ptr, len }
|
||||
}
|
||||
|
||||
/// Decompose a MappedSegment into its raw components: Pointer and length
|
||||
pub fn into_raw_parts(self) -> (*mut u8, usize) {
|
||||
let r = (self.ptr(), self.len());
|
||||
std::mem::forget(self);
|
||||
r
|
||||
}
|
||||
|
||||
/// The location of the memory segment
|
||||
pub fn ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
/// Length of the segment in bytes
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Release the memory segment
|
||||
///
|
||||
/// Compared to using the destructor which panics if unmapping fails, this allows explicit error handling to be used.
|
||||
///
|
||||
/// If this returns an error, the memory segment has not been freed. The values from
|
||||
/// [Self::into_raw_parts()] are returned as part of the error value so the caller has
|
||||
/// some chance to free the memory some other way (or to leak it if they so choose)
|
||||
pub fn unmap(self) -> Result<(), (rustix::io::Errno, *mut u8, usize)> {
|
||||
let (ptr, len) = self.into_raw_parts();
|
||||
let res = unsafe { rustix::mm::munmap(ptr.cast(), len) };
|
||||
res.map_err(|errno| (errno, ptr, len))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MappedSegment {
|
||||
fn drop(&mut self) {
|
||||
let mut owned = MappedSegment {
|
||||
ptr: null_mut(),
|
||||
len: 0,
|
||||
};
|
||||
std::mem::swap(self, &mut owned);
|
||||
if let Err((errno, _ptr, _len)) = owned.unmap() {
|
||||
panic!("Failed to unmap MappedSegment: {errno:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
4
util/src/secret_memory/mod.rs
Normal file
4
util/src/secret_memory/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Utilities for allocating secret memory
|
||||
|
||||
pub mod fd;
|
||||
pub mod mmap;
|
||||
107
util/src/sync/atomic/abstract_atomic.rs
Normal file
107
util/src/sync/atomic/abstract_atomic.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! Traits for types with atomic access semantics
|
||||
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
/// A trait for types with atomic access semantics
|
||||
///
|
||||
/// Using this trait allows us to achieve two goals:
|
||||
///
|
||||
/// 1. We can implement atomic semantics for types where
|
||||
/// there is no platform support for atomic semantics
|
||||
/// (e.g. by using a Mutex)
|
||||
/// 2. We can reuse implementations of concurrent data structures
|
||||
/// efficiently for the non-concurrent case. E.g. we could
|
||||
/// build a thread-local ring buffer using
|
||||
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeWriter]/
|
||||
/// [crate::ringbuf::concurrent::framework::ConcurrentPipeReader]
|
||||
/// by supplying some sort `Immediate<u64>` type for the atomics
|
||||
/// in [crate::ringbuf::concurrent::framework::ConcurrentPipeCore] that implements
|
||||
/// this trait by using a [std::cell::Cell]. It may seem counter
|
||||
/// intuitive, but this setup implements perfectly fine atomic-appearing
|
||||
/// semantics just as long as the cell is thread-local.
|
||||
pub trait AbstractAtomic<T> {
|
||||
/// Like [std::sync::atomic::AtomicU64::load()]
|
||||
fn load(&self, order: Ordering) -> T;
|
||||
|
||||
/// Like [std::sync::atomic::AtomicU64::compare_exchange_weak()].
|
||||
///
|
||||
/// The default implementation just calls [AbstractAtomic::compare_exchange()].
|
||||
fn compare_exchange_weak(
|
||||
&self,
|
||||
current: T,
|
||||
new: T,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<T, T> {
|
||||
self.compare_exchange(current, new, success, failure)
|
||||
}
|
||||
|
||||
/// Like [std::sync::atomic::AtomicU64::compare_exchange()].
|
||||
fn compare_exchange(
|
||||
&self,
|
||||
current: T,
|
||||
new: T,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<T, T>;
|
||||
}
|
||||
|
||||
/// Implements a default type for [AbstractAtomic]
|
||||
///
|
||||
/// This in essence is to [AbstractAtomic], as [std::ops::Deref] is to [std::borrow::Borrow];
|
||||
/// the same functionality, except with a
|
||||
pub trait AbstractAtomicType: AbstractAtomic<Self::ValueType> {
|
||||
/// The underlying atomic value
|
||||
type ValueType;
|
||||
}
|
||||
|
||||
/// Implementing [AbstractAtomic] and [AbstractAtomicType] for standard atomics
|
||||
macro_rules! impl_abstract_atomic_for_atomic {
|
||||
($($Atomic:ty : $Value:ty),*) => {
|
||||
$(
|
||||
impl AbstractAtomicType for $Atomic {
|
||||
type ValueType = $Value;
|
||||
}
|
||||
|
||||
impl AbstractAtomic<$Value> for $Atomic {
|
||||
fn load(&self, order: Ordering) -> $Value {
|
||||
<$Atomic>::load(&self, order)
|
||||
}
|
||||
|
||||
fn compare_exchange_weak(
|
||||
&self,
|
||||
current: $Value,
|
||||
new: $Value,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<$Value, $Value> {
|
||||
<$Atomic>::compare_exchange_weak(&self, current, new, success, failure)
|
||||
}
|
||||
|
||||
fn compare_exchange(
|
||||
&self,
|
||||
current: $Value,
|
||||
new: $Value,
|
||||
success: Ordering,
|
||||
failure: Ordering,
|
||||
) -> Result<$Value, $Value> {
|
||||
<$Atomic>::compare_exchange(&self, current, new, success, failure)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_abstract_atomic_for_atomic! {
|
||||
std::sync::atomic::AtomicBool: bool,
|
||||
std::sync::atomic::AtomicI8: i8,
|
||||
std::sync::atomic::AtomicI16: i16,
|
||||
std::sync::atomic::AtomicI32: i32,
|
||||
std::sync::atomic::AtomicI64: i64,
|
||||
std::sync::atomic::AtomicIsize: isize,
|
||||
std::sync::atomic::AtomicU8: u8,
|
||||
std::sync::atomic::AtomicU16: u16,
|
||||
std::sync::atomic::AtomicU32: u32,
|
||||
std::sync::atomic::AtomicU64: u64,
|
||||
std::sync::atomic::AtomicUsize: usize
|
||||
}
|
||||
5
util/src/sync/atomic/mod.rs
Normal file
5
util/src/sync/atomic/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
//! Synchronization using atomic semantics for multi-threaded programs.
|
||||
//!
|
||||
//! Analogous to [std::sync::atomic].
|
||||
|
||||
pub mod abstract_atomic;
|
||||
5
util/src/sync/mod.rs
Normal file
5
util/src/sync/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
//! Synchronization for multi-threaded programs.
|
||||
//!
|
||||
//! Analogous to [std::sync].
|
||||
|
||||
pub mod atomic;
|
||||
@@ -1,9 +1,9 @@
|
||||
//! A module providing the [`RefMaker`] type and its associated methods for constructing
|
||||
//! [`zerocopy::Ref`] references from byte buffers.
|
||||
|
||||
use anyhow::{ensure, Context};
|
||||
use anyhow::ensure;
|
||||
use std::marker::PhantomData;
|
||||
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
|
||||
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
|
||||
use zeroize::Zeroize;
|
||||
|
||||
use crate::zeroize::ZeroizedExt;
|
||||
@@ -19,10 +19,10 @@ use crate::zeroize::ZeroizedExt;
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref};///
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::RefMaker;
|
||||
///
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Header {
|
||||
/// field1: u32,
|
||||
@@ -82,7 +82,10 @@ impl<B, T> RefMaker<B, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
impl<B: SplitByteSlice, T> RefMaker<B, T>
|
||||
where
|
||||
T: KnownLayout + Immutable,
|
||||
{
|
||||
/// Parses the buffer into a [`zerocopy::Ref<B, T>`].
|
||||
///
|
||||
/// This will fail if the buffer is smaller than `size_of::<T>`.
|
||||
@@ -94,10 +97,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::RefMaker;
|
||||
///
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes, Debug)]
|
||||
/// #[derive(FromBytes, IntoBytes, Debug, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data(u32);
|
||||
///
|
||||
@@ -116,11 +119,11 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
/// let bytes = [1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
/// let parse_error = RefMaker::<_, Data>::new(&bytes[1..5]).parse()
|
||||
/// .expect_err("Should error");
|
||||
/// assert_eq!(parse_error.to_string(), "Parser error!");
|
||||
/// assert_eq!(parse_error.to_string(), "Parser error: Alignment(AlignmentError)");
|
||||
/// ```
|
||||
pub fn parse(self) -> anyhow::Result<Ref<B, T>> {
|
||||
self.ensure_fit()?;
|
||||
Ref::<B, T>::new(self.buf).context("Parser error!")
|
||||
Ref::<B, T>::from_bytes(self.buf).map_err(|e| anyhow::anyhow!("Parser error: {e:?}"))
|
||||
}
|
||||
|
||||
/// Splits the internal buffer into a `RefMaker` containing a buffer with
|
||||
@@ -142,7 +145,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
/// ```
|
||||
pub fn from_prefix_with_tail(self) -> anyhow::Result<(Self, B)> {
|
||||
self.ensure_fit()?;
|
||||
let (head, tail) = self.buf.split_at(Self::target_size());
|
||||
let (head, tail) = self
|
||||
.buf
|
||||
.split_at(Self::target_size())
|
||||
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
|
||||
Ok((Self::new(head), tail))
|
||||
}
|
||||
|
||||
@@ -165,7 +171,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
/// ```
|
||||
pub fn split_prefix(self) -> anyhow::Result<(Self, Self)> {
|
||||
self.ensure_fit()?;
|
||||
let (head, tail) = self.buf.split_at(Self::target_size());
|
||||
let (head, tail) = self
|
||||
.buf
|
||||
.split_at(Self::target_size())
|
||||
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
|
||||
Ok((Self::new(head), Self::new(tail)))
|
||||
}
|
||||
|
||||
@@ -204,7 +213,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
pub fn from_suffix_with_head(self) -> anyhow::Result<(Self, B)> {
|
||||
self.ensure_fit()?;
|
||||
let point = self.bytes().len() - Self::target_size();
|
||||
let (head, tail) = self.buf.split_at(point);
|
||||
let (head, tail) = self
|
||||
.buf
|
||||
.split_at(point)
|
||||
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
|
||||
Ok((Self::new(tail), head))
|
||||
}
|
||||
|
||||
@@ -227,7 +239,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
pub fn split_suffix(self) -> anyhow::Result<(Self, Self)> {
|
||||
self.ensure_fit()?;
|
||||
let point = self.bytes().len() - Self::target_size();
|
||||
let (head, tail) = self.buf.split_at(point);
|
||||
let (head, tail) = self
|
||||
.buf
|
||||
.split_at(point)
|
||||
.map_err(|_| anyhow::anyhow!("Failed to split buffer!"))?;
|
||||
Ok((Self::new(head), Self::new(tail)))
|
||||
}
|
||||
|
||||
@@ -282,7 +297,10 @@ impl<B: ByteSlice, T> RefMaker<B, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSliceMut, T> RefMaker<B, T> {
|
||||
impl<B: SplitByteSliceMut, T> RefMaker<B, T>
|
||||
where
|
||||
T: KnownLayout + Immutable,
|
||||
{
|
||||
/// Creates a zeroized reference of type `T` from the buffer.
|
||||
///
|
||||
/// # Errors
|
||||
@@ -292,9 +310,9 @@ impl<B: ByteSliceMut, T> RefMaker<B, T> {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref}; ///
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::RefMaker;
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data([u8; 4]);
|
||||
///
|
||||
@@ -312,7 +330,10 @@ impl<B: ByteSliceMut, T> RefMaker<B, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSliceMut, T> Zeroize for RefMaker<B, T> {
|
||||
impl<B: SplitByteSliceMut, T> Zeroize for RefMaker<B, T>
|
||||
where
|
||||
T: KnownLayout + Immutable,
|
||||
{
|
||||
fn zeroize(&mut self) {
|
||||
self.bytes_mut().zeroize()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Extension traits for converting `Ref<B, T>` into references backed by
|
||||
//! standard slices.
|
||||
|
||||
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
|
||||
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
|
||||
|
||||
/// A trait for converting a `Ref<B, T>` into a `Ref<&[u8], T>`.
|
||||
///
|
||||
@@ -16,9 +16,9 @@ pub trait ZerocopyEmancipateExt<B, T> {
|
||||
///
|
||||
/// ```
|
||||
/// # use std::ops::Deref;
|
||||
/// # use zerocopy::{AsBytes, ByteSlice, FromBytes, FromZeroes, Ref};
|
||||
/// # use zerocopy::{IntoBytes, SplitByteSlice, FromBytes, FromZeros, Ref, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::ZerocopyEmancipateExt;
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data(u32);
|
||||
/// #[repr(align(4))]
|
||||
@@ -44,9 +44,9 @@ pub trait ZerocopyEmancipateMutExt<B, T> {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Ref, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::{ZerocopyEmancipateMutExt};
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data(u32);
|
||||
/// #[repr(align(4))]
|
||||
@@ -64,18 +64,20 @@ pub trait ZerocopyEmancipateMutExt<B, T> {
|
||||
|
||||
impl<B, T> ZerocopyEmancipateExt<B, T> for Ref<B, T>
|
||||
where
|
||||
B: ByteSlice,
|
||||
B: SplitByteSlice,
|
||||
T: KnownLayout + Immutable,
|
||||
{
|
||||
fn emancipate(&self) -> Ref<&[u8], T> {
|
||||
Ref::new(self.bytes()).unwrap()
|
||||
Ref::from_bytes(zerocopy::Ref::bytes(self)).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, T> ZerocopyEmancipateMutExt<B, T> for Ref<B, T>
|
||||
where
|
||||
B: ByteSliceMut,
|
||||
B: SplitByteSliceMut,
|
||||
T: KnownLayout + Immutable,
|
||||
{
|
||||
fn emancipate_mut(&mut self) -> Ref<&mut [u8], T> {
|
||||
Ref::new(self.bytes_mut()).unwrap()
|
||||
Ref::from_bytes(zerocopy::Ref::bytes_mut(self)).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Extension traits for parsing slices into [`zerocopy::Ref`] values using the
|
||||
//! [`RefMaker`] abstraction.
|
||||
|
||||
use zerocopy::{ByteSlice, ByteSliceMut, Ref};
|
||||
use zerocopy::{Immutable, KnownLayout, Ref, SplitByteSlice, SplitByteSliceMut};
|
||||
|
||||
use super::RefMaker;
|
||||
|
||||
@@ -9,16 +9,16 @@ use super::RefMaker;
|
||||
///
|
||||
/// This trait adds methods for creating [`Ref`] references from
|
||||
/// slices by using the [`RefMaker`] type internally.
|
||||
pub trait ZerocopySliceExt: Sized + ByteSlice {
|
||||
pub trait ZerocopySliceExt: Sized + SplitByteSlice {
|
||||
/// Creates a new `RefMaker` for the given slice.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::{RefMaker, ZerocopySliceExt};
|
||||
///
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data(u32);
|
||||
///
|
||||
@@ -39,10 +39,10 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::ZerocopySliceExt;
|
||||
///
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data(u16, u16);
|
||||
/// #[repr(align(4))]
|
||||
@@ -52,7 +52,10 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
|
||||
/// assert_eq!(data_ref.0, 0x0201);
|
||||
/// assert_eq!(data_ref.1, 0x0403);
|
||||
/// ```
|
||||
fn zk_parse<T>(self) -> anyhow::Result<Ref<Self, T>> {
|
||||
fn zk_parse<T>(self) -> anyhow::Result<Ref<Self, T>>
|
||||
where
|
||||
T: Immutable + KnownLayout,
|
||||
{
|
||||
self.zk_ref_maker().parse()
|
||||
}
|
||||
|
||||
@@ -67,9 +70,9 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::ZerocopySliceExt;
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Header(u32);
|
||||
/// #[repr(align(4))]
|
||||
@@ -80,7 +83,10 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
|
||||
/// let header_ref = bytes.0.zk_parse_prefix::<Header>().unwrap();
|
||||
/// assert_eq!(header_ref.0, 0xDDCCBBAA);
|
||||
/// ```
|
||||
fn zk_parse_prefix<T>(self) -> anyhow::Result<Ref<Self, T>> {
|
||||
fn zk_parse_prefix<T>(self) -> anyhow::Result<Ref<Self, T>>
|
||||
where
|
||||
T: Immutable + KnownLayout,
|
||||
{
|
||||
self.zk_ref_maker().from_prefix()?.parse()
|
||||
}
|
||||
|
||||
@@ -95,9 +101,9 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::ZerocopySliceExt;
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Header(u32);
|
||||
/// #[repr(align(4))]
|
||||
@@ -108,18 +114,21 @@ pub trait ZerocopySliceExt: Sized + ByteSlice {
|
||||
/// let header_ref = bytes.0.zk_parse_suffix::<Header>().unwrap();
|
||||
/// assert_eq!(header_ref.0, 0x30201000);
|
||||
/// ```
|
||||
fn zk_parse_suffix<T>(self) -> anyhow::Result<Ref<Self, T>> {
|
||||
fn zk_parse_suffix<T>(self) -> anyhow::Result<Ref<Self, T>>
|
||||
where
|
||||
T: Immutable + KnownLayout,
|
||||
{
|
||||
self.zk_ref_maker().from_suffix()?.parse()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSlice> ZerocopySliceExt for B {}
|
||||
impl<B: SplitByteSlice> ZerocopySliceExt for B {}
|
||||
|
||||
/// Extension trait for zero-copy parsing of mutable slices with zeroization
|
||||
/// capabilities.
|
||||
///
|
||||
/// Provides convenience methods to create zero-initialized references.
|
||||
pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
|
||||
pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + SplitByteSliceMut {
|
||||
/// Creates a new zeroized reference from the entire slice.
|
||||
///
|
||||
/// This zeroizes the slice first, then provides a `Ref`.
|
||||
@@ -131,9 +140,9 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data([u8; 4]);
|
||||
/// #[repr(align(4))]
|
||||
@@ -143,7 +152,10 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
|
||||
/// assert_eq!(data_ref.0, [0,0,0,0]);
|
||||
/// assert_eq!(bytes.0, [0, 0, 0, 0]);
|
||||
/// ```
|
||||
fn zk_zeroized<T>(self) -> anyhow::Result<Ref<Self, T>> {
|
||||
fn zk_zeroized<T>(self) -> anyhow::Result<Ref<Self, T>>
|
||||
where
|
||||
T: Immutable + KnownLayout,
|
||||
{
|
||||
self.zk_ref_maker().make_zeroized()
|
||||
}
|
||||
|
||||
@@ -159,9 +171,9 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data([u8; 4]);
|
||||
/// #[repr(align(4))]
|
||||
@@ -171,7 +183,10 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
|
||||
/// assert_eq!(data_ref.0, [0,0,0,0]);
|
||||
/// assert_eq!(bytes.0, [0, 0, 0, 0, 0xFF, 0xFF]);
|
||||
/// ```
|
||||
fn zk_zeroized_from_prefix<T>(self) -> anyhow::Result<Ref<Self, T>> {
|
||||
fn zk_zeroized_from_prefix<T>(self) -> anyhow::Result<Ref<Self, T>>
|
||||
where
|
||||
T: Immutable + KnownLayout,
|
||||
{
|
||||
self.zk_ref_maker().from_prefix()?.make_zeroized()
|
||||
}
|
||||
|
||||
@@ -187,9 +202,9 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// # use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
/// # use zerocopy::{IntoBytes, FromBytes, Immutable, KnownLayout};
|
||||
/// # use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
|
||||
/// #[derive(FromBytes, FromZeroes, AsBytes)]
|
||||
/// #[derive(FromBytes, IntoBytes, Immutable, KnownLayout)]
|
||||
/// #[repr(C)]
|
||||
/// struct Data([u8; 4]);
|
||||
/// #[repr(align(4))]
|
||||
@@ -199,9 +214,12 @@ pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut {
|
||||
/// assert_eq!(data_ref.0, [0,0,0,0]);
|
||||
/// assert_eq!(bytes.0, [0xFF, 0xFF, 0, 0, 0, 0]);
|
||||
/// ```
|
||||
fn zk_zeroized_from_suffix<T>(self) -> anyhow::Result<Ref<Self, T>> {
|
||||
fn zk_zeroized_from_suffix<T>(self) -> anyhow::Result<Ref<Self, T>>
|
||||
where
|
||||
T: Immutable + KnownLayout,
|
||||
{
|
||||
self.zk_ref_maker().from_suffix()?.make_zeroized()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ByteSliceMut> ZerocopyMutSliceExt for B {}
|
||||
impl<B: SplitByteSliceMut> ZerocopyMutSliceExt for B {}
|
||||
|
||||
@@ -47,7 +47,7 @@ async fn janitor_demo() -> anyhow::Result<()> {
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
// At this point, all background jobs have finished, now we can check the result of all our
|
||||
// additions
|
||||
|
||||
@@ -41,6 +41,7 @@ rand = { workspace = true }
|
||||
procspawn = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["experiment_api"]
|
||||
experiment_api = ["rustix", "libc"]
|
||||
experiment_memfd_secret = []
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@
|
||||
|
||||
use std::{borrow::BorrowMut, fmt::Debug};
|
||||
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use crate::{
|
||||
api::{
|
||||
config::NetworkBrokerConfig,
|
||||
@@ -170,7 +172,8 @@ where
|
||||
let typ = msgs::MsgType::try_from(*typ)?;
|
||||
let msgs::MsgType::SetPsk = typ; // Assert type
|
||||
|
||||
let res = zerocopy::Ref::<&[u8], Envelope<SetPskResponse>>::new(res)
|
||||
let res = zerocopy::Ref::<&[u8], Envelope<SetPskResponse>>::from_bytes(res)
|
||||
.ok()
|
||||
.ok_or(invalid_msg_poller())?;
|
||||
let res: &msgs::SetPskResponse = &res.payload;
|
||||
let res: msgs::SetPskResponseReturnCode = res
|
||||
@@ -200,8 +203,10 @@ where
|
||||
let mut req = [0u8; BUF_SIZE];
|
||||
|
||||
// Construct message view
|
||||
let mut req = zerocopy::Ref::<&mut [u8], Envelope<msgs::SetPskRequest>>::new(&mut req)
|
||||
.ok_or(MsgError)?;
|
||||
let mut req =
|
||||
zerocopy::Ref::<&mut [u8], Envelope<msgs::SetPskRequest>>::from_bytes(&mut req)
|
||||
.ok()
|
||||
.ok_or(MsgError)?;
|
||||
|
||||
// Populate envelope
|
||||
req.msg_type = msgs::MsgType::SetPsk as u8;
|
||||
@@ -219,7 +224,7 @@ where
|
||||
// Send message
|
||||
self.io
|
||||
.borrow_mut()
|
||||
.send_msg(req.bytes())
|
||||
.send_msg(req.as_bytes())
|
||||
.map_err(IoError)?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
use std::str::{from_utf8, Utf8Error};
|
||||
|
||||
use zerocopy::{AsBytes, FromBytes, FromZeroes};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
|
||||
|
||||
/// The number of bytes reserved for overhead when packaging data.
|
||||
pub const ENVELOPE_OVERHEAD: usize = 1 + 3;
|
||||
@@ -15,8 +15,8 @@ pub const RESPONSE_MSG_BUFFER_SIZE: usize = ENVELOPE_OVERHEAD + 1;
|
||||
|
||||
/// Envelope for messages being passed around.
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
pub struct Envelope<M: AsBytes + FromBytes> {
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct Envelope<M: IntoBytes + FromBytes> {
|
||||
/// [MsgType] of this message
|
||||
pub msg_type: u8,
|
||||
/// Reserved for future use
|
||||
@@ -29,7 +29,7 @@ pub struct Envelope<M: AsBytes + FromBytes> {
|
||||
/// # Example
|
||||
///
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct SetPskRequest {
|
||||
/// The pre-shared key.
|
||||
pub psk: [u8; 32],
|
||||
@@ -85,7 +85,7 @@ impl SetPskRequest {
|
||||
|
||||
/// Message format for response to the set pre-shared key operation.
|
||||
#[repr(packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes)]
|
||||
#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)]
|
||||
pub struct SetPskResponse {
|
||||
pub return_code: u8,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
use std::borrow::BorrowMut;
|
||||
|
||||
use rosenpass_secret_memory::{Public, Secret};
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use crate::api::msgs::{self, Envelope, SetPskRequest, SetPskResponse};
|
||||
use crate::WireGuardBroker;
|
||||
@@ -78,14 +79,16 @@ where
|
||||
let typ = msgs::MsgType::try_from(*typ)?;
|
||||
let msgs::MsgType::SetPsk = typ; // Assert type
|
||||
|
||||
let req =
|
||||
zerocopy::Ref::<&[u8], Envelope<SetPskRequest>>::new(req).ok_or(InvalidMessage)?;
|
||||
let mut res =
|
||||
zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::new(res).ok_or(InvalidMessage)?;
|
||||
let req = zerocopy::Ref::<&[u8], Envelope<SetPskRequest>>::from_bytes(req)
|
||||
.ok()
|
||||
.ok_or(InvalidMessage)?;
|
||||
let mut res = zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::from_bytes(res)
|
||||
.ok()
|
||||
.ok_or(InvalidMessage)?;
|
||||
res.msg_type = msgs::MsgType::SetPsk as u8;
|
||||
self.handle_set_psk(&req.payload, &mut res.payload)?;
|
||||
|
||||
Ok(res.bytes().len())
|
||||
Ok(res.as_bytes().len())
|
||||
}
|
||||
|
||||
/// Sets the pre-shared key for the interface identified in `req` to the pre-shared key
|
||||
@@ -138,7 +141,7 @@ mod tests {
|
||||
use crate::brokers::netlink::SetPskError;
|
||||
use crate::{SerializedBrokerConfig, WireGuardBroker};
|
||||
use rosenpass_secret_memory::{secret_policy_use_only_malloc_secrets, Secret};
|
||||
use zerocopy::AsBytes;
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MockWireGuardBroker {
|
||||
|
||||
@@ -12,7 +12,7 @@ use tokio::task;
|
||||
use anyhow::{bail, ensure, Result};
|
||||
use clap::{ArgGroup, Parser};
|
||||
|
||||
use rosenpass_util::fd::claim_fd;
|
||||
use rosenpass_util::rustix::claim_fd;
|
||||
use rosenpass_wireguard_broker::api::msgs;
|
||||
|
||||
/// Command-line arguments for configuring the socket handler
|
||||
|
||||
Reference in New Issue
Block a user