Diffstat (limited to 'src/symbol.rs')
| -rw-r--r-- | src/symbol.rs | 274 |
1 files changed, 238 insertions, 36 deletions
diff --git a/src/symbol.rs b/src/symbol.rs index a3c8b4d..910d1fa 100644 --- a/src/symbol.rs +++ b/src/symbol.rs @@ -1,25 +1,70 @@ +//! Unique symbol made from human readable tags. + +// This module implements https://en.wikipedia.org/wiki/Arithmetic_coding for the +// purpose of packing a string into a u64 to use as an ID. If you want to use +// arithmetic coding consider the https://docs.rs/arcode/latest/arcode/ crate. +// This algorithm is able to compress about 10-12 characters into a u64. + use range::*; +/// Unique symbol, given a unique tag. +/// +/// ``` +/// use uniserde::symbol::Symbol; +/// +/// const PROPERTY: Symbol = Symbol::new("Property"); +/// # let _ = PROPERTY; +/// ``` +/// +/// This type can be used to create "extendable enums" where +/// each variant is given by a different Symbol. +/// +/// Internally each [`Symbol`] is just a [`u64`]. As such, a [`Symbol`] is very cheap +/// to copy and compare. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct Symbol(u64); impl Symbol { + /// Panicking version of [`Self::try_new()`]. + /// + /// This is more useful in const contexts as you don't have to handle the possible + /// error manually. + /// + /// # Panics + /// This function will panic when [`Self::try_new()`] would return an error. + /// + /// ```compile_fail + /// use uniserde::symbol::Symbol; + /// + /// // The tag is to long so new will panic. + /// const MY_SYMBOL: Symbol = Symbol::new("some really long tag"); + /// # let _ = MY_SYMBOL; + /// ``` #[track_caller] pub const fn new(tag: &str) -> Self { match Self::try_new(tag) { Ok(s) => s, Err(EncodeError::TooComplex { .. }) => { - panic!("Symbol is too complex to encode. Try making it shorter.") + panic!("Tag is too complex to encode. Try making it shorter.") } Err(EncodeError::UnknownChar(_)) => { panic!( - "Unknown character, supported characters are: a-z, A-Z, 0-9, _, -, and space." + "Unknown character, supported characters \ + are: a-z, A-Z, 0-9, _, -, and space." ); } } } + /// Create a new symbol from a string tag. + /// + /// `tag` can be a string with the characters `a-z`, `A-Z`, `0-9`, `_`, + /// `-`, and ` ` (space). The tag can range from 0 to about 10 characters + /// in length. There is not fixed limit on the length. If the tag + /// is too complex to store in the symbol then a + /// [`EncodeError::TooComplex`] error is returned. If this happens try using + /// a shorter tag and/or remove capitals and numbers. pub const fn try_new(tag: &str) -> Result<Self, EncodeError> { match encode(tag) { Ok(tag) => Ok(Self(tag)), @@ -27,13 +72,41 @@ impl Symbol { } } + /// Convert the symbol to it's integer representation. + /// + /// This is provided to allow using a [`Symbol`] as a const generic. + /// + /// ``` + /// use uniserde::symbol::Symbol; + /// + /// struct MyType<const N: u64>; + /// + /// type OtherType = MyType<{ Symbol::new("OtherType").to_int() }>; + /// # let _: OtherType; + /// ``` pub const fn to_int(self) -> u64 { self.0 } + /// Create a symbol from a integer. + /// + /// This is to be used in combination with [`Self::to_int()`] to rebuild + /// the [`Symbol`]. + /// + /// This function cannot fail. However, if a integer not made by + /// [`Self::to_int()`] is provided then the Symbol will likely display + /// as a random string when calling [`Display::fmt()`][core::fmt::Display::fmt] + /// or [`Debug::fmt()`][core::fmt::Debug::fmt]. pub const fn from_int(value: u64) -> Self { Self(value) } + + /// Const form of [`PartialEq::eq()`]. + /// + /// Checking for equality via [`Self::to_int()`] is also possible. + pub const fn eq(self, other: Self) -> bool { + self.0 == other.0 + } } impl core::fmt::Debug for Symbol { @@ -42,6 +115,7 @@ impl core::fmt::Debug for Symbol { } } +/// Helper for printing the symbol in the Debug impl. struct DisplayFmt(u64); impl core::fmt::Debug for DisplayFmt { @@ -56,59 +130,126 @@ impl core::fmt::Display for Symbol { } } +/// Error while encoding a tag for a [`Symbol`]. #[derive(Debug)] pub enum EncodeError<'a> { - TooComplex { encoded: &'a str }, + /// The tag is too complex to fit in a [`u64`]. + /// + /// If this happens, then reduce the string to around what `encoded` has. + TooComplex { + /// The amount of the input string that was encoded properly. + encoded: &'a str, + }, + + /// An unknown character was found in the tag. + /// + /// Supported characters: `a-z`, `A-Z`, `0-9`, `_`, `-`, and ` ` (space). UnknownChar(char), } impl<'a> EncodeError<'a> { + /// Create a TooComplex error in a const ocntext. const fn too_complex(input: &'a str, length: usize) -> Self { + // If the length is past the end then we just return all of the string. if length >= input.len() { return Self::TooComplex { encoded: input }; } + // Slice the string manually. let buffer = input.as_bytes(); - let slice = unsafe { core::slice::from_raw_parts(buffer.as_ptr(), length) }; - let s = unsafe { core::str::from_utf8_unchecked(slice) }; + let (slice, _) = buffer.split_at(length); + let s = match core::str::from_utf8(slice) { + Ok(s) => s, + + // This shouldn't happen because if we would split an unicode char we + // would have has a unknown char instead. + Err(_) => panic!("Not valid UTF8"), + }; Self::TooComplex { encoded: s } } } +impl<'a> core::fmt::Display for EncodeError<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + EncodeError::TooComplex { encoded } => write!( + f, + "Tag is to complex to encode. \ + Encoded this portion of the tag: \ + `{encoded}`. Try making the tag shorter." + ), + EncodeError::UnknownChar(char) => write!( + f, + "Unknown character `{char}`, \ + supported characters are: a-z, A-Z, 0-9, _, -, and space." + ), + } + } +} + +#[cfg(feature = "std")] +impl<'a> std::error::Error for EncodeError<'a> {} + +// Decode an arithmetic coded int into a string. fn decode(input: u64, f: &mut core::fmt::Formatter) -> core::fmt::Result { + // The interval and a range in that interval. let (interval, mut range) = Interval::new(); - let mut precision = 31; // Matches the interval. + + // The precision needs to match what was used for the encoding (in the interval). + let mut precision = 31; + + // The input buffer being read from. let mut input_buffer = 0; + + // Offset into the input. let mut offset = 0; + // Read a bit to fill the precision. for _ in 0..precision { input_buffer = (input_buffer << 1) | match bit(&mut precision, &mut offset, input) { Some(x) => x, None => { + // If an error happens then we print this and return. + // As decoding is only used for displaying the symbol. write!(f, "<EOF>")?; return Ok(()); } }; } + // Loop until we run out of input or hit a \0. let mut len = 0; loop { + // The symbol that covers the input. let symbol: usize; + + // The range of the symbol. let mut low_high: Range; + + // Lower and upper bound of binary search. let mut sym_idx_low_high = (0, ALPHABET.len()); // Binary search for the symbol that covers the input. loop { + // Guess half of the range. let sym_idx_mid = (sym_idx_low_high.0 + sym_idx_low_high.1) / 2; + + // Find the range the symbol covers. low_high = calculate_range(range, sym_idx_mid, len); + if low_high.low <= input_buffer && input_buffer < low_high.high { + // The range coverse the input so its the correct symbol. symbol = sym_idx_mid; break; } else if input_buffer >= low_high.high { + // The symbol covers a region that is to low. + // We move the search to above the symbol. sym_idx_low_high.0 = sym_idx_mid + 1; } else { + // The symbol covers a region that is to high. + // We move the search to below the symbol. sym_idx_low_high.1 = sym_idx_mid - 1; } } @@ -118,25 +259,39 @@ fn decode(input: u64, f: &mut core::fmt::Formatter) -> core::fmt::Result { break; } + // Update the range to be what we found for the symbol. range = low_high; + // Keep scaling the range up. while interval.in_bottom_half(range) || interval.in_upper_half(range) { if interval.in_bottom_half(range) { + // The range is in the bottom half of the interval. range = interval.scale_bottom_half(range); + + // Move the input to follow the scaling of the range. + // The lowest bit is given a new bit from the input. input_buffer = (2 * input_buffer) | match bit(&mut precision, &mut offset, input) { Some(x) => x, None => { + // If an error happens then we print this and return. + // As decoding is only used for displaying the symbol. write!(f, "<EOF>")?; return Ok(()); } }; } else if interval.in_upper_half(range) { + // The range in the top half of the interval. range = interval.scale_upper_half(range); + + // Move the input to follow the scaling of the range. + // The lowest bit is given a new bit from the input. input_buffer = (2 * (input_buffer - interval.half())) | match bit(&mut precision, &mut offset, input) { Some(x) => x, None => { + // If an error happens then we print this and return. + // As decoding is only used for displaying the symbol. write!(f, "<EOF>")?; return Ok(()); } @@ -144,39 +299,65 @@ fn decode(input: u64, f: &mut core::fmt::Formatter) -> core::fmt::Result { } } + // Keep scaling the range when it contains the half point of the interval. while interval.in_middle_half(range) { + // The range is in the middle of the interval (contains the center). + // Scale it up by a factor of 2. range = interval.scale_middle_half(range); + + // Scale the input to follow the range. + // The lowest bit is given a new bit from the input. input_buffer = (2 * (input_buffer - interval.quarter())) | match bit(&mut precision, &mut offset, input) { Some(x) => x, None => { + // If an error happens then we print this and return. + // As decoding is only used for displaying the symbol. write!(f, "<EOF>")?; return Ok(()); } }; } + // At this point we know the symbol and are ready to find the next one + // after making sure the range is large enough for the precision. + + // Write out the symbol we found. core::fmt::Display::fmt(&(ALPHABET[symbol].0 as char), f)?; + len += 1; } Ok(()) } +// Read a bit from the input int. fn bit(precision: &mut u64, offset: &mut u64, input: u64) -> Option<u64> { + // We only have 64 bits of input possible from a u64. if *offset < 64 { + // Read the bit. let bit = (input & (1u64 << *offset)) != 0; + + // Move the cursor forward to read the next bit next time. *offset += 1; + Some(bit as _) } else { + // If there is no more precision then we have no more input. if *precision == 0 { return None; } + + // Lower the precision and give back a 0 bit as we have no more real + // information. *precision -= 1; Some(0) } } +// Arithmetic code a string into a int. +// +// This is a type of compression. const fn encode(input: &str) -> Result<u64, EncodeError> { // Interval and current range in that interval. let (interval, mut range) = Interval::new(); @@ -269,24 +450,25 @@ const fn encode(input: &str) -> Result<u64, EncodeError> { Ok(output) } +// Emit a bit into the output int. const fn emit( mut output: u64, mut offset: u64, mut pending_bit_count: usize, bit: bool, ) -> Option<(u64, u64, usize)> { - // Emit the bit. + // We can't emit a bit if we ran out of room. if offset + pending_bit_count as u64 >= 64 { return None; } + + // Emit the bit. output |= (bit as u64) << offset; offset += 1; // For each middle half scale emit the inverse bit. while pending_bit_count > 0 { - if offset + pending_bit_count as u64 >= 64 { - return None; - } + // Emit a inverted bit for each pending bit. output |= (!bit as u64) << offset; offset += 1; pending_bit_count -= 1; @@ -296,6 +478,10 @@ const fn emit( Some((output, offset, pending_bit_count)) } +// Calculate the range a symbol covers. +// +// The input `range` is the current range the symbol's range will be scaled to be +// inside of. const fn calculate_range(range: Range, symbol_index: usize, len: usize) -> Range { // The symbol's range needs to be scaled down by the current width. let new_width = range.high - range.low; @@ -310,20 +496,28 @@ const fn calculate_range(range: Range, symbol_index: usize, len: usize) -> Range } } +// Get the probability range for a symbol. +// +// Probabilities are p/65536. const fn probability(symbol_index: usize, len: usize) -> Range { if symbol_index == ALPHABET.len() { - let (_, table) = ALPHABET[ALPHABET.len() - 1]; + // If this is the \0 symbol then we use whatever is left. + // Lookup the probability of the last symbol for this len. + let (_, table) = ALPHABET[ALPHABET.len() - 1]; let t = table[len].0; let f = table[len].1; + // Use the remaining range. Range { low: t + f, high: 65536, } } else { + // Lookup the symbol. let (_, table) = ALPHABET[symbol_index]; + // Lookup the probability for this length. let t = table[len].0; let f = table[len].1; @@ -334,11 +528,14 @@ const fn probability(symbol_index: usize, len: usize) -> Range { } } +// Find the index of a symbol. const fn find_symbol(symbol: u8) -> Option<usize> { if symbol == b'\0' { - return Some(ALPHABET.len()) + // We use this special value as \0 isn't in the table. + return Some(ALPHABET.len()); } + // Look through the table. let mut i = 0; while i < ALPHABET.len() { if ALPHABET[i].0 == symbol { @@ -350,12 +547,16 @@ const fn find_symbol(symbol: u8) -> Option<usize> { } mod range { + // A range in an interval. #[derive(Debug, Copy, Clone)] pub struct Range { pub low: u64, pub high: u64, } + // An interval given by a precision. + // + // The range of the arithmetic coding "bounces" around this interval. #[derive(Debug, Copy, Clone)] pub struct Interval { one_quarter: u64, @@ -423,10 +624,12 @@ mod range { range } + /// Get the half way point. pub const fn half(&self) -> u64 { self.half } + /// Get the quarter point. pub const fn quarter(&self) -> u64 { self.one_quarter } @@ -453,21 +656,21 @@ mod test { } Err(err) => panic!("{:?}", err), }; - + // Format symbol as string to decode it. let x = format!("{}", s); - + // The symbol as a string should match the original string. assert_eq!(x, str); } } #[test] - #[should_panic(expected = "Symbol is too complex to encode. Try making it shorter.")] + #[should_panic(expected = "Tag is too complex to encode. Try making it shorter.")] fn too_complex_panic() { Symbol::new("0123456789"); } - + #[test] #[should_panic( expected = "Unknown character, supported characters are: a-z, A-Z, 0-9, _, -, and space." @@ -478,14 +681,16 @@ mod test { #[test] fn in_match() { + // This module is to give a path in the match instead of just a + // possible binding. mod s { use super::*; - + pub const A: Symbol = Symbol::new("A"); pub const B: Symbol = Symbol::new("B"); pub const C: Symbol = Symbol::new("C"); } - + match Symbol::new("B") { s::A => panic!(), s::B => {} @@ -497,27 +702,16 @@ mod test { #[test] fn in_generic() { struct X<const N: u64>; - + + // A symbol can become a int to use in a const generic. type Z = X<{ Symbol::new("Hello world").to_int() }>; let _: Z; } #[test] - fn demo() { - let n = Symbol::new("Hello").to_int().rotate_left(64); - for mut x in 0..1000 { - let mut y = x; - // y |= x << 25; - - - // println!("{:0>64b}", y); - println!("`{}`", Symbol::from_int(y)); - } - todo!(); - } - - #[test] fn examples() { + // Some fixed examples. + const _: Symbol = Symbol::new("eeeeeeeeeeeeeeee"); const _: Symbol = Symbol::new("XXX"); const _: Symbol = Symbol::new("Hello world"); @@ -533,7 +727,7 @@ mod test { const _: Symbol = Symbol::new("Example "); const _: Symbol = Symbol::new("ABCDEF"); const _: Symbol = Symbol::new(""); - + const _: Symbol = Symbol::new("Protocol"); const _: Symbol = Symbol::new("Attribute"); const _: Symbol = Symbol::new("TypeName"); @@ -557,6 +751,11 @@ mod test { #[test] fn alphabet() { + // This test is special. It generates the model table from a set + // of basic probability lists. If the table in the code is wrong + // compared to the one generated then this test will print out + // the correct table to be copy pasted into the code. + let lower = HashMap::<char, f64>::from([ ('e', 12.02), ('t', 9.10), @@ -652,9 +851,10 @@ mod test { let diff = all.iter().map(|(_, x)| x).sum::<f64>() - 1.0; assert!(diff.abs() < 0.0001, "{}", diff); + // This is how likely the end of the string is given the length. let scales: [f64; 18] = [ - 0.0001, 0.001, 0.01, 0.023, 0.048, 0.074, 0.103, 0.135, 0.169, 0.207, 0.250, 0.298, 0.353, 0.419, - 0.500, 0.603, 0.750, 0.95, + 0.0001, 0.001, 0.01, 0.023, 0.048, 0.074, 0.103, 0.135, 0.169, 0.207, 0.250, 0.298, + 0.353, 0.419, 0.500, 0.603, 0.750, 0.95, ]; let mut totals = [0u64; 18]; let with_len: Vec<(u8, [(u64, u64); 18])> = all @@ -672,6 +872,7 @@ mod test { .collect(); assert_eq!(with_len.len(), 26 + 26 + 10 + 3); + // If the table in the code is wrong then print the right one and panic. if ALPHABET != with_len { eprintln!("const ALPHABET: &[(u8, [(u64, u64); 18])] = &["); for (c, table) in &with_len { @@ -689,6 +890,7 @@ mod test { } } +// This is the model used in the arithmetic coding. #[rustfmt::skip] const ALPHABET: &[(u8, [(u64, u64); 18])] = &[ (b' ', [(0,1180),(0,1178),(0,1168),(0,1153),(0,1123),(0,1092),(0,1058),(0,1020),(0,980),(0,935),(0,885),(0,828),(0,763),(0,685),(0,590),(0,468),(0,0),(0,0),]), |