use crate::serde::T; use crate::{Error, Result}; use raad::le::R; use serde::{ self, Deserialize, de::{self, DeserializeSeed, Visitor}, }; pub struct Deserializer<'de> { r: &'de [u8], } pub fn from_bytes<'a, T: Deserialize<'a>>(x: &'a [u8]) -> Result { let mut d = Deserializer { r: x }; T::deserialize(&mut d) } impl Deserializer<'_> { pub fn leb128(&mut self) -> Result { let mut res = 0u128; let mut shift = 0; let mut b = 128; while b & 128 != 0 { b = self.r.r::()?; if shift == 18 && b > 0b11 { return Err(Error::Overflow); } res |= u128::from(b & 127) .checked_shl(shift * 7) .ok_or(Error::Overflow)?; shift += 1; } Ok(res) } pub fn sleb128(&mut self) -> Result { let mut res = 0u128; let mut shift = 0; let mut b = 128; while b & 128 != 0 { b = self.r.r::()?; if shift > 18 || (shift == 18 && !matches!(b & 127, 0..=1 | 126..=127)) { return Err(Error::Overflow); } if shift == 18 && let 0..=1 | 126..=127 = b {} res |= u128::from(b & 127) .checked_shl(shift * 7) .ok_or(Error::Overflow)?; shift += 1; } if (shift * 7 < 128) && (b & 64) != 0 { res |= (!0u128).checked_shl(shift * 7).ok_or(Error::Overflow)?; } Ok(res.cast_signed()) } fn t(&mut self) -> Result { Ok(self.r.r()?) } fn a(&self) -> Result { self.r.first().ok_or(Error::OOB).copied() } #[track_caller] fn tag(&mut self, expected: T) -> Result<()> { let t = self.t()?; if t != expected as u8 { return Err(Error::Expected { expected, found: t }); } Ok(()) } } impl<'de> serde::Deserializer<'de> for &mut Deserializer<'de> { type Error = Error; fn deserialize_any(self, visitor: V) -> Result where V: Visitor<'de>, { let tag = self.a()?; match tag { x if T::True == x || T::False == x => self.deserialize_bool(visitor), x if T::Int == x => { self.tag(T::Int)?; let v = self.sleb128()? as _; if let Ok(v64) = i64::try_from(v) { visitor.visit_i64(v64) } else { visitor.visit_i128(v) } } x if T::Uint == x => { self.tag(T::Uint)?; let v = self.leb128()?; if let Ok(v64) = u64::try_from(v) { visitor.visit_u64(v64) } else { visitor.visit_u128(v) } } x if T::Float == x => self.deserialize_f32(visitor), x if T::Double == x => self.deserialize_f64(visitor), x if T::String == x => self.deserialize_str(visitor), x if T::List == x => self.deserialize_seq(visitor), x if T::Map == x => self.deserialize_map(visitor), x if T::None == x || T::Some == x => self.deserialize_option(visitor), x if T::NVariant == x => { self.tag(T::NVariant)?; // index, followed by an any visitor.visit_map(MapAccess { de: self, len: 1 }) } x if T::SVariant == x => { self.tag(T::SVariant)?; self.tag(T::Uint)?; let _idx = self.leb128()?; let len = self.leb128()?; visitor.visit_map(MapAccess { de: self, len: len as usize, }) } x if T::UVariant == x => { self.tag(T::UVariant)?; self.tag(T::Uint)?; visitor.visit_u32(self.leb128()?.try_into()?) } x if T::TVariant == x => { self.tag(T::TVariant)?; self.tag(T::Uint)?; let _ix = self.leb128()?; let len = self.leb128()?; visitor.visit_seq(SeqAccess::new(self, len as usize)) } x => Err(Error::NotTag(x)), } } #[inline] fn is_human_readable(&self) -> bool { false } fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, { match self.t()? { x if T::True == x => visitor.visit_bool(true), x if T::False == x => visitor.visit_bool(false), x => Err(Error::Expected { expected: T::True, found: x, }), } } fn deserialize_i8(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Int)?; visitor.visit_i8(self.sleb128()?.try_into()?) } fn deserialize_i16(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Int)?; visitor.visit_i16(self.sleb128()?.try_into()?) } fn deserialize_i32(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Int)?; visitor.visit_i32(self.sleb128()?.try_into()?) } fn deserialize_i64(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Int)?; visitor.visit_i64(self.sleb128()?.try_into()?) } fn deserialize_i128(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Int)?; visitor.visit_i128(self.sleb128()?) } fn deserialize_u8(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Uint)?; let v = self.leb128()?; visitor.visit_u8(v.try_into()?) } fn deserialize_u16(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Uint)?; let v = self.leb128()?; visitor.visit_u16(v.try_into()?) } fn deserialize_u32(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Uint)?; let v = self.leb128()?; visitor.visit_u32(v.try_into()?) } fn deserialize_u64(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Uint)?; let v = self.leb128()?; visitor.visit_u64(v.try_into()?) } fn deserialize_u128(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Uint)?; let v = self.leb128()?; visitor.visit_u128(v) } fn deserialize_f32(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Float)?; visitor.visit_f32(self.r.r()?) } fn deserialize_f64(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Double)?; visitor.visit_f64(self.r.r()?) } fn deserialize_char(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Uint)?; let v = self.leb128()?; visitor.visit_char(char::from_u32(v.try_into()?).ok_or(Error::NotChar(v.try_into()?))?) } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, { match self.a()? { x if T::Uint == x => { self.tag(T::Uint)?; let v = self.leb128()?; visitor.visit_str(&v.to_string()) } _ => { self.tag(T::String)?; let len = self.leb128()? as usize; let v = visitor .visit_borrowed_str(str::from_utf8(self.r.get(..len).ok_or(Error::OOB)?)?); self.r = self.r.get(len..).ok_or(Error::OOB)?; v } } } fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de>, { self.deserialize_str(visitor) } fn deserialize_bytes(self, visitor: V) -> Result where V: Visitor<'de>, { self.deserialize_byte_buf(visitor) } fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::String)?; let len = self.leb128()? as usize; let v = visitor.visit_borrowed_bytes(self.r.get(..len).ok_or(Error::OOB)?); self.r = self.r.get(len..).ok_or(Error::OOB)?; v } fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de>, { match self.t()? { x if T::Some == x => visitor.visit_some(self), x if T::None == x => visitor.visit_none(), x => Err(Error::Expected { expected: T::Some, found: x, }), } } fn deserialize_unit(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::None)?; visitor.visit_unit() } fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result where V: Visitor<'de>, { self.deserialize_unit(visitor) } fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result where V: Visitor<'de>, { visitor.visit_newtype_struct(self) } fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::List)?; let len = self.leb128()? as usize; visitor.visit_seq(SeqAccess::new(self, len)) } fn deserialize_tuple(self, _len: usize, visitor: V) -> Result where V: Visitor<'de>, { self.deserialize_seq(visitor) } fn deserialize_tuple_struct( self, _name: &'static str, _len: usize, visitor: V, ) -> Result where V: Visitor<'de>, { self.deserialize_seq(visitor) } fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { self.tag(T::Map)?; let len = self.leb128()? as usize; visitor.visit_map(MapAccess { de: self, len }) } fn deserialize_struct( self, _name: &'static str, _fields: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { // println!("hello"); self.deserialize_map(visitor) } fn deserialize_enum( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { let tag = self.t()?; // let variant_index = self.leb128()? as u32; visitor.visit_enum(EnumAccess::new(self, tag)) } fn deserialize_identifier(self, visitor: V) -> Result where V: Visitor<'de>, { match self.t()? { x if T::String == x => { let len = self.leb128()? as usize; let v = str::from_utf8(self.r.get(..len).ok_or(Error::OOB)?)?; self.r = self.r.get(len..).ok_or(Error::OOB)?; visitor.visit_borrowed_str(v) } x if T::Uint == x => visitor.visit_u32(self.leb128()?.try_into()?), x => Err(Error::Expected { expected: T::String, found: x, }), } } fn deserialize_ignored_any(self, visitor: V) -> Result where V: Visitor<'de>, { self.deserialize_any(visitor) } } struct SeqAccess<'a, 'de> { de: &'a mut Deserializer<'de>, len: usize, } impl<'a, 'de> SeqAccess<'a, 'de> { fn new(de: &'a mut Deserializer<'de>, len: usize) -> Self { SeqAccess { de, len } } } impl<'de> de::SeqAccess<'de> for SeqAccess<'_, 'de> { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result> where T: DeserializeSeed<'de>, { if self.len == 0 { Ok(None) } else { self.len -= 1; seed.deserialize(&mut *self.de).map(Some) } } } struct EnumAccess<'a, 'de> { de: &'a mut Deserializer<'de>, _tag: u8, } impl<'a, 'de> EnumAccess<'a, 'de> { fn new(de: &'a mut Deserializer<'de>, _tag: u8) -> Self { EnumAccess { de, _tag } } } impl<'de> de::EnumAccess<'de> for EnumAccess<'_, 'de> { type Error = Error; type Variant = Self; fn variant_seed(self, seed: V) -> Result<(V::Value, Self)> where V: DeserializeSeed<'de>, { let val = seed.deserialize(&mut *self.de)?; Ok((val, self)) } } impl<'de> de::VariantAccess<'de> for EnumAccess<'_, 'de> { type Error = Error; fn unit_variant(self) -> Result<()> { Ok(()) } fn newtype_variant_seed(self, seed: T) -> Result where T: DeserializeSeed<'de>, { seed.deserialize(self.de) } fn tuple_variant(self, len: usize, visitor: V) -> Result where V: Visitor<'de>, { let x = self.de.leb128()? as usize; if x != len { return Err(Error::IncoherentLen(x, len)); } visitor.visit_seq(SeqAccess::new(self.de, len)) } fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result where V: Visitor<'de>, { let x = self.de.leb128()? as usize; if x != fields.len() { return Err(Error::IncoherentLen(x, fields.len())); } // T::SVariant data follows visitor.visit_map(MapAccess { de: self.de, len: fields.len(), }) } } struct MapAccess<'a, 'de> { de: &'a mut Deserializer<'de>, len: usize, } impl<'de> de::MapAccess<'de> for MapAccess<'_, 'de> { type Error = Error; fn next_key_seed(&mut self, seed: K) -> Result> where K: DeserializeSeed<'de>, { if self.len == 0 { Ok(None) } else { self.len -= 1; seed.deserialize(&mut *self.de).map(Some) } } fn next_value_seed(&mut self, seed: V) -> Result where V: DeserializeSeed<'de>, { seed.deserialize(&mut *self.de) } } #[cfg(test)] mod tests;