This commit is contained in:
Karolin Varner
2025-11-01 20:49:47 +01:00
parent 0c960d57bc
commit 63511465de
8 changed files with 47 additions and 28 deletions

View File

@@ -1,6 +1,6 @@
use hex_literal::hex;
use rosenpass_util::zerocopy::RefMaker;
use zerocopy::{SplitByteSlice};
use zerocopy::SplitByteSlice;
use crate::RosenpassError::{self, InvalidApiMessageType};

View File

@@ -1,5 +1,5 @@
use rosenpass_util::zerocopy::ZerocopyMutSliceExt;
use zerocopy::{SplitByteSliceMut, FromBytes, Immutable, IntoBytes, KnownLayout, Ref};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, SplitByteSliceMut};
use super::{Message, RawMsgType, RequestMsgType, ResponseMsgType};

View File

@@ -125,7 +125,8 @@ impl<B: SplitByteSlice> RequestRefMaker<B> {
self.ensure_fit()?;
let point = self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point)
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer"))?;
Ok(Self { buf, msg_type })
}
@@ -135,7 +136,8 @@ impl<B: SplitByteSlice> RequestRefMaker<B> {
self.ensure_fit()?;
let point = self.buf.len() - self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point)
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer"))?;
Ok(Self { buf, msg_type })
}

View File

@@ -129,7 +129,8 @@ impl<B: SplitByteSlice> ResponseRefMaker<B> {
self.ensure_fit()?;
let point = self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point)
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer!"))?;
Ok(Self { buf, msg_type })
}
@@ -139,7 +140,8 @@ impl<B: SplitByteSlice> ResponseRefMaker<B> {
self.ensure_fit()?;
let point = self.buf.len() - self.target_size();
let Self { buf, msg_type } = self;
let (buf, _) = buf.split_at(point)
let (buf, _) = buf
.split_at(point)
.map_err(|_| anyhow!("Failed to split buffer!"))?;
Ok(Self { buf, msg_type })
}

View File

@@ -508,7 +508,10 @@ impl KnownResponseHasher {
/// # Panic & Safety
///
/// Panics in case of a problem with this underlying hash function
pub fn hash<Msg: IntoBytes + FromBytes + Immutable>(&self, msg: &Envelope<Msg>) -> KnownResponseHash {
pub fn hash<Msg: IntoBytes + FromBytes + Immutable>(
&self,
msg: &Envelope<Msg>,
) -> KnownResponseHash {
let data = &msg.as_bytes()[span_of!(Envelope<Msg>, msg_type..cookie)];
// This function is only used internally and results are not propagated
// to outside the peer. Thus, it uses SHAKE256 exclusively.
@@ -2188,8 +2191,9 @@ impl CryptoServer {
let peer = match msg_type {
Ok(MsgType::InitHello) => {
let msg_in: Ref<&[u8], Envelope<InitHello>> =
Ref::from_bytes(rx_buf).ok().ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<InitHello>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
// At this point, we do not know the hash functon used by the peer, thus we try both,
// with a preference for SHAKE256.
@@ -2222,8 +2226,9 @@ impl CryptoServer {
peer
}
Ok(MsgType::RespHello) => {
let msg_in: Ref<&[u8], Envelope<RespHello>> =
Ref::from_bytes(rx_buf).ok().ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<RespHello>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
let mut msg_out = truncating_cast_into::<Envelope<InitConf>>(tx_buf)?;
let peer = self.handle_resp_hello(&msg_in.payload, &mut msg_out.payload)?;
@@ -2239,8 +2244,9 @@ impl CryptoServer {
peer
}
Ok(MsgType::InitConf) => {
let msg_in: Ref<&[u8], Envelope<InitConf>> =
Ref::from_bytes(rx_buf).ok().ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<InitConf>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
let mut msg_out = truncating_cast_into::<Envelope<EmptyData>>(tx_buf)?;
@@ -2307,14 +2313,16 @@ impl CryptoServer {
peer
}
Ok(MsgType::EmptyData) => {
let msg_in: Ref<&[u8], Envelope<EmptyData>> =
Ref::from_bytes(rx_buf).ok().ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], Envelope<EmptyData>> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
self.handle_resp_conf(&msg_in, seal_broken.to_string())?
}
Ok(MsgType::CookieReply) => {
let msg_in: Ref<&[u8], CookieReply> =
Ref::from_bytes(rx_buf).ok().ok_or(RosenpassError::BufferSizeMismatch)?;
let msg_in: Ref<&[u8], CookieReply> = Ref::from_bytes(rx_buf)
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)?;
let peer = self.handle_cookie_reply(&msg_in)?;
len = 0;
peer

View File

@@ -10,12 +10,16 @@ use crate::RosenpassError;
pub fn truncating_cast_into<T: FromBytes + KnownLayout + Immutable>(
buf: &mut [u8],
) -> Result<Ref<&mut [u8], T>, RosenpassError> {
Ref::from_bytes(&mut buf[..size_of::<T>()]).ok().ok_or(RosenpassError::BufferSizeMismatch)
Ref::from_bytes(&mut buf[..size_of::<T>()])
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)
}
/// Used to parse a network message using [zerocopy], mutably
pub fn truncating_cast_into_nomut<T: FromBytes + KnownLayout + Immutable>(
buf: &[u8],
) -> Result<Ref<&[u8], T>, RosenpassError> {
Ref::from_bytes(&buf[..size_of::<T>()]).ok().ok_or(RosenpassError::BufferSizeMismatch)
Ref::from_bytes(&buf[..size_of::<T>()])
.ok()
.ok_or(RosenpassError::BufferSizeMismatch)
}

View File

@@ -203,7 +203,8 @@ where
let mut req = [0u8; BUF_SIZE];
// Construct message view
let mut req = zerocopy::Ref::<&mut [u8], Envelope<msgs::SetPskRequest>>::from_bytes(&mut req)
let mut req =
zerocopy::Ref::<&mut [u8], Envelope<msgs::SetPskRequest>>::from_bytes(&mut req)
.ok()
.ok_or(MsgError)?;

View File

@@ -79,10 +79,12 @@ where
let typ = msgs::MsgType::try_from(*typ)?;
let msgs::MsgType::SetPsk = typ; // Assert type
let req =
zerocopy::Ref::<&[u8], Envelope<SetPskRequest>>::from_bytes(req).ok().ok_or(InvalidMessage)?;
let mut res =
zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::from_bytes(res).ok().ok_or(InvalidMessage)?;
let req = zerocopy::Ref::<&[u8], Envelope<SetPskRequest>>::from_bytes(req)
.ok()
.ok_or(InvalidMessage)?;
let mut res = zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::from_bytes(res)
.ok()
.ok_or(InvalidMessage)?;
res.msg_type = msgs::MsgType::SetPsk as u8;
self.handle_set_psk(&req.payload, &mut res.payload)?;