diff --git a/util/src/length_prefix_encoding/encoder.rs b/util/src/length_prefix_encoding/encoder.rs index a736da2..1d3071e 100644 --- a/util/src/length_prefix_encoding/encoder.rs +++ b/util/src/length_prefix_encoding/encoder.rs @@ -584,3 +584,150 @@ impl> Zeroize for LengthPrefixEncoder { self.clear(); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lpe_error_conversion_upcast_valid() { + let len_error = MessageTooLarge; + let len_sanity_error: MessageLenSanityError = len_error.into(); + + let sanity_error: SanityError = len_error.into(); + assert!(matches!(sanity_error, SanityError::MessageTooLarge(_))); + let sanity_error: SanityError = len_sanity_error.into(); + assert!(matches!(sanity_error, SanityError::MessageTooLarge(_))); + + let pos_error = PositionOutOfBufferBounds; + let pos_sanity_error: PositionSanityError = pos_error.into(); + + let sanity_error: SanityError = pos_error.into(); + assert!(matches!( + sanity_error, + SanityError::PositionOutOfBufferBounds(_) + )); + + let sanity_error: SanityError = pos_sanity_error.into(); + assert!(matches!( + sanity_error, + SanityError::PositionOutOfBufferBounds(_) + )); + } + + #[test] + fn test_lpe_error_conversion_downcast_invalid() { + let pos_error = PositionOutOfBufferBounds; + let sanity_error = SanityError::PositionOutOfBufferBounds(pos_error.into()); + match MessageLenSanityError::try_from(sanity_error) { + Ok(_) => panic!("Conversion should always fail (incompatible enum variant)"), + Err(err) => assert!(matches!(err, PositionOutOfBufferBounds)), + } + } + + #[test] + fn test_write_to_stdio_cursor() { + use std::io::Cursor; + + let msg = String::from("Hello world"); + let prefixed_msg_size = msg.len() + HEADER_SIZE; + + let mut encoder = LengthPrefixEncoder::from_parts(msg.as_bytes(), msg.len(), 0).unwrap(); + assert_eq!(encoder.encoded_message_bytes(), prefixed_msg_size); + assert!(!encoder.exhausted()); + + let mut dummy_stdout = Cursor::new(vec![0; prefixed_msg_size + 1]); + + loop { + let result: WriteToIoReturn = encoder + .write_to_stdio(&mut dummy_stdout) + .expect("write failed"); + if dummy_stdout.position() as usize >= prefixed_msg_size { + // The entire message should've been written (and the encoder state reflect this) + assert!(result.done); + assert_eq!(result.bytes_written, msg.len()); + assert_eq!(encoder.header_written(), (msg.len() as u64).to_le_bytes()); + assert_eq!(encoder.message_written(), msg.as_bytes()); + break; + } + } + + let buffer_bytes = dummy_stdout.get_ref(); + match String::from_utf8(buffer_bytes.to_vec()) { + Ok(buffer_str) => assert_eq!(&buffer_str[HEADER_SIZE..prefixed_msg_size], msg), + Err(err) => println!("Error converting buffer to String: {:?}", err), + } + assert_eq!( + &dummy_stdout.get_ref()[HEADER_SIZE..prefixed_msg_size], + msg.as_bytes() + ); + } + + #[test] + fn test_write_offset_header() { + use std::io::Cursor; + + let mut msg = Vec::::new(); + msg.extend_from_slice(b"cats"); + msg.extend_from_slice(b" and dogs"); + let msg_len = msg.len(); + let prefixed_msg_size = msg_len + HEADER_SIZE; + msg.extend_from_slice(b" and other animals"); // To be discarded + + let mut encoder = LengthPrefixEncoder::from_short_message(msg.clone(), msg_len).unwrap(); + // Only the short message should have been stored (and the unused part discarded) + assert_eq!(encoder.message_mut(), b"cats and dogs"); + assert_eq!(encoder.message_written_mut(), []); + assert_eq!(encoder.message_left_mut(), b"cats and dogs"); + assert_eq!(encoder.buf_mut(), &msg); + + // Fast-forward as if the header had already been sent - only the message remains + encoder + .set_header_offset(HEADER_SIZE) + .expect("failed to move cursor"); + let mut sink = Cursor::new(vec![0; prefixed_msg_size + 1]); + encoder.write_all_to_stdio(&mut sink).expect("write failed"); + assert_eq!(&sink.get_ref()[0..msg_len], &msg[0..msg_len]); + + assert_eq!(encoder.message_mut(), b"cats and dogs"); + assert_eq!(encoder.message_written_mut(), b"cats and dogs"); + assert_eq!(encoder.message_left_mut(), []); + assert_eq!(encoder.buf_mut(), &msg); + } + + #[test] + fn test_some_assembly_required() { + let msg = String::from("hello world"); + let encoder = LengthPrefixEncoder::from_message(msg.as_bytes()); + assert!(encoder.encoded_message_bytes() > msg.len()); + assert!(!encoder.exhausted()); + + let (msg_buffer, msg_length, write_offset) = encoder.into_parts(); + assert_eq!(msg_buffer, msg.as_bytes()); + assert_eq!(write_offset, 0); + assert_eq!(msg_length, msg.len()); + } + + #[test] + fn test_restart_write_reset() { + let msg = String::from("hello world"); + let mut encoder = LengthPrefixEncoder::from_message(msg.as_bytes()); + assert_eq!(encoder.writing_position(), 0); + encoder.set_writing_position(4).unwrap(); + assert_eq!(encoder.writing_position(), 4); + encoder.restart_write(); + assert_eq!(encoder.writing_position(), 0); + } + + #[test] + fn test_zeroize_state() { + use zeroize::Zeroize; + + let mut msg = Vec::::new(); + msg.extend_from_slice(b"test"); + let mut encoder = LengthPrefixEncoder::from_message(msg.clone()); + assert_eq!(encoder.message(), msg); + encoder.zeroize(); + assert_eq!(encoder.message(), []); + } +}