bencode inspired tight self describing serialization format
Diffstat (limited to 'src/de.rs')
| -rw-r--r-- | src/de.rs | 545 |
1 files changed, 545 insertions, 0 deletions
diff --git a/src/de.rs b/src/de.rs new file mode 100644 index 0000000..175c147 --- /dev/null +++ b/src/de.rs @@ -0,0 +1,545 @@ +use crate::serde::T; +use crate::{Error, Result}; +use raad::le::R; +use serde::{ + self, Deserialize, + de::{self, DeserializeSeed, Visitor}, +}; + +pub struct Deserializer<'de> { + r: &'de [u8], +} +pub fn from_bytes<'a, T: Deserialize<'a>>(x: &'a [u8]) -> Result<T> { + let mut d = Deserializer { r: x }; + T::deserialize(&mut d) +} +impl<'de> Deserializer<'de> { + pub fn leb128(&mut self) -> Result<u128> { + let mut res = 0u128; + let mut shift = 0; + let mut b = 128; + while b & 128 != 0 { + b = self.r.r::<u8>()?; + res |= ((b & 127) as u128) + .checked_shl(shift * 7) + .ok_or(Error::Overflow)?; + shift += 1; + } + Ok(res) + } + pub fn sleb128(&mut self) -> Result<i128> { + let mut res = 0u128; + let mut shift = 0; + let mut b = 128; + while b & 128 != 0 { + b = self.r.r::<u8>()?; + res |= ((b & 127) as u128) + .checked_shl(shift * 7) + .ok_or(Error::Overflow)?; + shift += 1; + } + if (shift < 128) && ((b & 64) != 0) { + res |= (!0u128).checked_shl(shift * 7).ok_or(Error::Overflow)?; + } + Ok(res.cast_signed()) + } + fn t(&mut self) -> Result<u8> { + Ok(self.r.r()?) + } + + fn a(&self) -> Result<u8> { + self.r.first().ok_or(Error::OOB).copied() + } + #[track_caller] + fn tag(&mut self, expected: T) -> Result<()> { + let t = self.t()?; + if t != expected as u8 { + return Err(Error::Expected { expected, found: t }); + } + Ok(()) + } +} + +impl<'de> serde::Deserializer<'de> for &mut Deserializer<'de> { + type Error = Error; + + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + let tag = self.a()?; + match tag { + x if T::True == x || T::False == x => self.deserialize_bool(visitor), + x if T::Int == x => { + self.tag(T::Int)?; + let v = self.sleb128()? as _; + if let Ok(v64) = i64::try_from(v) { + visitor.visit_i64(v64) + } else { + visitor.visit_i128(v) + } + } + x if T::Uint == x => { + self.tag(T::Uint)?; + let v = self.leb128()?; + if let Ok(v64) = u64::try_from(v) { + visitor.visit_u64(v64) + } else { + visitor.visit_u128(v) + } + } + x if T::Float == x => self.deserialize_f32(visitor), + x if T::Double == x => self.deserialize_f64(visitor), + x if T::String == x => self.deserialize_str(visitor), + x if T::List == x => self.deserialize_seq(visitor), + + x if T::Map == x => self.deserialize_map(visitor), + x if T::None == x || T::Some == x => self.deserialize_option(visitor), + x if T::NVariant == x => { + self.tag(T::NVariant)?; // index, followed by an any + visitor.visit_map(MapAccess { de: self, len: 1 }) + } + x if T::SVariant == x => { + self.tag(T::SVariant)?; + self.tag(T::Uint)?; + let _idx = self.leb128()?; + let len = self.leb128()?; + visitor.visit_map(MapAccess { + de: self, + len: len as usize, + }) + } + x if T::UVariant == x => { + self.tag(T::UVariant)?; + self.tag(T::Uint)?; + visitor.visit_u32(self.leb128()? as _) + } + x if T::TVariant == x => { + self.tag(T::TVariant)?; + + self.tag(T::Uint)?; + let _ix = self.leb128()?; + let len = self.leb128()?; + visitor.visit_seq(SeqAccess::new(self, len as usize)) + } + x => Err(Error::NotTag(x)), + } + } + #[inline] + fn is_human_readable(&self) -> bool { + false + } + fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + match self.t()? { + x if T::True == x => visitor.visit_bool(true), + x if T::False == x => visitor.visit_bool(false), + x => Err(Error::Expected { + expected: T::True, + found: x, + }), + } + } + + fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Int)?; + visitor.visit_i8(self.sleb128()? as _) + } + + fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Int)?; + visitor.visit_i16(self.sleb128()? as _) + } + + fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Int)?; + visitor.visit_i32(self.sleb128()? as _) + } + + fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Int)?; + visitor.visit_i64(self.sleb128()? as _) + } + fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Int)?; + visitor.visit_i128(self.sleb128()?) + } + + fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Uint)?; + let v = self.leb128()?; + visitor.visit_u8(v as _) + } + + fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Uint)?; + let v = self.leb128()?; + visitor.visit_u16(v as _) + } + + fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Uint)?; + let v = self.leb128()?; + visitor.visit_u32(v as _) + } + + fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Uint)?; + let v = self.leb128()?; + visitor.visit_u64(v as _) + } + + fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Uint)?; + let v = self.leb128()?; + visitor.visit_u128(v) + } + + fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Float)?; + visitor.visit_f32(self.r.r()?) + } + + fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Double)?; + visitor.visit_f64(self.r.r()?) + } + + fn deserialize_char<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Uint)?; + let v = self.leb128()?; + visitor.visit_char(char::from_u32(v as _).ok_or(Error::NotChar(v as u32))?) + } + + fn deserialize_str<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + match self.a()? { + x if T::Uint == x => { + self.tag(T::Uint)?; + let v = self.leb128()?; + visitor.visit_str(&v.to_string()) + } + _ => { + self.tag(T::String)?; + let len = self.leb128()? as usize; + let v = visitor + .visit_borrowed_str(str::from_utf8(&self.r.get(..len).ok_or(Error::OOB)?)?); + self.r = self.r.get(len..).ok_or(Error::OOB)?; + v + } + } + } + + fn deserialize_string<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.deserialize_str(visitor) + } + + fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.deserialize_byte_buf(visitor) + } + + fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::String)?; + let len = self.leb128()? as usize; + let v = visitor.visit_borrowed_bytes(&self.r.get(..len).ok_or(Error::OOB)?); + self.r = &self.r.get(len..).ok_or(Error::OOB)?; + v + } + + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + match self.t()? { + x if T::Some == x => visitor.visit_some(self), + x if T::None == x => visitor.visit_none(), + x => Err(Error::Expected { + expected: T::Some, + found: x, + }), + } + } + + fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::None)?; + visitor.visit_unit() + } + + fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::List)?; + let len = self.leb128()? as usize; + visitor.visit_seq(SeqAccess::new(self, len)) + } + + fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + fn deserialize_tuple_struct<V>( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.tag(T::Map)?; + let len = self.leb128()? as usize; + _visitor.visit_map(MapAccess { de: self, len }) + } + + fn deserialize_struct<V>( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result<V::Value> + where + V: Visitor<'de>, + { + // println!("hello"); + self.deserialize_map(visitor) + } + + fn deserialize_enum<V>( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result<V::Value> + where + V: Visitor<'de>, + { + let tag = self.t()?; + + // let variant_index = self.leb128()? as u32; + + visitor.visit_enum(EnumAccess::new(self, tag)) + } + + fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + match self.t()? { + x if T::String == x => { + let len = self.leb128()? as usize; + let v = str::from_utf8(&self.r.get(..len).ok_or(Error::OOB)?)?; + self.r = self.r.get(len..).ok_or(Error::OOB)?; + visitor.visit_borrowed_str(&v) + } + x if T::Uint == x => visitor.visit_u32(self.leb128()? as _), + x => Err(Error::Expected { + expected: T::String, + found: x, + }), + } + } + + fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + self.deserialize_any(visitor) + } +} + +struct SeqAccess<'a, 'de> { + de: &'a mut Deserializer<'de>, + len: usize, +} + +impl<'a, 'de> SeqAccess<'a, 'de> { + fn new(de: &'a mut Deserializer<'de>, len: usize) -> Self { + SeqAccess { de, len } + } +} + +impl<'a, 'de> de::SeqAccess<'de> for SeqAccess<'a, 'de> { + type Error = Error; + + fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>> + where + T: DeserializeSeed<'de>, + { + if self.len == 0 { + Ok(None) + } else { + self.len -= 1; + seed.deserialize(&mut *self.de).map(Some) + } + } +} + +struct EnumAccess<'a, 'de> { + de: &'a mut Deserializer<'de>, + _tag: u8, +} + +impl<'a, 'de> EnumAccess<'a, 'de> { + fn new(de: &'a mut Deserializer<'de>, _tag: u8) -> Self { + EnumAccess { de, _tag } + } +} + +impl<'a, 'de> de::EnumAccess<'de> for EnumAccess<'a, 'de> { + type Error = Error; + type Variant = Self; + + fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self)> + where + V: DeserializeSeed<'de>, + { + let val = seed.deserialize(&mut *self.de)?; + Ok((val, self)) + } +} + +impl<'a, 'de> de::VariantAccess<'de> for EnumAccess<'a, 'de> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + Ok(()) + } + + fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value> + where + T: DeserializeSeed<'de>, + { + seed.deserialize(self.de) + } + + fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + let x = self.de.leb128()?; + assert_eq!(x, len as u128); + visitor.visit_seq(SeqAccess::new(self.de, len)) + } + + fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value> + where + V: Visitor<'de>, + { + let x = self.de.leb128()?; + assert_eq!(x, _fields.len() as u128); + // T::SVariant data follows + visitor.visit_map(MapAccess { + de: self.de, + len: _fields.len(), + }) + } +} + +struct MapAccess<'a, 'de> { + de: &'a mut Deserializer<'de>, + len: usize, +} +impl<'a, 'de> de::MapAccess<'de> for MapAccess<'a, 'de> { + type Error = Error; + + fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>> + where + K: DeserializeSeed<'de>, + { + if self.len == 0 { + Ok(None) + } else { + self.len -= 1; + seed.deserialize(&mut *self.de).map(Some) + } + } + + fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value> + where + V: DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.de) + } +} +#[cfg(test)] +mod tests; |