From 730a03957ade3b1bc05daf6959e6997586289d78 Mon Sep 17 00:00:00 2001 From: Karolin Varner Date: Sat, 3 Aug 2024 16:50:21 +0200 Subject: [PATCH] feat: A variety of utilities in preparation for implementing the API --- Cargo.lock | 3 + util/Cargo.toml | 3 + util/src/length_prefix_encoding/decoder.rs | 359 +++++++++++++++++++ util/src/length_prefix_encoding/encoder.rs | 381 +++++++++++++++++++++ util/src/length_prefix_encoding/mod.rs | 2 + util/src/lib.rs | 4 + util/src/mio.rs | 39 +++ util/src/result.rs | 11 + util/src/zerocopy/mod.rs | 7 + util/src/zerocopy/ref_maker.rs | 106 ++++++ util/src/zerocopy/zerocopy_ref_ext.rs | 27 ++ util/src/zerocopy/zerocopy_slice_ext.rs | 39 +++ util/src/zeroize/mod.rs | 2 + util/src/zeroize/zeroized_ext.rs | 10 + 14 files changed, 993 insertions(+) create mode 100644 util/src/length_prefix_encoding/decoder.rs create mode 100644 util/src/length_prefix_encoding/encoder.rs create mode 100644 util/src/length_prefix_encoding/mod.rs create mode 100644 util/src/mio.rs create mode 100644 util/src/zerocopy/mod.rs create mode 100644 util/src/zerocopy/ref_maker.rs create mode 100644 util/src/zerocopy/zerocopy_ref_ext.rs create mode 100644 util/src/zerocopy/zerocopy_slice_ext.rs create mode 100644 util/src/zeroize/mod.rs create mode 100644 util/src/zeroize/zeroized_ext.rs diff --git a/Cargo.lock b/Cargo.lock index 858723b..7f318b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2039,9 +2039,12 @@ version = "0.1.0" dependencies = [ "anyhow", "base64ct", + "mio 1.0.1", "rustix", "static_assertions", + "thiserror", "typenum", + "zerocopy", "zeroize", ] diff --git a/util/Cargo.toml b/util/Cargo.toml index 7e2ef08..e80ba0f 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -18,3 +18,6 @@ typenum = { workspace = true } static_assertions = { workspace = true } rustix = {workspace = true} zeroize = {workspace = true} +zerocopy = { workspace = true } +thiserror = { workspace = true } +mio = { workspace = true } diff --git a/util/src/length_prefix_encoding/decoder.rs b/util/src/length_prefix_encoding/decoder.rs new file mode 100644 index 0000000..26a7130 --- /dev/null +++ b/util/src/length_prefix_encoding/decoder.rs @@ -0,0 +1,359 @@ +use std::{borrow::BorrowMut, cmp::min, io}; + +use thiserror::Error; +use zeroize::Zeroize; + +use crate::{ + io::{TryIoErrorKind, TryIoResultKindHintExt}, + result::ensure_or, +}; + +pub const HEADER_SIZE: usize = std::mem::size_of::(); + +#[derive(Error, Debug)] +pub enum SanityError { + #[error("Offset is out of read buffer bounds")] + OutOfBufferBounds, + #[error("Offset is out of message buffer bounds")] + OutOfMessageBounds, +} + +#[derive(Error, Debug)] +#[error("Message too lage ({msg_size} bytes) for buffer ({buf_size} bytes)")] +pub struct MessageTooLargeError { + msg_size: usize, + buf_size: usize, +} + +impl MessageTooLargeError { + pub fn new(msg_size: usize, buf_size: usize) -> Self { + Self { msg_size, buf_size } + } + + pub fn ensure(msg_size: usize, buf_size: usize) -> Result<(), Self> { + let err = MessageTooLargeError { msg_size, buf_size }; + ensure_or(msg_size <= buf_size, err) + } +} + +#[derive(Debug)] +pub struct ReadFromIoReturn<'a> { + pub bytes_read: usize, + pub message: Option<&'a mut [u8]>, +} + +impl<'a> ReadFromIoReturn<'a> { + pub fn new(bytes_read: usize, message: Option<&'a mut [u8]>) -> Self { + Self { + bytes_read, + message, + } + } +} + +#[derive(Debug, Error)] +pub enum ReadFromIoError { + #[error("Error reading from the underlying stream")] + IoError(#[from] io::Error), + #[error("Message size out of buffer bounds")] + MessageTooLargeError(#[from] MessageTooLargeError), +} + +impl TryIoErrorKind for ReadFromIoError { + fn try_io_error_kind(&self) -> Option { + match self { + ReadFromIoError::IoError(ioe) => Some(ioe.kind()), + _ => None, + } + } +} + +#[derive(Debug, Default, Clone)] +pub struct LengthPrefixDecoder> { + header: [u8; HEADER_SIZE], + buf: Buf, + off: usize, +} + +impl> LengthPrefixDecoder { + pub fn new(buf: Buf) -> Self { + let header = Default::default(); + let off = 0; + Self { header, buf, off } + } + + pub fn clear(&mut self) { + self.zeroize() + } + + pub fn from_parts(header: [u8; HEADER_SIZE], buf: Buf, off: usize) -> Self { + Self { header, buf, off } + } + + pub fn into_parts(self) -> ([u8; HEADER_SIZE], Buf, usize) { + let Self { header, buf, off } = self; + (header, buf, off) + } + + pub fn read_all_from_stdio( + &mut self, + mut r: R, + ) -> Result<&mut [u8], ReadFromIoError> { + use io::ErrorKind as K; + loop { + match self.read_from_stdio(&mut r).try_io_err_kind_hint() { + // Success (appeasing the borrow checker by calling message_mut()) + Ok(ReadFromIoReturn { + message: Some(_), .. + }) => break Ok(self.message_mut().unwrap().unwrap()), + + // Unexpected EOF + Ok(ReadFromIoReturn { bytes_read: 0, .. }) => { + break Err(ReadFromIoError::IoError(io::Error::new( + K::UnexpectedEof, + "", + ))) + } + + // Retry + Ok(ReadFromIoReturn { message: None, .. }) => continue, + Err((_, Some(K::Interrupted))) => continue, + + // Other error + Err((e, _)) => break Err(e), + } + } + } + + pub fn read_from_stdio( + &mut self, + mut r: R, + ) -> Result { + Ok(match self.next_slice_to_write_to()? { + // Read some bytes; any MessageTooLargeError in the call to self.message_mut() is + // ignored to ensure this function changes no state upon errors; the user should rerun + // the function and colect the MessageTooLargeError on the following invocation + Some(buf) => { + let bytes_read = r.read(buf)?; + self.advance(bytes_read).unwrap(); + let message = self.message_mut().ok().flatten(); + ReadFromIoReturn { + bytes_read, + message, + } + } + // Message is already fully read; full delegation to self.message_mut() + None => ReadFromIoReturn { + bytes_read: 0, + message: self.message_mut()?, + }, + }) + } + + pub fn next_slice_to_write_to(&mut self) -> Result, MessageTooLargeError> { + fn some_if_nonempty(buf: &mut [u8]) -> Option<&mut [u8]> { + match buf.is_empty() { + true => None, + false => Some(buf), + } + } + + macro_rules! return_if_nonempty_some { + ($opt:expr) => {{ + // Deliberate double expansion of $opt to appease the borrow checker *sigh* + if $opt.and_then(some_if_nonempty).is_some() { + return Ok($opt); + } + }}; + } + + return_if_nonempty_some!(Some(self.header_buffer_left_mut())); + return_if_nonempty_some!(self.message_fragment_left_mut()?); + Ok(None) + } + + pub fn advance(&mut self, count: usize) -> Result<(), SanityError> { + let off = self.off + count; + let msg_off = off.saturating_sub(HEADER_SIZE); + + use SanityError as E; + let alloc = self.message_buffer().len(); + let msgsz = self.message_size(); + ensure_or(msg_off <= alloc, E::OutOfBufferBounds)?; + ensure_or( + msgsz.map(|s| msg_off <= s).unwrap_or(true), + E::OutOfMessageBounds, + )?; + + self.off = off; + Ok(()) + } + + pub fn ensure_sufficient_msg_buffer(&self) -> Result<(), MessageTooLargeError> { + let buf_size = self.message_buffer().len(); + let msg_size = match self.get_header() { + None => return Ok(()), + Some(v) => v, + }; + MessageTooLargeError::ensure(msg_size, buf_size) + } + + pub fn header_buffer(&self) -> &[u8] { + &self.header[..] + } + + pub fn header_buffer_mut(&mut self) -> &mut [u8] { + &mut self.header[..] + } + + pub fn message_buffer(&self) -> &[u8] { + self.buf.borrow() + } + + pub fn message_buffer_mut(&mut self) -> &mut [u8] { + self.buf.borrow_mut() + } + + pub fn bytes_read(&self) -> &usize { + &self.off + } + + pub fn into_message_buffer(self) -> Buf { + let Self { buf, .. } = self; + buf + } + + pub fn header_buffer_offset(&self) -> usize { + min(self.off, HEADER_SIZE) + } + + pub fn message_buffer_offset(&self) -> usize { + self.off.saturating_sub(HEADER_SIZE) + } + + pub fn has_header(&self) -> bool { + self.header_buffer_offset() == HEADER_SIZE + } + + pub fn has_message(&self) -> Result { + self.ensure_sufficient_msg_buffer()?; + let msg_size = match self.get_header() { + None => return Ok(false), + Some(v) => v, + }; + Ok(self.message_buffer_avail().len() == msg_size) + } + + pub fn header_buffer_avail(&self) -> &[u8] { + let off = self.header_buffer_offset(); + &self.header_buffer()[..off] + } + + pub fn header_buffer_avail_mut(&mut self) -> &mut [u8] { + let off = self.header_buffer_offset(); + &mut self.header_buffer_mut()[..off] + } + + pub fn header_buffer_left(&self) -> &[u8] { + let off = self.header_buffer_offset(); + &self.header_buffer()[off..] + } + + pub fn header_buffer_left_mut(&mut self) -> &mut [u8] { + let off = self.header_buffer_offset(); + &mut self.header_buffer_mut()[off..] + } + + pub fn message_buffer_avail(&self) -> &[u8] { + let off = self.message_buffer_offset(); + &self.message_buffer()[..off] + } + + pub fn message_buffer_avail_mut(&mut self) -> &mut [u8] { + let off = self.message_buffer_offset(); + &mut self.message_buffer_mut()[..off] + } + + pub fn message_buffer_left(&self) -> &[u8] { + let off = self.message_buffer_offset(); + &self.message_buffer()[off..] + } + + pub fn message_buffer_left_mut(&mut self) -> &mut [u8] { + let off = self.message_buffer_offset(); + &mut self.message_buffer_mut()[off..] + } + + pub fn get_header(&self) -> Option { + match self.header_buffer_offset() == HEADER_SIZE { + false => None, + true => Some(u64::from_le_bytes(self.header) as usize), + } + } + + pub fn message_size(&self) -> Option { + self.get_header() + } + + pub fn encoded_message_bytes(&self) -> Option { + self.message_size().map(|sz| sz + HEADER_SIZE) + } + + pub fn message_fragment(&self) -> Result, MessageTooLargeError> { + self.ensure_sufficient_msg_buffer()?; + Ok(self.message_size().map(|sz| &self.message_buffer()[..sz])) + } + + pub fn message_fragment_mut(&mut self) -> Result, MessageTooLargeError> { + self.ensure_sufficient_msg_buffer()?; + Ok(self + .message_size() + .map(|sz| &mut self.message_buffer_mut()[..sz])) + } + + pub fn message_fragment_avail(&self) -> Result, MessageTooLargeError> { + let off = self.message_buffer_avail().len(); + self.message_fragment() + .map(|frag| frag.map(|frag| &frag[..off])) + } + + pub fn message_fragment_avail_mut( + &mut self, + ) -> Result, MessageTooLargeError> { + let off = self.message_buffer_avail().len(); + self.message_fragment_mut() + .map(|frag| frag.map(|frag| &mut frag[..off])) + } + + pub fn message_fragment_left(&self) -> Result, MessageTooLargeError> { + let off = self.message_buffer_avail().len(); + self.message_fragment() + .map(|frag| frag.map(|frag| &frag[off..])) + } + + pub fn message_fragment_left_mut(&mut self) -> Result, MessageTooLargeError> { + let off = self.message_buffer_avail().len(); + self.message_fragment_mut() + .map(|frag| frag.map(|frag| &mut frag[off..])) + } + + pub fn message(&self) -> Result, MessageTooLargeError> { + let sz = self.message_size(); + self.message_fragment_avail() + .map(|frag_opt| frag_opt.and_then(|frag| (frag.len() == sz?).then_some(frag))) + } + + pub fn message_mut(&mut self) -> Result, MessageTooLargeError> { + let sz = self.message_size(); + self.message_fragment_avail_mut() + .map(|frag_opt| frag_opt.and_then(|frag| (frag.len() == sz?).then_some(frag))) + } +} + +impl> Zeroize for LengthPrefixDecoder { + fn zeroize(&mut self) { + self.header.zeroize(); + self.message_buffer_mut().zeroize(); + self.off.zeroize(); + } +} diff --git a/util/src/length_prefix_encoding/encoder.rs b/util/src/length_prefix_encoding/encoder.rs new file mode 100644 index 0000000..a7f86b5 --- /dev/null +++ b/util/src/length_prefix_encoding/encoder.rs @@ -0,0 +1,381 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::min, + io, +}; + +use thiserror::Error; +use zeroize::Zeroize; + +use crate::{io::IoResultKindHintExt, result::ensure_or}; + +pub const HEADER_SIZE: usize = std::mem::size_of::(); + +#[derive(Error, Debug, Clone, Copy)] +#[error("Write position is out of buffer bounds")] +pub struct PositionOutOfBufferBounds; + +#[derive(Error, Debug, Clone, Copy)] +#[error("Write position is out of message bounds")] +pub struct PositionOutOfMessageBounds; + +#[derive(Error, Debug, Clone, Copy)] +#[error("Write position is out of header bounds")] +pub struct PositionOutOfHeaderBounds; + +#[derive(Error, Debug, Clone, Copy)] +#[error("Message length is bigger than buffer length")] +pub struct MessageTooLarge; + +#[derive(Error, Debug, Clone, Copy)] +pub enum MessageLenSanityError { + #[error("{0:?}")] + PositionOutOfMessageBounds(#[from] PositionOutOfMessageBounds), + #[error("{0:?}")] + MessageTooLarge(#[from] MessageTooLarge), +} + +#[derive(Error, Debug, Clone, Copy)] +pub enum PositionSanityError { + #[error("{0:?}")] + PositionOutOfMessageBounds(#[from] PositionOutOfMessageBounds), + #[error("{0:?}")] + PositionOutOfBufferBounds(#[from] PositionOutOfBufferBounds), +} + +#[derive(Error, Debug, Clone, Copy)] +pub enum SanityError { + #[error("{0:?}")] + PositionOutOfMessageBounds(#[from] PositionOutOfMessageBounds), + #[error("{0:?}")] + PositionOutOfBufferBounds(#[from] PositionOutOfBufferBounds), + #[error("{0:?}")] + MessageTooLarge(#[from] MessageTooLarge), +} + +impl TryFrom for MessageLenSanityError { + type Error = PositionOutOfBufferBounds; + + fn try_from(value: SanityError) -> Result { + use {MessageLenSanityError as T, SanityError as F}; + match value { + F::PositionOutOfMessageBounds(e) => Ok(T::PositionOutOfMessageBounds(e)), + F::MessageTooLarge(e) => Ok(T::MessageTooLarge(e)), + F::PositionOutOfBufferBounds(e) => Err(e), + } + } +} + +impl From for SanityError { + fn from(value: MessageLenSanityError) -> Self { + use {MessageLenSanityError as F, SanityError as T}; + match value { + F::PositionOutOfMessageBounds(e) => T::PositionOutOfMessageBounds(e), + F::MessageTooLarge(e) => T::MessageTooLarge(e), + } + } +} + +impl From for SanityError { + fn from(value: PositionSanityError) -> Self { + use {PositionSanityError as F, SanityError as T}; + match value { + F::PositionOutOfBufferBounds(e) => T::PositionOutOfBufferBounds(e), + F::PositionOutOfMessageBounds(e) => T::PositionOutOfMessageBounds(e), + } + } +} + +pub struct WriteToIoReturn { + pub bytes_written: usize, + pub done: bool, +} + +#[derive(Clone, Copy, Debug)] +pub struct LengthPrefixEncoder> { + buf: Buf, + header: [u8; HEADER_SIZE], + pos: usize, +} + +impl> LengthPrefixEncoder { + pub fn from_buffer(buf: Buf) -> Self { + let (header, pos) = ([0u8; HEADER_SIZE], 0); + let mut r = Self { buf, header, pos }; + r.clear(); + r + } + + pub fn from_message(msg: Buf) -> Self { + let mut r = Self::from_buffer(msg); + r.restart_write_with_new_message(r.buffer_bytes().len()) + .unwrap(); + r + } + + pub fn from_short_message(msg: Buf, len: usize) -> Result { + let mut r = Self::from_message(msg); + r.set_message_len(len)?; + Ok(r) + } + + pub fn from_parts(buf: Buf, len: usize, pos: usize) -> Result { + let mut r = Self::from_buffer(buf); + r.set_msg_len_and_position(len, pos)?; + Ok(r) + } + + pub fn into_buffer(self) -> Buf { + let Self { buf, .. } = self; + buf + } + + pub fn into_parts(self) -> (Buf, usize, usize) { + let len = self.message_len(); + let pos = self.writing_position(); + let buf = self.into_buffer(); + (buf, len, pos) + } + + pub fn clear(&mut self) { + self.set_msg_len_and_position(0, 0).unwrap(); + self.set_message_offset(0).unwrap(); + } + + pub fn write_all_to_stdio(&mut self, mut w: W) -> io::Result<()> { + use io::ErrorKind as K; + loop { + match self.write_to_stdio(&mut w).io_err_kind_hint() { + // Done + Ok(WriteToIoReturn { done: true, .. }) => break Ok(()), + + // Retry + Ok(WriteToIoReturn { done: false, .. }) => continue, + Err((_, K::Interrupted)) => continue, + + Err((e, _)) => break Err(e), + } + } + } + + pub fn write_to_stdio(&mut self, mut w: W) -> io::Result { + if self.exhausted() { + return Ok(WriteToIoReturn { + bytes_written: 0, + done: true, + }); + } + + let buf = self.next_slice_to_write(); + let bytes_written = w.write(buf)?; + self.advance(bytes_written).unwrap(); + + let done = self.exhausted(); + Ok(WriteToIoReturn { + bytes_written, + done, + }) + } + + pub fn restart_write(&mut self) { + self.set_writing_position(0).unwrap() + } + + pub fn restart_write_with_new_message( + &mut self, + len: usize, + ) -> Result<(), MessageLenSanityError> { + self.set_msg_len_and_position(len, 0) + .map_err(|e| e.try_into().unwrap()) + } + + pub fn next_slice_to_write(&self) -> &[u8] { + let s = self.header_left(); + if !s.is_empty() { + return s; + } + + let s = self.message_left(); + if !s.is_empty() { + return s; + } + + &[] + } + + pub fn exhausted(&self) -> bool { + self.next_slice_to_write().is_empty() + } + + pub fn message(&self) -> &[u8] { + &self.buffer_bytes()[..self.message_len()] + } + + pub fn header_written(&self) -> &[u8] { + &self.header()[..self.header_offset()] + } + + pub fn header_left(&self) -> &[u8] { + &self.header()[self.header_offset()..] + } + + pub fn message_written(&self) -> &[u8] { + &self.message()[..self.message_offset()] + } + + pub fn message_left(&self) -> &[u8] { + &self.message()[self.message_offset()..] + } + + pub fn buf(&self) -> &Buf { + &self.buf + } + + pub fn buffer_bytes(&self) -> &[u8] { + self.buf().borrow() + } + + pub fn decode_header(&self) -> u64 { + u64::from_le_bytes(self.header) + } + + pub fn header(&self) -> &[u8; HEADER_SIZE] { + &self.header + } + + pub fn message_len(&self) -> usize { + self.decode_header() as usize + } + + pub fn encoded_message_bytes(&self) -> usize { + self.message_len() + HEADER_SIZE + } + + pub fn writing_position(&self) -> usize { + self.pos + } + + pub fn header_offset(&self) -> usize { + min(self.writing_position(), HEADER_SIZE) + } + + pub fn message_offset(&self) -> usize { + self.writing_position().saturating_sub(HEADER_SIZE) + } + + pub fn set_header(&mut self, header: [u8; HEADER_SIZE]) -> Result<(), MessageLenSanityError> { + self.offset_transaction(|t| { + t.header = header; + t.ensure_msg_in_buf_bounds()?; + t.ensure_pos_in_msg_bounds()?; + Ok(()) + }) + } + + pub fn encode_and_set_header(&mut self, header: u64) -> Result<(), MessageLenSanityError> { + self.set_header(header.to_le_bytes()) + } + + pub fn set_message_len(&mut self, len: usize) -> Result<(), MessageLenSanityError> { + self.encode_and_set_header(len as u64) + } + + pub fn set_writing_position(&mut self, pos: usize) -> Result<(), PositionSanityError> { + self.offset_transaction(|t| { + t.pos = pos; + t.ensure_pos_in_buf_bounds()?; + t.ensure_pos_in_msg_bounds()?; + Ok(()) + }) + } + + pub fn set_header_offset(&mut self, off: usize) -> Result<(), PositionOutOfHeaderBounds> { + ensure_or(off <= HEADER_SIZE, PositionOutOfHeaderBounds)?; + self.set_writing_position(off).unwrap(); + Ok(()) + } + + pub fn set_message_offset(&mut self, off: usize) -> Result<(), PositionSanityError> { + self.set_writing_position(off + HEADER_SIZE) + } + + pub fn advance(&mut self, off: usize) -> Result<(), PositionSanityError> { + self.set_writing_position(self.writing_position() + off) + } + + pub fn set_msg_len_and_position(&mut self, len: usize, pos: usize) -> Result<(), SanityError> { + self.pos = 0; + self.set_message_len(len)?; + self.set_writing_position(pos)?; + Ok(()) + } + + fn offset_transaction(&mut self, f: F) -> Result<(), E> + where + F: FnOnce(&mut LengthPrefixEncoder<&[u8]>) -> Result<(), E>, + { + let (header, pos) = { + let (buf, header, pos) = (self.buffer_bytes(), self.header, self.pos); + let mut tmp = LengthPrefixEncoder { buf, header, pos }; + f(&mut tmp)?; + Ok((tmp.header, tmp.pos)) + }?; + (self.header, self.pos) = (header, pos); + Ok(()) + } + + fn ensure_pos_in_buf_bounds(&self) -> Result<(), PositionOutOfBufferBounds> { + ensure_or( + self.message_offset() <= self.buffer_bytes().len(), + PositionOutOfBufferBounds, + ) + } + + fn ensure_pos_in_msg_bounds(&self) -> Result<(), PositionOutOfMessageBounds> { + ensure_or( + self.message_offset() <= self.message_len(), + PositionOutOfMessageBounds, + ) + } + + fn ensure_msg_in_buf_bounds(&self) -> Result<(), MessageTooLarge> { + ensure_or( + self.message_len() <= self.buffer_bytes().len(), + MessageTooLarge, + ) + } +} + +impl> LengthPrefixEncoder { + pub fn buf_mut(&mut self) -> &mut Buf { + &mut self.buf + } + + pub fn buffer_bytes_mut(&mut self) -> &mut [u8] { + self.buf.borrow_mut() + } + + pub fn message_mut(&mut self) -> &mut [u8] { + let off = self.message_len(); + &mut self.buffer_bytes_mut()[..off] + } + + pub fn message_written_mut(&mut self) -> &mut [u8] { + let off = self.message_offset(); + &mut self.message_mut()[..off] + } + + pub fn message_left_mut(&mut self) -> &mut [u8] { + let off = self.message_offset(); + &mut self.message_mut()[off..] + } +} + +impl> Zeroize for LengthPrefixEncoder { + fn zeroize(&mut self) { + self.buffer_bytes_mut().zeroize(); + self.header.zeroize(); + self.pos.zeroize(); + self.clear(); + } +} diff --git a/util/src/length_prefix_encoding/mod.rs b/util/src/length_prefix_encoding/mod.rs new file mode 100644 index 0000000..a4d5bf5 --- /dev/null +++ b/util/src/length_prefix_encoding/mod.rs @@ -0,0 +1,2 @@ +pub mod decoder; +pub mod encoder; diff --git a/util/src/lib.rs b/util/src/lib.rs index f16e98f..c4d3f31 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -5,8 +5,12 @@ pub mod fd; pub mod file; pub mod functional; pub mod io; +pub mod length_prefix_encoding; pub mod mem; +pub mod mio; pub mod ord; pub mod result; pub mod time; pub mod typenum; +pub mod zerocopy; +pub mod zeroize; diff --git a/util/src/mio.rs b/util/src/mio.rs new file mode 100644 index 0000000..10757ee --- /dev/null +++ b/util/src/mio.rs @@ -0,0 +1,39 @@ +use mio::net::{UnixListener, UnixStream}; +use rustix::fd::RawFd; + +use crate::fd::claim_fd; + +pub mod interest { + use mio::Interest; + pub const R: Interest = Interest::READABLE; + pub const W: Interest = Interest::WRITABLE; + pub const RW: Interest = R.add(W); +} + +pub trait UnixListenerExt: Sized { + fn claim_fd(fd: RawFd) -> anyhow::Result; +} + +impl UnixListenerExt for UnixListener { + fn claim_fd(fd: RawFd) -> anyhow::Result { + use std::os::unix::net::UnixListener as StdUnixListener; + + let sock = StdUnixListener::from(claim_fd(fd)?); + sock.set_nonblocking(true)?; + Ok(UnixListener::from_std(sock)) + } +} + +pub trait UnixStreamExt: Sized { + fn claim_fd(fd: RawFd) -> anyhow::Result; +} + +impl UnixStreamExt for UnixStream { + fn claim_fd(fd: RawFd) -> anyhow::Result { + use std::os::unix::net::UnixStream as StdUnixStream; + + let sock = StdUnixStream::from(claim_fd(fd)?); + sock.set_nonblocking(true)?; + Ok(UnixStream::from_std(sock)) + } +} diff --git a/util/src/result.rs b/util/src/result.rs index 986ac6b..76b6062 100644 --- a/util/src/result.rs +++ b/util/src/result.rs @@ -96,3 +96,14 @@ impl GuaranteedValue for Guaranteed { self.unwrap() } } + +pub fn ensure_or(b: bool, err: E) -> Result<(), E> { + match b { + true => Ok(()), + false => Err(err), + } +} + +pub fn bail_if(b: bool, err: E) -> Result<(), E> { + ensure_or(!b, err) +} diff --git a/util/src/zerocopy/mod.rs b/util/src/zerocopy/mod.rs new file mode 100644 index 0000000..6856bea --- /dev/null +++ b/util/src/zerocopy/mod.rs @@ -0,0 +1,7 @@ +mod ref_maker; +mod zerocopy_ref_ext; +mod zerocopy_slice_ext; + +pub use ref_maker::*; +pub use zerocopy_ref_ext::*; +pub use zerocopy_slice_ext::*; diff --git a/util/src/zerocopy/ref_maker.rs b/util/src/zerocopy/ref_maker.rs new file mode 100644 index 0000000..13e6fe9 --- /dev/null +++ b/util/src/zerocopy/ref_maker.rs @@ -0,0 +1,106 @@ +use std::marker::PhantomData; + +use anyhow::{ensure, Context}; +use zerocopy::{ByteSlice, ByteSliceMut, Ref}; +use zeroize::Zeroize; + +use crate::zeroize::ZeroizedExt; + +#[derive(Clone, Copy, Debug)] +pub struct RefMaker { + buf: B, + _phantom_t: PhantomData, +} + +impl RefMaker { + pub fn new(buf: B) -> Self { + let _phantom_t = PhantomData; + Self { buf, _phantom_t } + } + + pub const fn target_size() -> usize { + std::mem::size_of::() + } + + pub fn into_buf(self) -> B { + self.buf + } + + pub fn buf(&self) -> &B { + &self.buf + } + + pub fn buf_mut(&mut self) -> &mut B { + &mut self.buf + } +} + +impl RefMaker { + pub fn parse(self) -> anyhow::Result> { + self.ensure_fit()?; + Ref::::new(self.buf).context("Parser error!") + } + + pub fn from_prefix_with_tail(self) -> anyhow::Result<(Self, B)> { + self.ensure_fit()?; + let (head, tail) = self.buf.split_at(Self::target_size()); + Ok((Self::new(head), tail)) + } + + pub fn split_prefix(self) -> anyhow::Result<(Self, Self)> { + self.ensure_fit()?; + let (head, tail) = self.buf.split_at(Self::target_size()); + Ok((Self::new(head), Self::new(tail))) + } + + pub fn from_prefix(self) -> anyhow::Result { + Ok(Self::from_prefix_with_tail(self)?.0) + } + + pub fn from_suffix_with_tail(self) -> anyhow::Result<(Self, B)> { + self.ensure_fit()?; + let point = self.bytes().len() - Self::target_size(); + let (head, tail) = self.buf.split_at(point); + Ok((Self::new(head), tail)) + } + + pub fn split_suffix(self) -> anyhow::Result<(Self, Self)> { + self.ensure_fit()?; + let (head, tail) = self.buf.split_at(Self::target_size()); + Ok((Self::new(head), Self::new(tail))) + } + + pub fn from_suffix(self) -> anyhow::Result { + Ok(Self::from_suffix_with_tail(self)?.0) + } + + pub fn bytes(&self) -> &[u8] { + self.buf().deref() + } + + pub fn ensure_fit(&self) -> anyhow::Result<()> { + let have = self.bytes().len(); + let need = Self::target_size(); + ensure!( + need <= have, + "Buffer is undersized at {have} bytes (need {need} bytes)!" + ); + Ok(()) + } +} + +impl RefMaker { + pub fn make_zeroized(self) -> anyhow::Result> { + self.zeroized().parse() + } + + pub fn bytes_mut(&mut self) -> &mut [u8] { + self.buf_mut().deref_mut() + } +} + +impl Zeroize for RefMaker { + fn zeroize(&mut self) { + self.bytes_mut().zeroize() + } +} diff --git a/util/src/zerocopy/zerocopy_ref_ext.rs b/util/src/zerocopy/zerocopy_ref_ext.rs new file mode 100644 index 0000000..1acc52a --- /dev/null +++ b/util/src/zerocopy/zerocopy_ref_ext.rs @@ -0,0 +1,27 @@ +use zerocopy::{ByteSlice, ByteSliceMut, Ref}; + +pub trait ZerocopyEmancipateExt { + fn emancipate(&self) -> Ref<&[u8], T>; +} + +pub trait ZerocopyEmancipateMutExt { + fn emancipate_mut(&mut self) -> Ref<&mut [u8], T>; +} + +impl ZerocopyEmancipateExt for Ref +where + B: ByteSlice, +{ + fn emancipate(&self) -> Ref<&[u8], T> { + Ref::new(self.bytes()).unwrap() + } +} + +impl ZerocopyEmancipateMutExt for Ref +where + B: ByteSliceMut, +{ + fn emancipate_mut(&mut self) -> Ref<&mut [u8], T> { + Ref::new(self.bytes_mut()).unwrap() + } +} diff --git a/util/src/zerocopy/zerocopy_slice_ext.rs b/util/src/zerocopy/zerocopy_slice_ext.rs new file mode 100644 index 0000000..eb0000a --- /dev/null +++ b/util/src/zerocopy/zerocopy_slice_ext.rs @@ -0,0 +1,39 @@ +use zerocopy::{ByteSlice, ByteSliceMut, Ref}; + +use super::RefMaker; + +pub trait ZerocopySliceExt: Sized + ByteSlice { + fn zk_ref_maker(self) -> RefMaker { + RefMaker::::new(self) + } + + fn zk_parse(self) -> anyhow::Result> { + self.zk_ref_maker().parse() + } + + fn zk_parse_prefix(self) -> anyhow::Result> { + self.zk_ref_maker().from_prefix()?.parse() + } + + fn zk_parse_suffix(self) -> anyhow::Result> { + self.zk_ref_maker().from_prefix()?.parse() + } +} + +impl ZerocopySliceExt for B {} + +pub trait ZerocopyMutSliceExt: ZerocopySliceExt + Sized + ByteSliceMut { + fn zk_zeroized(self) -> anyhow::Result> { + self.zk_ref_maker().make_zeroized() + } + + fn zk_zeroized_from_prefix(self) -> anyhow::Result> { + self.zk_ref_maker().from_prefix()?.make_zeroized() + } + + fn zk_zeroized_from_suffic(self) -> anyhow::Result> { + self.zk_ref_maker().from_prefix()?.make_zeroized() + } +} + +impl ZerocopyMutSliceExt for B {} diff --git a/util/src/zeroize/mod.rs b/util/src/zeroize/mod.rs new file mode 100644 index 0000000..f1d37be --- /dev/null +++ b/util/src/zeroize/mod.rs @@ -0,0 +1,2 @@ +mod zeroized_ext; +pub use zeroized_ext::*; diff --git a/util/src/zeroize/zeroized_ext.rs b/util/src/zeroize/zeroized_ext.rs new file mode 100644 index 0000000..c4f87d5 --- /dev/null +++ b/util/src/zeroize/zeroized_ext.rs @@ -0,0 +1,10 @@ +use zeroize::Zeroize; + +pub trait ZeroizedExt: Zeroize + Sized { + fn zeroized(mut self) -> Self { + self.zeroize(); + self + } +} + +impl ZeroizedExt for T {}