diff --git a/rmp-serde/src/bytes.rs b/rmp-serde/src/bytes.rs new file mode 100644 index 0000000..0d49470 --- /dev/null +++ b/rmp-serde/src/bytes.rs @@ -0,0 +1,171 @@ +/// Hacky serializer that only allows `u8` + +use std::fmt; +use serde::ser::Impossible; +use serde::Serialize; + +pub(crate) struct OnlyBytes; +pub(crate) struct Nope; + +impl std::error::Error for Nope { +} + +impl std::fmt::Display for Nope { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + +impl std::fmt::Debug for Nope { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + +impl serde::ser::Error for Nope { + fn custom(_: T) -> Self { + Self + } +} + +impl serde::de::Error for Nope { + fn custom(_: T) -> Self { + Self + } +} + +impl serde::Serializer for OnlyBytes { + type Ok = u8; + type Error = Nope; + type SerializeSeq = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + type SerializeMap = Impossible; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; + + fn serialize_u8(self, val: u8) -> Result { + Ok(val) + } + + fn serialize_bool(self, _: bool) -> Result { + Err(Nope) + } + + fn serialize_i8(self, _: i8) -> Result { + Err(Nope) + } + + fn serialize_i16(self, _: i16) -> Result { + Err(Nope) + } + + fn serialize_i32(self, _: i32) -> Result { + Err(Nope) + } + + fn serialize_i64(self, _: i64) -> Result { + Err(Nope) + } + + fn serialize_u16(self, _: u16) -> Result { + Err(Nope) + } + + fn serialize_u32(self, _: u32) -> Result { + Err(Nope) + } + + fn serialize_u64(self, _: u64) -> Result { + Err(Nope) + } + + fn serialize_f32(self, _: f32) -> Result { + Err(Nope) + } + + fn serialize_f64(self, _: f64) -> Result { + Err(Nope) + } + + fn serialize_char(self, _: char) -> Result { + Err(Nope) + } + + fn serialize_str(self, _: &str) -> Result { + Err(Nope) + } + + fn serialize_bytes(self, _: &[u8]) -> Result { + Err(Nope) + } + + fn serialize_none(self) -> Result { + Err(Nope) + } + + fn serialize_some(self, _: &T) -> Result where T: Serialize { + Err(Nope) + } + + fn serialize_unit(self) -> Result { + Err(Nope) + } + + fn serialize_unit_struct(self, _: &'static str) -> Result { + Err(Nope) + } + + fn serialize_unit_variant(self, _: &'static str, _: u32, _: &'static str) -> Result { + Err(Nope) + } + + fn serialize_newtype_struct(self, _: &'static str, _: &T) -> Result where T: Serialize { + Err(Nope) + } + + fn serialize_newtype_variant(self, _: &'static str, _: u32, _: &'static str, _: &T) -> Result where T: Serialize { + Err(Nope) + } + + fn serialize_seq(self, _: Option) -> Result { + Err(Nope) + } + + fn serialize_tuple(self, _: usize) -> Result { + Err(Nope) + } + + fn serialize_tuple_struct(self, _: &'static str, _: usize) -> Result { + Err(Nope) + } + + fn serialize_tuple_variant(self, _: &'static str, _: u32, _: &'static str, _: usize) -> Result { + Err(Nope) + } + + fn serialize_map(self, _: Option) -> Result { + Err(Nope) + } + + fn serialize_struct(self, _: &'static str, _: usize) -> Result { + Err(Nope) + } + + fn serialize_struct_variant(self, _: &'static str, _: u32, _: &'static str, _: usize) -> Result { + Err(Nope) + } + + fn collect_seq(self, _: I) -> Result where I: IntoIterator, ::Item: Serialize { + Err(Nope) + } + + fn collect_map(self, _: I) -> Result where K: Serialize, V: Serialize, I: IntoIterator { + Err(Nope) + } + + fn collect_str(self, _: &T) -> Result where T: fmt::Display { + Err(Nope) + } +} diff --git a/rmp-serde/src/config.rs b/rmp-serde/src/config.rs index 83f2086..c039499 100644 --- a/rmp-serde/src/config.rs +++ b/rmp-serde/src/config.rs @@ -9,6 +9,8 @@ pub trait SerializerConfig: sealed::SerializerConfig {} impl SerializerConfig for T {} pub(crate) mod sealed { + use crate::config::BytesMode; + /// This is the inner trait - the real `SerializerConfig`. /// /// This hack disallows external implementations and usage of `SerializerConfig` and thus @@ -20,6 +22,7 @@ pub(crate) mod sealed { /// String struct fields fn is_named(&self) -> bool; + fn bytes(&self) -> BytesMode; } } @@ -27,6 +30,30 @@ pub(crate) mod sealed { pub(crate) struct RuntimeConfig { pub(crate) is_human_readable: bool, pub(crate) is_named: bool, + pub(crate) bytes: BytesMode, +} + +/// When to encode `[u8]` as `bytes` rather than a sequence +/// of integers. Serde without `serde_bytes` has trouble +/// using `bytes`, and this is hack to force it. It may +/// break some data types. +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] +pub enum BytesMode { + /// Use bytes only when Serde requires it + /// (typically only when `serde_bytes` is used) + #[default] + Normal, + /// Use bytes for slices, `Vec`, and a few other types that + /// use `Iterator` in Serde. + /// + /// This may break some implementations of `Deserialize`. + /// + /// This does not include fixed-length arrays. + ForceIterables, + /// Use bytes for everything that looks like a container of `u8`. + /// This breaks some implementations of `Deserialize`. + ForceAll, } impl RuntimeConfig { @@ -34,6 +61,7 @@ impl RuntimeConfig { Self { is_human_readable: other.is_human_readable(), is_named: other.is_named(), + bytes: other.bytes(), } } } @@ -48,6 +76,11 @@ impl sealed::SerializerConfig for RuntimeConfig { fn is_named(&self) -> bool { self.is_named } + + #[inline] + fn bytes(&self) -> BytesMode { + self.bytes + } } /// The default serializer/deserializer configuration. @@ -71,6 +104,11 @@ impl sealed::SerializerConfig for DefaultConfig { fn is_human_readable(&self) -> bool { false } + + #[inline(always)] + fn bytes(&self) -> BytesMode { + BytesMode::default() + } } /// Config wrapper, that overrides struct serialization by packing as a map with field names. @@ -104,6 +142,10 @@ where fn is_human_readable(&self) -> bool { self.0.is_human_readable() } + + fn bytes(&self) -> BytesMode { + self.0.bytes() + } } /// Config wrapper that overrides struct serlization by packing as a tuple without field @@ -132,6 +174,10 @@ where fn is_human_readable(&self) -> bool { self.0.is_human_readable() } + + fn bytes(&self) -> BytesMode { + self.0.bytes() + } } /// Config wrapper that overrides `Serializer::is_human_readable` and @@ -160,6 +206,10 @@ where fn is_human_readable(&self) -> bool { true } + + fn bytes(&self) -> BytesMode { + self.0.bytes() + } } /// Config wrapper that overrides `Serializer::is_human_readable` and @@ -188,4 +238,8 @@ where fn is_human_readable(&self) -> bool { false } + + fn bytes(&self) -> BytesMode { + self.0.bytes() + } } diff --git a/rmp-serde/src/encode.rs b/rmp-serde/src/encode.rs index 73d9380..f22f35a 100644 --- a/rmp-serde/src/encode.rs +++ b/rmp-serde/src/encode.rs @@ -1,5 +1,7 @@ //! Serialize a Rust data structure into MessagePack data. +use crate::bytes::OnlyBytes; +use crate::config::BytesMode; use std::error; use std::fmt::{self, Display}; use std::io::Write; @@ -259,6 +261,18 @@ impl Serializer { _back_compat_config: PhantomData, } } + + /// Prefer encoding sequences of `u8` as bytes, rather than + /// as a sequence of variable-size integers. + /// + /// This reduces overhead of binary data, but it may break + /// decodnig of some Serde types that happen to contain `[u8]`s, + /// but don't implement Serde's `visit_bytes`. + #[inline] + pub fn with_bytes(mut self, mode: BytesMode) -> Serializer { + self.config.bytes = mode; + self + } } impl UnderlyingWrite for Serializer { @@ -280,6 +294,50 @@ impl UnderlyingWrite for Serializer { } } +/// Hack to store fixed-size arrays (which serde says are tuples) +#[derive(Debug)] +#[doc(hidden)] +pub struct Tuple<'a, W, C> { + len: u32, + // can't know if all elements are u8 until the end ;( + buf: Option>, + se: &'a mut Serializer, +} + +impl<'a, W: Write + 'a, C: SerializerConfig> SerializeTuple for Tuple<'a, W, C> { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { + if let Some(buf) = &mut self.buf { + if let Ok(byte) = value.serialize(OnlyBytes) { + buf.push(byte); + return Ok(()); + } else { + encode::write_array_len(&mut self.se.wr, self.len)?; + for b in buf { + b.serialize(&mut *self.se)?; + } + self.buf = None; + } + } + value.serialize(&mut *self.se) + } + + fn end(self) -> Result { + if let Some(buf) = self.buf { + if self.len < 16 && buf.iter().all(|&b| b < 128) { + encode::write_array_len(&mut self.se.wr, self.len)?; + } else { + encode::write_bin_len(&mut self.se.wr, self.len)?; + } + self.se.wr.write_all(&buf) + .map_err(ValueWriteError::InvalidDataWrite)?; + } + Ok(()) + } +} + /// Part of serde serialization API. #[derive(Debug)] #[doc(hidden)] @@ -503,7 +561,7 @@ where type Error = Error; type SerializeSeq = MaybeUnknownLengthCompound<'a, W, C>; - type SerializeTuple = Compound<'a, W, C>; + type SerializeTuple = Tuple<'a, W, C>; type SerializeTupleStruct = Compound<'a, W, C>; type SerializeTupleVariant = Compound<'a, W, C>; type SerializeMap = MaybeUnknownLengthCompound<'a, W, C>; @@ -584,10 +642,7 @@ where } fn serialize_bytes(self, value: &[u8]) -> Result { - encode::write_bin_len(&mut self.wr, value.len() as u32)?; - self.wr - .write_all(value) - .map_err(|err| Error::InvalidValueWrite(ValueWriteError::InvalidDataWrite(err))) + Ok(encode::write_bin(&mut self.wr, value)?) } fn serialize_none(self) -> Result<(), Self::Error> { @@ -638,11 +693,17 @@ where self.maybe_unknown_len_compound(len.map(|len| len as u32), |wr, len| encode::write_array_len(wr, len)) } - //TODO: normal compund fn serialize_tuple(self, len: usize) -> Result { - encode::write_array_len(&mut self.wr, len as u32)?; - - self.compound() + Ok(Tuple { + buf: if self.config.bytes == BytesMode::ForceAll && len > 0 { + Some(Vec::new()) + } else { + encode::write_array_len(&mut self.wr, len as u32)?; + None + }, + len: len as u32, + se: self, + }) } fn serialize_tuple_struct(self, _name: &'static str, len: usize) -> @@ -659,7 +720,8 @@ where // encode as a map from variant idx to a sequence of its attributed data, like: {idx => [v1,...,vN]} encode::write_map_len(&mut self.wr, 1)?; self.serialize_str(variant)?; - self.serialize_tuple(len) + encode::write_array_len(&mut self.wr, len as u32)?; + self.compound() } #[inline] @@ -686,6 +748,49 @@ where self.serialize_str(variant)?; self.serialize_struct(name, len) } + + fn collect_seq(self, iter: I) -> Result where I: IntoIterator, I::Item: Serialize { + let iter = iter.into_iter(); + let len = match iter.size_hint() { + (lo, Some(hi)) if lo == hi && lo <= u32::MAX as usize => Some(lo as u32), + _ => None, + }; + + const MAX_ITER_SIZE: usize = std::mem::size_of::<<&[u8] as IntoIterator>::IntoIter>(); + const ITEM_PTR_SIZE: usize = std::mem::size_of::<&u8>(); + + // Estimate whether the input is `&[u8]` or similar (hacky, because Rust lacks proper specialization) + let might_be_a_bytes_iter = (std::mem::size_of::() == 1 || std::mem::size_of::() == ITEM_PTR_SIZE) + // Complex types like HashSet don't support reading bytes. + // The simplest iterator is ptr+len. + && std::mem::size_of::() <= MAX_ITER_SIZE; + + let mut iter = iter.peekable(); + if might_be_a_bytes_iter && self.config.bytes != BytesMode::Normal { + if let Some(len) = len { + // The `OnlyBytes` serializer emits `Err` for everything except `u8` + if iter.peek().map_or(false, |item| item.serialize(OnlyBytes).is_ok()) { + return self.bytes_from_iter(iter, len); + } + } + } + + let mut serializer = self.serialize_seq(len.map(|len| len as usize))?; + iter.try_for_each(|item| serializer.serialize_element(&item))?; + SerializeSeq::end(serializer) + } +} + +impl Serializer { + fn bytes_from_iter(&mut self, mut iter: I, len: u32) -> Result<(), <&mut Self as serde::Serializer>::Error> where I: Iterator, I::Item: Serialize { + encode::write_bin_len(&mut self.wr, len)?; + iter.try_for_each(|item| { + self.wr.write(std::slice::from_ref(&item.serialize(OnlyBytes) + .map_err(|_| Error::InvalidDataModel("BytesMode"))?)) + .map_err(ValueWriteError::InvalidDataWrite)?; + Ok(()) + }) + } } impl<'a, W: Write + 'a> serde::Serializer for &mut ExtFieldSerializer<'a, W> { diff --git a/rmp-serde/src/lib.rs b/rmp-serde/src/lib.rs index 70292f9..b561a41 100644 --- a/rmp-serde/src/lib.rs +++ b/rmp-serde/src/lib.rs @@ -66,6 +66,7 @@ pub use crate::encode::{to_vec, to_vec_named, Serializer}; pub use crate::decode::from_slice; +mod bytes; pub mod config; pub mod decode; pub mod encode; diff --git a/rmp-serde/tests/encode.rs b/rmp-serde/tests/encode.rs index 1bdfc8f..7c3a060 100644 --- a/rmp-serde/tests/encode.rs +++ b/rmp-serde/tests/encode.rs @@ -2,6 +2,7 @@ extern crate rmp_serde as rmps; use std::io::Cursor; +use rmps::config::BytesMode; use serde::Serialize; use rmp_serde::encode::{self, Error}; @@ -209,6 +210,45 @@ fn pass_tuple() { assert_eq!([0x92, 0x2a, 0xce, 0x0, 0x1, 0x88, 0x94], buf); } +#[test] +fn pass_tuple_not_bytes() { + let mut buf = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + + let val = (42u32, 100500u32); + val.serialize(&mut Serializer::new(&mut &mut buf[..]).with_bytes(BytesMode::ForceAll)).ok().unwrap(); + + assert_eq!([0x92, 0x2a, 0xce, 0x0, 0x1, 0x88, 0x94], buf); +} + +#[test] +fn pass_tuple_bytes() { + let mut buf = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + + let val = (1u8, 100u8, 200u8, 254u8); + val.serialize(&mut Serializer::new(&mut &mut buf[..]).with_bytes(BytesMode::ForceAll)).ok().unwrap(); + + assert_eq!([196, 4, 1, 100, 200, 254], buf); +} + +#[test] +fn pass_hash_array_bytes() { + use std::collections::HashSet; + let mut buf = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + + let val = [[255u8; 3], [1u8; 3]].into_iter().collect::>(); + val.serialize(&mut Serializer::new(&mut &mut buf[..]).with_bytes(BytesMode::ForceAll)).ok().unwrap(); +} + +#[test] +fn pass_tuple_low_bytes() { + let mut buf = [0x00, 0x00, 0x00, 0x00, 0x00]; + + let val = (1u8, 2, 3, 127); + val.serialize(&mut Serializer::new(&mut &mut buf[..]).with_bytes(BytesMode::ForceAll)).ok().unwrap(); + + assert_eq!([148, 1, 2, 3, 127], buf); +} + #[test] fn pass_option_some() { let mut buf = [0x00]; diff --git a/rmp-serde/tests/round.rs b/rmp-serde/tests/round.rs index 375ad44..120c2e6 100644 --- a/rmp-serde/tests/round.rs +++ b/rmp-serde/tests/round.rs @@ -669,6 +669,8 @@ fn roundtrip_some_failures() { #[cfg(test)] #[track_caller] fn assert_roundtrips Deserialize<'a>>(val: T) { + use rmp_serde::config::BytesMode; + assert_roundtrips_config(&val, "default", |s| s, |d| d); assert_roundtrips_config(&val, ".with_struct_map()", |s| s.with_struct_map(), |d| d); assert_roundtrips_config( @@ -704,6 +706,18 @@ fn assert_roundtrips Deseri }, |d| d.with_human_readable(), ); + assert_roundtrips_config( + &val, + ".with_bytes(ForceIterables)", + |s| s.with_bytes(BytesMode::ForceIterables), + |d| d, + ); + assert_roundtrips_config( + &val, + ".with_bytes(ForceAll)", + |s| s.with_bytes(BytesMode::ForceAll), + |d| d, + ); } #[cfg(test)]