Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative serialization implementation that compactly stores bytes #1303

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
40 changes: 39 additions & 1 deletion risc0/zkvm/src/guest/env.rs
Expand Up @@ -492,14 +492,30 @@ impl<W: Write + ?Sized> Write for &mut W {
pub struct FdWriter<F: Fn(&[u8])> {
fd: u32,
hook: F,
buffered_word: Option<u32>,
}

impl<F: Fn(&[u8])> FdWriter<F> {
fn new(fd: u32, hook: F) -> Self {
FdWriter { fd, hook }
FdWriter {
fd,
hook,
buffered_word: None,
}
}

fn flush(&mut self) {
if self.buffered_word.is_some() {
let words = [self.buffered_word.unwrap()];
self.buffered_word = None;
weikengchen marked this conversation as resolved.
Show resolved Hide resolved
let buffer_bytes = bytemuck::cast_slice(&words);
unsafe { sys_write(self.fd, buffer_bytes.as_ptr(), WORD_SIZE) }
(self.hook)(buffer_bytes);
}
}

fn write_bytes(&mut self, bytes: &[u8]) {
self.flush();
unsafe { sys_write(self.fd, bytes.as_ptr(), bytes.len()) }
(self.hook)(bytes);
}
Expand All @@ -516,6 +532,22 @@ impl<F: Fn(&[u8])> Write for FdWriter<F> {
}

impl<F: Fn(&[u8])> WordWrite for FdWriter<F> {
fn start_new_buffered_word(&mut self) -> crate::serde::Result<()> {
self.flush();
self.buffered_word = Some(0u32);
Ok(())
}

fn get_buffered_word(&self) -> crate::serde::Result<u32> {
assert!(self.buffered_word.is_some());
Ok(self.buffered_word.unwrap())
}

fn set_buffered_word(&mut self, word: u32) -> crate::serde::Result<()> {
self.buffered_word = Some(word);
Ok(())
}

fn write_words(&mut self, words: &[u32]) -> crate::serde::Result<()> {
self.write_bytes(bytemuck::cast_slice(words));
Ok(())
Expand All @@ -532,6 +564,12 @@ impl<F: Fn(&[u8])> WordWrite for FdWriter<F> {
}
}

impl<F: Fn(&[u8])> Drop for FdWriter<F> {
fn drop(&mut self) {
self.flush();
weikengchen marked this conversation as resolved.
Show resolved Hide resolved
}
}

#[cfg(feature = "std")]
impl<F: Fn(&[u8])> std::io::Write for FdWriter<F> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
Expand Down
55 changes: 53 additions & 2 deletions risc0/zkvm/src/serde/deserializer.rs
Expand Up @@ -88,9 +88,56 @@ pub fn from_slice<T: DeserializeOwned, P: Pod>(slice: &[P]) -> Result<T> {
}
}

#[derive(Default)]
struct ByteHandler {
pub status: usize,
pub buffer: [u8; 3],
}

impl ByteHandler {
#[inline]
fn reset(&mut self) -> Result<()> {
if self.status == 1 {
if self.buffer[0] != 0 || self.buffer[1] != 0 || self.buffer[2] != 0 {
return Err(Error::DeserializeBadByte);
}
} else if self.status == 2 {
if self.buffer[1] != 0 || self.buffer[2] != 0 {
return Err(Error::DeserializeBadByte);
}
} else if self.status == 3 {
if self.buffer[2] != 0 {
return Err(Error::DeserializeBadByte);
}
}
Comment on lines +100 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like this could just be slicing the range you're checking. Or is this specific for perf reasons?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort of for performance, though I feel that the compiler would be able to optimize in the same way. All these cases are not supposed to appear, and is used here to detect errors.

self.status = 0;
Ok(())
}

#[inline]
fn handle_byte<R: WordRead>(&mut self, reader: &mut R) -> Result<u8> {
if self.status != 0 {
let res = self.buffer[self.status - 1];
self.status = (self.status + 1) % 4;
Ok(res)
} else {
let mut val = 0u32;
reader.read_words(core::slice::from_mut(&mut val))?;
self.buffer = [
(val >> 8 & 0xff) as u8,
(val >> 16 & 0xff) as u8,
(val >> 24 & 0xff) as u8,
];
self.status = 1;
Ok((val & 0xff) as u8)
}
}
}

/// Enables deserializing from a WordRead
pub struct Deserializer<'de, R: WordRead + 'de> {
reader: R,
byte_handler: ByteHandler,
phantom: core::marker::PhantomData<&'de ()>,
}

Expand Down Expand Up @@ -193,11 +240,13 @@ impl<'de, R: WordRead + 'de> Deserializer<'de, R> {
pub fn new(reader: R) -> Self {
Deserializer {
reader,
byte_handler: ByteHandler::default(),
phantom: core::marker::PhantomData,
}
}

fn try_take_word(&mut self) -> Result<u32> {
self.byte_handler.reset()?;
let mut val = 0u32;
self.reader.read_words(core::slice::from_mut(&mut val))?;
Ok(val)
Expand Down Expand Up @@ -228,7 +277,7 @@ impl<'de, 'a, R: WordRead + 'de> serde::Deserializer<'de> for &'a mut Deserializ
where
V: Visitor<'de>,
{
let val = match self.try_take_word()? {
let val = match self.byte_handler.handle_byte(&mut self.reader)? {
0 => false,
1 => true,
_ => return Err(Error::DeserializeBadBool),
Expand Down Expand Up @@ -268,6 +317,7 @@ impl<'de, 'a, R: WordRead + 'de> serde::Deserializer<'de> for &'a mut Deserializ
where
V: Visitor<'de>,
{
self.byte_handler.reset()?;
let mut bytes = [0u8; 16];
self.reader.read_padded_bytes(&mut bytes)?;
visitor.visit_i128(i128::from_le_bytes(bytes))
Expand All @@ -277,7 +327,7 @@ impl<'de, 'a, R: WordRead + 'de> serde::Deserializer<'de> for &'a mut Deserializ
where
V: Visitor<'de>,
{
visitor.visit_u32(self.try_take_word()?)
visitor.visit_u8(self.byte_handler.handle_byte(&mut self.reader)?)
}

fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
Expand Down Expand Up @@ -305,6 +355,7 @@ impl<'de, 'a, R: WordRead + 'de> serde::Deserializer<'de> for &'a mut Deserializ
where
V: Visitor<'de>,
{
self.byte_handler.reset()?;
let mut bytes = [0u8; 16];
self.reader.read_padded_bytes(&mut bytes)?;
visitor.visit_u128(u128::from_le_bytes(bytes))
Expand Down
3 changes: 3 additions & 0 deletions risc0/zkvm/src/serde/err.rs
Expand Up @@ -22,6 +22,8 @@ pub enum Error {
Custom(String),
/// Found a bool that wasn't 0 or 1
DeserializeBadBool,
/// Found some nonzero bytes in the buffer
DeserializeBadByte,
/// Found an invalid unicode char
DeserializeBadChar,
/// Found an Option discriminant that wasn't 0 or 1
Expand All @@ -44,6 +46,7 @@ impl Display for Error {
formatter.write_str(match self {
Self::Custom(msg) => msg,
Self::DeserializeBadBool => "Found a bool that wasn't 0 or 1",
Self::DeserializeBadByte => "Found some nonzero bytes in the buffer",
Self::DeserializeBadChar => "Found an invalid unicode char",
Self::DeserializeBadOption => "Found an Option discriminant that wasn't 0 or 1",
Self::DeserializeBadUtf8 => "Tried to parse invalid utf-8",
Expand Down
91 changes: 91 additions & 0 deletions risc0/zkvm/src/serde/mod.rs
Expand Up @@ -49,6 +49,7 @@ pub use serializer::{to_vec, to_vec_with_capacity, Serializer, WordWrite};
#[cfg(test)]
mod tests {
use alloc::{collections::BTreeMap, string::String, vec, vec::Vec};
use serde::{Deserialize, Serialize};

use crate::serde::{from_slice, to_vec};

Expand Down Expand Up @@ -76,4 +77,94 @@ mod tests {
let output: (u32, u64) = from_slice(data.as_slice()).unwrap();
assert_eq!(input, output);
}

#[test]
fn test_mixed_tuple() {
type TestType = (Vec<u8>, u32, u8, u8, Vec<u8>, u8);
let input: TestType = (vec![0u8, 1], 8u32, 7u8, 222u8, vec![1, 1, 1, 1, 1], 5u8);
let data = to_vec(&input).unwrap();
assert_eq!([2, 256, 8, 56839, 5, 16843009, 1281].as_slice(), data);
let output: TestType = from_slice(&data).unwrap();
assert_eq!(input, output);
}

#[test]
fn test_mixed_struct() {
#[derive(Debug, Serialize, PartialEq, Eq, Deserialize)]
struct WrappedU8(pub u8);

#[derive(Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
struct TestType {
pub wrapped_u8_seq: Vec<WrappedU8>,
pub u16_seq: Vec<u16>,
pub u32_seq: Vec<u32>,
pub u64_seq: Vec<u64>,
pub i8_seq: Vec<i8>,
pub i16_seq: Vec<i16>,
pub i32_seq: Vec<i32>,
pub i64_seq: Vec<i64>,
pub u8: u8,
pub bool: bool,
pub some_u16: Option<u16>,
pub none_u32: Option<u32>,
pub string_seq: Vec<u8>,
pub string_seq_seq: Vec<Vec<u8>>,
}

let mut input = TestType::default();
input.wrapped_u8_seq = vec![WrappedU8(1u8), WrappedU8(231u8), WrappedU8(123u8)];
input.u16_seq = vec![124u16, 41374u16];
input.u32_seq = vec![14710471u32, 3590275702u32, 1u32, 2u32];
input.u64_seq = vec![352905235952532u64, 2147102974910410u64];
input.i8_seq = vec![-1i8, 120i8, -22i8];
input.i16_seq = vec![-7932i16];
input.i32_seq = vec![-4327i32, 35207277i32];
input.i64_seq = vec![-1i64, 1i64];
input.u8 = 3u8;
input.bool = true;
input.some_u16 = Some(5u16);
input.none_u32 = None;
input.string_seq = b"Here is a string.".to_vec();
input.string_seq_seq = vec![b"string a".to_vec(), b"34720471290497230".to_vec()];

let data = to_vec(&input).unwrap();
assert_eq!(
[
3u32, 8120065, 2, 124, 41374, 4, 14710471, 3590275702, 1, 2, 2, 658142100, 82167,
1578999754, 499911, 3, 4294967295, 120, 4294967274, 1, 4294959364, 2, 4294962969,
35207277, 2, 4294967295, 4294967295, 1, 0, 259, 1, 5, 0, 17, 1701995848, 544434464,
1953701985, 1735289202, 46, 2, 8, 1769108595, 1629513582, 17, 842478643, 825701424,
875575602, 858928953, 48,
]
.as_slice(),
data
);

let output: TestType = from_slice(&data).unwrap();
assert_eq!(input, output);
}

#[test]
fn test_edge_cases() {
#[derive(PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
struct U8SeqThenU32(Vec<u8>, u32);
#[derive(PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
struct U32ThenU8SeqT(u32, Vec<u8>);

for i in 0..=5 {
let mut input = U8SeqThenU32::default();
input.0 = vec![127u8; i];
let data = to_vec(&input).unwrap();
let output: U8SeqThenU32 = from_slice(&data).unwrap();
assert_eq!(input, output);
}

for i in 0..=5 {
let mut input = U32ThenU8SeqT::default();
input.1 = vec![127u8; i];
let data = to_vec(&input).unwrap();
let output: U32ThenU8SeqT = from_slice(&data).unwrap();
assert_eq!(input, output);
}
}
}