Unnamed repository; edit this file 'description' to name the repository.
Add support for borsh
| -rw-r--r-- | lib/smol_str/Cargo.toml | 3 | ||||
| -rw-r--r-- | lib/smol_str/src/borsh.rs | 58 | ||||
| -rw-r--r-- | lib/smol_str/src/lib.rs | 2 | ||||
| -rw-r--r-- | lib/smol_str/tests/test.rs | 52 |
4 files changed, 114 insertions, 1 deletions
diff --git a/lib/smol_str/Cargo.toml b/lib/smol_str/Cargo.toml index c1e34e7d7b..7dd7a5f9bb 100644 --- a/lib/smol_str/Cargo.toml +++ b/lib/smol_str/Cargo.toml @@ -13,6 +13,7 @@ all-features = true [dependencies] serde = { version = "1.0", optional = true, default-features = false } +borsh = { version = "1.4.0", optional = true, default-features = false } arbitrary = { version = "1.3", optional = true } [dev-dependencies] @@ -22,4 +23,4 @@ serde = { version = "1.0", features = ["derive"] } [features] default = ["std"] -std = ["serde?/std"] +std = ["serde?/std", "borsh?/std"] diff --git a/lib/smol_str/src/borsh.rs b/lib/smol_str/src/borsh.rs new file mode 100644 index 0000000000..12580cb4f2 --- /dev/null +++ b/lib/smol_str/src/borsh.rs @@ -0,0 +1,58 @@ +use crate::{Repr, SmolStr, INLINE_CAP}; +use alloc::string::{String, ToString}; +use borsh::io::{Error, ErrorKind, Read, Write}; +use borsh::{BorshDeserialize, BorshSerialize}; +use core::intrinsics::transmute; + +impl BorshSerialize for SmolStr { + fn serialize<W: Write>(&self, writer: &mut W) -> borsh::io::Result<()> { + self.as_str().serialize(writer) + } +} + +impl BorshDeserialize for SmolStr { + #[inline] + fn deserialize_reader<R: Read>(reader: &mut R) -> borsh::io::Result<Self> { + let len = u32::deserialize_reader(reader)?; + if (len as usize) < INLINE_CAP { + let mut buf = [0u8; INLINE_CAP]; + reader.read_exact(&mut buf[..len as usize])?; + _ = core::str::from_utf8(&buf[..len as usize]).map_err(|err| { + let msg = err.to_string(); + Error::new(ErrorKind::InvalidData, msg) + })?; + Ok(SmolStr(Repr::Inline { + len: unsafe { transmute(len as u8) }, + buf, + })) + } else { + // u8::vec_from_reader always returns Some on success in current implementation + let vec = u8::vec_from_reader(len, reader)?.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "u8::vec_from_reader unexpectedly returned None".to_string(), + ) + })?; + Ok(SmolStr::from(String::from_utf8(vec).map_err(|err| { + let msg = err.to_string(); + Error::new(ErrorKind::InvalidData, msg) + })?)) + } + } +} + +#[cfg(feature = "borsh/unstable__schema")] +mod schema { + use alloc::collections::BTreeMap; + use borsh::schema::{Declaration, Definition}; + use borsh::BorshSchema; + impl BorshSchema for SmolStr { + fn add_definitions_recursively(definitions: &mut BTreeMap<Declaration, Definition>) { + str::add_definitions_recursively(definitions) + } + + fn declaration() -> Declaration { + str::declaration() + } + } +} diff --git a/lib/smol_str/src/lib.rs b/lib/smol_str/src/lib.rs index 448315c338..cc8612ee45 100644 --- a/lib/smol_str/src/lib.rs +++ b/lib/smol_str/src/lib.rs @@ -795,5 +795,7 @@ impl<'a> arbitrary::Arbitrary<'a> for SmolStr { } } +#[cfg(feature = "borsh")] +mod borsh; #[cfg(feature = "serde")] mod serde; diff --git a/lib/smol_str/tests/test.rs b/lib/smol_str/tests/test.rs index 0d553caabc..22b9df2afd 100644 --- a/lib/smol_str/tests/test.rs +++ b/lib/smol_str/tests/test.rs @@ -348,3 +348,55 @@ mod test_str_ext { assert!(!result.is_heap_allocated()); } } +#[cfg(feature = "borsh")] + +mod borsh_tests { + use borsh::BorshDeserialize; + use smol_str::{SmolStr, ToSmolStr}; + use std::io::Cursor; + + #[test] + fn borsh_serialize_stack() { + let smolstr_on_stack = "aßΔCaßδc".to_smolstr(); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&smolstr_on_stack, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let decoded: SmolStr = borsh::BorshDeserialize::deserialize_reader(&mut cursor).unwrap(); + assert_eq!(smolstr_on_stack, decoded); + } + #[test] + fn borsh_serialize_heap() { + let smolstr_on_heap = "aßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδc".to_smolstr(); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&smolstr_on_heap, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let decoded: SmolStr = borsh::BorshDeserialize::deserialize_reader(&mut cursor).unwrap(); + assert_eq!(smolstr_on_heap, decoded); + } + #[test] + fn borsh_non_utf8_stack() { + let invalid_utf8: Vec<u8> = vec![0xF0, 0x9F, 0x8F]; // Incomplete UTF-8 sequence + + let wrong_utf8 = SmolStr::from(unsafe { String::from_utf8_unchecked(invalid_utf8) }); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&wrong_utf8, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let result = SmolStr::deserialize_reader(&mut cursor); + assert!(result.is_err()); + } + + #[test] + fn borsh_non_utf8_heap() { + let invalid_utf8: Vec<u8> = vec![ + 0xC1, 0x8A, 0x5F, 0xE2, 0x3A, 0x9E, 0x3B, 0xAA, 0x01, 0x08, 0x6F, 0x2F, 0xC0, 0x32, + 0xAB, 0xE1, 0x9A, 0x2F, 0x4A, 0x3F, 0x25, 0x0D, 0x8A, 0x2A, 0x19, 0x11, 0xF0, 0x7F, + 0x0E, 0x80, + ]; + let wrong_utf8 = SmolStr::from(unsafe { String::from_utf8_unchecked(invalid_utf8) }); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&wrong_utf8, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let result = SmolStr::deserialize_reader(&mut cursor); + assert!(result.is_err()); + } +} |