From c159f7445c038ef317f47c769a45d2b18e365972 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Thu, 22 Feb 2024 14:11:11 -0800 Subject: [PATCH 01/45] Bitcode rewrite. --- .cargo/config.toml | 2 +- .github/workflows/build.yml | 45 - .gitignore | 8 +- Cargo.toml | 36 +- README.md | 148 +-- bitcode_derive/Cargo.toml | 6 +- bitcode_derive/src/attribute.rs | 434 +------- bitcode_derive/src/decode.rs | 471 +++++++-- bitcode_derive/src/derive.rs | 294 ------ bitcode_derive/src/encode.rs | 453 +++++++-- bitcode_derive/src/huffman.rs | 68 -- bitcode_derive/src/lib.rs | 22 +- bitcode_derive/src/shared.rs | 70 ++ fuzz/.gitignore | 2 - fuzz/Cargo.toml | 9 +- fuzz/fuzz_targets/fuzz.rs | 116 ++- src/__private.rs | 159 --- src/benches.rs | 562 ++++------- src/benches_borrowed.rs | 130 +++ src/bit_buffer.rs | 228 ----- src/bool.rs | 79 ++ src/buffer.rs | 113 --- src/code.rs | 512 ---------- src/code_impls.rs | 1039 -------------------- src/coder.rs | 106 ++ src/consume.rs | 51 + src/derive/array.rs | 72 ++ src/derive/empty.rs | 27 + src/derive/impls.rs | 300 ++++++ src/derive/map.rs | 119 +++ src/derive/mod.rs | 219 +++++ src/derive/option.rs | 129 +++ src/derive/smart_ptr.rs | 76 ++ src/derive/variant.rs | 139 +++ src/derive/vec.rs | 394 ++++++++ src/encoding/bit_string/ascii.rs | 29 - src/encoding/bit_string/ascii_lowercase.rs | 31 - src/encoding/bit_string/bit_utils.rs | 35 - src/encoding/bit_string/mod.rs | 191 ---- src/encoding/expect_normalized_float.rs | 168 ---- src/encoding/expected_range_u64.rs | 199 ---- src/encoding/gamma.rs | 149 --- src/encoding/mod.rs | 124 --- src/encoding/prelude.rs | 120 --- src/error.rs | 47 + src/ext/arrayvec.rs | 220 +++++ src/ext/glam.rs | 42 + src/ext/mod.rs | 69 ++ src/f32.rs | 149 +++ src/fast.rs | 522 ++++++++++ src/guard.rs | 47 - src/histogram.rs | 95 ++ src/int.rs | 143 +++ src/length.rs | 268 +++++ src/lib.rs | 332 ++----- src/nightly.rs | 62 +- src/pack.rs | 714 ++++++++++++++ src/pack_ints.rs | 461 +++++++++ src/read.rs | 183 ---- src/register_buffer.rs | 157 --- src/serde/de.rs | 573 ++++++++--- src/serde/guard.rs | 22 + src/serde/mod.rs | 87 +- src/serde/ser.rs | 493 ++++++++-- src/serde/variant.rs | 71 ++ src/str.rs | 216 ++++ src/tests.rs | 560 ----------- src/u8_char.rs | 43 + src/word.rs | 5 - src/word_buffer.rs | 595 ----------- src/write.rs | 186 ---- 71 files changed, 7042 insertions(+), 7004 deletions(-) delete mode 100644 .github/workflows/build.yml delete mode 100644 bitcode_derive/src/derive.rs delete mode 100644 bitcode_derive/src/huffman.rs create mode 100644 bitcode_derive/src/shared.rs delete mode 100644 src/__private.rs create mode 100644 src/benches_borrowed.rs delete mode 100644 src/bit_buffer.rs create mode 100644 src/bool.rs delete mode 100644 src/buffer.rs delete mode 100644 src/code.rs delete mode 100644 src/code_impls.rs create mode 100644 src/coder.rs create mode 100644 src/consume.rs create mode 100644 src/derive/array.rs create mode 100644 src/derive/empty.rs create mode 100644 src/derive/impls.rs create mode 100644 src/derive/map.rs create mode 100644 src/derive/mod.rs create mode 100644 src/derive/option.rs create mode 100644 src/derive/smart_ptr.rs create mode 100644 src/derive/variant.rs create mode 100644 src/derive/vec.rs delete mode 100644 src/encoding/bit_string/ascii.rs delete mode 100644 src/encoding/bit_string/ascii_lowercase.rs delete mode 100644 src/encoding/bit_string/bit_utils.rs delete mode 100644 src/encoding/bit_string/mod.rs delete mode 100644 src/encoding/expect_normalized_float.rs delete mode 100644 src/encoding/expected_range_u64.rs delete mode 100644 src/encoding/gamma.rs delete mode 100644 src/encoding/mod.rs delete mode 100644 src/encoding/prelude.rs create mode 100644 src/error.rs create mode 100644 src/ext/arrayvec.rs create mode 100644 src/ext/glam.rs create mode 100644 src/ext/mod.rs create mode 100644 src/f32.rs create mode 100644 src/fast.rs delete mode 100644 src/guard.rs create mode 100644 src/histogram.rs create mode 100644 src/int.rs create mode 100644 src/length.rs create mode 100644 src/pack.rs create mode 100644 src/pack_ints.rs delete mode 100644 src/read.rs delete mode 100644 src/register_buffer.rs create mode 100644 src/serde/guard.rs create mode 100644 src/serde/variant.rs create mode 100644 src/str.rs delete mode 100644 src/tests.rs create mode 100644 src/u8_char.rs delete mode 100644 src/word.rs delete mode 100644 src/word_buffer.rs delete mode 100644 src/write.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index ddff440..d5135e9 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,2 @@ [build] -rustflags = ["-C", "target-cpu=native"] +rustflags = ["-C", "target-cpu=native"] \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index aefb774..0000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Build - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - # Nightly toolchain must ship the `rust-std` component for - # `i686-unknown-linux-gnu` and `mips64-unknown-linux-gnuabi64`. - # In practice, `rust-std` almost always ships for - # `i686-unknown-linux-gnu` so we just need to check this page for a - # compatible nightly: - # https://rust-lang.github.io/rustup-components-history/mips64-unknown-linux-gnuabi64.html - toolchain: nightly-2023-07-04 - override: true - components: rustfmt, miri - - name: Lint - run: cargo fmt --check - - name: Test (debug) - run: cargo test - - name: Install i686 and GCC multilib - run: rustup target add i686-unknown-linux-gnu && sudo apt update && sudo apt install -y gcc-multilib - - name: Test (32-bit) - run: cargo test --target i686-unknown-linux-gnu - - name: Setup Miri - run: cargo miri setup - - name: Test (miri) - run: MIRIFLAGS="-Zmiri-permissive-provenance" cargo miri test - - name: Setup Miri (big-endian) - run: rustup target add mips64-unknown-linux-gnuabi64 && cargo miri setup --target mips64-unknown-linux-gnuabi64 - - name: Test (miri big-endian) - run: MIRIFLAGS="-Zmiri-permissive-provenance" cargo miri test --target mips64-unknown-linux-gnuabi64 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 4b3ff11..2de8d77 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,4 @@ -/target -/Cargo.lock -/bitcode_derive/Cargo.lock +target/ +Cargo.lock +perf.* .idea -perf.data -perf.data.old diff --git a/Cargo.toml b/Cargo.toml index 3dcf281..0a9639d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,40 +6,36 @@ members = [ [package] name = "bitcode" authors = [ "Cai Bear", "Finn Bear" ] -version = "0.5.1" +version = "0.6.0" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode" description = "bitcode is a bitwise binary serializer" exclude = ["fuzz/"] +publish = false # TODO remove when ready (also remove in bitcode_derive). [dependencies] -bitcode_derive = { version = "0.5.0", path="./bitcode_derive", optional = true } -bytemuck = { version = "1.13", features = [ "extern_crate_alloc" ] } -from_bytes_or_zeroed = "0.1" -residua-zigzag = "0.1.0" +arrayvec = { version = "0.7", default-features = false, optional = true } +bitcode_derive = { version = "0.6.0", path = "./bitcode_derive", optional = true } +bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] } +glam = { version = "0.22", default-features = false, features = [ "std" ], optional = true } serde = { version = "1.0", optional = true } -simdutf8 = { version = "0.1.4", optional = true } [dev-dependencies] -arrayvec = { version = "0.7.2", features = [ "serde" ] } +arrayvec = { version = "0.7", features = ["serde"] } bincode = "1.3.3" -bitvec = { version = "1.0.1" } -flate2 = "1.0.25" +flate2 = "1.0.28" lz4_flex = "0.10.0" -musli = "0.0.42" -paste = "1.0.12" -postcard = { version = "1.0", features = ["alloc"] } -rand = { version = "0.8.5", default-features = false } +paste = "1.0.14" +rand = "0.8.5" rand_chacha = "0.3.1" -serde = { version = "1.0.159", features = [ "derive" ] } +serde = { version = "1.0", features = [ "derive" ] } +zstd = "0.13.0" [features] derive = [ "bitcode_derive" ] -default = [ "derive", "simdutf8" ] +default = [ "derive" ] -[package.metadata.docs.rs] -features = ["serde"] - -[profile.bench] -lto = true +# TODO halfs speed of benches_borrowed::bench_splitcode_decode +#[profile.bench] +#lto = true diff --git a/README.md b/README.md index e84985b..9197745 100644 --- a/README.md +++ b/README.md @@ -2,111 +2,63 @@ [![Documentation](https://docs.rs/bitcode/badge.svg)](https://docs.rs/bitcode) [![crates.io](https://img.shields.io/crates/v/bitcode.svg)](https://crates.io/crates/bitcode) [![Build](https://github.com/SoftbearStudios/bitcode/actions/workflows/build.yml/badge.svg)](https://github.com/SoftbearStudios/bitcode/actions/workflows/build.yml) -[![unsafe forbidden](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/) -A bitwise encoder/decoder similar to [bincode](https://github.com/bincode-org/bincode), which attempts to shrink the serialized size without sacrificing speed (as would be the case with compression). - -The format may change between major versions, so we are free to optimize it. - -## Comparison with [bincode](https://github.com/bincode-org/bincode) - -### Features - -- Bitwise serialization -- [Gamma](https://en.wikipedia.org/wiki/Elias_gamma_coding) encoded lengths and enum variant indices - -### Additional features with `#[derive(bitcode::Encode, bitcode::Decode)]` - -- Enums use the fewest possible bits, e.g. an enum with 4 variants uses 2 bits -- Apply attributes to fields/enum variants: - -| Attribute | Type | Result | -|-----------------------------------------------|---------------|------------------------------------------------------------------------------------------------------------| -| `#[bitcode_hint(ascii)]` | String | Uses 7 bits per character | -| `#[bitcode_hint(ascii_lowercase)]` | String | Uses 5 bits per character | -| `#[bitcode_hint(expected_range = "50..100"]` | u8-u64 | Uses log2(range.end - range.start) bits | -| `#[bitcode_hint(expected_range = "0.0..1.0"]` | f32/f64 | Uses ~25 bits for `f32` and ~54 bits for `f64` | -| `#[bitcode_hint(frequency = 123)` | enum variant | Frequent variants use fewer bits (see [Huffman coding](https://en.wikipedia.org/wiki/Huffman_coding)) | -| `#[bitcode_hint(gamma)]` | i8-i64/u8-u64 | Small integers use fewer bits (see [Elias gamma coding](https://en.wikipedia.org/wiki/Elias_gamma_coding)) | -| `#[bitcode(with_serde)]` | T: Serialize | Uses `serde::Serialize` instead of `bitcode::Encode` | - -### Limitations - -- Doesn't support streaming APIs -- Format may change between major versions -- With `feature = "derive"`, types containing themselves must use `#[bitcode(recursive)]` to compile - -## Benchmarks vs. [bincode](https://github.com/bincode-org/bincode) and [postcard](https://github.com/jamesmunns/postcard) - -### Primitives (size in bits) - -| Type | Bitcode (derive) | Bitcode (serde) | Bincode | Bincode (varint) | Postcard | -|---------------------|------------------|-----------------|---------|------------------|----------| -| bool | 1 | 1 | 8 | 8 | 8 | -| u8/i8 | 8 | 8 | 8 | 8 | 8 | -| u16/i16 | 16 | 16 | 16 | 8-24 | 8-24 | -| u32/i32 | 32 | 32 | 32 | 8-40 | 8-40 | -| u64/i64 | 64 | 64 | 64 | 8-72 | 8-80 | -| u128/i128 | 128 | 128 | 128 | 8-136 | 8-152 | -| usize/isize | 64 | 64 | 64 | 8-72 | 8-80 | -| f32 | 32 | 32 | 32 | 32 | 32 | -| f64 | 64 | 64 | 64 | 64 | 64 | -| char | 21 | 21 | 8-32 | 8-32 | 16-40 | -| Option<()> | 1 | 1 | 8 | 8 | 8 | -| Result<(), ()> | 1 | 1-3 | 32 | 8 | 8 | -| enum { A, B, C, D } | 2 | 1-5 | 32 | 8 | 8 | -| Duration | 94 | 96 | 96 | 16-112 | 16-120 | - -Note: These are defaults, and can be optimized with hints in the case of Bitcode (derive) or custom `impl Serialize` in the case of `serde` serializers. - -### Values (size in bits) - -| Value | Bitcode (derive) | Bitcode (serde) | Bincode | Bincode (varint) | Postcard | -|---------------------|------------------|-----------------|---------|------------------|----------| -| [true; 4] | 4 | 4 | 32 | 32 | 32 | -| vec![(); 0] | 1 | 1 | 64 | 8 | 8 | -| vec![(); 1] | 3 | 3 | 64 | 8 | 8 | -| vec![(); 256] | 17 | 17 | 64 | 24 | 16 | -| vec![(); 65536] | 33 | 33 | 64 | 40 | 24 | -| "" | 1 | 1 | 64 | 8 | 8 | -| "abcd" | 37 | 37 | 96 | 40 | 40 | -| "abcd1234" | 71 | 71 | 128 | 72 | 72 | - - -### Random [Structs and Enums](https://github.com/SoftbearStudios/bitcode/blob/2a47235eee64f4a7c49ad1841a5b509abd2d0e99/src/benches.rs#L16-L88) (average size and speed) - -| Format | Size (bytes) | Serialize (ns) | Deserialize (ns) | -|------------------------|--------------|----------------|------------------| -| Bitcode (derive) | 6.2 | 14 | 50 | -| Bitcode (serde) | 6.7 | 18 | 59 | -| Bincode | 20.3 | 17 | 61 | -| Bincode (varint) | 10.9 | 26 | 68 | -| Bincode (LZ4) | 9.9 | 58 | 73 | -| Bincode (Deflate Fast) | 8.4 | 336 | 279 | -| Bincode (Deflate Best) | 7.8 | 1990 | 275 | -| Postcard | 10.7 | 21 | 57 | - -### More benchmarks - -[rust_serialization_benchmark](https://david.kolo.ski/rust_serialization_benchmark/) - -## Acknowledgement - -Some test cases were derived from [bincode](https://github.com/bincode-org/bincode) (see comment in `tests.rs`). +A binary encoder/decoder with the following goals: +- 🔥 Blazingly fast +- 🐁 Tiny serialized size +- 💎 Highly compressible by Deflate/LZ4/Zstd + +In contrast, these are non-goals: +- Stable format across major versions +- Self describing format +- Compatibility with languages other than Rust + +See [rust_serialization_benchmark](https://github.com/djkoloski/rust_serialization_benchmark) for benchmarks. + +## Example +```rust +use bitcode::{Encode, Decode}; + +#[derive(Encode, Decode, PartialEq, Debug)] +struct Foo<'a> { + x: u32, + y: &'a str, +} + +fn main() { + let original = Foo { + x: 10, + y: "abc", + }; + + let encoded: Vec = bitcode::encode(&original); // No error + let decoded: Foo<'_> = bitcode::decode(&encoded).unwrap(); + assert_eq!(original, decoded); +} +``` + +## Tuple vs Array +If you have multiple values of the same type: +- Use a tuple or struct when the values are semantically different: `x: u32, y: u32` +- Use an array when all values are semantically similar: `pixels: [u8; 16]` + +## Implementation Details +- Heavily inspired by +- All instances of each field are grouped together making compression easier +- Uses smaller integers where possible all the way down to 1 bit +- Validation is performed up front on typed vectors before deserialization +- Code is designed to be auto-vectorized by LLVM ## License - Licensed under either of - - * Apache License, Version 2.0 - ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) - * MIT license - ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +* Apache License, Version 2.0 + ([LICENSE-APACHE](LICENSE-APACHE) or ) +* MIT license + ([LICENSE-MIT](LICENSE-MIT) or ) at your option. ## Contribution - Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be -dual licensed as above, without any additional terms or conditions. \ No newline at end of file +dual licensed as above, without any additional terms or conditions. diff --git a/bitcode_derive/Cargo.toml b/bitcode_derive/Cargo.toml index 7527c8c..3bb0d19 100644 --- a/bitcode_derive/Cargo.toml +++ b/bitcode_derive/Cargo.toml @@ -1,17 +1,17 @@ [package] name = "bitcode_derive" authors = [ "Cai Bear", "Finn Bear" ] -version = "0.5.0" +version = "0.6.0" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode/" description = "Implementation of #[derive(Encode, Decode)] for bitcode" +publish = false # TODO remove when ready [lib] proc-macro = true [dependencies] -packagemerge = "0.1" proc-macro2 = "1.0" quote = "1.0" -syn = { version = "2.0.3", features = [ "extra-traits" ] } +syn = { version = "2.0.3", features = [ "extra-traits", "visit-mut" ] } diff --git a/bitcode_derive/src/attribute.rs b/bitcode_derive/src/attribute.rs index 38809cc..a9e944d 100644 --- a/bitcode_derive/src/attribute.rs +++ b/bitcode_derive/src/attribute.rs @@ -1,85 +1,35 @@ -use crate::huffman::huffman; -use crate::{err, error, private}; +use crate::{err, error}; use proc_macro2::TokenStream; -use quote::quote; use std::str::FromStr; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{parse2, Attribute, DataEnum, Expr, Lit, Meta, Path, Result, Token, Type}; +use syn::{parse2, Attribute, Expr, ExprLit, Lit, Meta, Path, Result, Token, Type}; enum BitcodeAttr { BoundType(Type), - Encoding(Encoding), - Frequency(f64), - Recursive, - WithSerde, } impl BitcodeAttr { - fn new(nested: &Meta, is_hint: bool) -> Result { + fn new(nested: &Meta) -> Result { let path = path_ident_string(nested.path(), &nested)?; match path.as_str() { - _ if is_hint => match nested { - Meta::Path(p) => { - let encoding = match path.as_str() { - "ascii" => Encoding::Ascii, - "ascii_lowercase" => Encoding::AsciiLowercase, - "fixed" => Encoding::Fixed, - "gamma" => Encoding::Gamma, - _ => return err(p, "unknown hint"), - }; - Ok(Self::Encoding(encoding)) - } - Meta::NameValue(name_value) => { - let expr = &name_value.value; - let expr_lit = match expr { - Expr::Lit(expr_lit) => expr_lit, - _ => return err(&expr, "expected literal"), - }; - - match path.as_str() { - "frequency" => { - let frequency: f64 = match &expr_lit.lit { - Lit::Float(float_lit) => float_lit.base10_parse::().unwrap(), - Lit::Int(int_lit) => int_lit.base10_parse::().unwrap(), - _ => return err(expr_lit, "expected number"), - }; - Ok(Self::Frequency(frequency)) - } - "expected_range" => Ok(BitcodeAttr::Encoding(match &expr_lit.lit { - Lit::Str(str_lit) => { - let range = str_lit.value(); - parse_expected_range(&range).map_err(|s| error(expr_lit, s))? - } - _ => return err(expr_lit, "expected str"), - })), - _ => err(&name_value, "unknown hint"), - } - } - _ => err(&nested, "unknown hint"), - }, "bound_type" => match nested { Meta::NameValue(name_value) => { let expr = &name_value.value; - let expr_lit = match expr { - Expr::Lit(expr_lit) => expr_lit, - _ => return err(&expr, "expected literal"), + let str_lit = match expr { + Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }) => v, + _ => return err(&expr, "expected string e.g. \"T\""), }; - match &expr_lit.lit { - Lit::Str(str_lit) => { - let value = TokenStream::from_str(&str_lit.value()).unwrap(); - Ok(Self::BoundType( - parse2(value).map_err(|e| error(str_lit, &format!("{e}")))?, - )) - } - _ => err(expr_lit, "expected str"), - } + let value = TokenStream::from_str(&str_lit.value()).unwrap(); + Ok(Self::BoundType( + parse2(value).map_err(|e| error(str_lit, &format!("{e}")))?, + )) } _ => err(&nested, "expected name value"), }, - "recursive" if matches!(nested, Meta::Path(_)) => Ok(Self::Recursive), - "with_serde" if matches!(nested, Meta::Path(_)) => Ok(Self::WithSerde), _ => err(&nested, "unknown attribute"), } } @@ -96,127 +46,26 @@ impl BitcodeAttr { return err(nested, "can only apply bound to fields"); } } - Self::Encoding(encoding) => { - if attrs.encoding.is_some() { - return err(nested, "duplicate"); - } - attrs.encoding = Some(encoding); - } - Self::Frequency(w) => { - if let AttrType::Variant { frequency, .. } = &mut attrs.attr_type { - if frequency.is_some() { - return err(nested, "duplicate"); - } - *frequency = Some(w); - } else { - return err(nested, "can only apply frequency to enum variants"); - } - } - Self::Recursive => { - if let AttrType::Derive { recursive, .. } = &mut attrs.attr_type { - if *recursive { - return err(nested, "duplicate"); - } - *recursive = true; - } else { - return err(nested, "can only apply frequency to enum variants"); - } - } - Self::WithSerde => { - if attrs.with_serde { - return err(nested, "duplicate"); - } - attrs.with_serde = true; - } } Ok(()) } } -#[derive(Copy, Clone, Debug)] -enum Encoding { - Ascii, - AsciiLowercase, - Fixed, - ExpectNormalizedFloat, - ExpectedRangeU64 { min: u64, max: u64 }, - Gamma, -} - -impl Encoding { - fn tokens(&self) -> TokenStream { - let private = private(); - match self { - Self::Ascii => quote! { #private::BitString(#private::Ascii) }, - Self::AsciiLowercase => quote! { #private::BitString(#private::AsciiLowercase) }, - Self::Fixed => quote! { #private::Fixed }, - Self::ExpectNormalizedFloat => quote! { #private::ExpectNormalizedFloat }, - Self::ExpectedRangeU64 { min, max } => { - quote! { - #private::ExpectedRangeU64::<#min, #max> - } - } - Self::Gamma => quote! { #private::Gamma }, - } - } -} - #[derive(Clone)] pub struct BitcodeAttrs { attr_type: AttrType, - encoding: Option, - with_serde: bool, } #[derive(Clone)] enum AttrType { - Derive { - recursive: bool, - }, - Variant { - derive_attrs: Box, - frequency: Option, - }, - Field { - parent_attrs: Box, - bound_type: Option, - }, + Derive, + Variant, + Field { bound_type: Option }, } impl BitcodeAttrs { fn new(attr_type: AttrType) -> Self { - Self { - attr_type, - encoding: None, - with_serde: false, - } - } - - fn parent(&self) -> Option<&Self> { - match &self.attr_type { - AttrType::Derive { .. } => None, - AttrType::Variant { derive_attrs, .. } => Some(derive_attrs), - AttrType::Field { parent_attrs, .. } => Some(parent_attrs), - } - } - - pub fn is_recursive(&self) -> bool { - match &self.attr_type { - AttrType::Derive { recursive, .. } => *recursive, - AttrType::Variant { derive_attrs, .. } => derive_attrs.is_recursive(), - AttrType::Field { parent_attrs, .. } => parent_attrs.is_recursive(), - } - } - - pub fn with_serde(&self) -> bool { - if self.with_serde { - return true; - } - if let Some(parent) = self.parent() { - parent.with_serde() - } else { - false - } + Self { attr_type } } pub fn bound_type(&self) -> Option { @@ -226,39 +75,21 @@ impl BitcodeAttrs { } } - // Gets the most specific encoding. For example field encoding overrides variant encoding which - // overrides enum encoding. - fn most_specific_encoding(&self) -> Option { - self.encoding - .or_else(|| self.parent().and_then(|p| p.most_specific_encoding())) - } - - pub fn get_encoding(&self) -> Option { - let encoding = self.most_specific_encoding(); - encoding.map(|e| e.tokens()) - } - pub fn parse_derive(attrs: &[Attribute]) -> Result { - let mut ret = Self::new(AttrType::Derive { recursive: false }); + let mut ret = Self::new(AttrType::Derive); ret.parse_inner(attrs)?; Ok(ret) } - pub fn parse_variant(attrs: &[Attribute], derive_attrs: &Self) -> Result { - let mut ret = Self::new(AttrType::Variant { - derive_attrs: Box::new(derive_attrs.clone()), - frequency: None, - }); + #[allow(unused)] // TODO + pub fn parse_variant(attrs: &[Attribute], _derive_attrs: &Self) -> Result { + let mut ret = Self::new(AttrType::Variant); ret.parse_inner(attrs)?; Ok(ret) } - /// `parent_attrs` is either derive or variant attrs. - pub fn parse_field(attrs: &[Attribute], parent_attrs: &Self) -> Result { - let mut ret = Self::new(AttrType::Field { - parent_attrs: Box::new(parent_attrs.clone()), - bound_type: None, - }); + pub fn parse_field(attrs: &[Attribute], _parent_attrs: &Self) -> Result { + let mut ret = Self::new(AttrType::Field { bound_type: None }); ret.parse_inner(attrs)?; Ok(ret) } @@ -266,182 +97,19 @@ impl BitcodeAttrs { fn parse_inner(&mut self, attrs: &[Attribute]) -> Result<()> { for attr in attrs { let path = path_ident_string(attr.path(), attr)?; - let is_hint = match path.as_str() { - "bitcode" => false, - "bitcode_hint" => true, - _ => continue, // Ignore all other attributes. - }; + if path.as_str() != "bitcode" { + continue; // Ignore all other attributes. + } let nested = attr.parse_args_with(Punctuated::::parse_terminated)?; for nested in nested { - BitcodeAttr::new(&nested, is_hint)?.apply(self, &nested)?; + BitcodeAttr::new(&nested)?.apply(self, &nested)?; } } Ok(()) } } -#[derive(Copy, Clone, Debug)] -pub struct PrefixCode { - pub value: u32, - pub bits: usize, -} - -impl PrefixCode { - fn format_code(&self) -> TokenStream { - // TODO leading zeros up to bits. - let binary = format!("{:#b}", self.value); - TokenStream::from_str(&binary).unwrap() - } - - fn format_mask(&self) -> TokenStream { - let mask = (1u64 << self.bits) - 1; - let binary = format!("{:#b}", mask); - TokenStream::from_str(&binary).unwrap() - } -} - -pub struct VariantEncoding { - variant_count: u32, - codes: Option>, -} - -impl VariantEncoding { - pub fn parse_data_enum(data_enum: &DataEnum, attrs: &BitcodeAttrs) -> Result { - let variant_count = data_enum.variants.len() as u32; - - let codes = if variant_count >= 2 { - let frequencies: Result> = data_enum - .variants - .iter() - .map(|variant| { - if let AttrType::Variant { frequency, .. } = - BitcodeAttrs::parse_variant(&variant.attrs, attrs)?.attr_type - { - Ok(frequency.unwrap_or(1.0)) - } else { - unreachable!() - } - }) - .collect(); - - let frequencies = frequencies?; - - Some(huffman(&frequencies, 32)) - } else { - None - }; - - Ok(Self { - variant_count, - codes, - }) - } - - fn iter_codes(&self) -> impl Iterator + '_ { - self.codes.as_ref().unwrap().iter().enumerate() - } - - pub fn encode_variants( - &self, - // variant_index, before_fields, variant_bits - mut encode: impl FnMut(usize, TokenStream, usize) -> Result, - ) -> Result { - // if variant_count is 0 or 1 no encoding is required. - Ok(match self.variant_count { - 0 => quote! {}, - 1 => { - let encode_variant = encode(0, quote! {}, 0)?; - quote! { - match self { - #encode_variant - } - } - } - _ => { - let variants: Result = self - .iter_codes() - .map(|(i, prefix_code)| { - let code = prefix_code.format_code(); - let bits = prefix_code.bits; - - encode( - i, - quote! { - enc_variant!(#code, #bits); - }, - bits, - ) - }) - .collect(); - let variants = variants?; - - quote! { - match self { - #variants - } - } - } - }) - } - - pub fn decode_variants( - &self, - // variant_index, before_fields, variant_bits - mut decode: impl FnMut(usize, TokenStream, usize) -> Result, - ) -> Result { - // if variant_count is 0 or 1 no encoding is required. - Ok(match self.variant_count { - 0 => { - let private = private(); - quote! { - end_dec!(); // No variants so we call to avoid unused warning. - Err(#private::invalid_variant()) - } - } - 1 => { - let decode_variant = decode(0, quote! {}, 0)?; - quote! { - Ok({#decode_variant}) - } - } - _ => { - let variants: Result = self - .iter_codes() - .map(|(i, prefix_code)| { - let mask = prefix_code.format_mask(); - let code = prefix_code.format_code(); - let bits = prefix_code.bits; - let decode_variant = decode(i, quote! {}, bits)?; - - // Match anything as the last pattern to avoid _ => unreachable!(). - let pattern = if i == self.variant_count as usize - 1 { - quote! { _ } - } else { - quote! { b if b & #mask == #code } - }; - - Ok(quote! { - #pattern => { - dec_variant_advance!(#bits); - #decode_variant - } - }) - }) - .collect(); - let variants = variants?; - - quote! { - #[allow(clippy::verbose_bit_mask)] - Ok(match dec_variant_peek!() { - #variants, - }) - } - } - }) - } -} - fn path_ident_string(path: &Path, spanned: &impl Spanned) -> Result { if let Some(path) = path.get_ident() { Ok(path.to_string()) @@ -449,53 +117,3 @@ fn path_ident_string(path: &Path, spanned: &impl Spanned) -> Result { err(spanned, "expected ident") } } - -type Result2 = std::result::Result; - -fn parse_expected_range(range: &str) -> Result2 { - range - .split_once("..") - .and_then(|(min, max)| { - parse_expected_range_u64(min, max) - .or_else(|| parse_expected_range_i64(min, max)) - .or_else(|| parse_expected_range_f64(min, max)) - }) - .unwrap_or(Err("not a range, e.g. 0..1")) -} - -fn parse_expected_range_u64(min: &str, max: &str) -> Option> { - let min = u64::from_str(min).ok()?; - let max = u64::from_str(max).ok()?; - Some(if min >= max { - Err("the lower bound must be less than the higher bound") - } else { - Ok(Encoding::ExpectedRangeU64 { min, max }) - }) -} - -fn parse_expected_range_i64(min: &str, max: &str) -> Option> { - let min = i64::from_str(min).ok()?; - let max = i64::from_str(max).ok()?; - Some(if min >= max { - Err("the lower bound must be less than the higher bound") - } else { - Err("signed integer ranges are not yet supported") - }) -} - -fn parse_expected_range_f64(min: &str, max: &str) -> Option> { - let either_int = i64::from_str(min).is_ok() || i64::from_str(max).is_ok(); - - let min = f64::from_str(min).ok()?; - let max = f64::from_str(max).ok()?; - - Some(if either_int { - Err("both bounds must be floats or ints") - } else if min >= max { - Err("the start bound must be less than the end bound") - } else if (min..max) != (0.0..1.0) { - Err("float ranges other than 0.0..1.0 are not yet supported") - } else { - Ok(Encoding::ExpectNormalizedFloat) - }) -} diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index 4f7b317..7b5bdac 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -1,115 +1,412 @@ -use crate::derive::{unwrap_encoding, Derive}; -use crate::private; -use proc_macro2::TokenStream; +use crate::attribute::BitcodeAttrs; +use crate::bound::FieldBounds; +use crate::shared::{ + destructure_fields, field_name, remove_lifetimes, replace_lifetimes, ReplaceSelves, +}; +use crate::{err, private}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use syn::{parse_quote, Path, Type}; +use syn::{ + parse_quote, Data, DeriveInput, Fields, GenericParam, Lifetime, LifetimeParam, Path, + PredicateLifetime, Result, Type, WherePredicate, +}; -pub struct Decode; +const DE_LIFETIME: &str = "__de"; +fn de_lifetime() -> Lifetime { + parse_quote!('__de) // Must match DE_LIFETIME. +} -impl Derive for Decode { - fn serde_impl(&self) -> (TokenStream, Path) { - let private = private(); - ( - quote! { - end_dec!(); - #private::deserialize_compat(encoding, reader) - }, - parse_quote!(#private::DeserializeOwned), - ) - } +#[derive(Copy, Clone)] +#[repr(u8)] +enum Item { + Type, + Default, + Populate, + Decode, + DecodeInPlace, +} + +impl Item { + const ALL: [Self; 5] = [ + Self::Type, + Self::Default, + Self::Populate, + Self::Decode, + Self::DecodeInPlace, + ]; + const COUNT: usize = Self::ALL.len(); fn field_impl( - &self, - with_serde: bool, + self, field_name: TokenStream, + global_field_name: TokenStream, + real_field_name: TokenStream, field_type: &Type, - encoding: Option, - ) -> (TokenStream, Path) { - let private = private(); - if with_serde { - let encoding = unwrap_encoding(encoding); - ( - // Field is using serde making DECODE_MAX unknown so we flush the current register - // buffer and read directly from the reader. See optimized_dec! macro in code.rs. - quote! { - let #field_name = #private::deserialize_compat(#encoding, flush!())?; - }, - parse_quote!(#private::DeserializeOwned), - ) - } else if let Some(encoding) = encoding { - ( - // Field has an encoding making DECODE_MAX unknown so we flush the current register - // buffer and read directly from the reader. See optimized_dec! macro in code.rs. + ) -> TokenStream { + match self { + Self::Type => { + let de_type = replace_lifetimes(field_type, DE_LIFETIME); + let private = private(); + let de = de_lifetime(); quote! { - let #field_name = #private::Decode::decode(#encoding, flush!())?; - }, - parse_quote!(#private::Decode), - ) - } else { - ( - // Field has a known DECODE_MAX. dec! will evaluate if it can read from the current - // register buffer. See optimized_dec! macro in code.rs. + #global_field_name: <#de_type as #private::Decode<#de>>::Decoder, + } + } + Self::Default => quote! { + #global_field_name: Default::default(), + }, + Self::Populate => quote! { + self.#global_field_name.populate(input, __length)?; + }, + Self::Decode => quote! { + let #field_name = self.#global_field_name.decode(); + }, + Self::DecodeInPlace => { + let de_type = replace_lifetimes(field_type, DE_LIFETIME); + let private = private(); quote! { - dec!(#field_name, #field_type); - }, - parse_quote!(#private::Decode), - ) + self.#global_field_name.decode_in_place(#private::uninit_field!(out.#real_field_name: #de_type)); + } + } } } - fn struct_impl(&self, destructure_fields: TokenStream, do_fields: TokenStream) -> TokenStream { - quote! { - #do_fields - end_dec!(); - Ok(Self #destructure_fields) + fn struct_impl( + self, + ident: &Ident, + destructure_fields: &TokenStream, + do_fields: &TokenStream, + ) -> TokenStream { + match self { + Self::Decode => { + quote! { + #do_fields + #ident #destructure_fields + } + } + _ => quote! { #do_fields }, } } - fn variant_impl( - &self, - before_fields: TokenStream, - field_impls: TokenStream, - destructure_variant: TokenStream, + pub fn variant_impls( + self, + variant_count: usize, + mut pattern: impl FnMut(usize) -> TokenStream, + mut inner: impl FnMut(Self, usize) -> TokenStream, ) -> TokenStream { - quote! { - #before_fields - #field_impls - end_dec!(); - #destructure_variant + // if variant_count is 0 or 1 variants don't have to be decoded. + let decode_variants = variant_count > 1; + let never = variant_count == 0; + + match self { + Self::Type => { + let de = de_lifetime(); + let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); + let variants = decode_variants + .then(|| { + let private = private(); + let c_style = inners.is_empty(); + quote! { variants: #private::VariantDecoder<#de, #variant_count, #c_style>, } + }) + .unwrap_or_default(); + quote! { + #variants + #inners + } + } + Self::Default => { + let variants = decode_variants + .then(|| quote! { variants: Default::default(), }) + .unwrap_or_default(); + let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); + quote! { + #variants + #inners + } + } + Self::Populate => { + if never { + let private = private(); + return quote! { + if __length != 0 { + return #private::invalid_enum_variant(); + } + }; + } + + let variants = decode_variants + .then(|| { + quote! { self.variants.populate(input, __length)?; } + }) + .unwrap_or_default(); + let inners: TokenStream = (0..variant_count) + .map(|i| { + let inner = inner(self, i); + if inner.is_empty() { + quote! {} + } else { + let i: u8 = i + .try_into() + .expect("enums with more than 256 variants are not supported"); // TODO don't panic. + let length = decode_variants + .then(|| { + quote! { + let __length = self.variants.length(#i); + } + }) + .unwrap_or_default(); + quote! { + #length + #inner + } + } + }) + .collect(); + quote! { + #variants + #inners + } + } + Self::Decode | Self::DecodeInPlace => { + if never { + return quote! { + // Safety: View::populate will error on length != 0 so decode won't be called. + unsafe { std::hint::unreachable_unchecked() } + }; + } + let mut pattern = |i: usize| { + let pattern = pattern(i); + matches!(self, Self::DecodeInPlace) + .then(|| { + quote! { + out.write(#pattern); + } + }) + .unwrap_or(pattern) + }; + let item = Self::Decode; // DecodeInPlace doesn't work on enums. + + decode_variants + .then(|| { + let variants: TokenStream = (0..variant_count) + .map(|i| { + let inner = inner(item, i); + let pattern = pattern(i); + let i: u8 = i.try_into().unwrap(); // Already checked in reserve impl. + quote! { + #i => { + #inner + #pattern + }, + } + }) + .collect(); + quote! { + match self.variants.decode() { + #variants + // Safety: VariantDecoder::decode outputs numbers less than N. + _ => unsafe { std::hint::unreachable_unchecked() } + } + } + }) + .or_else(|| { + (variant_count == 1).then(|| { + let inner = inner(item, 0); + let pattern = pattern(0); + quote! { + #inner + #pattern + } + }) + }) + .unwrap_or_default() + } } } - fn is_encode(&self) -> bool { - false - } + // TODO dedup with encode.rs + fn field_impls( + self, + global_prefix: Option<&str>, + fields: &Fields, + parent_attrs: &BitcodeAttrs, + bounds: &mut FieldBounds, + ) -> Result { + fields + .iter() + .enumerate() + .map(move |(i, field)| { + let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?; - fn stream_trait_ident(&self) -> TokenStream { - let private = private(); - quote! { #private::Read } - } + let name = field_name(i, field, false); + let real_name = field_name(i, field, true); - fn trait_ident(&self) -> TokenStream { - let private = private(); - quote! { #private::Decode } - } + let global_name = global_prefix + .map(|global_prefix| { + let ident = + Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); + quote! { #ident } + }) + .unwrap_or_else(|| name.clone()); - fn min_bits(&self) -> TokenStream { - quote! { DECODE_MIN } - } + let field_impl = self.field_impl(name, global_name, real_name, &field.ty); - fn max_bits(&self) -> TokenStream { - quote! { DECODE_MAX } + let private = private(); + let de = de_lifetime(); + let bound: Path = parse_quote!(#private::Decode<#de>); + bounds.add_bound_type(field.clone(), &field_attrs, bound); + Ok(field_impl) + }) + .collect() } +} - fn trait_fn_impl(&self, body: TokenStream) -> TokenStream { - let private = private(); - quote! { - #[allow(clippy::all)] - #[cfg_attr(not(debug_assertions), inline(always))] - fn decode(encoding: impl #private::Encoding, reader: &mut impl #private::Read) -> #private::Result { - #private::optimized_dec!(encoding, reader); - #body - } +struct Output([TokenStream; Item::COUNT]); + +impl Output { + fn make_ghost(mut self) -> Self { + let type_ = &mut self.0[Item::Type as usize]; + if type_.is_empty() { + let de = de_lifetime(); + *type_ = quote! { __spooky: std::marker::PhantomData<&#de ()>, }; } + let default = &mut self.0[Item::Default as usize]; + if default.is_empty() { + *default = quote! { __spooky: Default::default(), }; + } + self } } + +pub fn derive_impl(mut input: DeriveInput) -> Result { + let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; + let mut generics = input.generics; + let mut bounds = FieldBounds::default(); + + let ident = input.ident; + syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); + let output = (match input.data { + Data::Struct(data_struct) => { + let destructure_fields = &destructure_fields(&data_struct.fields); + Output(Item::ALL.map(|item| { + let field_impls = item + .field_impls(None, &data_struct.fields, &attrs, &mut bounds) + .unwrap(); // TODO don't unwrap + item.struct_impl(&ident, destructure_fields, &field_impls) + })) + } + Data::Enum(data_enum) => { + let variant_count = data_enum.variants.len(); + Output(Item::ALL.map(|item| { + item.variant_impls( + variant_count, + |i| { + let variant = &data_enum.variants[i]; + let variant_name = &variant.ident; + let destructure_fields = destructure_fields(&variant.fields); + quote! { + #ident::#variant_name #destructure_fields + } + }, + |item, i| { + let variant = &data_enum.variants[i]; + let global_prefix = format!("{}_", &variant.ident); + let attrs = BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); // TODO don't unwrap. + item.field_impls(Some(&global_prefix), &variant.fields, &attrs, &mut bounds) + .unwrap() // TODO don't unwrap. + }, + ) + })) + } + Data::Union(u) => err(&u.union_token, "unions are not supported")?, + }) + .make_ghost(); + + bounds.apply_to_generics(&mut generics); + let input_generics = generics.clone(); + let (_, input_generics, _) = input_generics.split_for_impl(); + let input_ty = quote! { #ident #input_generics }; + + // Add 'de lifetime after isolating input_generics. + let de = de_lifetime(); + let de_where_predicate = WherePredicate::Lifetime(PredicateLifetime { + lifetime: de.clone(), + colon_token: parse_quote!(:), + bounds: generics + .params + .iter() + .filter_map(|p| { + if let GenericParam::Lifetime(p) = p { + Some(p.lifetime.clone()) + } else { + None + } + }) + .collect(), + }); + + // Push de_param after bounding 'de: 'a. + let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone())); + generics.params.push(de_param.clone()); // TODO bound to other lifetimes. + generics + .make_where_clause() + .predicates + .push(de_where_predicate); + + let combined_generics = generics.clone(); + let (impl_generics, _, where_clause) = combined_generics.split_for_impl(); + + // Decoder can't contain any lifetimes from input (which would limit reuse of decoder). + remove_lifetimes(&mut generics); + generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it. + let (decoder_impl_generics, decoder_generics, decoder_where_clause) = generics.split_for_impl(); + + let Output([type_body, default_body, populate_body, decode_body, decode_in_place_body]) = + output; + let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site()); + let decoder_ty = quote! { #decoder_ident #decoder_generics }; + let private = private(); + + let ret = quote! { + const _: () = { + impl #impl_generics #private::Decode<#de> for #input_ty #where_clause { + type Decoder = #decoder_ty; + } + + #[allow(non_snake_case)] + pub struct #decoder_ident #decoder_impl_generics #decoder_where_clause { + #type_body + } + + // Avoids bounding #impl_generics: Default. + impl #decoder_impl_generics std::default::Default for #decoder_ty #decoder_where_clause { + fn default() -> Self { + Self { + #default_body + } + } + } + + impl #decoder_impl_generics #private::View<#de> for #decoder_ty #decoder_where_clause { + fn populate(&mut self, input: &mut &#de [u8], __length: usize) -> #private::Result<()> { + #populate_body + Ok(()) + } + } + + impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause { + #[cfg_attr(not(debug_assertions), inline(always))] + fn decode(&mut self) -> #input_ty { + #decode_body + } + + #[cfg_attr(not(debug_assertions), inline(always))] + fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) { + #decode_in_place_body + } + } + }; + }; + // panic!("{ret}"); + Ok(ret) +} diff --git a/bitcode_derive/src/derive.rs b/bitcode_derive/src/derive.rs deleted file mode 100644 index 44622c3..0000000 --- a/bitcode_derive/src/derive.rs +++ /dev/null @@ -1,294 +0,0 @@ -use crate::attribute::{BitcodeAttrs, VariantEncoding}; -use crate::bound::FieldBounds; -use crate::{err, private}; -use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; -use syn::{parse_quote, Data, DeriveInput, Field, Fields, Generics, Path, Result, Type}; - -struct Output { - body: TokenStream, - bit_bounds: BitBounds, -} - -pub struct BitBounds { - min: TokenStream, - max: TokenStream, -} - -impl BitBounds { - fn zero() -> Self { - Self::new(0, 0) - } - - fn unbounded() -> Self { - Self::new(1, usize::MAX) - } - - fn new(min: usize, max: usize) -> Self { - Self { - min: quote! { #min }, - max: quote! { #max }, - } - } - - fn add(&mut self, other: Self) { - let a_min = &self.min; - let a_max = &self.max; - let b_min = &other.min; - let b_max = &other.max; - - *self = Self { - min: quote! { #a_min + #b_min }, - max: quote! { (#a_max).saturating_add(#b_max) }, - }; - } - - fn or(&mut self, other: Self) { - let a_min = &self.min; - let a_max = &self.max; - let b_min = &other.min; - let b_max = &other.max; - let private = private(); - - *self = Self { - min: quote! { #private::min(#a_min, #b_min) }, - max: quote! { #private::max(#a_max, #b_max) }, - }; - } -} - -pub fn unwrap_encoding(encoding: Option) -> TokenStream { - encoding.unwrap_or_else(|| quote! { encoding }) -} - -/// Derive code shared between Encode and Decode. -pub trait Derive { - /// Returns (serde_impl, bound). - fn serde_impl(&self) -> (TokenStream, Path); - - /// Returns (field_impl, bound). - fn field_impl( - &self, - with_serde: bool, - field_name: TokenStream, - field_type: &Type, - encoding: Option, - ) -> (TokenStream, Path); - - fn struct_impl(&self, destructure_fields: TokenStream, do_fields: TokenStream) -> TokenStream; - - fn variant_impl( - &self, - before_fields: TokenStream, - field_impls: TokenStream, - destructure_variant: TokenStream, - ) -> TokenStream; - - fn is_encode(&self) -> bool; - - fn stream_trait_ident(&self) -> TokenStream; - - fn trait_ident(&self) -> TokenStream; - - fn min_bits(&self) -> TokenStream; - - fn max_bits(&self) -> TokenStream; - - fn trait_fn_impl(&self, body: TokenStream) -> TokenStream; - - fn field_impls( - &self, - fields: &Fields, - parent_attrs: &BitcodeAttrs, - bounds: &mut FieldBounds, - ) -> Result { - fields - .iter() - .enumerate() - .map(move |(i, field)| { - let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?; - let encoding = field_attrs.get_encoding(); - - let field_name = field_name(i, field); - let field_type = &field.ty; - - let (field_impl, bound) = - self.field_impl(field_attrs.with_serde(), field_name, field_type, encoding); - bounds.add_bound_type(field.clone(), &field_attrs, bound); - Ok(field_impl) - }) - .collect() - } - - fn field_bit_bounds(&self, fields: &Fields, parent_attrs: &BitcodeAttrs) -> BitBounds { - let mut recursive_max = quote! { 0usize }; - let min: TokenStream = fields - .iter() - .map(|field| { - let ty = &field.ty; - let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs).unwrap(); - - // Encodings can make our bounds inaccurate and serde types can't give us bounds. - let unknown_bounds = - field_attrs.get_encoding().is_some() || field_attrs.with_serde(); - let BitBounds { min, max } = if unknown_bounds { - BitBounds::unbounded() - } else { - let max = if field_attrs.is_recursive() { - quote! { usize::MAX } - } else { - let max_bits = self.max_bits(); - quote! { <#ty>::#max_bits } - }; - - let min_bits = self.min_bits(); - BitBounds { - min: quote! {<#ty>::#min_bits }, - max, - } - }; - - recursive_max = quote! { #recursive_max.saturating_add(#max) }; - let min = quote! { #min + }; - min - }) - .collect(); - - let min = quote! { #min 0 }; - let max = quote! { #recursive_max }; - - BitBounds { min, max } - } - - fn derive_impl(&self, input: DeriveInput) -> Result { - let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; - let mut generics = input.generics; - let mut bounds = FieldBounds::default(); - - let ident = input.ident; - let output = match input.data { - _ if attrs.with_serde() => { - let (body, bound) = self.serde_impl(); - add_type_bound(&mut generics, parse_quote!(Self), bound); - - Output { - body, - bit_bounds: BitBounds::unbounded(), - } - } - Data::Struct(data_struct) => { - let destructure_fields = destructure_fields(&data_struct.fields); - let do_fields = self.field_impls(&data_struct.fields, &attrs, &mut bounds)?; - let body = self.struct_impl(destructure_fields, do_fields); - let bit_bounds = self.field_bit_bounds(&data_struct.fields, &attrs); - - Output { body, bit_bounds } - } - Data::Enum(data_enum) => { - let variant_encoding = VariantEncoding::parse_data_enum(&data_enum, &attrs)?; - let mut enum_bit_bounds: Option = None; - - let variant_impls = (if self.is_encode() { - VariantEncoding::encode_variants - } else { - VariantEncoding::decode_variants - })( - &variant_encoding, - |variant_index, before_fields, bits| { - let variant = &data_enum.variants[variant_index]; - let attrs = BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); - let variant_name = &variant.ident; - - let destructure_fields = destructure_fields(&variant.fields); - let field_impls = self.field_impls(&variant.fields, &attrs, &mut bounds)?; - - let destructure_variant = quote! { - Self::#variant_name #destructure_fields - }; - - let body = - self.variant_impl(before_fields, field_impls, destructure_variant); - let mut variant_bit_bounds = self.field_bit_bounds(&variant.fields, &attrs); - - // Bits to encode variant index. - variant_bit_bounds.add(BitBounds::new(bits, bits)); - - if let Some(enum_bit_bounds) = &mut enum_bit_bounds { - enum_bit_bounds.or(variant_bit_bounds); - } else { - enum_bit_bounds = Some(variant_bit_bounds) - } - - Ok(body) - }, - )?; - - let stream_trait = self.stream_trait_ident(); - let body = quote! { - use #stream_trait as _; - #variant_impls - }; - - Output { - body, - bit_bounds: enum_bit_bounds.unwrap_or_else(BitBounds::zero), - } - } - Data::Union(u) => err(&u.union_token, "unions are not supported")?, - }; - - bounds.apply_to_generics(&mut generics); - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - let trait_ident = self.trait_ident(); - - let BitBounds { min, max } = output.bit_bounds; - let min_bits = self.min_bits(); - let max_bits = self.max_bits(); - - let bit_bounds = quote! { - const #min_bits: usize = #min; - const #max_bits: usize = #max; - }; - - let impl_trait_fn = self.trait_fn_impl(output.body); - Ok(quote! { - impl #impl_generics #trait_ident for #ident #ty_generics #where_clause { - #bit_bounds - #impl_trait_fn - } - }) - } -} - -fn add_type_bound(generics: &mut Generics, typ: Type, bound: Path) { - generics - .make_where_clause() - .predicates - .push(parse_quote!(#typ: #bound)); -} - -fn destructure_fields(fields: &Fields) -> TokenStream { - let field_names = fields.iter().enumerate().map(|(i, f)| field_name(i, f)); - match fields { - Fields::Named(_) => quote! { - {#(#field_names),*} - }, - Fields::Unnamed(_) => quote! { - (#(#field_names),*) - }, - Fields::Unit => quote! {}, - } -} - -fn field_name(i: usize, field: &Field) -> TokenStream { - field - .ident - .as_ref() - .map(|id| quote! {#id}) - .unwrap_or_else(|| { - let name = format!("f{i}"); - let ident = Ident::new(&name, Span::call_site()); - quote! {#ident} - }) -} diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index 1f9870a..9338bdc 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -1,119 +1,386 @@ -use crate::derive::{unwrap_encoding, Derive}; -use crate::private; -use proc_macro2::TokenStream; +use crate::attribute::BitcodeAttrs; +use crate::bound::FieldBounds; +use crate::shared::{ + destructure_fields, field_name, remove_lifetimes, replace_lifetimes, ReplaceSelves, +}; +use crate::{err, private}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use syn::{parse_quote, Path, Type}; +use syn::{parse_quote, Data, DeriveInput, Fields, Path, Result, Type}; -pub struct Encode; +#[derive(Copy, Clone)] +enum Item { + Type, + Default, + Encode, + EncodeVectored, + CollectInto, + Reserve, +} -impl Derive for Encode { - fn serde_impl(&self) -> (TokenStream, Path) { - let private = private(); - ( - quote! { - #private::serialize_compat(self, encoding, writer)?; - }, - parse_quote!(#private::Serialize), - ) - } +impl Item { + const ALL: [Self; 6] = [ + Self::Type, + Self::Default, + Self::Encode, + Self::EncodeVectored, + Self::CollectInto, + Self::Reserve, + ]; + const COUNT: usize = Self::ALL.len(); fn field_impl( - &self, - with_serde: bool, + self, field_name: TokenStream, + global_field_name: TokenStream, + real_field_name: TokenStream, field_type: &Type, - encoding: Option, - ) -> (TokenStream, Path) { - let private = private(); - if with_serde { - let encoding = unwrap_encoding(encoding); - ( - // Field is using serde making ENCODE_MAX unknown so we flush the current register - // buffer and write directly to the writer. See optimized_enc! macro in code.rs. - quote! { - #private::serialize_compat(#field_name, #encoding, flush!())?; - }, - parse_quote!(#private::Serialize), - ) - } else if let Some(encoding) = encoding { - ( - // Field has an encoding making ENCODE_MAX unknown so we flush the current register - // buffer and write directly to the writer. See optimized_enc! macro in code.rs. - quote! { - #private::Encode::encode(#field_name, #encoding, flush!())?; - }, - parse_quote!(#private::Encode), - ) - } else { - ( - // Field has a known ENCODE_MAX. enc! will evaluate if it can fit within the current - // register buffer. See optimized_enc! macro in code.rs. + ) -> TokenStream { + match self { + Self::Type => { + let static_type = replace_lifetimes(field_type, "static"); + let private = private(); quote! { - enc!(#field_name, #field_type); - }, - parse_quote!(#private::Encode), - ) + #global_field_name: <#static_type as #private::Encode>::Encoder, + } + } + Self::Default => quote! { + #global_field_name: Default::default(), + }, + Self::Encode | Self::EncodeVectored => { + let static_type = replace_lifetimes(field_type, "static"); + let value = if &static_type != field_type { + let underscore_type = replace_lifetimes(field_type, "_"); + + // HACK: Since encoders don't have lifetimes we can't reference as Encode>::Encoder since 'a + // does not exist. Instead we replace this with as Encode>::Encoder and transmute it to + // T<'a>. No encoder actually encodes T<'static> any differently from T<'a> so this is sound. + quote! { + unsafe { std::mem::transmute::<&#underscore_type, &#static_type>(#field_name) } + } + } else { + quote! { #field_name } + }; + + if matches!(self, Self::EncodeVectored) { + quote! { + self.#global_field_name.encode_vectored(i.clone().map(|me| { + let #field_name = &me.#real_field_name; + #value + })); + } + } else { + quote! { + self.#global_field_name.encode(#value); + } + } + } + Self::CollectInto => quote! { + self.#global_field_name.collect_into(out); + }, + Self::Reserve => quote! { + self.#global_field_name.reserve(__additional); + }, } } - fn struct_impl(&self, destructure_fields: TokenStream, do_fields: TokenStream) -> TokenStream { - let private = private(); - quote! { - let Self #destructure_fields = self; - #private::optimized_enc!(encoding, writer); - #do_fields - end_enc!(); + fn struct_impl( + self, + ident: &Ident, + destructure_fields: &TokenStream, + do_fields: &TokenStream, + ) -> TokenStream { + match self { + Self::Encode => { + quote! { + let #ident #destructure_fields = v; + #do_fields + } + } + _ => quote! { #do_fields }, } } - fn variant_impl( - &self, - before_fields: TokenStream, - field_impls: TokenStream, - destructure_variant: TokenStream, + pub fn variant_impls( + self, + variant_count: usize, + mut pattern: impl FnMut(usize) -> TokenStream, + mut inner: impl FnMut(Self, usize) -> TokenStream, ) -> TokenStream { - let private = private(); - quote! { - #destructure_variant => { - #private::optimized_enc!(encoding, writer); - #before_fields - #field_impls - end_enc!(); - }, + // if variant_count is 0 or 1 variants don't have to be encoded. + let encode_variants = variant_count > 1; + match self { + Self::Type => { + let variants = encode_variants + .then(|| { + let private = private(); + quote! { variants: #private::VariantEncoder<#variant_count>, } + }) + .unwrap_or_default(); + let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); + quote! { + #variants + #inners + } + } + Self::Default => { + let variants = encode_variants + .then(|| quote! { variants: Default::default(), }) + .unwrap_or_default(); + let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); + quote! { + #variants + #inners + } + } + Self::Encode => { + let variants = encode_variants + .then(|| { + let variants: TokenStream = (0..variant_count) + .map(|i| { + let pattern = pattern(i); + let i: u8 = i + .try_into() + .expect("enums with more than 256 variants are not supported"); // TODO don't panic. + quote! { + #pattern => #i, + } + }) + .collect(); + quote! { + #[allow(unused_variables)] + self.variants.encode(&match v { + #variants + }); + } + }) + .unwrap_or_default(); + let inners: TokenStream = (0..variant_count) + .map(|i| { + // We don't know the exact number of this variant since there are more than one so we have to + // reserve one at a time. + let reserve = encode_variants + .then(|| { + let reserve = inner(Self::Reserve, i); + quote! { + let __additional = std::num::NonZeroUsize::MIN; + #reserve + } + }) + .unwrap_or_default(); + let inner = inner(self, i); + let pattern = pattern(i); + quote! { + #pattern => { + #reserve + #inner + } + } + }) + .collect(); + (variant_count != 0) + .then(|| { + quote! { + #variants + match v { + #inners + } + } + }) + .unwrap_or_default() + } + Self::EncodeVectored => unimplemented!(), // TODO encode enum vectored. + Self::CollectInto => { + let variants = encode_variants + .then(|| { + quote! { self.variants.collect_into(out); } + }) + .unwrap_or_default(); + let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); + quote! { + #variants + #inners + } + } + Self::Reserve => { + encode_variants + .then(|| { + quote! { self.variants.reserve(__additional); } + }) + .or_else(|| { + (variant_count == 1).then(|| { + // We know the exact number of this variant since it's the only one so we can reserve it. + inner(self, 0) + }) + }) + .unwrap_or_default() + } } } - fn is_encode(&self) -> bool { - true - } + fn field_impls( + self, + global_prefix: Option<&str>, + fields: &Fields, + parent_attrs: &BitcodeAttrs, + bounds: &mut FieldBounds, + ) -> Result { + fields + .iter() + .enumerate() + .map(move |(i, field)| { + let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?; - fn stream_trait_ident(&self) -> TokenStream { - let private = private(); - quote! { #private::Write } - } + let name = field_name(i, field, false); + let real_name = field_name(i, field, true); - fn trait_ident(&self) -> TokenStream { - let private = private(); - quote! { #private::Encode } - } + let global_name = global_prefix + .map(|global_prefix| { + let ident = + Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); + quote! { #ident } + }) + .unwrap_or_else(|| name.clone()); - fn min_bits(&self) -> TokenStream { - quote! { ENCODE_MIN } + let field_impl = self.field_impl(name, global_name, real_name, &field.ty); + let private = private(); + let bound: Path = parse_quote!(#private::Encode); + bounds.add_bound_type(field.clone(), &field_attrs, bound); + Ok(field_impl) + }) + .collect() } +} - fn max_bits(&self) -> TokenStream { - quote! { ENCODE_MAX } - } +struct Output([TokenStream; Item::COUNT]); - fn trait_fn_impl(&self, body: TokenStream) -> TokenStream { - let private = private(); - quote! { - #[allow(clippy::all)] - #[cfg_attr(not(debug_assertions), inline(always))] - fn encode(&self, encoding: impl #private::Encoding, writer: &mut impl #private::Write) -> #private::Result<()> { - #body - Ok(()) - } +pub fn derive_impl(mut input: DeriveInput) -> Result { + let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; + let mut generics = input.generics; + let mut bounds = FieldBounds::default(); + + let ident = input.ident; + syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); + + let (output, is_encode_vectored) = match input.data { + Data::Struct(data_struct) => { + let destructure_fields = &destructure_fields(&data_struct.fields); + ( + Output(Item::ALL.map(|item| { + let field_impls = item + .field_impls(None, &data_struct.fields, &attrs, &mut bounds) + .unwrap(); // TODO don't unwrap + item.struct_impl(&ident, destructure_fields, &field_impls) + })), + true, + ) } - } + Data::Enum(data_enum) => { + let variant_count = data_enum.variants.len(); + ( + Output(Item::ALL.map(|item| { + if matches!(item, Item::EncodeVectored) { + return Default::default(); // Unimplemented for now. + } + + item.variant_impls( + variant_count, + |i| { + let variant = &data_enum.variants[i]; + let variant_name = &variant.ident; + let destructure_fields = destructure_fields(&variant.fields); + quote! { + #ident::#variant_name #destructure_fields + } + }, + |item, i| { + let variant = &data_enum.variants[i]; + let global_prefix = format!("{}_", &variant.ident); + let attrs = + BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); // TODO don't unwrap. + item.field_impls( + Some(&global_prefix), + &variant.fields, + &attrs, + &mut bounds, + ) + .unwrap() // TODO don't unwrap. + }, + ) + })), + false, + ) + } + Data::Union(u) => err(&u.union_token, "unions are not supported")?, + }; + + bounds.apply_to_generics(&mut generics); + let input_generics = generics.clone(); + let (impl_generics, input_generics, where_clause) = input_generics.split_for_impl(); + let input_ty = quote! { #ident #input_generics }; + + // Encoder can't contain any lifetimes from input (which would limit reuse of encoder). + remove_lifetimes(&mut generics); + let (encoder_impl_generics, encoder_generics, encoder_where_clause) = generics.split_for_impl(); + + let Output( + [type_body, default_body, encode_body, encode_vectored_body, collect_into_body, reserve_body], + ) = output; + let encoder_ident = Ident::new(&format!("{ident}Encoder"), Span::call_site()); + let encoder_ty = quote! { #encoder_ident #encoder_generics }; + let private = private(); + + let encode_vectored = is_encode_vectored.then(|| quote! { + // #[cfg_attr(not(debug_assertions), inline(always))] + // #[inline(never)] + fn encode_vectored<'__v>(&mut self, i: impl Iterator + Clone) where #input_ty: '__v { + #[allow(unused_imports)] + use #private::Buffer as _; + #encode_vectored_body + } + }); + + let ret = quote! { + const _: () = { + impl #impl_generics #private::Encode for #input_ty #where_clause { + type Encoder = #encoder_ty; + } + + #[allow(non_snake_case)] + pub struct #encoder_ident #encoder_impl_generics #encoder_where_clause { + #type_body + } + + // Avoids bounding #impl_generics: Default. + impl #encoder_impl_generics std::default::Default for #encoder_ty #encoder_where_clause { + fn default() -> Self { + Self { + #default_body + } + } + } + + impl #impl_generics #private::Encoder<#input_ty> for #encoder_ty #where_clause { + #[cfg_attr(not(debug_assertions), inline(always))] + fn encode(&mut self, v: &#input_ty) { + #[allow(unused_imports)] + use #private::Buffer as _; + #encode_body + } + #encode_vectored + } + + impl #encoder_impl_generics #private::Buffer for #encoder_ty #encoder_where_clause { + fn collect_into(&mut self, out: &mut Vec) { + #collect_into_body + } + + fn reserve(&mut self, __additional: std::num::NonZeroUsize) { + #reserve_body + } + } + }; + }; + // panic!("{ret}"); + Ok(ret) } diff --git a/bitcode_derive/src/huffman.rs b/bitcode_derive/src/huffman.rs deleted file mode 100644 index 2a2e1e5..0000000 --- a/bitcode_derive/src/huffman.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate::attribute::PrefixCode; - -/// Returns tuples of (code, bit length) in same order as input frequencies. -pub fn huffman(frequencies: &[f64], max_len: u8) -> Vec { - struct Symbol { - index: usize, - code: u32, - len: u8, - } - - let frequencies = frequencies.iter().map(|f| f.max(0.0)).collect::>(); - let lengths = packagemerge::package_merge(&frequencies, max_len as u32).unwrap(); - let mut symbols = lengths - .into_iter() - .enumerate() - .map(|(index, len)| Symbol { - index, - code: u32::MAX, - len: len as u8, - }) - .collect::>(); - symbols.sort_by_key(|symbol| (symbol.len, symbol.index)); - let mut code: u32 = 0; - let mut last_len: u8 = 0; - for (i, symbol) in symbols.iter_mut().enumerate() { - if i > 0 { - code = (code + 1) << (symbol.len - last_len); - } - symbol.code = code; - last_len = symbol.len; - } - symbols.sort_by_key(|symbol| symbol.index); - - symbols - .into_iter() - .map(|symbol| PrefixCode { - value: symbol.code.reverse_bits() >> (u32::BITS - symbol.len as u32), - bits: symbol.len as usize, - }) - .collect() -} - -#[cfg(test)] -mod tests { - use super::huffman; - - #[test] - fn unconstrained() { - let symbol_frequencies = vec![('a', 10), ('b', 1), ('c', 15), ('d', 7)]; - let frequencies = symbol_frequencies - .iter() - .map(|(_, l)| *l as f64) - .collect::>(); - let code_len = huffman(&frequencies, 3); - assert_eq!(code_len, vec![(0b10, 2), (0b110, 3), (0b0, 1), (0b111, 3)]); - } - - #[test] - fn constrained() { - let symbol_frequencies = vec![('a', 10), ('b', 1), ('c', 15), ('d', 7)]; - let frequencies = symbol_frequencies - .iter() - .map(|(_, l)| *l as f64) - .collect::>(); - let code_len = huffman(&frequencies, 2); - assert_eq!(code_len, vec![(0, 2), (1, 2), (2, 2), (3, 2)]); - } -} diff --git a/bitcode_derive/src/lib.rs b/bitcode_derive/src/lib.rs index ddc4bd5..a566032 100644 --- a/bitcode_derive/src/lib.rs +++ b/bitcode_derive/src/lib.rs @@ -1,6 +1,3 @@ -use crate::decode::Decode; -use crate::derive::Derive; -use crate::encode::Encode; use proc_macro::TokenStream; use quote::quote; use syn::spanned::Spanned; @@ -9,24 +6,21 @@ use syn::{parse_macro_input, DeriveInput}; mod attribute; mod bound; mod decode; -mod derive; mod encode; -mod huffman; +mod shared; -#[proc_macro_derive(Encode, attributes(bitcode, bitcode_hint))] +#[proc_macro_derive(Encode, attributes(bitcode))] pub fn derive_encode(input: TokenStream) -> TokenStream { - derive(Encode, input) + let input = parse_macro_input!(input as DeriveInput); + encode::derive_impl(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } -#[proc_macro_derive(Decode, attributes(bitcode, bitcode_hint))] +#[proc_macro_derive(Decode, attributes(bitcode))] pub fn derive_decode(input: TokenStream) -> TokenStream { - derive(Decode, input) -} - -fn derive(derive: impl Derive, input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - derive - .derive_impl(input) + decode::derive_impl(input) .unwrap_or_else(syn::Error::into_compile_error) .into() } diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs new file mode 100644 index 0000000..5841642 --- /dev/null +++ b/bitcode_derive/src/shared.rs @@ -0,0 +1,70 @@ +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::visit_mut::VisitMut; +use syn::{Field, Fields, GenericParam, Generics, Index, Lifetime, Type, WherePredicate}; + +pub fn destructure_fields(fields: &Fields) -> TokenStream { + let field_names = fields + .iter() + .enumerate() + .map(|(i, f)| field_name(i, f, false)); + match fields { + Fields::Named(_) => quote! { + {#(#field_names),*} + }, + Fields::Unnamed(_) => quote! { + (#(#field_names),*) + }, + Fields::Unit => quote! {}, + } +} + +pub fn field_name(i: usize, field: &Field, real: bool) -> TokenStream { + field + .ident + .as_ref() + .map(|id| quote! {#id}) + .unwrap_or_else(|| { + if real { + Index::from(i).to_token_stream() + } else { + Ident::new(&format!("f{i}"), Span::call_site()).to_token_stream() + } + }) +} + +pub fn remove_lifetimes(generics: &mut Generics) { + generics.params = std::mem::take(&mut generics.params) + .into_iter() + .filter(|param| !matches!(param, GenericParam::Lifetime(_))) + .collect(); + if let Some(where_clause) = &mut generics.where_clause { + where_clause.predicates = std::mem::take(&mut where_clause.predicates) + .into_iter() + .filter(|predicate| !matches!(predicate, WherePredicate::Lifetime(_))) + .collect() + } +} + +#[must_use] +pub fn replace_lifetimes(t: &Type, s: &str) -> Type { + let mut t = t.clone(); + syn::visit_mut::visit_type_mut(&mut ReplaceLifetimes(s), &mut t); + t +} + +struct ReplaceLifetimes<'a>(&'a str); +impl VisitMut for ReplaceLifetimes<'_> { + fn visit_lifetime_mut(&mut self, lifetime: &mut Lifetime) { + lifetime.ident = Ident::new(self.0, lifetime.ident.span()); + } +} + +pub struct ReplaceSelves<'a>(pub &'a Ident); +impl VisitMut for ReplaceSelves<'_> { + fn visit_ident_mut(&mut self, ident: &mut Ident) { + if ident == "Self" { + *ident = self.0.clone(); + } + } +} diff --git a/fuzz/.gitignore b/fuzz/.gitignore index e940bbe..b400c27 100644 --- a/fuzz/.gitignore +++ b/fuzz/.gitignore @@ -1,4 +1,2 @@ -target corpus artifacts -/Cargo.lock diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 9289246..653b3c4 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -9,13 +9,10 @@ edition = "2018" cargo-fuzz = true [dependencies] -bitvec = { version = "1.0.1" } +arrayvec = { version = "0.7", features = ["serde"] } +bitcode = { path = "..", features = [ "arrayvec", "serde" ] } libfuzzer-sys = "0.4" -serde = { version ="1.0", features=["derive"] } - -[dependencies.bitcode] -path = ".." -features = ["serde"] +serde = { version ="1.0", features = [ "derive" ] } # Prevent this from interfering with workspaces [workspace] diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index 358ed4d..2f91781 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -1,12 +1,11 @@ #![no_main] use libfuzzer_sys::fuzz_target; extern crate bitcode; +use arrayvec::{ArrayString, ArrayVec}; use bitcode::{Decode, Encode}; -use bitvec::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; -use std::ffi::CString; -use std::time::Duration; +use std::num::NonZeroU32; fuzz_target!(|data: &[u8]| { if data.len() < 3 { @@ -14,38 +13,35 @@ fuzz_target!(|data: &[u8]| { } let (start, data) = data.split_at(3); - let mut bv = BitVec::::default(); - for byte in data { - let boolean = match byte { - 0 => false, - 1 => true, - _ => return, - }; - bv.push(boolean); - } - let data = bv.as_raw_slice(); - macro_rules! test { ($typ1: expr, $typ2: expr, $data: expr, $($typ: ty),*) => { { let mut j = 0; $( - let mut buffer = bitcode::Buffer::new(); - if j == $typ1 { - for _ in 0..2 { - if $typ2 == 0 { - if let Ok(de) = buffer.decode::<$typ>(data) { - let data2 = buffer.encode(&de).unwrap(); + if $typ2 == 0 { + let mut encode_buffer = bitcode::EncodeBuffer::<$typ>::default(); + let mut decode_buffer = bitcode::DecodeBuffer::<$typ>::default(); + + let mut previous = None; + for _ in 0..2 { + let current = if let Ok(de) = decode_buffer.decode(data) { + let data2 = encode_buffer.encode(&de); let de2 = bitcode::decode::<$typ>(&data2).unwrap(); assert_eq!(de, de2); + true + } else { + false + }; + if let Some(previous) = std::mem::replace(&mut previous, Some(current)) { + assert_eq!(previous, current); } - } else if $typ2 == 1 { - if let Ok(de) = buffer.deserialize::<$typ>(data) { - let data2 = buffer.serialize(&de).unwrap(); - let de2 = bitcode::deserialize::<$typ>(&data2).unwrap(); - assert_eq!(de, de2); - } + } + } else if $typ2 == 1 { + if let Ok(de) = bitcode::deserialize::<$typ>(data) { + let data2 = bitcode::serialize(&de).unwrap(); + let de2 = bitcode::deserialize::<$typ>(&data2).unwrap(); + assert_eq!(de, de2); } } } @@ -74,18 +70,11 @@ fuzz_target!(|data: &[u8]| { [$typ; 1], [$typ; 2], [$typ; 3], - [$typ; 7], - [$typ; 8], - ([bool; 1], $typ), - ([bool; 2], $typ), - ([bool; 3], $typ), - ([bool; 4], $typ), - ([bool; 5], $typ), - ([bool; 6], $typ), - ([bool; 7], $typ), Option<$typ>, Vec<$typ>, - HashMap + HashMap, + ArrayVec<$typ, 0>, + ArrayVec<$typ, 5> ); } #[allow(unused)] @@ -98,6 +87,30 @@ fuzz_target!(|data: &[u8]| { } } + #[rustfmt::skip] + mod enums { + use super::*; + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum2 { A, B } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum3 { A, B, C } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum4 { A, B, C, D } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum5 { A, B, C, D, E } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum6 { A, B, C, D, E, F } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum7 { A, B, C, D, E, F, G } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum15 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum16 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum17 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q } + } + use enums::*; + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] enum Enum { A, @@ -106,17 +119,7 @@ fuzz_target!(|data: &[u8]| { D { a: u8, b: u8 }, E(String), F, - G(#[bitcode_hint(expected_range = "0.0..1.0")] BitsEqualF32), - H(#[bitcode_hint(expected_range = "0.0..1.0")] BitsEqualF64), - I(#[bitcode_hint(expected_range = "0..32")] u8), - J(#[bitcode_hint(expected_range = "3..51")] u16), - K(#[bitcode_hint(expected_range = "200..5000")] u32), - L(#[bitcode_hint(gamma)] i8), - M(#[bitcode_hint(gamma)] u64), - N(#[bitcode_hint(ascii)] String), - O(#[bitcode_hint(ascii_lowercase)] String), P(BTreeMap), - Q(Duration), } #[derive(Serialize, Deserialize, Encode, Decode, Debug)] @@ -145,6 +148,7 @@ fuzz_target!(|data: &[u8]| { (), bool, char, + NonZeroU32, u8, i8, u16, @@ -155,13 +159,25 @@ fuzz_target!(|data: &[u8]| { i64, u128, i128, - usize, - isize, + // usize, + // isize, BitsEqualF32, BitsEqualF64, Vec, String, - CString, - Enum + Enum2, + Enum3, + Enum4, + Enum5, + Enum6, + Enum7, + Enum15, + Enum16, + Enum17, + Enum, + ArrayString<5>, + ArrayString<70>, + ArrayVec, + ArrayVec ); }); diff --git a/src/__private.rs b/src/__private.rs deleted file mode 100644 index 990d7e3..0000000 --- a/src/__private.rs +++ /dev/null @@ -1,159 +0,0 @@ -// Exports for derive macro. #[doc(hidden)] because not stable between versions. - -pub use crate::code::*; -pub use crate::encoding::*; -pub use crate::nightly::{max, min}; -pub use crate::read::Read; -pub use crate::register_buffer::*; -pub use crate::write::Write; -pub use crate::Error; - -#[cfg(any(test, feature = "serde"))] -pub use crate::serde::de::deserialize_compat; -#[cfg(any(test, feature = "serde"))] -pub use crate::serde::ser::serialize_compat; -#[cfg(any(test, feature = "serde"))] -pub use serde::{de::DeserializeOwned, Serialize}; - -// TODO only define once. -pub type Result = std::result::Result; - -pub fn invalid_variant() -> Error { - crate::E::Invalid("enum variant").e() -} - -#[cfg(all(test, debug_assertions))] -mod tests { - use crate::{Decode, Encode}; - use serde::{Deserialize, Serialize}; - use std::marker::PhantomData; - - #[derive(Debug, Default, PartialEq, Encode, Decode)] - #[bitcode(recursive)] - struct Recursive { - a: Option>, - b: Option>, - c: Vec, - } - - #[test] - fn test_recursive() { - // If these functions aren't called, Rust hides some kinds of compile errors. - crate::encode(&Recursive::default()).unwrap(); - let _ = crate::decode::(&[]); - } - - trait ParamTrait { - type One; - type Two: Encode + Decode; - type Three; - type Four; - } - - struct Param; - - #[derive(Serialize, Deserialize)] - struct SerdeU32(u32); - - impl ParamTrait for Param { - type One = i8; - type Two = u16; - type Three = SerdeU32; - type Four = &'static str; - } - - #[derive(Encode, Decode)] - #[bitcode_hint(gamma)] - struct UsesParamTrait { - #[bitcode(bound_type = "B::One")] - a: Vec, - #[bitcode(bound_type = "A::One")] // Make sure redundant bound_type works. - b: A::One, - c: Vec, // Always Encode + Decode so no bound_type needed. - #[bitcode(with_serde, bound_type = "(A::Three, B::Three)")] - d: Vec<(A::Three, B::Three)>, - e: PhantomData, - } - - #[test] - fn test_uses_param_trait() { - type T = UsesParamTrait; - let t: T = UsesParamTrait { - a: vec![1, 2, 3], - b: 1, - c: vec![1, 2, 3], - d: vec![(SerdeU32(1), SerdeU32(2))], - e: PhantomData, - }; - - let encoded = crate::encode(&t).unwrap(); - let _ = crate::decode::(&encoded).unwrap(); - } - - #[derive(Debug, PartialEq, Encode, Decode)] - struct Empty; - - #[derive(Debug, PartialEq, Encode, Decode)] - struct Tuple(usize, u8); - - #[derive(Debug, PartialEq, Encode, Decode)] - struct Generic(usize, T); - - #[derive(Debug, PartialEq, Encode, Decode)] - struct FooInner { - foo: u8, - #[bitcode_hint(gamma)] - bar: usize, - baz: String, - } - - #[derive(Debug, PartialEq, Encode, Decode)] - #[allow(unused)] - enum Foo { - #[bitcode_hint(frequency = 100)] - A, - #[bitcode_hint(frequency = 10)] - B(String), - C { - #[bitcode_hint(gamma)] - baz: usize, - qux: f32, - }, - #[bitcode_hint(fixed)] - Foo(FooInner, #[bitcode_hint(gamma)] i64), - #[bitcode_hint(expected_range = "0..10")] - Tuple(Tuple), - Empty(Empty), - } - - #[derive(Encode, Decode)] - enum Never {} - - #[derive(Copy, Clone, Debug, PartialEq, Encode, Decode)] - enum Xyz { - #[bitcode_hint(frequency = 2)] - X, - Y, - Z, - } - - #[test] - fn test_encode_x() { - let v = [Xyz::X; 16]; - let encoded = crate::encode(&v).unwrap(); - assert_eq!(encoded.len(), 2); - - let decoded: [Xyz; 16] = crate::decode(&encoded).unwrap(); - assert_eq!(v, decoded); - } - - #[test] - fn test_encode_y() { - let v = [Xyz::Y; 16]; - let encoded = crate::encode(&v).unwrap(); - assert_eq!(encoded.len(), 4); - - let decoded: [Xyz; 16] = crate::decode(&encoded).unwrap(); - assert_eq!(v, decoded); - } -} diff --git a/src/benches.rs b/src/benches.rs index 5f2d13c..7e9c1c5 100644 --- a/src/benches.rs +++ b/src/benches.rs @@ -1,62 +1,78 @@ -use crate::Buffer; -use crate::{Decode, Encode}; -use bincode::Options; -use flate2::read::DeflateDecoder; -use flate2::write::DeflateEncoder; -use flate2::Compression; -use lz4_flex::{compress_prepend_size, decompress_size_prepended}; -use paste::paste; -use rand::distributions::Alphanumeric; use rand::prelude::*; use rand_chacha::ChaCha20Rng; -use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use test::{black_box, Bencher}; - -// type StringImpl = arrayvec::ArrayString<16>; -type StringImpl = String; - -#[derive(Debug, Default, PartialEq, Encode, Decode, Serialize, Deserialize)] -struct Data { - #[bitcode_hint(expected_range = "0.0..1.0")] - x: Option, - y: Option, - z: u16, - #[bitcode_hint(ascii)] - s: StringImpl, - e: DataEnum, -} +use test::black_box; + +#[cfg(feature = "arrayvec")] +use arrayvec::{ArrayString, ArrayVec}; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "derive", derive(crate::Encode, crate::Decode))] +pub struct Data { + #[cfg(feature = "arrayvec")] + pub entity: ArrayString<8>, + #[cfg(not(feature = "arrayvec"))] + pub entity: String, + + pub x: u8, + pub y: bool, + + #[cfg(feature = "arrayvec")] + pub item: ArrayString<12>, + #[cfg(not(feature = "arrayvec"))] + pub item: String, + + pub z: u16, -fn gen_len(r: &mut (impl Rng + ?Sized)) -> usize { - (r.gen::().powi(4) * 16.0) as usize + #[cfg(feature = "arrayvec")] + pub e: ArrayVec, + #[cfg(not(feature = "arrayvec"))] + pub e: Vec, } +pub const MAX_DATA_ENUMS: usize = 5; impl Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> Data { - let n = gen_len(rng); Data { - x: rng.gen_bool(0.15).then(|| rng.gen()), - y: rng.gen_bool(0.3).then(|| rng.gen()), - z: rng.gen(), - s: StringImpl::try_from( - rng.sample_iter(Alphanumeric) - .take(n) - .map(char::from) - .collect::() - .as_str(), - ) + entity: (*[ + "cow", "sheep", "zombie", "skeleton", "spider", "creeper", "parrot", "bee", + ] + .choose(rng) + .unwrap()) + .try_into() .unwrap(), - e: rng.gen(), + x: rng.gen(), + y: rng.gen_bool(0.1), + item: (*[ + "dirt", + "stone", + "pickaxe", + "sand", + "gravel", + "shovel", + "chestplate", + "steak", + ] + .choose(rng) + .unwrap()) + .try_into() + .unwrap(), + z: rng.gen(), + e: (0..rng.gen_range(0..MAX_DATA_ENUMS)) + .map(|_| rng.gen()) + .collect(), } } } -#[derive(Debug, Default, PartialEq, Encode, Decode, Serialize, Deserialize)] -enum DataEnum { - #[default] - #[bitcode_hint(frequency = 10)] +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "derive", derive(crate::Encode, crate::Decode))] +pub enum DataEnum { Bar, - Baz(#[bitcode_hint(ascii)] StringImpl), + #[cfg(feature = "arrayvec")] + Baz(ArrayString<16>), + #[cfg(not(feature = "arrayvec"))] + Baz(String), Foo(Option), } @@ -65,19 +81,18 @@ impl Distribution for rand::distributions::Standard { if rng.gen_bool(0.9) { DataEnum::Bar } else if rng.gen_bool(0.5) { - let n = gen_len(rng); + let n = rng.gen_range(0..15); DataEnum::Baz( - StringImpl::try_from( - rng.sample_iter(Alphanumeric) - .take(n) - .map(char::from) - .collect::() - .as_str(), - ) - .unwrap(), + rng.sample_iter(rand::distributions::Alphanumeric) + .take(n) + .map(char::from) + .collect::() + .as_str() + .try_into() + .unwrap(), ) } else { - DataEnum::Foo(rng.gen_bool(0.5).then(|| rng.gen())) + DataEnum::Foo(rng.gen_bool(0.2).then(|| rng.gen())) } } } @@ -87,366 +102,193 @@ fn random_data(n: usize) -> Vec { (0..n).map(|_| rng.gen()).collect() } -fn bitcode_encode(v: &(impl Encode + ?Sized)) -> Vec { - crate::encode(v).unwrap() +#[cfg(feature = "derive")] +fn bitcode_encode(v: &(impl crate::Encode + ?Sized)) -> Vec { + crate::encode(v) } -fn bitcode_decode(v: &[u8]) -> T { +#[cfg(feature = "derive")] +fn bitcode_decode(v: &[u8]) -> T { crate::decode(v).unwrap() } +#[cfg(feature = "serde")] fn bitcode_serialize(v: &(impl Serialize + ?Sized)) -> Vec { - crate::serde::serialize(v).unwrap() -} - -fn bitcode_deserialize(v: &[u8]) -> T { - crate::serde::deserialize(v).unwrap() -} - -fn bincode_fixint_serialize(v: &(impl Serialize + ?Sized)) -> Vec { - bincode::serialize(v).unwrap() -} - -fn bincode_fixint_deserialize(v: &[u8]) -> T { - bincode::deserialize(v).unwrap() -} - -fn bincode_varint_serialize(v: &(impl Serialize + ?Sized)) -> Vec { - bincode::DefaultOptions::new().serialize(v).unwrap() -} - -fn bincode_varint_deserialize(v: &[u8]) -> T { - bincode::DefaultOptions::new().deserialize(v).unwrap() -} - -fn bincode_lz4_serialize(v: &(impl Serialize + ?Sized)) -> Vec { - compress_prepend_size(&bincode::DefaultOptions::new().serialize(v).unwrap()) -} - -fn bincode_lz4_deserialize(v: &[u8]) -> T { - bincode::DefaultOptions::new() - .deserialize(&decompress_size_prepended(v).unwrap()) - .unwrap() -} - -fn bincode_flate2_fast_serialize(v: &(impl Serialize + ?Sized)) -> Vec { - let mut e = DeflateEncoder::new(Vec::new(), Compression::fast()); - bincode::DefaultOptions::new() - .serialize_into(&mut e, v) - .unwrap(); - e.finish().unwrap() -} - -fn bincode_flate2_fast_deserialize(v: &[u8]) -> T { - bincode::DefaultOptions::new() - .deserialize_from(DeflateDecoder::new(v)) - .unwrap() -} - -fn bincode_flate2_best_serialize(v: &(impl Serialize + ?Sized)) -> Vec { - let mut e = DeflateEncoder::new(Vec::new(), Compression::best()); - bincode::DefaultOptions::new() - .serialize_into(&mut e, v) - .unwrap(); - e.finish().unwrap() -} - -fn bincode_flate2_best_deserialize(v: &[u8]) -> T { - bincode_flate2_fast_deserialize(v) -} - -fn postcard_serialize(v: &(impl Serialize + ?Sized)) -> Vec { - postcard::to_allocvec(v).unwrap() -} - -fn postcard_deserialize(buf: &[u8]) -> T { - postcard::from_bytes(buf).unwrap() -} - -fn bench_data() -> Vec { - random_data(1000) + crate::serialize(v).unwrap() } -fn bench_serialize(b: &mut Bencher, ser: fn(&[Data]) -> Vec) { - let data = bench_data(); - b.iter(|| { - black_box(ser(black_box(&data))); - }) +#[cfg(feature = "serde")] +fn bitcode_deserialize(v: &[u8]) -> T { + crate::deserialize(v).unwrap() } -fn bench_deserialize(b: &mut Bencher, ser: fn(&[Data]) -> Vec, de: fn(&[u8]) -> Vec) { - let data = bench_data(); - let serialized_data = &ser(&data); - assert_eq!(de(serialized_data), data); - b.iter(|| { - black_box(de(black_box(serialized_data))); - }) -} - -#[bench] -fn bench_bitcode_buffer_serialize(b: &mut Bencher) { - let data = bench_data(); - let mut buf = Buffer::new(); - buf.serialize(&data).unwrap(); - let initial_cap = buf.capacity(); - b.iter(|| { - black_box(buf.serialize(black_box(&data)).unwrap()); - }); - assert_eq!(buf.capacity(), initial_cap); -} - -#[bench] -fn bench_bitcode_buffer_deserialize(b: &mut Bencher) { - let data = bench_data(); - let bytes = &crate::serde::serialize(&data).unwrap(); - let mut buf = Buffer::new(); - assert_eq!(buf.deserialize::>(bytes).unwrap(), data); - let initial_cap = buf.capacity(); - b.iter(|| { - black_box(buf.deserialize::>(black_box(bytes)).unwrap()); - }); - assert_eq!(buf.capacity(), initial_cap); -} - -#[bench] -fn bench_bitcode_long_string_serialize(b: &mut Bencher) { - let data = "abcde1234☺".repeat(1000); - let mut buf = Buffer::new(); - buf.serialize(&data).unwrap(); - b.iter(|| { - black_box(buf.serialize(black_box(&data)).unwrap()); - }); -} - -#[bench] -fn bench_bitcode_long_string_deserialize(b: &mut Bencher) { - let data = "abcde1234☺".repeat(1000); - let mut buf = Buffer::new(); - let bytes = buf.serialize(&data).unwrap().to_vec(); - assert_eq!(buf.deserialize::(&bytes).unwrap(), data); - b.iter(|| { - black_box(buf.deserialize::(black_box(&bytes)).unwrap()); - }); +pub fn bench_data() -> Vec { + random_data(crate::limit_bench_miri(1000)) } +#[cfg(any(feature = "derive", feature = "serde"))] macro_rules! bench { ($serialize:ident, $deserialize:ident, $($name:ident),*) => { - paste! { + paste::paste! { $( #[bench] - fn [] (b: &mut Bencher) { - bench_serialize(b, [<$name _ $serialize>]) + fn [] (b: &mut test::Bencher) { + let data = bench_data(); + b.iter(|| { + black_box([<$name _ $serialize>](black_box(&data))); + }) } #[bench] - fn [] (b: &mut Bencher) { - bench_deserialize(b, [<$name _ $serialize>], [<$name _ $deserialize>]) + fn [] (b: &mut test::Bencher) { + let data = bench_data(); + let serialized_data = &[<$name _ $serialize>](&data); + assert_eq!([<$name _ $deserialize>]::>(serialized_data), data); + b.iter(|| { + black_box([<$name _ $deserialize>]::>(black_box(serialized_data))); + }) } )* } } } -mod derive { - use super::*; - bench!(encode, decode, bitcode); -} - -bench!( - serialize, - deserialize, - bitcode, - bincode_fixint, - bincode_varint, - bincode_lz4, - bincode_flate2_fast, - bincode_flate2_best, - postcard -); +#[cfg(feature = "serde")] +bench!(serialize, deserialize, bitcode); +#[cfg(feature = "derive")] +bench!(encode, decode, bitcode); #[cfg(test)] mod tests { use super::*; + use bincode::Options; use std::time::{Duration, Instant}; - // cargo test comparison1 --release -- --nocapture --include-ignored + /// # With many allocations in deserialize + /// cargo test --release --features=serde -- --show-output comparison1 + /// + /// # With String -> ArrayString and Vec -> ArrayVec + /// cargo test --release --all-features -- --show-output comparison1 #[test] - #[ignore = "don't run unless --include-ignored"] + #[cfg_attr(debug_assertions, ignore = "don't run unless --include-ignored")] fn comparison1() { let data = &random_data(10000); - let print_results = - |name: &'static str, ser: fn(&[Data]) -> Vec, de: fn(&[u8]) -> Vec| { - let b = ser(&data); - let zeros = 100.0 * b.iter().filter(|&&b| b == 0).count() as f32 / b.len() as f32; - let _precision = 2 - (zeros.log10().ceil() as usize).min(1); - - fn benchmark_ns(f: impl Fn()) -> usize { - const WARMUP: usize = 10; + let print_single = |name: &str, + compression: &str, + ser: &dyn Fn(&[Data]) -> Vec, + de: &dyn Fn(&[u8]) -> Vec| { + let b = ser(&data); + // if compression.is_empty() { + // print!("{name} {compression} "); + // println!("{}", String::from_utf8_lossy(&b).replace(char::is_control, "�")); + // // println!("{:?}", b); + // } + + fn benchmark_ns(f: impl Fn()) -> usize { + const WARMUP: usize = 2; + let start = Instant::now(); + for _ in 0..WARMUP { + f(); + } + let warmup_duration = start.elapsed(); + let per_second = (WARMUP as f32 / warmup_duration.as_secs_f32()) as usize; + let samples: usize = (per_second / 32).max(1); + let mut duration = Duration::ZERO; + for _ in 0..samples { let start = Instant::now(); - for _ in 0..WARMUP { - f(); - } - let warmup_duration = start.elapsed(); - let per_second = (WARMUP as f32 / warmup_duration.as_secs_f32()) as usize; - let samples: usize = (per_second / 4).max(10); - let mut duration = Duration::ZERO; - for _ in 0..samples { - let start = Instant::now(); - f(); - duration += start.elapsed(); - } - duration.as_nanos() as usize / samples + f(); + duration += start.elapsed(); } + duration.as_nanos() as usize / samples + } - let ser_time = benchmark_ns(|| { - black_box(ser(black_box(&data))); - }) / data.len(); + let ser_time = benchmark_ns(|| { + black_box(ser(black_box(&data))); + }) / data.len(); - let de_time = benchmark_ns(|| { - black_box(de(black_box(&b))); - }) / data.len(); + let de_time = benchmark_ns(|| { + black_box(de(black_box(&b))); + }) / data.len(); - // {zeros:>4.1$}% - println!( - "| {name:<22} | {:<12.1} | {ser_time:<10} | {de_time:<10} |", + println!( + "| {name:<16} | {compression:<12} | {:<12.1} | {ser_time:<10} | {de_time:<10} |", b.len() as f32 / data.len() as f32, - //precision, ); - }; + }; - println!("| Format | Size (bytes) | Serialize (ns) | Deserialize (ns) |"); - println!("|------------------------|--------------|----------------|------------------|"); - print_results("Bitcode (derive)", bitcode_encode, bitcode_decode); - print_results("Bitcode (serde)", bitcode_serialize, bitcode_deserialize); - print_results( - "Bincode", - bincode_fixint_serialize, - bincode_fixint_deserialize, - ); - print_results( - "Bincode (varint)", - bincode_varint_serialize, - bincode_varint_deserialize, - ); + let print_results = + |name: &str, ser: fn(&[Data]) -> Vec, de: fn(&[u8]) -> Vec| { + for (compression, encode, decode) in compression::ALGORITHMS { + print_single(name, compression, &|v| encode(&ser(v)), &|v| de(&decode(v))); + } + }; - // These use varint since it makes the result smaller and actually speeds up compression. - print_results( - "Bincode (LZ4)", - bincode_lz4_serialize, - bincode_lz4_deserialize, - ); + println!("| Format | Compression | Size (bytes) | Serialize (ns) | Deserialize (ns) |"); + println!("|------------------|--------------|--------------|----------------|------------------|"); print_results( - "Bincode (Deflate Fast)", - bincode_flate2_fast_serialize, - bincode_flate2_fast_deserialize, + "bincode", + |v| bincode::serialize(v).unwrap(), + |v| bincode::deserialize(v).unwrap(), ); print_results( - "Bincode (Deflate Best)", - bincode_flate2_best_serialize, - bincode_flate2_best_deserialize, + "bincode-varint", + |v| bincode::DefaultOptions::new().serialize(v).unwrap(), + |v| bincode::DefaultOptions::new().deserialize(v).unwrap(), ); - - // TODO compressed postcard. - print_results("Postcard", postcard_serialize, postcard_deserialize); + #[cfg(feature = "serde")] + print_results("bitcode", bitcode_serialize, bitcode_deserialize); + #[cfg(feature = "derive")] + print_results("bitcode-derive", bitcode_encode, bitcode_decode); } +} - #[test] - #[cfg(debug_assertions)] - fn comparison2() { - use std::ops::RangeInclusive; - - fn compare_inner(name: &str, r: RangeInclusive) -> String { - fn measure(t: &T) -> [usize; 5] { - const COUNT: usize = 8; - let many: [&T; COUNT] = std::array::from_fn(|_| t); - [ - bitcode_encode(&many).len(), - bitcode_serialize(&many).len(), - bincode_fixint_serialize(&many).len(), - bincode_varint_serialize(&many).len(), - postcard_serialize(&many).len(), - ] - .map(|b| 8 * b / COUNT) - } - - let lo = measure(&r.start()); - let hi = measure(&r.end()); - - let v: Vec<_> = lo - .into_iter() - .zip(hi) - .map(|(lo, hi)| { - if lo == hi { - format!("{lo}") - } else { - format!("{lo}-{hi}") - } - }) - .collect(); - - format!( - "| {name:<19} | {:<16} | {:<15} | {:<7} | {:<16} | {:<8} |", - v[0], v[1], v[2], v[3], v[4], - ) - } +mod compression { + use flate2::read::DeflateDecoder; + use flate2::write::DeflateEncoder; + use flate2::Compression; + use lz4_flex::{compress_prepend_size, decompress_size_prepended}; + use std::io::{Read, Write}; + + pub static ALGORITHMS: &[(&str, fn(&[u8]) -> Vec, fn(&[u8]) -> Vec)] = &[ + ("", ToOwned::to_owned, ToOwned::to_owned), + ("lz4", lz4_encode, lz4_decode), + ("deflate-fast", deflate_fast_encode, deflate_decode), + ("deflate-best", deflate_best_encode, deflate_decode), + ("zstd-0", zstd_encode::<0>, zstd_decode), + ("zstd-22", zstd_encode::<22>, zstd_decode), + ]; + + fn lz4_encode(v: &[u8]) -> Vec { + compress_prepend_size(v) + } - fn compare(name: &str, r: RangeInclusive) { - println!("{}", compare_inner(name, r)); - } + fn lz4_decode(v: &[u8]) -> Vec { + decompress_size_prepended(v).unwrap() + } - fn compare_one(name: &str, t: T) { - println!("{}", compare_inner(name, &t..=&t)); - } + fn deflate_fast_encode(v: &[u8]) -> Vec { + let mut e = DeflateEncoder::new(Vec::new(), Compression::fast()); + e.write_all(v).unwrap(); + e.finish().unwrap() + } - fn compare_int( - name: &str, - u: RangeInclusive, - s: RangeInclusive, - ) { - let unsigned = compare_inner(name, u); - let signed = compare_inner(name, s); - assert_eq!(unsigned, signed, "unsigned/signed sizes are different"); - println!("{unsigned}"); - } + fn deflate_best_encode(v: &[u8]) -> Vec { + let mut e = DeflateEncoder::new(Vec::new(), Compression::best()); + e.write_all(v).unwrap(); + e.finish().unwrap() + } - #[derive(Clone, Encode, Decode, Serialize, Deserialize)] - enum Enum { - A, - B, - C, - D, - } + fn deflate_decode(v: &[u8]) -> Vec { + let mut bytes = vec![]; + DeflateDecoder::new(v).read_to_end(&mut bytes).unwrap(); + bytes + } - println!("| Type | Bitcode (derive) | Bitcode (serde) | Bincode | Bincode (varint) | Postcard |"); - println!("|---------------------|------------------|-----------------|---------|------------------|----------|"); - compare("bool", false..=true); - compare_int("u8/i8", 0u8..=u8::MAX, 0i8..=i8::MAX); - compare_int("u16/i16", 0u16..=u16::MAX, 0i16..=i16::MAX); - compare_int("u32/i32", 0u32..=u32::MAX, 0i32..=i32::MAX); - compare_int("u64/i64", 0u64..=u64::MAX, 0i64..=i64::MAX); - compare_int("u128/i128", 0u128..=u128::MAX, 0i128..=i128::MAX); - compare_int("usize/isize", 0usize..=usize::MAX, 0isize..=isize::MAX); - compare_one("f32", 0f32); - compare_one("f64", 0f64); - compare("char", (0 as char)..=char::MAX); - compare("Option<()>", None..=Some(())); - compare("Result<(), ()>", Ok(())..=Err(())); - compare("enum { A, B, C, D }", Enum::A..=Enum::D); - compare( - "Duration", - Duration::ZERO..=Duration::new(u64::MAX, 999_999_999), - ); + fn zstd_encode(v: &[u8]) -> Vec { + zstd::stream::encode_all(v, LEVEL).unwrap() + } - println!(); - println!("| Value | Bitcode (derive) | Bitcode (serde) | Bincode | Bincode (varint) | Postcard |"); - println!("|---------------------|------------------|-----------------|---------|------------------|----------|"); - compare_one("[true; 4]", [true; 4]); - compare_one("vec![(); 0]", vec![(); 0]); - compare_one("vec![(); 1]", vec![(); 1]); - compare_one("vec![(); 256]", vec![(); 256]); - compare_one("vec![(); 65536]", vec![(); 65536]); - compare_one(r#""""#, ""); - compare_one(r#""abcd""#, "abcd"); - compare_one(r#""abcd1234""#, "abcd1234"); + fn zstd_decode(v: &[u8]) -> Vec { + zstd::stream::decode_all(v).unwrap() } } diff --git a/src/benches_borrowed.rs b/src/benches_borrowed.rs new file mode 100644 index 0000000..e0f47f9 --- /dev/null +++ b/src/benches_borrowed.rs @@ -0,0 +1,130 @@ +use crate::benches::{bench_data, Data, DataEnum, MAX_DATA_ENUMS}; +use serde::{Deserialize, Serialize}; +use std::array; +use test::{black_box, Bencher}; + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "derive", derive(crate::Encode, crate::Decode))] +struct Data2<'a> { + entity: &'a str, + x: u8, + y: bool, + item: &'a str, + z: u16, + e: [DataEnum2<'a>; MAX_DATA_ENUMS], +} + +impl<'a> From<&'a Data> for Data2<'a> { + fn from(v: &'a Data) -> Self { + Self { + entity: &v.entity, + x: v.x, + y: v.y, + item: &v.item, + z: v.z, + e: array::from_fn(|i| v.e.get(i).map(From::from).unwrap_or(DataEnum2::Bar)), + } + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[cfg_attr(feature = "derive", derive(crate::Encode, crate::Decode))] +enum DataEnum2<'a> { + Bar, + Baz(&'a str), + Foo(Option), +} + +impl<'a> From<&'a DataEnum> for DataEnum2<'a> { + fn from(v: &'a DataEnum) -> Self { + match v { + DataEnum::Bar => Self::Bar, + DataEnum::Baz(v) => Self::Baz(v), + DataEnum::Foo(v) => Self::Foo(*v), + } + } +} + +fn bench_data2(bench_data: &[Data]) -> Vec { + bench_data.iter().map(From::from).collect() +} + +#[bench] +fn bench_bincode_serialize(b: &mut Bencher) { + let data = bench_data(); + let data = bench_data2(&data); + let mut buffer = vec![]; + + b.iter(|| { + let buffer = black_box(&mut buffer); + buffer.clear(); + bincode::serialize_into(buffer, black_box(&data)).unwrap(); + }) +} + +#[bench] +fn bench_bincode_deserialize(b: &mut Bencher) { + let data = bench_data(); + let data = bench_data2(&data); + let mut bytes = vec![]; + bincode::serialize_into(&mut bytes, &data).unwrap(); + + assert_eq!( + bincode::deserialize::>(&mut bytes.as_slice()).unwrap(), + data + ); + b.iter(|| { + black_box(bincode::deserialize::>(&mut black_box(bytes.as_slice())).unwrap()); + }) +} + +#[cfg(feature = "derive")] +#[bench] +fn bench_bitcode_encode(b: &mut Bencher) { + let data = bench_data(); + let data = bench_data2(&data); + let mut buffer = crate::EncodeBuffer::default(); + + b.iter(|| { + black_box(buffer.encode(black_box(&data))); + }) +} + +#[cfg(feature = "derive")] +#[bench] +fn bench_bitcode_decode(b: &mut Bencher) { + let data = bench_data(); + let data = bench_data2(&data); + let mut encode_buffer = crate::EncodeBuffer::default(); + let bytes = encode_buffer.encode(&data); + + let mut decode_buffer = crate::DecodeBuffer::>::default(); + assert_eq!(decode_buffer.decode(bytes).unwrap(), data); + b.iter(|| { + black_box(decode_buffer.decode(black_box(bytes)).unwrap()); + }) +} + +#[cfg(feature = "serde")] +#[bench] +fn bench_bitcode_serialize(b: &mut Bencher) { + let data = bench_data(); + let data = bench_data2(&data); + + b.iter(|| { + black_box(crate::serialize(black_box(&data)).unwrap()); + }) +} + +#[cfg(feature = "serde")] +#[bench] +fn bench_bitcode_deserialize(b: &mut Bencher) { + let data = bench_data(); + let data = bench_data2(&data); + let bytes = crate::serialize(&data).unwrap(); + + assert_eq!(crate::deserialize::>(&bytes).unwrap(), data); + b.iter(|| { + black_box(crate::deserialize::>(black_box(bytes.as_slice())).unwrap()); + }) +} diff --git a/src/bit_buffer.rs b/src/bit_buffer.rs deleted file mode 100644 index 058de06..0000000 --- a/src/bit_buffer.rs +++ /dev/null @@ -1,228 +0,0 @@ -use crate::buffer::BufferTrait; -use crate::encoding::ByteEncoding; -use crate::read::Read; -use crate::word::*; -use crate::write::Write; -use crate::{Result, E}; -use bitvec::domain::Domain; -use bitvec::prelude::*; -use std::num::NonZeroUsize; - -/// A slow proof of concept [`Buffer`] that uses [`BitVec`]. Useful for comparison. -#[derive(Debug, Default)] -pub struct BitBuffer { - bits: BitVec, - read_bytes_buf: Box<[u8]>, -} - -impl BufferTrait for BitBuffer { - type Writer = BitWriter; - type Reader<'a> = BitReader<'a>; - type Context = (); - - fn capacity(&self) -> usize { - self.bits.capacity() / u8::BITS as usize - } - - fn with_capacity(cap: usize) -> Self { - Self { - bits: BitVec::with_capacity(cap * u8::BITS as usize), - ..Default::default() - } - } - - fn start_write(&mut self) -> Self::Writer { - self.bits.clear(); - Self::Writer { - bits: std::mem::take(&mut self.bits), - } - } - - fn finish_write(&mut self, writer: Self::Writer) -> &[u8] { - let Self::Writer { bits } = writer; - self.bits = bits; - - self.bits.force_align(); - self.bits.as_raw_slice() - } - - fn start_read<'a>(&'a mut self, bytes: &'a [u8]) -> (Self::Reader<'a>, Self::Context) { - let bits = BitSlice::from_slice(bytes); - let reader = Self::Reader { - bits, - read_bytes_buf: &mut self.read_bytes_buf, - advanced_too_far: false, - }; - - (reader, ()) - } - - fn finish_read(reader: Self::Reader<'_>, _: Self::Context) -> Result<()> { - if reader.advanced_too_far { - return Err(E::Eof.e()); - } - - if reader.bits.is_empty() { - return Ok(()); - } - - // Make sure no trailing 1 bits or zero bytes. - let e = match reader.bits.domain() { - Domain::Enclave(e) => e, - Domain::Region { head, body, tail } => { - if !body.is_empty() { - return Err(E::ExpectedEof.e()); - } - head.xor(tail).ok_or_else(|| E::ExpectedEof.e())? - } - }; - (e.into_bitslice().count_ones() == 0) - .then_some(()) - .ok_or_else(|| E::ExpectedEof.e()) - } -} - -pub struct BitWriter { - bits: BitVec, -} - -impl Write for BitWriter { - type Revert = usize; - fn get_revert(&mut self) -> Self::Revert { - self.bits.len() - } - fn revert(&mut self, index: Self::Revert) { - self.bits.truncate(index); - } - - fn write_bit(&mut self, v: bool) { - self.bits.push(v); - } - - fn write_bits(&mut self, word: Word, bits: usize) { - self.bits - .extend_from_bitslice(&BitSlice::::from_slice(&word.to_le_bytes())[..bits]); - } - - fn write_bytes(&mut self, bytes: &[u8]) { - self.bits - .extend_from_bitslice(BitSlice::::from_slice(bytes)); - } - - fn write_encoded_bytes(&mut self, bytes: &[u8]) -> bool { - for b in bytes { - let word = *b as Word; - if !C::validate(word, 1) { - return false; - } - self.write_bits(C::pack(word), C::BITS_PER_BYTE) - } - true - } -} - -pub struct BitReader<'a> { - bits: &'a BitSlice, - read_bytes_buf: &'a mut Box<[u8]>, - advanced_too_far: bool, -} - -impl BitReader<'_> { - fn read_slice(&mut self, bits: usize) -> Result<&BitSlice> { - if bits > self.bits.len() { - return Err(E::Eof.e()); - } - - let (slice, remaining) = self.bits.split_at(bits); - self.bits = remaining; - Ok(slice) - } -} - -impl Read for BitReader<'_> { - fn advance(&mut self, bits: usize) { - if bits > self.bits.len() { - // Handle the error later since we can't return it. - self.advanced_too_far = true; - } - self.bits = &self.bits[bits.min(self.bits.len())..]; - } - - fn peek_bits(&mut self) -> Result { - if self.advanced_too_far { - return Err(E::Eof.e()); - } - - let bits = self.bits.len().min(64); - - let mut v = [0; 8]; - BitSlice::::from_slice_mut(&mut v)[..bits].copy_from_bitslice(&self.bits[..bits]); - Ok(Word::from_le_bytes(v)) - } - - fn read_bit(&mut self) -> Result { - Ok(self.read_slice(1)?[0]) - } - - fn read_bits(&mut self, bits: usize) -> Result { - let slice = self.read_slice(bits)?; - - let mut v = [0; 8]; - BitSlice::::from_slice_mut(&mut v)[..bits].copy_from_bitslice(slice); - Ok(Word::from_le_bytes(v)) - } - - fn read_bytes(&mut self, len: NonZeroUsize) -> Result<&[u8]> { - let len = len.get(); - - // Take to avoid borrowing issue. - let mut tmp = std::mem::take(self.read_bytes_buf); - - let bits = len - .checked_mul(u8::BITS as usize) - .ok_or_else(|| E::Eof.e())?; - let slice = self.read_slice(bits)?; - - // Only allocate after reserve_read to prevent memory exhaustion attacks. - if tmp.len() < len { - tmp = vec![0; len.next_power_of_two()].into_boxed_slice() - } - - tmp.as_mut_bits()[..slice.len()].copy_from_bitslice(slice); - *self.read_bytes_buf = tmp; - Ok(&self.read_bytes_buf[..len]) - } - - fn read_encoded_bytes(&mut self, len: NonZeroUsize) -> Result<&[u8]> { - let len = len.get(); - - // Take to avoid borrowing issue. - let mut tmp = std::mem::take(self.read_bytes_buf); - - let bits = len - .checked_mul(C::BITS_PER_BYTE) - .ok_or_else(|| E::Eof.e())?; - let slice = self.read_slice(bits)?; - - // Only allocate after reserve_read to prevent memory exhaustion attacks. - if tmp.len() < len { - tmp = vec![0; len.next_power_of_two()].into_boxed_slice() - } - - for (dst, src) in tmp[..len].iter_mut().zip(slice.chunks(C::BITS_PER_BYTE)) { - let mut byte = [0u8]; - byte.as_mut_bits()[..C::BITS_PER_BYTE].copy_from_bitslice(src); - *dst = C::unpack(byte[0] as Word) as u8; - } - *self.read_bytes_buf = tmp; - Ok(&self.read_bytes_buf[..len]) - } - - fn reserve_bits(&self, bits: usize) -> Result<()> { - if bits <= self.bits.len() { - Ok(()) - } else { - Err(E::Eof.e()) - } - } -} diff --git a/src/bool.rs b/src/bool.rs new file mode 100644 index 0000000..19cd31b --- /dev/null +++ b/src/bool.rs @@ -0,0 +1,79 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; +use crate::pack::{pack_bools, unpack_bools}; +use std::num::NonZeroUsize; + +#[derive(Debug, Default)] +pub struct BoolEncoder(VecImpl); + +impl Encoder for BoolEncoder { + #[inline(always)] + fn as_primitive(&mut self) -> Option<&mut VecImpl> { + Some(&mut self.0) + } + + #[inline(always)] + fn encode(&mut self, t: &bool) { + unsafe { self.0.push_unchecked(*t) }; + } +} + +impl Buffer for BoolEncoder { + fn collect_into(&mut self, out: &mut Vec) { + pack_bools(self.0.as_slice(), out); + self.0.clear(); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve(additional.get()); + } +} + +#[derive(Debug, Default)] +pub struct BoolDecoder<'a>(CowSlice<'a, bool>); + +impl<'a> View<'a> for BoolDecoder<'a> { + fn populate(&mut self, input: &mut &'_ [u8], length: usize) -> Result<()> { + unpack_bools(input, length, &mut self.0)?; + Ok(()) + } +} + +impl<'a> Decoder<'a, bool> for BoolDecoder<'a> { + #[inline(always)] + fn as_primitive_ptr(&self) -> Option<*const u8> { + Some(self.0.ref_slice().as_ptr() as *const u8) + } + + #[inline(always)] + unsafe fn as_primitive_advance(&mut self, n: usize) { + self.0.mut_slice().advance(n); + } + + #[inline(always)] + fn decode(&mut self) -> bool { + unsafe { self.0.mut_slice().next_unchecked() } + } +} + +#[cfg(test)] +mod test { + fn bench_data() -> Vec { + (0..=1000).map(|_| false).collect() + } + crate::bench_encode_decode!(bool_vec: Vec<_>); +} + +#[cfg(test)] +mod test2 { + fn bench_data() -> Vec> { + crate::random_data::(125) + .into_iter() + .map(|n| { + let n = 1 + n / 16; + (0..n).map(|_| false).collect() + }) + .collect() + } + crate::bench_encode_decode!(bool_vecs: Vec>); +} diff --git a/src/buffer.rs b/src/buffer.rs deleted file mode 100644 index 0f817c4..0000000 --- a/src/buffer.rs +++ /dev/null @@ -1,113 +0,0 @@ -use crate::read::Read; -use crate::word_buffer::WordBuffer; -use crate::write::Write; -use crate::{Result, E}; - -/// A buffer for reusing allocations between any number of calls to [`Buffer::encode`] and/or -/// [`Buffer::decode`]. -/// -/// ### Usage -/// ```edition2021 -/// use bitcode::Buffer; -/// -/// // We preallocate buffers with capacity 1000. This will allow us to encode and decode without -/// // any allocations as long as the encoded object takes less than 1000 bytes. -/// let bytes = 1000; -/// let mut encode_buf = Buffer::with_capacity(bytes); -/// let mut decode_buf = Buffer::with_capacity(bytes); -/// -/// // The object that we will encode. -/// let target: [u8; 5] = [1, 2, 3, 4, 5]; -/// -/// // We encode into `encode_buf`. This won't cause any allocations. -/// let encoded: &[u8] = encode_buf.encode(&target).unwrap(); -/// assert!(encoded.len() <= bytes, "oh no we allocated"); -/// -/// // We decode into `decode_buf` because `encoded` is borrowing `encode_buf`. -/// let decoded: [u8; 5] = decode_buf.decode(&encoded).unwrap(); -/// assert_eq!(target, decoded); -/// -/// // If we need ownership of `encoded`, we can convert it to a vec. -/// // This will allocate, but it's still more efficient than calling bitcode::encode. -/// let _owned: Vec = encoded.to_vec(); -/// ``` -#[derive(Default)] -pub struct Buffer(pub(crate) WordBuffer); - -impl Buffer { - /// Constructs a new buffer without any capacity. - pub fn new() -> Self { - Self::default() - } - - /// Constructs a new buffer with at least the specified capacity in bytes. - pub fn with_capacity(capacity: usize) -> Self { - Self(BufferTrait::with_capacity(capacity)) - } - - /// Returns the capacity in bytes. - #[cfg(test)] - pub(crate) fn capacity(&self) -> usize { - self.0.capacity() - } -} - -pub trait BufferTrait: Default { - type Writer: Write; - type Reader<'a>: Read; - type Context; - - fn capacity(&self) -> usize; - fn with_capacity(capacity: usize) -> Self; - - /// Clears the buffer. - fn start_write(&mut self) -> Self::Writer; - /// Returns the written bytes. - fn finish_write(&mut self, writer: Self::Writer) -> &[u8]; - - fn start_read<'a>(&'a mut self, bytes: &'a [u8]) -> (Self::Reader<'a>, Self::Context); - /// Check for errors such as Eof and ExpectedEof - fn finish_read(reader: Self::Reader<'_>, context: Self::Context) -> Result<()>; - /// Overrides decoding errors with Eof since the reader might allow reading slightly past the - /// end. Only WordBuffer currently does this. - fn finish_read_with_result( - reader: Self::Reader<'_>, - context: Self::Context, - decode_result: Result, - ) -> Result { - let finish_result = Self::finish_read(reader, context); - if let Err(e) = &finish_result { - if e.same(&E::Eof.e()) { - return Err(E::Eof.e()); - } - } - let t = decode_result?; - finish_result?; - Ok(t) - } -} - -#[cfg(all(test, not(miri), debug_assertions))] -mod tests { - use crate::bit_buffer::BitBuffer; - use crate::buffer::BufferTrait; - use crate::word_buffer::WordBuffer; - use paste::paste; - - macro_rules! test_with_capacity { - ($name:ty, $t:ty) => { - paste! { - #[test] - fn []() { - for cap in 0..200 { - let buf = $t::with_capacity(cap); - assert!(buf.capacity() >= cap, "with_capacity: {cap}, capacity {}", buf.capacity()); - } - } - } - } - } - - test_with_capacity!(bit_buffer, BitBuffer); - test_with_capacity!(word_buffer, WordBuffer); -} diff --git a/src/code.rs b/src/code.rs deleted file mode 100644 index 1b11a6a..0000000 --- a/src/code.rs +++ /dev/null @@ -1,512 +0,0 @@ -use crate::buffer::BufferTrait; -use crate::encoding::{Encoding, Fixed}; -use crate::read::Read; -use crate::write::Write; -use crate::Result; - -pub(crate) fn encode_internal<'a>( - buffer: &'a mut impl BufferTrait, - t: &(impl Encode + ?Sized), -) -> Result<&'a [u8]> { - let mut writer = buffer.start_write(); - t.encode(Fixed, &mut writer)?; - Ok(buffer.finish_write(writer)) -} - -pub(crate) fn decode_internal( - buffer: &mut B, - bytes: &[u8], -) -> Result { - let (mut reader, context) = buffer.start_read(bytes); - let decode_result = T::decode(Fixed, &mut reader); - B::finish_read_with_result(reader, context, decode_result) -} - -/// A type which can be encoded to bytes with [`encode`][`crate::encode`]. -/// -/// Must use `#[derive(Encode)]` to implement. -/// ``` -/// #[derive(bitcode::Encode)] -/// // If your struct contains itself you must annotate it with `#[bitcode(recursive)]`. -/// // This disables certain speed optimizations that aren't possible on recursive types. -/// struct MyStruct { -/// a: u32, -/// b: bool, -/// // If you want to use serde::Serialize on a field instead of bitcode::Encode. -/// #[cfg(feature = "serde")] -/// #[bitcode(with_serde)] -/// c: String, -/// } -/// ``` -pub trait Encode { - // The minimum and maximum number of bits a type can encode as. For now these are only valid if - // the encoding is fixed. Before using them make sure the encoding passed to encode is fixed. - // TODO make these const functions that take an encoding (once const fn is available in traits). - #[doc(hidden)] - const ENCODE_MIN: usize; - - // If max is lower than the actual max, we may not encode all the bits. - #[doc(hidden)] - const ENCODE_MAX: usize; - - #[doc(hidden)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()>; -} - -/// A type which can be decoded from bytes with [`decode`][`crate::decode`]. -/// -/// Must use `#[derive(Decode)]` to implement. -/// ``` -/// #[derive(bitcode::Decode)] -/// // If your struct contains itself you must annotate it with `#[bitcode(recursive)]`. -/// // This disables certain speed optimizations that aren't possible on recursive types. -/// struct MyStruct { -/// a: u32, -/// b: bool, -/// // If you want to use serde::Deserialize on a field instead of bitcode::Decode. -/// #[cfg(feature = "serde")] -/// #[bitcode(with_serde)] -/// c: String, -/// } -/// ``` -pub trait Decode: Sized { - // Copy of Encode constants. See Encode for details. - // If min is higher than the actual min, we may get EOFs. - #[doc(hidden)] - const DECODE_MIN: usize; - - #[doc(hidden)] - const DECODE_MAX: usize; - - #[doc(hidden)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result; -} - -/// A macro that facilitates writing to a RegisterWriter when encoding multiple values less than 64 bits. -/// This can dramatically speed operations like encoding a tuple of 8 bytes. -/// -/// Once you call `optimized_enc!()`, you must call `end_enc!()` at the end to flush the remaining bits. -/// If the execution path diverges it must never converge or this won't optimize well. -#[doc(hidden)] -#[macro_export] -macro_rules! optimized_enc { - ($encoding:ident, $writer:ident) => { - let mut buf = $crate::__private::RegisterWriter::new($writer); - #[allow(unused_mut)] - let mut i: usize = 0; - #[allow(unused)] - let no_encoding_upstream = $encoding.is_fixed(); - - // Call on each field (that doesn't get it's encoding overridden in the derive macro). - #[allow(unused)] - macro_rules! enc { - ($t:expr, $T:ty) => { - // ENCODE_MAX is only accurate if there isn't any encoding upstream. - // Downstream encodings make ENCODE_MAX = usize::MAX in derive macro. - if <$T as $crate::__private::Encode>::ENCODE_MAX.saturating_add(i) <= 64 - && no_encoding_upstream - { - <$T as $crate::__private::Encode>::encode(&$t, $encoding, &mut buf.inner)?; - } else { - if i != 0 { - buf.flush(); - } - - if <$T as $crate::__private::Encode>::ENCODE_MAX < 64 && no_encoding_upstream { - <$T as $crate::__private::Encode>::encode(&$t, $encoding, &mut buf.inner)?; - } else { - <$T as $crate::__private::Encode>::encode(&$t, $encoding, buf.writer)?; - } - } - - i = if <$T as $crate::__private::Encode>::ENCODE_MAX.saturating_add(i) <= 64 - && no_encoding_upstream - { - <$T as $crate::__private::Encode>::ENCODE_MAX + i - } else { - if <$T as $crate::__private::Encode>::ENCODE_MAX < 64 && no_encoding_upstream { - <$T as $crate::__private::Encode>::ENCODE_MAX - } else { - 0 - } - }; - }; - } - - // Call to flush the contents of the RegisterWriter and get the inner writer. - macro_rules! flush { - () => {{ - if i != 0 { - buf.flush(); - } - i = 0; - &mut *buf.writer - }}; - } - - // Call to encode an enum variant. Faster than flush!().write_bits($variant, $bits). - // Must be fist call after optimized_enc!. - #[allow(unused)] - macro_rules! enc_variant { - ($variant:literal, $bits:literal) => { - debug_assert!(i == 0); - buf.inner.write_bits($variant, $bits); - i = $bits + i; - }; - } - - // Call once done encoding. - macro_rules! end_enc { - () => { - let _ = flush!(); - let _ = i; - #[allow(clippy::drop_non_drop)] - drop(buf); - }; - } - }; -} -pub use optimized_enc; - -// These benchmarks ensure that optimized_enc is working. They all run about 8 times faster with optimized_enc. -#[cfg(all(test, not(miri)))] -mod optimized_enc_tests { - use std::collections::{BinaryHeap, VecDeque}; - use test::{black_box, Bencher}; - - type A = u8; - type B = u8; - - #[derive(Clone, Debug, PartialEq, crate::Encode, crate::Decode)] - struct Foo { - a: A, - b: B, - } - - #[bench] - fn bench_foo(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = Foo { a: 1, b: 2 }; - let foo = vec![foo; 4000]; - - let bytes = buffer.encode(&foo).unwrap().to_vec(); - let decoded: Vec = buffer.decode(&bytes).unwrap(); - assert_eq!(foo, decoded); - - b.iter(|| { - let foo = black_box(foo.as_slice()); - let bytes = buffer.encode(foo).unwrap(); - black_box(bytes); - }) - } - - #[bench] - fn bench_tuple(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = vec![(0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8); 1000]; - - b.iter(|| { - let foo = black_box(foo.as_slice()); - let bytes = buffer.encode(foo).unwrap(); - black_box(bytes); - }) - } - - #[bench] - fn bench_array(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = vec![[0u8; 8]; 1000]; - - b.iter(|| { - let foo = black_box(foo.as_slice()); - let bytes = buffer.encode(foo).unwrap(); - black_box(bytes); - }) - } - - #[bench] - fn bench_bool_slice(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = vec![false; 8 * 1000]; - - b.iter(|| { - let foo = black_box(foo.as_slice()); - let bytes = buffer.encode(foo).unwrap(); - black_box(bytes); - }) - } - - #[bench] - fn bench_vec(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = vec![0u8; 8 * 1000]; - - b.iter(|| { - let foo = black_box(foo.as_slice()); - let bytes = buffer.encode(foo).unwrap(); - black_box(bytes); - }) - } - - #[bench] - fn bench_vec_deque(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let mut foo = VecDeque::from(vec![0u8; 8000]); - for _ in 0..4000 { - // Make it not contiguous. - foo.pop_front().unwrap(); - foo.push_back(1u8); - } - - b.iter(|| { - let foo = black_box(&foo); - let bytes = buffer.encode(foo).unwrap(); - black_box(bytes); - }) - } - - // BinaryHeap::encode isn't optimized yet. - #[bench] - fn bench_binary_heap(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = BinaryHeap::from_iter((0u16..8000).map(|v| v as u8)); - - b.iter(|| { - let foo = black_box(&foo); - let bytes = buffer.encode(foo).unwrap(); - black_box(bytes); - }) - } -} - -/// A macro that facilitates reading from a RegisterReader when decoding multiple values less than 64 bits. -/// This can dramatically speed operations like decoding a tuple of 8 bytes. -/// -/// Once you call `optimized_dec!()`, you must call `end_dec!()` at the end to advance the reader. -/// If the execution path diverges it must never converge or this won't optimize well. -#[doc(hidden)] -#[macro_export] -macro_rules! optimized_dec { - ($encoding:ident, $reader:ident) => { - #[allow(unused_mut)] - let mut buf = $crate::__private::RegisterReader::new($reader); - #[allow(unused_mut)] - let mut i: usize = 0; - #[allow(unused)] - let no_encoding_upstream = $encoding.is_fixed(); - - // Call on each field (that doesn't get it's encoding overridden in the derive macro). - #[allow(unused)] - macro_rules! dec { - ($t:ident, $T:ty) => { - // DECODE_MAX is only accurate if there isn't any encoding upstream. - // Downstream encodings make DECODE_MAX = usize::MAX in derive macro. - let $t = if i >= <$T as $crate::__private::Decode>::DECODE_MAX - && no_encoding_upstream - { - <$T as $crate::__private::Decode>::decode($encoding, &mut buf.inner)? - } else { - if <$T as $crate::__private::Decode>::DECODE_MAX < 64 && no_encoding_upstream { - buf.refill()?; - <$T as $crate::__private::Decode>::decode($encoding, &mut buf.inner)? - } else { - buf.advance_reader(); - <$T as $crate::__private::Decode>::decode($encoding, buf.reader)? - } - }; - - i = if i >= <$T as $crate::__private::Decode>::DECODE_MAX && no_encoding_upstream { - i - <$T as $crate::__private::Decode>::DECODE_MAX - } else { - if <$T as $crate::__private::Decode>::DECODE_MAX < 64 && no_encoding_upstream { - // Needs saturating since it's const (even though we've checked it). - 64usize.saturating_sub(<$T as $crate::__private::Decode>::DECODE_MAX) - } else { - 0 - } - }; - }; - } - - // Call to flush the contents of the RegisterReader and get the inner reader. - macro_rules! flush { - () => {{ - let _ = i; - i = 0; - buf.advance_reader(); - &mut *buf.reader - }}; - } - - // Call to peek bits to decode an enum variant. Faster than flush!().peek_bits()?. - // Must be fist call after optimized_dec!. - #[allow(unused)] - macro_rules! dec_variant_peek { - () => {{ - buf.refill()?; - buf.inner.peek_bits()? - }}; - } - - // Call to advance bits to decode an enum variant. Faster than flush!().advance($bits)?. - // Must be second call after dec_variant_peek!. - #[allow(unused)] - macro_rules! dec_variant_advance { - ($bits:literal) => { - debug_assert!(i == 0); - buf.inner.advance($bits); - i = 64 - $bits; - }; - } - - // Call once done decoding. - macro_rules! end_dec { - () => { - let _ = flush!(); - let _ = i; - #[allow(clippy::drop_non_drop)] - drop(buf); - }; - } - }; -} -pub use optimized_dec; - -// These benchmarks ensure that optimized_dec is working. They run 4-8 times faster with optimized_dec. -#[cfg(all(test, not(miri)))] -mod optimized_dec_tests { - use std::collections::{BTreeSet, BinaryHeap, VecDeque}; - use test::{black_box, Bencher}; - - type A = u8; - type B = u8; - - #[derive(Clone, Debug, PartialEq, crate::Encode, crate::Decode)] - #[repr(C, align(8))] - struct Foo { - a: A, - b: B, - c: A, - d: B, - e: A, - f: B, - g: A, - h: B, - } - - #[bench] - fn bench_foo(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = Foo { - a: 1, - b: 2, - c: 3, - d: 4, - e: 5, - f: 6, - g: 7, - h: 8, - }; - let foo = vec![foo; 1000]; - type T = Vec; - - let bytes = buffer.encode(&foo).unwrap().to_vec(); - let decoded: T = buffer.decode(&bytes).unwrap(); - assert_eq!(foo, decoded); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::(bytes).unwrap()) - }) - } - - #[bench] - fn bench_tuple(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = vec![(0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8); 1000]; - type T = Vec<(u8, u8, u8, u8, u8, u8, u8, u8)>; - - let bytes = buffer.encode(&foo).unwrap().to_vec(); - let decoded: T = buffer.decode(&bytes).unwrap(); - assert_eq!(foo, decoded); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::(bytes).unwrap()) - }) - } - - #[bench] - fn bench_array(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = vec![[0u8; 8]; 1000]; - type T = Vec<[u8; 8]>; - - let bytes = buffer.encode(&foo).unwrap().to_vec(); - let decoded: T = buffer.decode(&bytes).unwrap(); - assert_eq!(foo, decoded); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::(bytes).unwrap()) - }) - } - - #[bench] - fn bench_vec(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = vec![0u8; 8000]; - type T = Vec; - - let bytes = buffer.encode(&foo).unwrap().to_vec(); - let decoded: T = buffer.decode(&bytes).unwrap(); - assert_eq!(foo, decoded); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::(bytes).unwrap()) - }) - } - - #[bench] - fn bench_vec_deque(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let mut foo = VecDeque::from(vec![0u8; 8000]); - for _ in 0..4000 { - // Make it not contiguous. - foo.pop_front().unwrap(); - foo.push_back(1u8); - } - type T = VecDeque; - - let bytes = buffer.encode(&foo).unwrap().to_vec(); - let decoded: T = buffer.decode(&bytes).unwrap(); - assert_eq!(foo, decoded); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::(bytes).unwrap()) - }) - } - - #[bench] - fn bench_binary_heap(b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let foo = BinaryHeap::from_iter((0u16..8000).map(|v| v as u8)); - type T = BinaryHeap; - - let bytes = buffer.encode(&foo).unwrap().to_vec(); - let decoded: T = buffer.decode(&bytes).unwrap(); - - // Binary heaps can't be compared directly. - assert_eq!( - BTreeSet::from_iter(foo.iter().copied()), - BTreeSet::from_iter(decoded.iter().copied()) - ); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::(bytes).unwrap()) - }) - } -} diff --git a/src/code_impls.rs b/src/code_impls.rs deleted file mode 100644 index 51899a6..0000000 --- a/src/code_impls.rs +++ /dev/null @@ -1,1039 +0,0 @@ -use crate::code::{optimized_dec, optimized_enc, Decode, Encode}; -use crate::encoding::{Encoding, Fixed, Gamma}; -use crate::guard::guard_len; -use crate::nightly::{max, min}; -use crate::read::Read; -use crate::write::Write; -use crate::{Result, E}; -use std::collections::{BTreeMap, HashMap, HashSet}; -use std::ffi::{CStr, CString}; -use std::hash::{BuildHasher, Hash}; -use std::marker::PhantomData; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::num::*; -use std::time::Duration; - -macro_rules! impl_enc_const { - ($v:expr) => { - const ENCODE_MIN: usize = $v; - const ENCODE_MAX: usize = $v; - }; -} - -macro_rules! impl_enc_size { - ($t:ty) => { - impl_enc_const!(std::mem::size_of::<$t>() * u8::BITS as usize); - }; -} - -macro_rules! impl_enc_same { - ($other:ty) => { - const ENCODE_MIN: usize = <$other>::ENCODE_MIN; - const ENCODE_MAX: usize = <$other>::ENCODE_MAX; - }; -} - -macro_rules! impl_dec_from_enc { - () => { - const DECODE_MIN: usize = Self::ENCODE_MIN; - const DECODE_MAX: usize = Self::ENCODE_MAX; - }; -} - -macro_rules! impl_dec_same { - ($other:ty) => { - const DECODE_MIN: usize = <$other>::DECODE_MIN; - const DECODE_MAX: usize = <$other>::DECODE_MAX; - }; -} - -impl Encode for bool { - impl_enc_const!(1); - - #[inline(always)] - fn encode(&self, _: impl Encoding, writer: &mut impl Write) -> Result<()> { - writer.write_bit(*self); - Ok(()) - } -} - -impl Decode for bool { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(_: impl Encoding, reader: &mut impl Read) -> Result { - reader.read_bit() - } -} - -macro_rules! impl_uints { - ($read:ident, $write:ident, $($int: ty),*) => { - $( - impl Encode for $int { - impl_enc_size!(Self); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - encoding.$write::<{ Self::BITS as usize }>(writer, (*self).into()); - Ok(()) - } - } - - impl Decode for $int { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - Ok(encoding.$read::<{ Self::BITS as usize }>(reader)? as Self) - } - } - )* - } -} - -macro_rules! impl_ints { - ($read:ident, $write:ident, $($int: ty => $uint: ty),*) => { - $( - impl Encode for $int { - impl_enc_size!(Self); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - let word = if encoding.zigzag() { - zigzag::ZigZagEncode::zigzag_encode(*self).into() - } else { - (*self as $uint).into() - }; - encoding.$write::<{ Self::BITS as usize }>(writer, word); - Ok(()) - } - } - - impl Decode for $int { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let word = encoding.$read::<{ Self::BITS as usize }>(reader)?; - let sint = if encoding.zigzag() { - zigzag::ZigZagDecode::zigzag_decode(word as $uint) - } else { - word as Self - }; - Ok(sint) - } - } - )* - } -} - -impl_uints!(read_u64, write_u64, u8, u16, u32, u64); -impl_ints!(read_u64, write_u64, i8 => u8, i16 => u16, i32 => u32, i64 => u64); -impl_uints!(read_u128, write_u128, u128); -impl_ints!(read_u128, write_u128, i128 => u128); - -macro_rules! impl_try_int { - ($a:ty, $b:ty) => { - impl Encode for $a { - impl_enc_size!($b); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - (*self as $b).encode(encoding, writer) - } - } - - impl Decode for $a { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - <$b>::decode(encoding, reader)? - .try_into() - .map_err(|_| E::Invalid(stringify!($a)).e()) - } - } - }; -} - -impl_try_int!(usize, u64); -impl_try_int!(isize, i64); - -macro_rules! impl_float { - ($a:ty, $write:ident, $read:ident) => { - impl Encode for $a { - impl_enc_size!($a); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - encoding.$write(writer, *self); - Ok(()) - } - } - - impl Decode for $a { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - encoding.$read(reader) - } - } - }; -} - -impl_float!(f32, write_f32, read_f32); -impl_float!(f64, write_f64, read_f64); - -// Subtracts 1 in encode and adds one in decode (so gamma is smaller). -macro_rules! impl_non_zero { - ($($a:ty),*) => { - $( - impl Encode for $a { - impl_enc_size!($a); - - #[inline(always)] - fn encode(&self, _: impl Encoding, writer: &mut impl Write) -> Result<()> { - (self.get() - 1).encode(Fixed, writer) - } - } - - impl Decode for $a { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(_: impl Encoding, reader: &mut impl Read) -> Result { - let v = Decode::decode(Fixed, reader)?; - let _ = Self::new(v); // Type inference. - Self::new(v.wrapping_add(1)).ok_or_else(|| E::Invalid("non zero").e()) - } - } - )* - }; -} - -impl_non_zero!(NonZeroU8, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroUsize); -impl_non_zero!(NonZeroI8, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroIsize); - -impl Encode for char { - impl_enc_const!(21); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - encoding.write_u64::<21>(writer, *self as u64); - Ok(()) - } -} - -impl Decode for char { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let bits = encoding.read_u64::<21>(reader)? as u32; - char::from_u32(bits).ok_or_else(|| E::Invalid("char").e()) - } -} - -impl Encode for Option { - const ENCODE_MIN: usize = 1; - const ENCODE_MAX: usize = T::ENCODE_MAX.saturating_add(1); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - if let Some(t) = self { - fn encode_some( - t: &T, - encoding: impl Encoding, - writer: &mut impl Write, - ) -> Result<()> { - optimized_enc!(encoding, writer); - enc!(true, bool); - enc!(t, T); - end_enc!(); - Ok(()) - } - encode_some(t, encoding, writer) - } else { - writer.write_false(); - Ok(()) - } - } -} - -impl Decode for Option { - const DECODE_MIN: usize = 1; - const DECODE_MAX: usize = T::DECODE_MAX.saturating_add(1); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - optimized_dec!(encoding, reader); - dec!(v, bool); - if v { - dec!(t, T); - end_dec!(); - Ok(Some(t)) - } else { - end_dec!(); - Ok(None) - } - } -} - -macro_rules! impl_either { - ($typ: path, $a: ident, $a_t:ty, $b:ident, $b_t: ty $(,$($generic: ident);*)*) => { - impl $(<$($generic: Encode),*>)* Encode for $typ { - const ENCODE_MIN: usize = 1 + min(<$a_t>::ENCODE_MIN, <$b_t>::ENCODE_MIN); - const ENCODE_MAX: usize = max(<$a_t>::ENCODE_MAX, <$b_t>::ENCODE_MAX).saturating_add(1); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - match self { - Self::$a(a) => { - writer.write_false(); - optimized_enc!(encoding, writer); - enc!(a, $a_t); - end_enc!(); - Ok(()) - }, - Self::$b(b) => { - optimized_enc!(encoding, writer); - enc!(true, bool); - enc!(b, $b_t); - end_enc!(); - Ok(()) - }, - } - } - } - - impl $(<$($generic: Decode),*>)* Decode for $typ { - const DECODE_MIN: usize = 1 + min(<$a_t>::DECODE_MIN, <$b_t>::DECODE_MIN); - const DECODE_MAX: usize = max(<$a_t>::DECODE_MAX, <$b_t>::DECODE_MAX).saturating_add(1); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - optimized_dec!(encoding, reader); - dec!(v, bool); - Ok(if v { - dec!(b, $b_t); - end_dec!(); - Self::$b(b) - } else { - dec!(a, $a_t); - end_dec!(); - Self::$a(a) - }) - } - } - } -} - -impl_either!(std::result::Result, Ok, T, Err, E, T ; E); - -macro_rules! impl_wrapper { - ($(::$ptr: ident)*) => { - impl Encode for $(::$ptr)* { - impl_enc_same!(T); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - T::encode(&self.0, encoding, writer) - } - } - - impl Decode for $(::$ptr)* { - impl_dec_same!(T); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - Ok(Self(T::decode(encoding, reader)?)) - } - } - } -} - -impl_wrapper!(::std::num::Wrapping); -impl_wrapper!(::std::cmp::Reverse); - -macro_rules! impl_smart_ptr { - ($(::$ptr: ident)*) => { - impl Encode for $(::$ptr)* { - impl_enc_same!(T); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - T::encode(self, encoding, writer) - } - } - - impl Decode for $(::$ptr)* { - impl_dec_same!(T); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - Ok(T::decode(encoding, reader)?.into()) - } - } - - impl Decode for $(::$ptr)*<[T]> { - impl_dec_same!(Vec); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - Ok(Vec::::decode(encoding, reader)?.into()) // TODO avoid Vec allocation for Rc<[T]> and Arc<[T]>. - } - } - - impl Decode for $(::$ptr)* { - impl_dec_same!(String); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - Ok(String::decode(encoding, reader)?.into()) - } - } - } -} - -impl_smart_ptr!(::std::boxed::Box); -impl_smart_ptr!(::std::rc::Rc); -impl_smart_ptr!(::std::sync::Arc); - -// Writes multiple elements per flush. -#[cfg_attr(not(debug_assertions), inline(always))] -fn encode_elements( - elements: &[T], - encoding: impl Encoding, - writer: &mut impl Write, -) -> Result<()> { - if T::ENCODE_MAX == 0 { - return Ok(()); // Nothing to serialize. - } - let chunk_size = 64 / T::ENCODE_MAX; - - if chunk_size > 1 && encoding.is_fixed() { - let mut buf = crate::register_buffer::RegisterWriter::new(writer); - - let chunks = elements.chunks_exact(chunk_size); - let remainder = chunks.remainder(); - - for chunk in chunks { - for t in chunk { - t.encode(encoding, &mut buf.inner)?; - } - buf.flush(); - } - - if !remainder.is_empty() { - for t in remainder { - t.encode(encoding, &mut buf.inner)?; - } - buf.flush(); - } - } else { - for t in elements.iter() { - t.encode(encoding, writer)? - } - } - Ok(()) -} - -// Reads multiple elements per flush. -#[cfg_attr(not(debug_assertions), inline(always))] -fn decode_elements( - len: usize, - encoding: impl Encoding, - reader: &mut impl Read, -) -> Result> { - let chunk_size = if encoding.is_fixed() && T::DECODE_MAX != 0 { - 64 / T::DECODE_MAX - } else { - 1 - }; - - if chunk_size >= 2 { - let chunks = len / chunk_size; - let remainder = len % chunk_size; - - let mut ret = Vec::with_capacity(len); - let mut buf = crate::register_buffer::RegisterReader::new(reader); - - for _ in 0..chunks { - buf.refill()?; - let r = &mut buf.inner; - - // This avoids checking if allocation is needed for every item for chunks divisible by 8. - // Adding more impls for other sizes slows down this case for some reason. - if chunk_size % 8 == 0 { - for _ in 0..chunk_size / 8 { - ret.extend([ - T::decode(encoding, r)?, - T::decode(encoding, r)?, - T::decode(encoding, r)?, - T::decode(encoding, r)?, - T::decode(encoding, r)?, - T::decode(encoding, r)?, - T::decode(encoding, r)?, - T::decode(encoding, r)?, - ]) - } - } else { - for _ in 0..chunk_size { - ret.push(T::decode(encoding, r)?) - } - } - } - - buf.refill()?; - for _ in 0..remainder { - ret.push(T::decode(encoding, &mut buf.inner)?); - } - buf.advance_reader(); - - Ok(ret) - } else { - // This is faster than extend for some reason. - let mut vec = Vec::with_capacity(len); - for _ in 0..len { - // Avoid generating allocation logic in push (we've allocated enough capacity). - if vec.len() == vec.capacity() { - panic!(); - } - vec.push(T::decode(encoding, reader)?); - } - Ok(vec) - } -} - -impl Encode for [T; N] { - const ENCODE_MIN: usize = T::ENCODE_MIN * N; - const ENCODE_MAX: usize = T::ENCODE_MAX.saturating_mul(N); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - encode_elements(self, encoding, writer) - } -} - -impl Decode for [T; N] { - const DECODE_MIN: usize = T::DECODE_MIN * N; - const DECODE_MAX: usize = T::DECODE_MAX.saturating_mul(N); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - // TODO find a safe way to decode an array without allocating. - // Maybe use ArrayVec, but that would require another dep. - Ok(decode_elements(N, encoding, reader)? - .try_into() - .ok() - .unwrap()) - } -} - -// Blocked TODO: https://github.com/rust-lang/rust/issues/37653 -// -// Implement faster encoding of &[u8] or more generally any &[bytemuck::Pod] that encodes the same. -impl Encode for [T] { - const ENCODE_MIN: usize = 1; - // [()] max bits is 127 (gamma of u64::MAX - 1). - const ENCODE_MAX: usize = (T::ENCODE_MAX.saturating_mul(usize::MAX)).saturating_add(127); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - self.len().encode(Gamma, writer)?; - encode_elements(self, encoding, writer) - } -} - -impl Encode for Vec { - impl_enc_same!([T]); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - self.as_slice().encode(encoding, writer) - } -} - -// Blocked TODO: https://github.com/rust-lang/rust/issues/37653 -// -// Implement faster decoding of Vec or more generally any Vec that encodes the same. -impl Decode for Vec { - const DECODE_MIN: usize = 1; - // Vec<()> max bits is 127 (gamma of u64::MAX - 1). - const DECODE_MAX: usize = (T::DECODE_MAX.saturating_mul(usize::MAX)).saturating_add(127); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let len = usize::decode(Gamma, reader)?; - guard_len::(len, encoding, reader)?; - decode_elements(len, encoding, reader) - } -} - -macro_rules! impl_iter_encode { - ($item:ty) => { - impl_enc_same!([$item]); - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - self.len().encode(Gamma, writer)?; - for t in self { - t.encode(encoding, writer)?; - } - Ok(()) - } - }; -} - -macro_rules! impl_collection { - ($collection: ident $(,$bound: ident)*) => { - impl Encode for std::collections::$collection { - impl_iter_encode!(T); - } - - impl Decode for std::collections::$collection { - impl_dec_same!(Vec); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let len = usize::decode(Gamma, reader)?; - guard_len::(len, encoding, reader)?; - - (0..len).map(|_| T::decode(encoding, reader)).collect() - } - } - } -} - -impl_collection!(BTreeSet, Ord); -impl_collection!(LinkedList); - -// Some collections can be efficiently created from a Vec such as BinaryHeap/VecDeque. -macro_rules! impl_collection_decode_from_vec { - ($collection: ident $(,$bound: ident)*) => { - impl Decode for std::collections::$collection { - impl_dec_same!(Vec); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - Ok(Vec::decode(encoding, reader)?.into()) - } - } - } -} - -impl Encode for std::collections::VecDeque { - impl_enc_same!([T]); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - self.len().encode(Gamma, writer)?; - let (a, b) = self.as_slices(); - encode_elements(a, encoding, writer)?; - encode_elements(b, encoding, writer) - } -} -impl_collection_decode_from_vec!(VecDeque); - -impl Encode for std::collections::BinaryHeap { - // TODO optimize with encode_elements(binary_heap.as_slice(), ..) once it's stable. - impl_iter_encode!(T); -} -impl_collection_decode_from_vec!(BinaryHeap, Ord); - -impl Encode for str { - impl_enc_same!([u8]); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - encoding.write_str(writer, self); - Ok(()) - } -} - -impl Encode for String { - impl_enc_same!(str); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - self.as_str().encode(encoding, writer) - } -} - -impl Decode for String { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - Ok(encoding.read_str(reader)?.to_owned()) - } -} - -impl Encode for CStr { - impl_enc_same!(str); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - encoding.write_byte_str(writer, self.to_bytes()); - Ok(()) - } -} - -impl Encode for CString { - impl_enc_same!(CStr); - - #[inline(always)] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - self.as_c_str().encode(encoding, writer) - } -} - -impl Decode for CString { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - CString::new(encoding.read_byte_str(reader)?).map_err(|_| E::Invalid("CString").e()) - } -} - -impl Encode for BTreeMap { - impl_iter_encode!((K, V)); -} - -impl Decode for BTreeMap { - impl_dec_same!(Vec<(K, V)>); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let len = usize::decode(Gamma, reader)?; - guard_len::<(K, V)>(len, encoding, reader)?; - - // Collect is faster than insert for BTreeMap since it can add the items in bulk once it - // ensures they are sorted. - (0..len) - .map(|_| <(K, V)>::decode(encoding, reader)) - .collect() - } -} - -impl Encode for HashMap { - impl_iter_encode!((K, V)); -} - -impl Decode for HashMap { - impl_dec_same!(Vec<(K, V)>); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let len = usize::decode(Gamma, reader)?; - guard_len::<(K, V)>(len, encoding, reader)?; - - // Insert is faster than collect for HashMap since it only reserves size_hint / 2 in collect. - let mut map = Self::with_capacity_and_hasher(len, Default::default()); - for _ in 0..len { - let (k, v) = <(K, V)>::decode(encoding, reader)?; - map.insert(k, v); - } - Ok(map) - } -} - -impl Encode for HashSet { - impl_iter_encode!(T); -} - -impl Decode for HashSet { - impl_dec_same!(Vec); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let len = usize::decode(Gamma, reader)?; - guard_len::(len, encoding, reader)?; - - // Insert is faster than collect for HashSet since it only reserves size_hint / 2 in collect. - let mut set = Self::with_capacity_and_hasher(len, Default::default()); - for _ in 0..len { - set.insert(T::decode(encoding, reader)?); - } - Ok(set) - } -} - -macro_rules! impl_ipvx_addr { - ($addr:ident, $bytes:expr, $int:ty) => { - impl Encode for $addr { - impl_enc_const!($bytes * u8::BITS as usize); - - #[inline(always)] - fn encode(&self, _: impl Encoding, writer: &mut impl Write) -> Result<()> { - <$int>::from_le_bytes(self.octets()).encode(Fixed, writer) - } - } - - impl Decode for $addr { - impl_dec_from_enc!(); - - #[inline(always)] - fn decode(_: impl Encoding, reader: &mut impl Read) -> Result { - Ok(Self::from( - <$int as Decode>::decode(Fixed, reader)?.to_le_bytes(), - )) - } - } - }; -} - -impl_ipvx_addr!(Ipv4Addr, 4, u32); -impl_ipvx_addr!(Ipv6Addr, 16, u128); -impl_either!(IpAddr, V4, Ipv4Addr, V6, Ipv6Addr); - -macro_rules! impl_socket_addr_vx { - ($addr:ident, $ip_addr:ident, $bytes:expr $(,$extra: expr)*) => { - impl Encode for $addr { - impl_enc_const!(($bytes) * u8::BITS as usize); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - optimized_enc!(encoding, writer); - enc!(self.ip(), $ip_addr); - enc!(self.port(), u16); - end_enc!(); - Ok(()) - } - } - - impl Decode for $addr { - impl_dec_from_enc!(); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - optimized_dec!(encoding, reader); - dec!(ip, $ip_addr); - dec!(port, u16); - end_dec!(); - Ok(Self::new( - ip, - port - $(,$extra)* - )) - } - } - } -} - -impl_socket_addr_vx!(SocketAddrV4, Ipv4Addr, 4 + 2); -impl_socket_addr_vx!(SocketAddrV6, Ipv6Addr, 16 + 2, 0, 0); -impl_either!(SocketAddr, V4, SocketAddrV4, V6, SocketAddrV6); - -impl Encode for Duration { - impl_enc_const!(94); // 64 bits seconds + 30 bits nanoseconds - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - encoding.write_u128::<{ Self::ENCODE_MAX }>(writer, self.as_nanos()); - Ok(()) - } -} - -impl Decode for Duration { - impl_dec_from_enc!(); - - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - let nanos = encoding.read_u128::<{ Self::DECODE_MAX }>(reader)?; - - // Manual implementation of Duration::from_nanos since it takes a u64 instead of a u128. - const NANOS_PER_SEC: u128 = Duration::new(1, 0).as_nanos(); - let secs = (nanos / NANOS_PER_SEC) - .try_into() - .map_err(|_| E::Invalid("Duration").e())?; - Ok(Duration::new(secs, (nanos % NANOS_PER_SEC) as u32)) - } -} - -impl Encode for PhantomData { - impl_enc_const!(0); - - fn encode(&self, _: impl Encoding, _: &mut impl Write) -> Result<()> { - Ok(()) - } -} - -impl Decode for PhantomData { - impl_dec_from_enc!(); - - fn decode(_: impl Encoding, _: &mut impl Read) -> Result { - Ok(PhantomData) - } -} - -// TODO maybe Atomic*, Bound, Cell, Range, RangeInclusive, SystemTime. - -// Allows `&str` and `&[T]` to implement encode. -impl<'a, T: Encode + ?Sized> Encode for &'a T { - impl_enc_same!(T); - - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - T::encode(self, encoding, writer) - } -} - -impl Encode for () { - impl_enc_const!(0); - - fn encode(&self, _: impl Encoding, _: &mut impl Write) -> Result<()> { - Ok(()) - } -} - -impl Decode for () { - impl_dec_from_enc!(); - - fn decode(_: impl Encoding, _: &mut impl Read) -> Result { - Ok(()) - } -} - -macro_rules! impl_tuples { - ($($len:expr => ($($n:tt $name:ident)+))+) => { - $( - impl<$($name),+> Encode for ($($name,)+) - where - $($name: Encode,)+ - { - const ENCODE_MIN: usize = $(<$name>::ENCODE_MIN +)+ 0; - const ENCODE_MAX: usize = 0usize $(.saturating_add(<$name>::ENCODE_MAX))+; - - #[cfg_attr(not(debug_assertions), inline(always))] - fn encode(&self, encoding: impl Encoding, writer: &mut impl Write) -> Result<()> { - optimized_enc!(encoding, writer); - $( - enc!(self.$n, $name); - )+ - end_enc!(); - Ok(()) - } - } - - impl<$($name),+> Decode for ($($name,)+) - where - $($name: Decode,)+ - { - const DECODE_MIN: usize = $(<$name>::DECODE_MIN +)+ 0; - const DECODE_MAX: usize = 0usize $(.saturating_add(<$name>::DECODE_MAX))+; - - #[allow(non_snake_case)] - #[cfg_attr(not(debug_assertions), inline(always))] - fn decode(encoding: impl Encoding, reader: &mut impl Read) -> Result { - optimized_dec!(encoding, reader); - $( - dec!($name, $name); - )+ - end_dec!(); - Ok(($($name,)+)) - } - } - )+ - } -} - -impl_tuples! { - 1 => (0 T0) - 2 => (0 T0 1 T1) - 3 => (0 T0 1 T1 2 T2) - 4 => (0 T0 1 T1 2 T2 3 T3) - 5 => (0 T0 1 T1 2 T2 3 T3 4 T4) - 6 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5) - 7 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6) - 8 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7) - 9 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8) - 10 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9) - 11 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10) - 12 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11) - 13 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12) - 14 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13) - 15 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14) - 16 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15) -} - -#[cfg(all(test, not(miri)))] -mod tests { - use paste::paste; - use std::net::*; - use std::time::Duration; - use test::{black_box, Bencher}; - - macro_rules! bench { - ($name:ident, $t:ty, $v:expr) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let v = vec![$v; 1000]; - let _ = buffer.encode(&v).unwrap(); - - b.iter(|| { - let v = black_box(v.as_slice()); - let bytes = buffer.encode(v).unwrap(); - black_box(bytes); - }) - } - - #[bench] - fn [](b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let v = vec![$v; 1000]; - - let bytes = buffer.encode(&v).unwrap().to_vec(); - let decoded: Vec<$t> = buffer.decode(&bytes).unwrap(); - assert_eq!(v, decoded); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::>(bytes).unwrap()) - }) - } - } - }; - } - - bench!(char, char, 'a'); // TODO bench on random chars. - bench!(duration, Duration, Duration::new(123, 456)); - bench!(ipv4_addr, Ipv4Addr, Ipv4Addr::from([1, 2, 3, 4])); - bench!(ipv6_addr, Ipv6Addr, Ipv6Addr::from([4; 16])); - bench!( - socket_addr_v4, - SocketAddrV4, - SocketAddrV4::new(Ipv4Addr::from([1, 2, 3, 4]), 1234) - ); - bench!( - socket_addr_v6, - SocketAddrV6, - SocketAddrV6::new(Ipv6Addr::from([4; 16]), 1234, 0, 0) - ); - - macro_rules! bench_map_or_set { - ($name:ident, $t:ty, $f:expr) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let v = $t::from_iter((0u16..1000).map($f)); - let _ = buffer.encode(&v).unwrap(); - - b.iter(|| { - let v = black_box(&v); - let bytes = buffer.encode(v).unwrap(); - black_box(bytes); - }) - } - - #[bench] - fn [](b: &mut Bencher) { - let mut buffer = crate::Buffer::new(); - let v = $t::from_iter((0u16..1000).map($f)); - - let bytes = buffer.encode(&v).unwrap().to_vec(); - let decoded: $t = buffer.decode(&bytes).unwrap(); - assert_eq!(v, decoded); - - b.iter(|| { - let bytes = black_box(bytes.as_slice()); - black_box(buffer.decode::<$t>(bytes).unwrap()) - }) - } - } - }; - } - - macro_rules! bench_map { - ($name:ident, $t:ident) => { - bench_map_or_set!($name, std::collections::$t::, |v| (v, v)); - }; - } - bench_map!(btree_map, BTreeMap); - bench_map!(hash_map, HashMap); - - macro_rules! bench_set { - ($name:ident, $t:ident) => { - bench_map_or_set!($name, std::collections::$t::, |v| v); - }; - } - bench_set!(btree_set, BTreeSet); - bench_set!(hash_set, HashSet); -} diff --git a/src/coder.rs b/src/coder.rs new file mode 100644 index 0000000..885534c --- /dev/null +++ b/src/coder.rs @@ -0,0 +1,106 @@ +use crate::fast::VecImpl; +use std::mem::MaybeUninit; +use std::num::NonZeroUsize; + +pub type Result = std::result::Result; + +pub trait Buffer { + /// Convenience function for `collect_into`. + fn collect(&mut self) -> Vec { + let mut vec = vec![]; + self.collect_into(&mut vec); + vec + } + + /// Collects the buffer into a single `Vec`. This clears the buffer. + fn collect_into(&mut self, out: &mut Vec); + + /// Reserves space for `additional` calls to [`Encoder::encode`]. May be a no-op. Takes a NonZeroUsize to avoid + /// useless calls. + fn reserve(&mut self, additional: NonZeroUsize); +} + +/// Iterators passed to [`Encoder::encode_vectored`] must have length <= this. +pub const MAX_VECTORED_CHUNK: usize = 64; + +pub trait Encoder: Buffer + Default { + /// Returns a `VecImpl` if `T` is a type that can be encoded by copying. + #[inline(always)] + fn as_primitive(&mut self) -> Option<&mut VecImpl> + where + T: Sized, + { + None + } + + /// Encodes a single value. Can't error since anything can be encoded. + fn encode(&mut self, t: &T); + + /// Calls [`Self::encode`] once for every item in `i`. Only use this with **FAST** iterators. + // #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) + where + T: 'a, + { + for t in i { + self.encode(t); + } + } +} + +pub trait View<'a> { + /// Reads `length` items out of `input` provisioning `length` calls to [`Decoder::decode`]. This overwrites the view. + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()>; +} + +pub trait Decoder<'a, T>: View<'a> + Default { + /// Returns a `Some(ptr)` to the current element if it can be decoded by copying. + #[inline(always)] + fn as_primitive_ptr(&self) -> Option<*const u8> { + None + } + + /// Assuming [`Self::as_primitive_ptr`] returns `Some(ptr)`, this advances `ptr` by `n`. + /// # Safety + /// All advances and decodes must not pass `Self::populate(_, length)`. + unsafe fn as_primitive_advance(&mut self, n: usize) { + let _ = n; + unreachable!(); + } + + /// Decodes a single value. Can't error since `View::populate` has already validated the input. + fn decode(&mut self) -> T; + + /// [`Self::decode`] without redundant copies. Only downside is panics will leak the value. + /// The only panics out of our control are Hash/Ord/PartialEq for BinaryHeap/BTreeMap/HashMap. + /// E.g. if a user PartialEq panics we will leak some memory which is an acceptable tradeoff. + /// TODO make this required and add default impl for Self::decode. + #[inline(always)] + fn decode_in_place(&mut self, out: &mut MaybeUninit) { + out.write(self.decode()); + } +} + +macro_rules! decode_from_in_place { + ($t:ty) => { + #[inline(always)] + fn decode(&mut self) -> $t { + let mut out = std::mem::MaybeUninit::uninit(); + self.decode_in_place(&mut out); + unsafe { out.assume_init() } + } + }; +} +pub(crate) use decode_from_in_place; + +#[doc(hidden)] +#[macro_export] +macro_rules! __private_uninit_field { + ($uninit:ident.$field:tt:$field_ty:ty) => { + unsafe { + &mut *(std::ptr::addr_of_mut!((*$uninit.as_mut_ptr()).$field) + as *mut std::mem::MaybeUninit<$field_ty>) + } + }; +} +pub use __private_uninit_field as uninit_field; diff --git a/src/consume.rs b/src/consume.rs new file mode 100644 index 0000000..0919652 --- /dev/null +++ b/src/consume.rs @@ -0,0 +1,51 @@ +use crate::coder::Result; +use crate::error::{err, error}; + +/// Attempts to claim `bytes` bytes out of `input`. +pub fn consume_bytes<'a>(input: &mut &'a [u8], bytes: usize) -> Result<&'a [u8]> { + if bytes > input.len() { + return err("EOF"); + } + let (bytes, remaining) = input.split_at(bytes); + *input = remaining; + Ok(bytes) +} + +/// Attempts to claim one byte out of `input`. +pub fn consume_byte(input: &mut &[u8]) -> Result { + Ok(consume_bytes(input, 1)?[0]) +} + +/// Like `consume_bytes` but consumes `[u8; N]` instead of `u8`. +pub fn consume_byte_arrays<'a, const N: usize>( + input: &mut &'a [u8], + length: usize, +) -> Result<&'a [[u8; N]]> { + // Avoid * overflow by using / instead. + if input.len() / N < length { + return err("EOF"); + } + + // Safety: input.len() >= mid since we've checked it above. + let mid = length * N; + let (bytes, remaining) = unsafe { (input.get_unchecked(..mid), input.get_unchecked(mid..)) }; + + *input = remaining; + Ok(bytemuck::cast_slice(bytes)) +} + +/// Check if `input` is empty or return error. +pub fn expect_eof(input: &[u8]) -> Result<()> { + if cfg!(not(fuzzing)) && !input.is_empty() { + err("Expected EOF") + } else { + Ok(()) + } +} + +/// Returns `Ok(length * x)` if it does not overflow. +pub fn mul_length(length: usize, x: usize) -> Result { + length + .checked_mul(x) + .ok_or_else(|| error("length overflow")) +} diff --git a/src/derive/array.rs b/src/derive/array.rs new file mode 100644 index 0000000..eccdb9c --- /dev/null +++ b/src/derive/array.rs @@ -0,0 +1,72 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::consume::mul_length; +use crate::derive::{Decode, Encode}; +use std::mem::MaybeUninit; +use std::num::NonZeroUsize; + +#[derive(Debug)] +pub struct ArrayEncoder(T::Encoder); + +// Can't derive since it would bound T: Default. +impl Default for ArrayEncoder { + fn default() -> Self { + Self(Default::default()) + } +} + +impl Encoder<[T; N]> for ArrayEncoder { + #[inline(always)] + fn encode(&mut self, array: &[T; N]) { + // TODO use encode_vectored if N is large enough. + for v in array { + self.0.encode(v); + } + } +} + +impl Buffer for ArrayEncoder { + fn collect_into(&mut self, out: &mut Vec) { + self.0.collect_into(out); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve( + additional + .checked_mul(NonZeroUsize::new(N).unwrap()) + .unwrap(), + ); + } +} + +#[derive(Debug)] +pub struct ArrayDecoder<'a, T: Decode<'a>, const N: usize>(T::Decoder); + +// Can't derive since it would bound T: Default. +impl<'a, T: Decode<'a>, const N: usize> Default for ArrayDecoder<'a, T, N> { + fn default() -> Self { + Self(Default::default()) + } +} + +impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayDecoder<'a, T, N> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + let length = mul_length(length, N)?; + self.0.populate(input, length) + } +} + +impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, [T; N]> for ArrayDecoder<'a, T, N> { + #[inline(always)] + fn decode(&mut self) -> [T; N] { + std::array::from_fn(|_| self.0.decode()) + } + + #[inline(always)] + fn decode_in_place(&mut self, out: &mut MaybeUninit<[T; N]>) { + // Safety: Equivalent to nightly MaybeUninit::transpose. + let out = unsafe { &mut *(out.as_mut_ptr() as *mut [MaybeUninit; N]) }; + for out in out { + self.0.decode_in_place(out) + } + } +} diff --git a/src/derive/empty.rs b/src/derive/empty.rs new file mode 100644 index 0000000..d10642f --- /dev/null +++ b/src/derive/empty.rs @@ -0,0 +1,27 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use std::marker::PhantomData; +use std::num::NonZeroUsize; + +#[derive(Debug, Default)] +pub struct EmptyCoder; + +impl Encoder> for EmptyCoder { + fn encode(&mut self, _: &PhantomData) {} +} + +impl Buffer for EmptyCoder { + fn collect_into(&mut self, _: &mut Vec) {} + fn reserve(&mut self, _: NonZeroUsize) {} +} + +impl<'a> View<'a> for EmptyCoder { + fn populate(&mut self, _: &mut &'a [u8], _: usize) -> Result<()> { + Ok(()) + } +} + +impl<'a, T> Decoder<'a, PhantomData> for EmptyCoder { + fn decode(&mut self) -> PhantomData { + PhantomData + } +} diff --git a/src/derive/impls.rs b/src/derive/impls.rs new file mode 100644 index 0000000..98e731d --- /dev/null +++ b/src/derive/impls.rs @@ -0,0 +1,300 @@ +use crate::bool::{BoolDecoder, BoolEncoder}; +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::derive::array::{ArrayDecoder, ArrayEncoder}; +use crate::derive::empty::EmptyCoder; +use crate::derive::map::{MapDecoder, MapEncoder}; +use crate::derive::option::{OptionDecoder, OptionEncoder}; +use crate::derive::smart_ptr::{DerefEncoder, FromDecoder}; +use crate::derive::vec::{VecDecoder, VecEncoder}; +use crate::derive::{Decode, Encode}; +use crate::f32::{F32Decoder, F32Encoder}; +use crate::int::{CheckedIntDecoder, IntDecoder, IntEncoder}; +use crate::str::{StrDecoder, StrEncoder}; +use std::collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque}; +use std::hash::{BuildHasher, Hash}; +use std::marker::PhantomData; +use std::mem::MaybeUninit; +use std::num::*; + +macro_rules! impl_both { + ($t:ty, $encoder:ident, $decoder:ident) => { + impl Encode for $t { + type Encoder = $encoder; + } + impl<'a> Decode<'a> for $t { + type Decoder = $decoder<'a>; + } + }; +} +impl_both!(bool, BoolEncoder, BoolDecoder); +impl_both!(f32, F32Encoder, F32Decoder); +impl_both!(String, StrEncoder, StrDecoder); + +macro_rules! impl_int { + ($($a:ty => $b:ty),+) => { + $( + impl Encode for $a { + type Encoder = IntEncoder<$b>; + } + + impl<'a> Decode<'a> for $a { + type Decoder = IntDecoder<'a, $b>; + } + )+ + } +} +impl_int!(u8 => u8, u16 => u16, u32 => u32, u64 => u64, u128 => u128); +impl_int!(i8 => u8, i16 => u16, i32 => u32, i64 => u64, i128 => u128); +impl_int!(f64 => u64); // Totally an int... + +macro_rules! impl_checked_int { + ($($a:ty => $b:ty),+) => { + $( + impl Encode for $a { + type Encoder = IntEncoder<$b>; + } + impl<'a> Decode<'a> for $a { + type Decoder = CheckedIntDecoder<'a, $a, $b>; + } + )+ + } +} +impl_checked_int!(NonZeroU8 => u8, NonZeroU16 => u16, NonZeroU32 => u32, NonZeroU64 => u64, NonZeroU128 => u128); +impl_checked_int!(NonZeroI8 => u8, NonZeroI16 => u16, NonZeroI32 => u32, NonZeroI64 => u64, NonZeroI128 => u128); +impl_checked_int!(char => u32); + +macro_rules! impl_t { + ($t:ident, $encoder:ident, $decoder:ident) => { + impl Encode for $t { + type Encoder = $encoder; + } + impl<'a, T: Decode<'a>> Decode<'a> for $t { + type Decoder = $decoder<'a, T>; + } + }; +} +impl_t!(LinkedList, VecEncoder, VecDecoder); +impl_t!(Option, OptionEncoder, OptionDecoder); +impl_t!(Vec, VecEncoder, VecDecoder); +impl_t!(VecDeque, VecEncoder, VecDecoder); + +macro_rules! impl_smart_ptr { + ($(::$ptr: ident)*) => { + impl Encode for $(::$ptr)* { + type Encoder = DerefEncoder; + } + + impl<'a, T: Decode<'a>> Decode<'a> for $(::$ptr)* { + type Decoder = FromDecoder<'a, T>; + } + + impl<'a, T: Decode<'a>> Decode<'a> for $(::$ptr)*<[T]> { + // TODO avoid Vec allocation for Rc<[T]> and Arc<[T]>. + type Decoder = FromDecoder<'a, Vec>; + } + + impl<'a> Decode<'a> for $(::$ptr)* { + // TODO avoid String allocation for Rc and Arc. + type Decoder = FromDecoder<'a, String>; + } + } +} +impl_smart_ptr!(::std::boxed::Box); +impl_smart_ptr!(::std::rc::Rc); +impl_smart_ptr!(::std::sync::Arc); + +impl Encode for [T; N] { + type Encoder = ArrayEncoder; +} +impl<'a, T: Decode<'a>, const N: usize> Decode<'a> for [T; N] { + type Decoder = ArrayDecoder<'a, T, N>; +} + +// Convenience impls copied from serde etc. Makes Box work on Box<[T]>. +impl Encode for [T] { + type Encoder = VecEncoder; +} +impl Encode for str { + type Encoder = StrEncoder; +} + +// Partial zero copy deserialization like serde. +impl Encode for &str { + type Encoder = StrEncoder; +} +impl<'a> Decode<'a> for &'a str { + type Decoder = StrDecoder<'a>; +} + +impl Encode for BinaryHeap { + type Encoder = VecEncoder; +} +impl<'a, T: Decode<'a> + Ord> Decode<'a> for BinaryHeap { + type Decoder = VecDecoder<'a, T>; +} +impl Encode for BTreeSet { + type Encoder = VecEncoder; +} +impl<'a, T: Decode<'a> + Ord> Decode<'a> for BTreeSet { + type Decoder = VecDecoder<'a, T>; +} +impl Encode for HashSet { + type Encoder = VecEncoder; +} +impl<'a, T: Decode<'a> + Eq + Hash, S: BuildHasher + Default> Decode<'a> for HashSet { + type Decoder = VecDecoder<'a, T>; +} + +impl Encode for BTreeMap { + type Encoder = MapEncoder; +} +impl<'a, K: Decode<'a> + Ord, V: Decode<'a>> Decode<'a> for BTreeMap { + type Decoder = MapDecoder<'a, K, V>; +} +impl Encode for HashMap { + type Encoder = MapEncoder; +} +impl<'a, K: Decode<'a> + Eq + Hash, V: Decode<'a>, S: BuildHasher + Default> Decode<'a> + for HashMap +{ + type Decoder = MapDecoder<'a, K, V>; +} + +impl Encode for PhantomData { + type Encoder = EmptyCoder; +} +impl<'a, T> Decode<'a> for PhantomData { + type Decoder = EmptyCoder; +} + +macro_rules! impl_tuples { + ($(($($n:tt $name:ident)*))+) => { + $( + #[allow(unused, clippy::unused_unit)] + const _: () = { + impl<$($name: Encode,)*> Encode for ($($name,)*) { + type Encoder = TupleEncoder<$($name,)*>; + } + + #[derive(Debug)] + pub struct TupleEncoder<$($name: Encode,)*>( + $($name::Encoder,)* + ); + + impl<$($name: Encode,)*> Default for TupleEncoder<$($name,)*> { + fn default() -> Self { + Self( + $($name::Encoder::default(),)* + ) + } + } + + impl<$($name: Encode,)*> Encoder<($($name,)*)> for TupleEncoder<$($name,)*> { + #[inline(always)] + fn encode(&mut self, t: &($($name,)*)) { + $( + self.$n.encode(&t.$n); + )* + } + + // #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) where ($($name,)*): 'a { + $( + self.$n.encode_vectored(i.clone().map(|t| &t.$n)); + )* + } + } + + impl<$($name: Encode,)*> Buffer for TupleEncoder<$($name,)*> { + fn collect_into(&mut self, out: &mut Vec) { + $( + self.$n.collect_into(out); + )* + } + + fn reserve(&mut self, length: NonZeroUsize) { + $( + self.$n.reserve(length); + )* + } + } + + impl<'a, $($name: Decode<'a>,)*> Decode<'a> for ($($name,)*) { + type Decoder = TupleDecoder<'a, $($name,)*>; + } + + #[derive(Debug)] + pub struct TupleDecoder<'a, $($name: Decode<'a>,)*>( + $($name::Decoder,)* + std::marker::PhantomData<&'a ()>, + ); + + impl<'a, $($name: Decode<'a>,)*> Default for TupleDecoder<'a, $($name,)*> { + fn default() -> Self { + Self( + $($name::Decoder::default(),)* + Default::default(), + ) + } + } + + impl<'a, $($name: Decode<'a>,)*> Decoder<'a, ($($name,)*)> for TupleDecoder<'a, $($name,)*> { + #[inline(always)] + fn decode(&mut self) -> ($($name,)*) { + ( + $(self.$n.decode(),)* + ) + } + + #[inline(always)] + fn decode_in_place(&mut self, out: &mut MaybeUninit<($($name,)*)>) { + $( + self.$n.decode_in_place(crate::coder::uninit_field!(out.$n: $name)); + )* + } + } + + impl<'a, $($name: Decode<'a>,)*> View<'a> for TupleDecoder<'a, $($name,)*> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + $( + self.$n.populate(input, length)?; + )* + Ok(()) + } + } + }; + )+ + } +} + +impl_tuples! { + () + (0 T0) + (0 T0 1 T1) + (0 T0 1 T1 2 T2) + (0 T0 1 T1 2 T2 3 T3) + (0 T0 1 T1 2 T2 3 T3 4 T4) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14) + (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15) +} + +#[cfg(test)] +mod tests { + type Tuple = (u64, u32, u8, i32, u8, u16, i8, (u8, u8, u8, u8), i8); + fn bench_data() -> Vec<(Tuple, Option)> { + crate::random_data(1000) + .into_iter() + .map(|t: Tuple| (t, None)) + .collect() + } + crate::bench_encode_decode!(tuple_vec: Vec<_>); +} diff --git a/src/derive/map.rs b/src/derive/map.rs new file mode 100644 index 0000000..c85e58c --- /dev/null +++ b/src/derive/map.rs @@ -0,0 +1,119 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::derive::{Decode, Encode}; +use crate::length::{LengthDecoder, LengthEncoder}; +use std::collections::{BTreeMap, HashMap}; +use std::hash::{BuildHasher, Hash}; +use std::num::NonZeroUsize; + +#[derive(Debug)] +pub struct MapEncoder { + lengths: LengthEncoder, + keys: K::Encoder, + values: V::Encoder, +} + +// Can't derive since it would bound K + V: Default. +impl Default for MapEncoder { + fn default() -> Self { + Self { + lengths: Default::default(), + keys: Default::default(), + values: Default::default(), + } + } +} + +impl Buffer for MapEncoder { + fn collect_into(&mut self, out: &mut Vec) { + self.lengths.collect_into(out); + self.keys.collect_into(out); + self.values.collect_into(out); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.lengths.reserve(additional); + // We don't know the lengths of the maps, so we can't reserve more. + } +} + +#[derive(Debug)] +pub struct MapDecoder<'a, K: Decode<'a>, V: Decode<'a>> { + lengths: LengthDecoder<'a>, + keys: K::Decoder, + values: V::Decoder, +} + +// Can't derive since it would bound K + V: Default. +impl<'a, K: Decode<'a>, V: Decode<'a>> Default for MapDecoder<'a, K, V> { + fn default() -> Self { + Self { + lengths: Default::default(), + keys: Default::default(), + values: Default::default(), + } + } +} + +impl<'a, K: Decode<'a>, V: Decode<'a>> View<'a> for MapDecoder<'a, K, V> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + self.lengths.populate(input, length)?; + self.keys.populate(input, self.lengths.length())?; + self.values.populate(input, self.lengths.length()) + } +} + +macro_rules! encode_body { + ($t:ty) => { + #[inline(always)] + fn encode(&mut self, map: &$t) { + let n = map.len(); + self.lengths.encode(&n); + + if let Some(n) = NonZeroUsize::new(n) { + self.keys.reserve(n); + self.values.reserve(n); + for (k, v) in map { + self.keys.encode(k); + self.values.encode(v); + } + } + } + }; +} +macro_rules! decode_body { + ($t:ty) => { + #[inline(always)] + fn decode(&mut self) -> $t { + // BTreeMap::from_iter is faster than BTreeMap::insert since it can add the items in + // bulk once it ensures they are sorted. They are about equivalent for HashMap. + (0..self.lengths.decode()) + .map(|_| (self.keys.decode(), self.values.decode())) + .collect() + } + }; +} + +impl Encoder> for MapEncoder { + encode_body!(BTreeMap); +} +impl<'a, K: Decode<'a> + Ord, V: Decode<'a>> Decoder<'a, BTreeMap> for MapDecoder<'a, K, V> { + decode_body!(BTreeMap); +} + +impl Encoder> for MapEncoder { + encode_body!(HashMap); +} +impl<'a, K: Decode<'a> + Eq + Hash, V: Decode<'a>, S: BuildHasher + Default> + Decoder<'a, HashMap> for MapDecoder<'a, K, V> +{ + decode_body!(HashMap); +} + +#[cfg(test)] +mod test { + use std::collections::{BTreeMap, HashMap}; + fn bench_data>() -> T { + (0..=255).map(|k| (k, 0)).collect() + } + crate::bench_encode_decode!(btree_map: BTreeMap<_, _>, hash_map: HashMap<_, _>); +} diff --git a/src/derive/mod.rs b/src/derive/mod.rs new file mode 100644 index 0000000..d5167bd --- /dev/null +++ b/src/derive/mod.rs @@ -0,0 +1,219 @@ +use crate::coder::{Buffer, Decoder, Encoder, View}; +use crate::consume::expect_eof; +use crate::Error; +use std::num::NonZeroUsize; + +mod array; +mod empty; +mod impls; +mod map; +mod option; +mod smart_ptr; +mod variant; +pub(crate) mod vec; + +// For derive macro. +#[cfg(feature = "derive")] +#[doc(hidden)] +pub mod __private { + pub use crate::coder::{uninit_field, Buffer, Decoder, Encoder, Result, View}; + pub use crate::derive::variant::{VariantDecoder, VariantEncoder}; + pub use crate::derive::{Decode, Encode}; + pub fn invalid_enum_variant() -> Result { + crate::error::err("invalid enum variant") + } +} + +/// A type which can be encoded to bytes with [`encode`]. +/// +/// Use `#[derive(Encode)]` to implement. +pub trait Encode { + #[doc(hidden)] + type Encoder: Encoder; +} + +/// A type which can be decoded from bytes with [`decode`]. +/// +/// Use `#[derive(Decode)]` to implement. +pub trait Decode<'a>: Sized { + #[doc(hidden)] + type Decoder: Decoder<'a, Self>; +} + +/// A type which can be decoded without borrowing any bytes from the input. +/// +/// This type is a shorter version of `for<'de> Decode<'de>`. +pub trait DecodeOwned: for<'de> Decode<'de> {} +impl DecodeOwned for T where T: for<'de> Decode<'de> {} + +/// Encodes a `T:` [`Encode`] into a [`Vec`]. +/// +/// **Warning:** The format is subject to change between major versions. +pub fn encode(t: &T) -> Vec { + let mut encoder = T::Encoder::default(); + encoder.reserve(NonZeroUsize::new(1).unwrap()); + + #[inline(never)] + fn encode_inner(encoder: &mut T::Encoder, t: &T) { + encoder.encode(t); + } + encode_inner(&mut encoder, t); + encoder.collect() +} + +/// Decodes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Decode`]. +/// +/// **Warning:** The format is subject to change between major versions. +pub fn decode<'a, T: Decode<'a>>(mut bytes: &'a [u8]) -> Result { + let mut decoder = T::Decoder::default(); + decoder.populate(&mut bytes, 1)?; + expect_eof(bytes)?; + #[inline(never)] + fn decode_inner<'a, T: Decode<'a>>(decoder: &mut T::Decoder) -> T { + decoder.decode() + } + Ok(decode_inner(&mut decoder)) +} + +/// A buffer for reusing allocations between multiple calls to [`EncodeBuffer::encode`]. +pub struct EncodeBuffer { + encoder: T::Encoder, + out: Vec, +} + +// #[derive(Default)] bounds T: Default. +impl Default for EncodeBuffer { + fn default() -> Self { + Self { + encoder: Default::default(), + out: Default::default(), + } + } +} + +impl EncodeBuffer { + /// Encodes a `T:` [`Encode`] into a [`&[u8]`][`prim@slice`]. + /// + /// Can reuse allocations when called multiple times on the same [`EncodeBuffer`]. + /// + /// **Warning:** The format is subject to change between major versions. + pub fn encode<'a>(&'a mut self, t: &T) -> &'a [u8] { + // TODO dedup with encode. + self.encoder.reserve(NonZeroUsize::new(1).unwrap()); + #[inline(never)] + fn encode_inner(encoder: &mut T::Encoder, t: &T) { + encoder.encode(t); + } + encode_inner(&mut self.encoder, t); + self.out.clear(); + self.encoder.collect_into(&mut self.out); + self.out.as_slice() + } +} + +/// A buffer for reusing allocations between multiple calls to [`DecodeBuffer::decode`]. +/// +/// TODO don't bound [`DecodeBuffer`] to decode's `&'a [u8]`. +pub struct DecodeBuffer<'a, T: Decode<'a>>(>::Decoder); + +impl<'a, T: Decode<'a>> Default for DecodeBuffer<'a, T> { + fn default() -> Self { + Self(Default::default()) + } +} + +impl<'a, T: Decode<'a>> DecodeBuffer<'a, T> { + /// Decodes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Decode`]. + /// + /// Can reuse allocations when called multiple times on the same [`DecodeBuffer`]. + /// + /// **Warning:** The format is subject to change between major versions. + pub fn decode(&mut self, mut bytes: &'a [u8]) -> Result { + // TODO dedup with decode. + self.0.populate(&mut bytes, 1)?; + expect_eof(bytes)?; + #[inline(never)] + fn decode_inner<'a, T: Decode<'a>>(decoder: &mut T::Decoder) -> T { + decoder.decode() + } + let ret = decode_inner(&mut self.0); + Ok(ret) + } +} + +#[cfg(test)] +mod tests { + use crate::{Decode, Encode}; + + #[test] + fn decode() { + macro_rules! test { + ($v:expr, $t:ty) => { + let encoded = super::encode::<$t>(&$v); + println!("{:<24} {encoded:?}", stringify!($t)); + assert_eq!($v, super::decode::<$t>(&encoded).unwrap()); + }; + } + + test!(("abc", "123"), (&str, &str)); + test!(Vec::>::new(), Vec>); + test!(vec![None, Some(1), None], Vec>); + } + + #[derive(Encode, Decode)] + enum Never {} + + #[derive(Encode, Decode)] + enum One { + A(u8), + } + + #[derive(Encode, Decode)] + enum Two { + A(u8), + B(i8), + } + + #[derive(Encode, Decode)] + struct TupleStruct(u8, i8); + + #[derive(Encode, Decode)] + struct Generic(T); + + #[derive(Encode, Decode)] + struct GenericManual(#[bitcode(bound_type = "T")] T); + + #[derive(Encode, Decode)] + struct GenericWhere(A, B) + where + A: From; + + #[derive(Encode, Decode)] + struct Lifetime<'a>(&'a str); + + #[derive(Encode, Decode)] + struct LifetimeWhere<'a, 'b>(&'a str, &'b str) + where + 'a: 'b; + + #[derive(Encode, Decode)] + struct ConstGeneric([u8; N]); + + #[derive(Encode, Decode)] + struct Empty; + + #[derive(Encode, Decode)] + struct AssociatedConst([u8; Self::N]); + impl AssociatedConst { + const N: usize = 1; + } + + #[derive(Encode, Decode)] + struct AssociatedConstTrait([u8; ::N]); + trait Trait { + const N: usize; + } + impl Trait for AssociatedConstTrait { + const N: usize = 1; + } +} diff --git a/src/derive/option.rs b/src/derive/option.rs new file mode 100644 index 0000000..368749a --- /dev/null +++ b/src/derive/option.rs @@ -0,0 +1,129 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View, MAX_VECTORED_CHUNK}; +use crate::derive::variant::{VariantDecoder, VariantEncoder}; +use crate::derive::{Decode, Encode}; +use crate::fast::{FastArrayVec, PushUnchecked}; +use std::num::NonZeroUsize; + +#[derive(Debug)] +pub struct OptionEncoder { + variants: VariantEncoder<2>, + some: T::Encoder, +} + +// Can't derive since it would bound T: Default. +impl Default for OptionEncoder { + fn default() -> Self { + Self { + variants: Default::default(), + some: Default::default(), + } + } +} + +impl Encoder> for OptionEncoder { + #[inline(always)] + fn encode(&mut self, t: &Option) { + self.variants.encode(&(t.is_some() as u8)); + if let Some(t) = t { + self.some.reserve(NonZeroUsize::new(1).unwrap()); + self.some.encode(t); + } + } + + fn encode_vectored<'a>(&mut self, i: impl Iterator> + Clone) + where + Option: 'a, + { + // Types with many vectorized encoders benefit from a &[&T] since encode_vectorized is still + // faster even with the extra indirection. TODO vectored encoder count >= 8 instead of size_of. + if std::mem::size_of::() >= 64 { + let mut uninit = std::mem::MaybeUninit::uninit(); + let mut refs = FastArrayVec::<_, MAX_VECTORED_CHUNK>::new(&mut uninit); + + for t in i { + self.variants.encode(&(t.is_some() as u8)); + if let Some(t) = t { + // Safety: Even if all `Some` won't write more than MAX_VECTORED_CHUNK elements. + unsafe { refs.push_unchecked(t) }; + } + } + + let refs = refs.as_slice(); + let Some(some_count) = NonZeroUsize::new(refs.len()) else { + return; + }; + self.some.reserve(some_count); + self.some.encode_vectored(refs.iter().copied()); + } else { + let mut some_count = 0; + for t in i.clone() { + let is_some = t.is_some() as u8; + some_count += is_some as usize; + self.variants.encode(&is_some); + } + + let Some(some_sum) = NonZeroUsize::new(some_count) else { + return; + }; + self.some.reserve(some_sum); + for t in i.flatten() { + self.some.encode(t); + } + } + } +} + +impl Buffer for OptionEncoder { + fn collect_into(&mut self, out: &mut Vec) { + self.variants.collect_into(out); + self.some.collect_into(out); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.variants.reserve(additional); + // We don't know how many are Some, so we can't reserve more. + } +} + +#[derive(Debug)] +pub struct OptionDecoder<'a, T: Decode<'a>> { + variants: VariantDecoder<'a, 2, false>, + some: T::Decoder, +} + +// Can't derive since it would bound T: Default. +impl<'a, T: Decode<'a>> Default for OptionDecoder<'a, T> { + fn default() -> Self { + Self { + variants: Default::default(), + some: Default::default(), + } + } +} + +impl<'a, T: Decode<'a>> View<'a> for OptionDecoder<'a, T> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + self.variants.populate(input, length)?; + self.some.populate(input, self.variants.length(1)) + } +} + +impl<'a, T: Decode<'a>> Decoder<'a, Option> for OptionDecoder<'a, T> { + #[inline(always)] + fn decode(&mut self) -> Option { + if self.variants.decode() != 0 { + Some(self.some.decode()) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + #[rustfmt::skip] + fn bench_data() -> Vec> { + crate::random_data(1000) + } + crate::bench_encode_decode!(option_vec: Vec<_>); +} diff --git a/src/derive/smart_ptr.rs b/src/derive/smart_ptr.rs new file mode 100644 index 0000000..323cfd1 --- /dev/null +++ b/src/derive/smart_ptr.rs @@ -0,0 +1,76 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::derive::{Decode, Encode}; +use std::num::NonZeroUsize; +use std::ops::Deref; + +pub struct DerefEncoder(T::Encoder); + +// Can't derive since it would bound T: Default. +impl Default for DerefEncoder { + fn default() -> Self { + Self(Default::default()) + } +} + +impl, T: Encode + ?Sized> Encoder for DerefEncoder { + #[inline(always)] + fn encode(&mut self, t: &D) { + self.0.encode(t) + } +} + +impl Buffer for DerefEncoder { + fn collect_into(&mut self, out: &mut Vec) { + self.0.collect_into(out); + } + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve(additional); + } +} + +/// Decodes a `T` and then converts it with [`From`]. For `T` -> `Box` and `Vec` -> `Box<[T]>`. +#[derive(Debug)] +pub struct FromDecoder<'a, T: Decode<'a>>(T::Decoder); + +// Can't derive since it would bound T: Default. +impl<'a, T: Decode<'a>> Default for FromDecoder<'a, T> { + fn default() -> Self { + Self(Default::default()) + } +} + +impl<'a, T: Decode<'a>> View<'a> for FromDecoder<'a, T> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + self.0.populate(input, length) + } +} + +impl<'a, F: From, T: Decode<'a>> Decoder<'a, F> for FromDecoder<'a, T> { + #[inline(always)] + fn decode(&mut self) -> F { + F::from(self.0.decode()) + } +} + +#[cfg(test)] +mod tests { + use crate::{decode, encode}; + + #[test] + fn box_() { + let v = Box::new(123u8); + assert_eq!(decode::>(&encode(&v)).unwrap(), v); + } + + #[test] + fn box_slice() { + let v = vec![123u8].into_boxed_slice(); + assert_eq!(decode::>(&encode(&v)).unwrap(), v); + } + + #[test] + fn box_str() { + let v = "box".to_string().into_boxed_str(); + assert_eq!(decode::>(&encode(&v)).unwrap(), v); + } +} diff --git a/src/derive/variant.rs b/src/derive/variant.rs new file mode 100644 index 0000000..f1a12c5 --- /dev/null +++ b/src/derive/variant.rs @@ -0,0 +1,139 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; +use crate::pack::{pack_bytes_less_than, unpack_bytes_less_than}; +use std::num::NonZeroUsize; + +#[derive(Debug, Default)] +pub struct VariantEncoder(VecImpl); + +impl Encoder for VariantEncoder { + #[inline(always)] + fn encode(&mut self, v: &u8) { + unsafe { self.0.push_unchecked(*v) }; + } +} + +impl Buffer for VariantEncoder { + fn collect_into(&mut self, out: &mut Vec) { + assert!(N >= 2); + pack_bytes_less_than::(self.0.as_slice(), out); + self.0.clear(); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve(additional.get()) + } +} + +#[derive(Debug)] +pub struct VariantDecoder<'a, const N: usize, const C_STYLE: bool> { + variants: CowSlice<'a, u8>, + histogram: [usize; N], // Not required if C_STYLE. TODO don't reserve space for it. +} + +// [(); N] doesn't implement Default. +impl Default for VariantDecoder<'_, N, C_STYLE> { + fn default() -> Self { + Self { + variants: Default::default(), + histogram: std::array::from_fn(|_| 0), + } + } +} + +// C style enums don't require length, so we can skip making a histogram for them. +impl<'a, const N: usize> VariantDecoder<'a, N, false> { + pub fn length(&self, variant_index: u8) -> usize { + self.histogram[variant_index as usize] + } +} + +impl<'a, const N: usize, const C_STYLE: bool> View<'a> for VariantDecoder<'a, N, C_STYLE> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + assert!(N >= 2); + if C_STYLE { + unpack_bytes_less_than::(input, length, &mut self.variants)?; + } else { + self.histogram = unpack_bytes_less_than::(input, length, &mut self.variants)?; + } + Ok(()) + } +} + +impl<'a, const N: usize, const C_STYLE: bool> Decoder<'a, u8> for VariantDecoder<'a, N, C_STYLE> { + // Guaranteed to output numbers less than N. + #[inline(always)] + fn decode(&mut self) -> u8 { + unsafe { self.variants.mut_slice().next_unchecked() } + } +} + +#[cfg(test)] +mod tests { + use crate::{decode, encode}; + + #[allow(unused)] + #[test] + fn test_c_style_enum() { + #[derive(crate::Encode, crate::Decode)] + enum Enum1 { + A, + B, + C, + D, + E, + F, + } + #[derive(crate::Decode)] + enum Enum2 { + A, + B, + C, + D, + E, + } + // 5 and 6 element enums serialize the same, so we can use them to test variant bounds checking. + assert!(matches!(decode(&encode(&Enum1::A)), Ok(Enum2::A))); + assert!(decode::(&encode(&Enum1::F)).is_err()); + assert!(matches!(decode(&encode(&Enum1::F)), Ok(Enum1::F))); + } + + #[allow(unused)] + #[test] + fn test_rust_style_enum() { + #[derive(crate::Encode, crate::Decode)] + enum Enum1 { + A(u8), + B, + C, + D, + E, + F, + } + #[derive(crate::Decode)] + enum Enum2 { + A(u8), + B, + C, + D, + E, + } + // 5 and 6 element enums serialize the same, so we can use them to test variant bounds checking. + assert!(matches!(decode(&encode(&Enum1::A(1))), Ok(Enum2::A(1)))); + assert!(decode::(&encode(&Enum1::F)).is_err()); + assert!(matches!(decode(&encode(&Enum1::F)), Ok(Enum1::F))); + } + + #[derive(Debug, PartialEq, crate::Encode, crate::Decode)] + enum BoolEnum { + True, + False, + } + fn bench_data() -> Vec { + crate::random_data(1000) + .into_iter() + .map(|v| if v { BoolEnum::True } else { BoolEnum::False }) + .collect() + } + crate::bench_encode_decode!(bool_enum_vec: Vec<_>); +} diff --git a/src/derive/vec.rs b/src/derive/vec.rs new file mode 100644 index 0000000..0f773c0 --- /dev/null +++ b/src/derive/vec.rs @@ -0,0 +1,394 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View, MAX_VECTORED_CHUNK}; +use crate::derive::{Decode, Encode}; +use crate::length::{LengthDecoder, LengthEncoder}; +use std::collections::{BTreeSet, BinaryHeap, HashSet, LinkedList, VecDeque}; +use std::hash::{BuildHasher, Hash}; +use std::mem::MaybeUninit; +use std::num::NonZeroUsize; +use std::ptr::NonNull; + +#[derive(Debug)] +pub struct VecEncoder { + // pub(crate) for arrayvec.rs + pub(crate) lengths: LengthEncoder, + pub(crate) elements: T::Encoder, + vectored_impl: Option>, +} + +// Can't derive since it would bound T: Default. +impl Default for VecEncoder { + fn default() -> Self { + Self { + lengths: Default::default(), + elements: Default::default(), + vectored_impl: Default::default(), + } + } +} + +impl Buffer for VecEncoder { + fn collect_into(&mut self, out: &mut Vec) { + self.lengths.collect_into(out); + self.elements.collect_into(out); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.lengths.reserve(additional); + // We don't know the lengths of the vectors, so we can't reserve more. + } +} + +/// Copies `N` or `n` bytes from `src` to `dst` depending on if `src` lies within a memory page. +/// https://stackoverflow.com/questions/37800739/is-it-safe-to-read-past-the-end-of-a-buffer-within-the-same-page-on-x86-and-x64 +/// Safety: Same as [`copy_nonoverlapping_unaligned`] but with the additional requirements that +/// `n != 0 && n <= N` and `dst` has room for a `[T; N]`. +/// Is a macro instead of an `#[inline(always)] fn` because it optimizes better. +macro_rules! unsafe_wild_copy { + // pub unsafe fn wild_copy(src: *const T, dst: *mut T, n: usize) { + ([$T:ident; $N:ident], $src:ident, $dst:ident, $n:ident) => { + debug_assert!($n != 0 && $n <= $N); + + let page_size = 4096; + let read_size = std::mem::size_of::<[$T; $N]>(); + let within_page = $src as usize & (page_size - 1) < (page_size - read_size) && cfg!(all( + // Miri doesn't like this. + not(miri), + // cargo fuzz's memory sanitizer complains about buffer overrun. + // Without nightly we can't detect memory sanitizers, so we check debug_assertions. + not(debug_assertions), + // x86/x86_64/aarch64 all have min page size of 4096, so reading past the end of a non-empty + // buffer won't page fault. + any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64") + )); + + if within_page { + std::ptr::write_unaligned($dst as *mut std::mem::MaybeUninit<[$T; $N]>, + std::ptr::read_unaligned($src as *const std::mem::MaybeUninit<[$T; $N]>) + ); + } else { + #[cold] + unsafe fn cold(src: *const T, dst: *mut T, n: usize) { + crate::derive::vec::copy_nonoverlapping_unaligned(src, dst, n); + } + cold($src, $dst, $n); + } + } +} +pub(crate) use unsafe_wild_copy; + +/// Equivalent to `std::ptr::copy_nonoverlapping` but neither `src` nor `dst` has to be aligned. +/// Safety: Same as [`std::ptr::copy_nonoverlapping`], but without any alignment requirements. +#[inline(always)] +pub unsafe fn copy_nonoverlapping_unaligned(src: *const T, dst: *mut T, n: usize) { + std::ptr::copy_nonoverlapping( + src as *const u8, + dst as *mut u8, + n * std::mem::size_of::(), + ); +} + +impl VecEncoder { + /// Copy fixed size slices. Much faster than memcpy. + #[inline(never)] + fn encode_vectored_max_len<'a, I: Iterator + Clone, const N: usize>( + &mut self, + i: I, + ) where + T: 'a, + { + unsafe { + let primitives = self.elements.as_primitive().unwrap(); + primitives.reserve(i.size_hint().1.unwrap() * N); + + let mut dst = primitives.end_ptr(); + if self.lengths.encode_vectored_max_len::<_, N>( + i.clone(), + #[inline(always)] + |s| { + let src = s.as_ptr(); + let n = s.len(); + // Safety: encode_vectored_max_len skips len == 0 and ensures len <= N. + // `dst` has enough space for `[T; N]` because we've reserved size_hint * N. + unsafe_wild_copy!([T; N], src, dst, n); + dst = dst.add(n); + }, + ) { + // Use fallback for impls that copy more than 64 bytes. + let size = std::mem::size_of::(); + self.vectored_impl = NonNull::new(match N { + 1 if size <= 32 => Self::encode_vectored_max_len::, + 2 if size <= 16 => Self::encode_vectored_max_len::, + 4 if size <= 8 => Self::encode_vectored_max_len::, + 8 if size <= 4 => Self::encode_vectored_max_len::, + 16 if size <= 2 => Self::encode_vectored_max_len::, + 32 if size <= 1 => Self::encode_vectored_max_len::, + _ => Self::encode_vectored_fallback::, + } as *mut ()); + let f: fn(&mut Self, i: I) = std::mem::transmute(self.vectored_impl); + f(self, i); + return; + } + primitives.set_end_ptr(dst); + } + } + + /// Fallback for when length > [`Self::encode_vectored_max_len`]'s max_len. + #[inline(never)] + fn encode_vectored_fallback<'a, I: Iterator>(&mut self, i: I) + where + T: 'a, + { + let primitives = self.elements.as_primitive().unwrap(); + self.lengths.encode_vectored_fallback(i, |s| unsafe { + let n = s.len(); + primitives.reserve(n); + let ptr = primitives.end_ptr(); + copy_nonoverlapping_unaligned(s.as_ptr(), ptr, n); + primitives.set_end_ptr(ptr.add(n)); + }); + } +} + +impl Encoder<[T]> for VecEncoder { + #[inline(always)] + fn encode(&mut self, v: &[T]) { + let n = v.len(); + self.lengths.encode(&n); + + if let Some(primitive) = self.elements.as_primitive() { + primitive.reserve(n); + unsafe { + let ptr = primitive.end_ptr(); + copy_nonoverlapping_unaligned(v.as_ptr(), ptr, n); + primitive.set_end_ptr(ptr.add(n)); + } + } else if let Some(n) = NonZeroUsize::new(n) { + self.elements.reserve(n); + // Uses chunks to keep everything in the CPU cache. TODO pick optimal chunk size. + for chunk in v.chunks(MAX_VECTORED_CHUNK) { + self.elements.encode_vectored(chunk.iter()); + } + } + } + + #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) + where + [T]: 'a, + { + if self.elements.as_primitive().is_some() { + /// Convert impl trait to named generic type. + #[inline(always)] + fn inner<'a, T: Encode + 'a, I: Iterator + Clone>( + me: &mut VecEncoder, + i: I, + ) { + // We can't set this in the Default constructor because we don't have the type I. + if me.vectored_impl.is_none() { + // Use match to avoid "use of generic parameter from outer function". + // Start at the pointer size (assumed to be 8 bytes) to not be wasteful. + me.vectored_impl = NonNull::new(match (8 / std::mem::size_of::()).max(1) { + 1 => VecEncoder::encode_vectored_max_len::, + 2 => VecEncoder::encode_vectored_max_len::, + 4 => VecEncoder::encode_vectored_max_len::, + 8 => VecEncoder::encode_vectored_max_len::, + _ => unreachable!(), + } as *mut ()); + } + let f: fn(&mut VecEncoder, i: I) = + unsafe { std::mem::transmute(me.vectored_impl) }; + f(me, i); + } + inner(self, i); + } else { + for v in i { + self.encode(v); + } + } + } +} + +#[derive(Debug)] +pub struct VecDecoder<'a, T: Decode<'a>> { + // pub(crate) for arrayvec::ArrayVec. + pub(crate) lengths: LengthDecoder<'a>, + pub(crate) elements: T::Decoder, +} + +// Can't derive since it would bound T: Default. +impl<'a, T: Decode<'a>> Default for VecDecoder<'a, T> { + fn default() -> Self { + Self { + lengths: Default::default(), + elements: Default::default(), + } + } +} + +impl<'a, T: Decode<'a>> View<'a> for VecDecoder<'a, T> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + self.lengths.populate(input, length)?; + self.elements.populate(input, self.lengths.length()) + } +} + +macro_rules! encode_body { + ($t:ty) => { + #[inline(always)] + fn encode(&mut self, v: &$t) { + let n = v.len(); + self.lengths.encode(&n); + if let Some(n) = NonZeroUsize::new(n) { + self.elements.reserve(n); + for v in v { + self.elements.encode(v); + } + } + } + }; +} +// Faster on some collections. +macro_rules! encode_body_internal_iteration { + ($t:ty) => { + #[inline(always)] + fn encode(&mut self, v: &$t) { + let n = v.len(); + self.lengths.encode(&n); + if let Some(n) = NonZeroUsize::new(n) { + self.elements.reserve(n); + v.iter().for_each(|v| self.elements.encode(v)); + } + } + }; +} +macro_rules! decode_body { + ($t:ty) => { + #[inline(always)] + fn decode(&mut self) -> $t { + // - BTreeSet::from_iter is faster than BTreeSet::insert (see comment in map.rs). + // - HashSet is about the same either way. + // - Vec::from_iter is slower (so it doesn't use this). + (0..self.lengths.decode()) + .map(|_| self.elements.decode()) + .collect() + } + }; +} + +impl Encoder> for VecEncoder { + #[inline(always)] + fn encode(&mut self, v: &Vec) { + self.encode(v.as_slice()) + } + + #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator> + Clone) + where + Vec: 'a, + { + self.encode_vectored(i.map(Vec::as_slice)); + } +} +impl<'a, T: Decode<'a>> Decoder<'a, Vec> for VecDecoder<'a, T> { + crate::coder::decode_from_in_place!(Vec); + + #[inline(always)] + fn decode_in_place(&mut self, out: &mut MaybeUninit>) { + let length = self.lengths.decode(); + // Fast path, avoid memcpy and mutating len. + if length == 0 { + out.write(Vec::new()); + return; + } + + let v = out.write(Vec::with_capacity(length)); + if let Some(primitive) = self.elements.as_primitive_ptr() { + unsafe { + copy_nonoverlapping_unaligned(primitive as *const T, v.as_mut_ptr(), length); + self.elements.as_primitive_advance(length); + } + } else { + let spare = v.spare_capacity_mut(); + for i in 0..length { + let out = unsafe { spare.get_unchecked_mut(i) }; + self.elements.decode_in_place(out); + } + } + unsafe { v.set_len(length) }; + } +} + +impl Encoder> for VecEncoder { + encode_body!(BinaryHeap); // When BinaryHeap::as_slice is stable use [T] impl. +} +impl<'a, T: Decode<'a> + Ord> Decoder<'a, BinaryHeap> for VecDecoder<'a, T> { + #[inline(always)] + fn decode(&mut self) -> BinaryHeap { + let v: Vec = self.decode(); + v.into() + } +} + +impl Encoder> for VecEncoder { + encode_body!(BTreeSet); +} +impl<'a, T: Decode<'a> + Ord> Decoder<'a, BTreeSet> for VecDecoder<'a, T> { + decode_body!(BTreeSet); +} + +impl Encoder> for VecEncoder { + // Internal iteration is 1.6x faster. Interestingly this does not apply to HashMap which + // I assume is due to HashSet::iter being implemented with HashMap::keys. + encode_body_internal_iteration!(HashSet); +} +impl<'a, T: Decode<'a> + Eq + Hash, S: BuildHasher + Default> Decoder<'a, HashSet> + for VecDecoder<'a, T> +{ + decode_body!(HashSet); +} + +impl Encoder> for VecEncoder { + encode_body!(LinkedList); +} +impl<'a, T: Decode<'a>> Decoder<'a, LinkedList> for VecDecoder<'a, T> { + decode_body!(LinkedList); +} + +impl Encoder> for VecEncoder { + encode_body_internal_iteration!(VecDeque); // Internal iteration is 10x faster. +} +impl<'a, T: Decode<'a>> Decoder<'a, VecDeque> for VecDecoder<'a, T> { + #[inline(always)] + fn decode(&mut self) -> VecDeque { + let v: Vec = self.decode(); + v.into() + } +} + +#[cfg(test)] +mod test { + use std::collections::*; + fn bench_data>() -> T { + (0..=255).collect() + } + crate::bench_encode_decode!( + btree_set: BTreeSet<_>, + hash_set: HashSet<_>, + linked_list: LinkedList<_>, + vec: Vec<_>, + vec_deque: VecDeque<_> + ); + + // BinaryHeap can't use bench_encode_decode because it doesn't implement PartialEq. + #[bench] + fn bench_binary_heap_decode(b: &mut test::Bencher) { + type T = BinaryHeap; + let data: T = bench_data(); + let encoded = crate::encode(&data); + b.iter(|| { + let decoded: T = crate::decode::(&encoded).unwrap(); + debug_assert!(data.iter().eq(decoded.iter())); + decoded + }) + } +} diff --git a/src/encoding/bit_string/ascii.rs b/src/encoding/bit_string/ascii.rs deleted file mode 100644 index 67b5f4a..0000000 --- a/src/encoding/bit_string/ascii.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::encoding::bit_string::bit_utils::{pack_lsb, unpack_lsb}; -use crate::encoding::bit_string::ByteEncoding; -use crate::encoding::prelude::*; - -#[derive(Copy, Clone)] -pub struct Ascii; - -impl Ascii { - const MASK: Word = Word::from_le_bytes([0x7F; 8]); -} - -impl ByteEncoding for Ascii { - const BITS_PER_BYTE: usize = 7; - - #[inline(always)] - fn validate(word: Word, _: usize) -> bool { - word & !Self::MASK == 0 - } - - #[inline(always)] - fn pack(word: Word) -> Word { - pack_lsb::<{ Self::BITS_PER_BYTE }>(word) - } - - #[inline(always)] - fn unpack(word: Word) -> Word { - unpack_lsb::<{ Self::BITS_PER_BYTE }>(word) - } -} diff --git a/src/encoding/bit_string/ascii_lowercase.rs b/src/encoding/bit_string/ascii_lowercase.rs deleted file mode 100644 index 9cc6c02..0000000 --- a/src/encoding/bit_string/ascii_lowercase.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::encoding::bit_string::bit_utils::{pack_lsb, unpack_lsb}; -use crate::encoding::bit_string::ByteEncoding; -use crate::encoding::prelude::*; - -#[derive(Copy, Clone)] -pub struct AsciiLowercase; - -impl AsciiLowercase { - const DATA_MASK: Word = Word::from_le_bytes([0b00011111; 8]); - const SET_MASK: Word = Word::from_le_bytes([0b01100000; 8]); -} - -impl ByteEncoding for AsciiLowercase { - const BITS_PER_BYTE: usize = 5; - - #[inline(always)] - fn validate(word: Word, bytes: usize) -> bool { - let extra_bits = WORD_BITS - (bytes * u8::BITS as usize); - word & !Self::DATA_MASK == ((Self::SET_MASK << extra_bits) >> extra_bits) - } - - #[inline(always)] - fn pack(word: Word) -> Word { - pack_lsb::<{ Self::BITS_PER_BYTE }>(word) - } - - #[inline(always)] - fn unpack(word: Word) -> Word { - unpack_lsb::<{ Self::BITS_PER_BYTE }>(word) | Self::SET_MASK - } -} diff --git a/src/encoding/bit_string/bit_utils.rs b/src/encoding/bit_string/bit_utils.rs deleted file mode 100644 index 7241287..0000000 --- a/src/encoding/bit_string/bit_utils.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::word::Word; - -#[inline(always)] -pub fn pack_lsb(word: Word) -> Word { - let mask = Word::from_le_bytes([(1 << BITS) - 1; 8]); - - // TODO: use pext (see https://github.com/SoftbearStudios/bitcode/issues/17) - - // Mask off bits that we don't care about. - let bytes = (word & mask).to_le_bytes(); - let mut ret1 = 0; - for (i, &b) in bytes[..4].iter().enumerate() { - ret1 |= (b as u32) << (i * BITS); - } - let mut ret2 = 0; - for (i, &b) in bytes[4..].iter().enumerate() { - ret2 |= (b as u32) << (i * BITS); - } - - // 2 steps + merge is a tiny bit faster. - ret1 as u64 | (ret2 as u64) << (BITS * 4) -} - -#[inline(always)] -pub fn unpack_lsb(word: Word) -> Word { - // TODO: use pdep (see https://github.com/SoftbearStudios/bitcode/issues/17) - - let mut bytes = [0u8; 8]; - - for (i, b) in bytes.iter_mut().enumerate() { - *b = (word >> (i * BITS) & ((1 << BITS) - 1)) as u8; - } - - Word::from_le_bytes(bytes) -} diff --git a/src/encoding/bit_string/mod.rs b/src/encoding/bit_string/mod.rs deleted file mode 100644 index 14b5ee8..0000000 --- a/src/encoding/bit_string/mod.rs +++ /dev/null @@ -1,191 +0,0 @@ -use crate::encoding::prelude::*; -use crate::encoding::{Fixed, Gamma}; -use crate::Encode; -use std::num::NonZeroUsize; - -mod ascii; -mod ascii_lowercase; -mod bit_utils; - -pub use ascii::Ascii; -pub use ascii_lowercase::AsciiLowercase; - -/// Encodes strings with character sizes other than 8 bits (e.g. Ascii). -#[derive(Copy, Clone)] -pub struct BitString(pub C); - -impl Encoding for BitString { - #[inline(always)] - fn write_byte_str(self, writer: &mut impl Write, bytes: &[u8]) { - let n = bytes.len(); - n.encode(Gamma, writer).unwrap(); - if n == 0 { - return; - } - - let revert = writer.get_revert(); - writer.write_false(); - let is_valid = writer.write_encoded_bytes::(bytes); - - if !is_valid { - #[cold] - fn cold(writer: &mut W, v: &[u8], revert: W::Revert) { - writer.revert(revert); - writer.write_bit(true); - writer.write_bytes(v); - } - cold(writer, bytes, revert); - } - } - - #[inline(always)] - fn read_bytes(self, reader: &mut impl Read, len: NonZeroUsize) -> Result<&[u8]> { - let is_valid = !reader.read_bit()?; - if is_valid { - reader.read_encoded_bytes::(len) - } else { - #[cold] - fn cold(reader: &mut impl Read, len: NonZeroUsize) -> Result<&[u8]> { - reader.read_bytes(len) - } - cold(reader, len) - } - } -} - -/// A `u8` encoding for [`BitString`]. Each `u8` is encoded with a fixed number of bits -/// (e.g. Ascii = 7 bits). -pub trait ByteEncoding: Copy { - const BITS_PER_BYTE: usize; - - /// Returns if the `word` of up to 8 characters valid. Only `bytes` bytes are included - /// (the remaining are zeroed). `bytes` must be at least 1. - fn validate(word: Word, bytes: usize) -> bool; - - /// Packs 8 bytes to 8 * [`Self::BITS_PER_BYTE`] bits. The returned extra bits must - /// be zeroed. - fn pack(word: Word) -> Word; - - /// Unpacks 8 * [`Self::BITS_PER_BYTE`] bits to 8 bytes. The inputted extra bits are - /// undefined. The returned extra bytes are undefined. - fn unpack(word: Word) -> Word; -} - -// For benchmarking overhead of BitString. DO NOT USE -impl ByteEncoding for Fixed { - const BITS_PER_BYTE: usize = 8; - - #[inline(always)] - fn validate(_: Word, _: usize) -> bool { - true - } - - #[inline(always)] - fn pack(word: Word) -> Word { - word - } - - #[inline(always)] - fn unpack(word: Word) -> Word { - word - } -} - -#[cfg(all(test, debug_assertions, not(miri)))] -mod tests { - use super::*; - use crate::encoding::prelude::test_prelude::*; - use crate::encoding::BitString; - - #[test] - fn test() { - fn t(value: V) { - test_encoding(BitString(Ascii), value.clone()); - test_encoding(BitString(AsciiLowercase), value.clone()); - test_encoding(BitString(Fixed), value.clone()); - test_encoding(Fixed, value); - } - - for i in 0..u8::MAX { - t(i.to_string()); - } - - t("abcd123".repeat(10)); - t("hello".to_string()); - t("☺".to_string()); - - #[derive(Encode, Copy, Clone)] - struct AsciiString(#[bitcode_hint(ascii)] &'static str); - #[derive(Encode, Copy, Clone)] - struct AsciiLowercaseString(#[bitcode_hint(ascii_lowercase)] &'static str); - - let is_valid_bit = 1; - - // Is ascii (ascii is 2 bits shorter, ascii_lowercase is 8 bits shorter). - let s = "foo"; - let len_bits = 5; - assert_eq!( - crate::encode(&[s; 8]).unwrap().len(), - len_bits + s.len() * Fixed::BITS_PER_BYTE - ); - assert_eq!( - crate::encode(&[AsciiString(s); 8]).unwrap().len(), - len_bits + is_valid_bit + s.len() * Ascii::BITS_PER_BYTE - ); - assert_eq!( - crate::encode(&[AsciiLowercaseString(s); 8]).unwrap().len(), - len_bits + is_valid_bit + s.len() * AsciiLowercase::BITS_PER_BYTE - ); - - // Isn't ascii (both 1 bit longer output). - let s = "☺☺☺"; - let len_bits = 7; - assert_eq!( - crate::encode(&[s; 8]).unwrap().len(), - len_bits + s.len() * Fixed::BITS_PER_BYTE - ); - assert_eq!( - crate::encode(&[AsciiString(s); 8]).unwrap().len(), - len_bits + is_valid_bit + s.len() * Fixed::BITS_PER_BYTE - ); - assert_eq!( - crate::encode(&[AsciiLowercaseString(s); 8]).unwrap().len(), - len_bits + is_valid_bit + s.len() * Fixed::BITS_PER_BYTE - ); - } -} - -#[cfg(all(test, not(miri)))] -mod benches { - use super::*; - use crate::encoding::bench_prelude::*; - - fn string_dataset() -> Vec { - // return vec!["a".repeat(10000)]; - let max_size = 16; - dataset::() - .into_iter() - .map(|n| "e".repeat(n as usize % (max_size + 1))) - .collect() - } - - mod ascii { - use super::*; - bench_encoding!(crate::encoding::BitString(Ascii), string_dataset); - } - - mod ascii_lowercase { - use super::*; - bench_encoding!(crate::encoding::BitString(AsciiLowercase), string_dataset); - } - - mod fixed { - use super::*; - bench_encoding!(crate::encoding::BitString(Fixed), string_dataset); - } - - mod fixed_string { - use super::*; - bench_encoding!(crate::encoding::Fixed, string_dataset); - } -} diff --git a/src/encoding/expect_normalized_float.rs b/src/encoding/expect_normalized_float.rs deleted file mode 100644 index f024241..0000000 --- a/src/encoding/expect_normalized_float.rs +++ /dev/null @@ -1,168 +0,0 @@ -use crate::code::{Decode, Encode}; -use crate::encoding::prelude::*; -use crate::encoding::Fixed; - -#[derive(Copy, Clone)] -pub struct ExpectNormalizedFloat; - -// Cannot currently be more than 12 because that would make f64 > 64 bits (requiring multiple reads/writes). -const MAX_EXP_ZEROS: usize = 12; - -macro_rules! impl_float { - ($write:ident, $read:ident, $t:ty, $i: ty, $mantissa:literal, $exp_bias: literal) => { - #[inline(always)] - fn $write(self, writer: &mut impl Write, v: $t) { - let mantissa_bits = $mantissa as usize; - let exp_bias = $exp_bias as u32; - let sign_bit = 1 << (<$i>::BITS - 1); - - let bits = v.to_bits(); - let sign = bits & sign_bit; - let bits_without_sign = bits & !sign_bit; - let exp = (bits_without_sign >> mantissa_bits) as u32; - let exp_zeros = (exp_bias - 1).wrapping_sub(exp) as usize; - - if (sign | exp_zeros as $i) < MAX_EXP_ZEROS as $i { - let mantissa = bits as $i & !(<$i>::MAX << mantissa_bits); - let v = (((mantissa as u64) << 1) | 1) << exp_zeros; - writer.write_bits(v, mantissa_bits + exp_zeros + 1); - } else { - #[cold] - fn cold(writer: &mut impl Write, v: $t) { - writer.write_zeros(MAX_EXP_ZEROS); - v.encode(Fixed, writer).unwrap() - } - cold(writer, v); - } - } - - #[inline(always)] - fn $read(self, reader: &mut impl Read) -> Result<$t> { - let mantissa_bits = $mantissa as usize; - let exp_bias = $exp_bias as u32; - - let v = reader.peek_bits()?; - let exp_zeros = v.trailing_zeros() as usize; - - if exp_zeros < MAX_EXP_ZEROS { - let exp_bits = exp_zeros + 1; - reader.advance(mantissa_bits + exp_bits); - - let mantissa = (v >> exp_bits) as $i & !(<$i>::MAX << mantissa_bits); - let exp = (exp_bias - 1) - exp_zeros as u32; - Ok(<$t>::from_bits(exp as $i << mantissa_bits | mantissa)) - } else { - #[cold] - fn cold(reader: &mut impl Read) -> Result<$t> { - reader.advance(MAX_EXP_ZEROS); - <$t>::decode(Fixed, reader) - } - cold(reader) - } - } - } -} - -impl Encoding for ExpectNormalizedFloat { - impl_float!(write_f32, read_f32, f32, u32, 23, 127); - impl_float!(write_f64, read_f64, f64, u64, 52, 1023); -} - -#[cfg(all(test, not(miri)))] -mod benches { - mod f32 { - use crate::encoding::bench_prelude::*; - bench_encoding!(crate::encoding::ExpectNormalizedFloat, dataset::); - } - - mod f64 { - use crate::encoding::bench_prelude::*; - bench_encoding!(crate::encoding::ExpectNormalizedFloat, dataset::); - } -} - -#[cfg(all(test, debug_assertions, not(miri)))] -mod tests { - macro_rules! impl_test { - ($t:ty, $i:ty) => { - use crate::encoding::expect_normalized_float::*; - use crate::encoding::prelude::test_prelude::*; - use rand::{Rng, SeedableRng}; - - fn t(value: $t) { - #[derive(Copy, Clone, Debug, Encode, Decode)] - struct ExactBits(#[bitcode_hint(expected_range = "0.0..1.0")] $t); - - impl PartialEq for ExactBits { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits() == other.0.to_bits() - } - } - test_encoding(ExpectNormalizedFloat, ExactBits(value)); - } - - #[test] - fn test_random() { - let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default()); - for _ in 0..100000 { - let f = <$t>::from_bits(rng.gen::<$i>()); - t(f) - } - } - - #[test] - fn test2() { - t(0.0); - t(0.5); - t(1.0); - t(-1.0); - t(<$t>::INFINITY); - t(<$t>::NEG_INFINITY); - t(<$t>::NAN); - t(0.0000000000001); - - fn normalized_floats(n: usize) -> impl Iterator { - let scale = 1.0 / n as $t; - (0..n).map(move |i| i as $t * scale) - } - - fn normalized_float_bits(n: usize) -> $t { - use crate::buffer::BufferTrait; - use crate::word_buffer::WordBuffer; - - let mut buffer = WordBuffer::default(); - let mut writer = buffer.start_write(); - for v in normalized_floats(n) { - v.encode(ExpectNormalizedFloat, &mut writer).unwrap(); - } - let bytes = buffer.finish_write(writer).to_vec(); - - let (mut reader, context) = buffer.start_read(&bytes); - for v in normalized_floats(n) { - let decoded = <$t>::decode(ExpectNormalizedFloat, &mut reader).unwrap(); - assert_eq!(decoded, v); - } - WordBuffer::finish_read(reader, context).unwrap(); - - (bytes.len() * u8::BITS as usize) as $t / n as $t - } - - if <$i>::BITS == 32 { - assert!((25.0..25.5).contains(&normalized_float_bits(1 << 12))); - // panic!("bits {}", normalized_float_bits(6000000)); // bits 25.013674 - } else { - assert!((54.0..54.5).contains(&normalized_float_bits(1 << 12))); - // panic!("bits {}", normalized_float_bits(6000000)); // bits 54.019532 - } - } - }; - } - - mod f32 { - impl_test!(f32, u32); - } - - mod f64 { - impl_test!(f64, u64); - } -} diff --git a/src/encoding/expected_range_u64.rs b/src/encoding/expected_range_u64.rs deleted file mode 100644 index e257cf6..0000000 --- a/src/encoding/expected_range_u64.rs +++ /dev/null @@ -1,199 +0,0 @@ -use crate::encoding::prelude::*; - -#[derive(Copy, Clone)] -pub struct ExpectedRangeU64; - -impl ExpectedRangeU64 { - const RANGE: u64 = MAX - MIN; - const _A: () = assert!(Self::RANGE < u64::MAX / 2); - - const fn range_bits(self) -> usize { - ilog2_u64(Self::RANGE.next_power_of_two()) as usize - } - - const fn invalid_bit_pattern(self) -> Option { - if Self::RANGE.is_power_of_two() { - None - } else { - Some(Self::RANGE) - } - } - - const fn has_header_bit(self) -> bool { - self.invalid_bit_pattern().is_none() - } - - const fn total_bits(self) -> usize { - self.range_bits() + self.has_header_bit() as usize - } - - const fn is_pointless(self, bits: usize) -> bool { - bits <= self.total_bits() - } -} - -impl Encoding for ExpectedRangeU64 { - #[inline(always)] - fn write_u64(self, writer: &mut impl Write, word: Word) { - // Don't use use this encoding if it's pointless. - if self.is_pointless(BITS) { - writer.write_bits(word, BITS); - return; - } - - // TODO could extend min and max. - if (MIN..MAX).contains(&word) { - let value = word - MIN; - let header_bit = self.has_header_bit() as u64; - let value_with_header = (value << header_bit) | header_bit; - writer.write_bits(value_with_header, self.total_bits()); - } else { - #[cold] - fn cold( - me: ExpectedRangeU64, - word: Word, - bits: usize, - writer: &mut impl Write, - ) { - if let Some(invalid_bit_pattern) = me.invalid_bit_pattern() { - writer.write_bits(invalid_bit_pattern, me.range_bits()); - writer.write_bits(word, bits); - } else { - writer.write_false(); - writer.write_bits(word, bits); - } - } - cold(self, word, BITS, writer); - } - } - - #[inline(always)] - fn read_u64(self, reader: &mut impl Read) -> Result { - // Don't use use this encoding if it's pointless. - if self.is_pointless(BITS) { - return reader.read_bits(BITS); - } - - let raw_bits = reader.peek_bits()?; - let total_bits = self.total_bits(); - - let value_and_header = raw_bits & ((1 << total_bits) - 1); - if let Some(invalid_bit_pattern) = self.invalid_bit_pattern() { - if value_and_header != invalid_bit_pattern { - reader.advance(total_bits); - - let value = value_and_header; - let word = value + MIN; - if BITS < WORD_BITS && word >= (1 << BITS) { - Err(E::Invalid("expected range").e()) - } else { - Ok(word) - } - } else { - #[cold] - fn cold(reader: &mut impl Read, bits: usize, skip: usize) -> Result { - reader.advance(skip); - reader.read_bits(bits) - } - cold(reader, BITS, self.range_bits()) - } - } else if value_and_header & 1 != 0 { - reader.advance(total_bits); - - let value = value_and_header >> 1; - let word = value + MIN; - if BITS < WORD_BITS && word >= (1 << BITS) { - Err(E::Invalid("expected range").e()) - } else { - Ok(word) - } - } else { - #[cold] - fn cold(reader: &mut impl Read, bits: usize) -> Result { - reader.advance(1); - reader.read_bits(bits) - } - cold(reader, BITS) - } - } -} - -#[cfg(all(test, not(miri)))] -mod benches { - use crate::encoding::prelude::bench_prelude::*; - use rand::prelude::*; - - fn dataset() -> Vec { - let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default()); - (0..1000).map(|_| rng.gen_range(0..100)).collect() - } - - bench_encoding!(super::ExpectedRangeU64::<0, 100>, dataset); -} - -#[cfg(all(test, debug_assertions, not(miri)))] -mod tests { - use super::*; - use crate::encoding::prelude::test_prelude::*; - - #[test] - fn test() { - fn t(value: V) { - let encoding: ExpectedRangeU64<0, 10> = ExpectedRangeU64; - test_encoding(encoding, value); - - let encoding: ExpectedRangeU64<0, 16> = ExpectedRangeU64; - test_encoding(encoding, value); - } - - for i in 0..u8::MAX { - t(i); - } - - t(u16::MAX); - t(u32::MAX); - t(u64::MAX); - - #[derive(Copy, Clone, Debug, PartialEq, Encode, Decode)] - struct IntLessThan1(#[bitcode_hint(expected_range = "0..1")] T); - - for i in 0..1u8 { - let bits_required = bitcode::encode(&[IntLessThan1(i); 8]).unwrap().len(); - // 1 bits are required. - assert_eq!(bits_required, 1); - } - - for i in 1..10u8 { - let bits_required = bitcode::encode(&[IntLessThan1(i); 8]).unwrap().len(); - assert_eq!(bits_required, 9); - } - - #[derive(Copy, Clone, Debug, PartialEq, Encode, Decode)] - struct IntLessThan10(#[bitcode_hint(expected_range = "0..10")] T); - - for i in 0..10u8 { - let bits_required = bitcode::encode(&[IntLessThan10(i); 8]).unwrap().len(); - // Only 4 bits are required since there are invalid bit patterns to use. - assert_eq!(bits_required, 4); - } - - for i in 10..20u8 { - let bits_required = bitcode::encode(&[IntLessThan10(i); 8]).unwrap().len(); - assert_eq!(bits_required, 12); - } - - #[derive(Copy, Clone, Debug, PartialEq, Encode, Decode)] - struct IntLessThan16(#[bitcode_hint(expected_range = "0..16")] T); - - for i in 0..16u8 { - let bits_required = bitcode::encode(&[IntLessThan16(i); 8]).unwrap().len(); - // 5 bits are required since there aren't invalid bit patterns to use. - assert_eq!(bits_required, 5); - } - - for i in 16..32u8 { - let bits_required = bitcode::encode(&[IntLessThan16(i); 8]).unwrap().len(); - assert_eq!(bits_required, 9); - } - } -} diff --git a/src/encoding/gamma.rs b/src/encoding/gamma.rs deleted file mode 100644 index 53ae339..0000000 --- a/src/encoding/gamma.rs +++ /dev/null @@ -1,149 +0,0 @@ -use super::prelude::*; -use crate::nightly::ilog2_non_zero_u64; -use std::num::NonZeroU64; - -#[derive(Copy, Clone)] -pub struct Gamma; -impl Encoding for Gamma { - fn zigzag(self) -> bool { - true - } - - #[inline(always)] - fn write_u64(self, writer: &mut impl Write, word: Word) { - debug_assert!(BITS <= WORD_BITS); - if BITS != WORD_BITS { - debug_assert_eq!(word, word & ((1 << BITS) - 1)); - } - - // https://en.wikipedia.org/wiki/Elias_gamma_coding - // Gamma can't encode 0 so add 1. - if let Some(nz) = NonZeroU64::new(word.wrapping_add(1)) { - let zero_bits = ilog2_non_zero_u64(nz) as usize; - writer.write_zeros(zero_bits); - - // Special case max value as BITS zeros. - if BITS != 64 && word == (u64::MAX >> (64 - BITS)) { - return; - } - - let integer_bits = zero_bits + 1; - - // Rotate bits mod `integer_bits` instead of reversing since it's faster. - // 00001bbb -> 0000bbb1 - let rotated = (nz.get() << 1 & !((1 << 1) << zero_bits)) | 1; - writer.write_bits(rotated, integer_bits); - } else { - // Special case u64::MAX as as 64 zeros (based on u64::MAX + 1 == 0 so we skip branch in ilog2). - writer.write_zeros(64); - } - } - - #[inline(always)] - fn read_u64(self, reader: &mut impl Read) -> Result { - debug_assert!((1..=WORD_BITS).contains(&BITS)); - - let peek = reader.peek_bits()?; - let zero_bits = peek.trailing_zeros() as usize; - - let fast = zero_bits < BITS.min(u32::BITS as usize); - if fast { - let integer_bits = zero_bits + 1; - let gamma_bits = zero_bits + integer_bits; - reader.advance(gamma_bits); - - let rotated = peek >> zero_bits & ((1 << integer_bits) - 1); - - // Rotate bits mod `integer_bits` instead of reversing since it's faster. - // 0000bbb1 -> 00001bbb - let v = (rotated >> 1) | (1 << (integer_bits - 1)); - - // Gamma can't encode 0 so sub 1. - let v = v - 1; - Ok(v) - } else { - // The representation is > 64 bits or it's the max value. - #[cold] - fn slow(reader: &mut impl Read) -> Result { - // True if the representation can't be > 64 bits so it's the max value. - let always_special_case = BITS < u32::BITS as usize; - if always_special_case { - reader.advance(BITS); - return Ok(u64::MAX >> (64 - BITS)); - } - - let zero_bits = (reader.peek_bits()?.trailing_zeros() as usize).min(BITS); - reader.advance(zero_bits); - - // Max value is special cased as as BITS zeros. - if zero_bits == BITS { - return Ok(u64::MAX >> (64 - BITS)); - } - - let integer_bits = zero_bits + 1; - let rotated = reader.read_bits(integer_bits)?; - - let v = (rotated >> 1) | (1 << (integer_bits - 1)); - Ok(v - 1) - } - slow::(reader) - } - } -} - -#[cfg(all(test, not(miri)))] -mod benches { - mod u8 { - use crate::encoding::bench_prelude::*; - bench_encoding!(crate::encoding::Gamma, dataset::); - } - - mod u16 { - use crate::encoding::bench_prelude::*; - bench_encoding!(crate::encoding::Gamma, dataset::); - } - - mod u32 { - use crate::encoding::bench_prelude::*; - bench_encoding!(crate::encoding::Gamma, dataset::); - } - - mod u64 { - use crate::encoding::bench_prelude::*; - bench_encoding!(crate::encoding::Gamma, dataset::); - } -} - -#[cfg(all(test, debug_assertions, not(miri)))] -mod tests { - use super::*; - use crate::encoding::prelude::test_prelude::*; - - #[test] - fn test() { - fn t(value: V) { - test_encoding(Gamma, value) - } - - for i in 0..u8::MAX { - t(i); - } - - t(u16::MAX); - t(u32::MAX); - t(u64::MAX); - - t(-1i8); - t(-1i16); - t(-1i32); - t(-1i64); - - #[derive(Debug, PartialEq, Encode, Decode)] - struct GammaInt(#[bitcode_hint(gamma)] T); - - for i in -7..=7i64 { - // Zig-zag means that low magnitude signed ints are under one byte. - assert_eq!(bitcode::encode(&GammaInt(i)).unwrap().len(), 1); - } - } -} diff --git a/src/encoding/mod.rs b/src/encoding/mod.rs deleted file mode 100644 index 517f06e..0000000 --- a/src/encoding/mod.rs +++ /dev/null @@ -1,124 +0,0 @@ -use crate::{Decode, Encode}; -use prelude::*; -use std::num::NonZeroUsize; - -#[cfg(all(feature = "simdutf8", not(miri)))] -use simdutf8::basic::from_utf8; -#[cfg(not(all(feature = "simdutf8", not(miri))))] -use std::str::from_utf8; - -mod bit_string; -mod expect_normalized_float; -mod expected_range_u64; -mod gamma; -mod prelude; - -pub use bit_string::*; -pub use expect_normalized_float::ExpectNormalizedFloat; -pub use expected_range_u64::ExpectedRangeU64; -pub use gamma::Gamma; - -pub trait Encoding: Copy { - fn is_fixed(self) -> bool { - false - } - - fn zigzag(self) -> bool { - false - } - - #[inline(always)] - fn write_u64(self, writer: &mut impl Write, v: u64) { - writer.write_bits(v, BITS); - } - - #[inline(always)] - fn read_u64(self, reader: &mut impl Read) -> Result { - reader.read_bits(BITS) - } - - // TODO add implementations to Gamma and ExpectedRange. - #[inline(always)] - fn write_u128(self, writer: &mut impl Write, v: u128) { - debug_assert!((65..=128).contains(&BITS)); - - let lo = v as u64; - let hi = (v >> 64) as u64; - writer.write_bits(lo, 64); - writer.write_bits(hi, BITS - 64); - } - - #[inline(always)] - fn read_u128(self, reader: &mut impl Read) -> Result { - debug_assert!((65..=128).contains(&BITS)); - - let lo = reader.read_bits(64)?; - let hi = reader.read_bits(BITS - 64)?; - Ok(lo as u128 | ((hi as u128) << 64)) - } - - #[inline(always)] - fn write_f32(self, writer: &mut impl Write, v: f32) { - v.to_bits().encode(Fixed, writer).unwrap() - } - - #[inline(always)] - fn read_f32(self, reader: &mut impl Read) -> Result { - Ok(f32::from_bits(Decode::decode(Fixed, reader)?)) - } - - #[inline(always)] - fn write_f64(self, writer: &mut impl Write, v: f64) { - v.to_bits().encode(Fixed, writer).unwrap() - } - - #[inline(always)] - fn read_f64(self, reader: &mut impl Read) -> Result { - Ok(f64::from_bits(Decode::decode(Fixed, reader)?)) - } - - #[inline(always)] - fn write_str(self, writer: &mut impl Write, v: &str) { - self.write_byte_str(writer, v.as_bytes()); - } - - #[inline(always)] - fn read_str(self, reader: &mut impl Read) -> Result<&str> { - let len = usize::decode(Gamma, reader)?; - if let Some(len) = NonZeroUsize::new(len) { - from_utf8(self.read_bytes(reader, len)?).map_err(|_| E::Invalid("utf8").e()) - } else { - Ok("") - } - } - - #[inline(always)] - fn write_byte_str(self, writer: &mut impl Write, v: &[u8]) { - v.len().encode(Gamma, writer).unwrap(); - writer.write_bytes(v); - } - - #[inline(always)] - fn read_byte_str(self, reader: &mut impl Read) -> Result<&[u8]> { - let len = usize::decode(Gamma, reader)?; - if let Some(len) = NonZeroUsize::new(len) { - self.read_bytes(reader, len) - } else { - Ok(&[]) - } - } - - #[inline(always)] - fn read_bytes(self, reader: &mut impl Read, len: NonZeroUsize) -> Result<&[u8]> { - reader.read_bytes(len) - } -} - -#[derive(Copy, Clone)] -pub struct Fixed; - -impl Encoding for Fixed { - fn is_fixed(self) -> bool { - true - } -} diff --git a/src/encoding/prelude.rs b/src/encoding/prelude.rs deleted file mode 100644 index 4af49aa..0000000 --- a/src/encoding/prelude.rs +++ /dev/null @@ -1,120 +0,0 @@ -pub use crate::encoding::Encoding; -pub use crate::nightly::ilog2_u64; -pub use crate::read::Read; -pub use crate::word::*; -pub use crate::write::Write; -pub(crate) use crate::{Result, E}; - -#[cfg(test)] -pub mod test_prelude { - pub use super::*; - pub use crate::{Decode, Encode}; - pub use std::fmt::Debug; - - #[cfg(all(test, debug_assertions))] - pub fn test_encoding_inner< - B: crate::buffer::BufferTrait, - V: Encode + Decode + Debug + PartialEq, - >( - encoding: impl Encoding, - value: &V, - ) { - let mut buffer = B::default(); - - let mut writer = buffer.start_write(); - value.encode(encoding, &mut writer).unwrap(); - let bytes = buffer.finish_write(writer).to_owned(); - - let (mut reader, context) = buffer.start_read(&bytes); - assert_eq!(&V::decode(encoding, &mut reader).unwrap(), value); - B::finish_read(reader, context).unwrap(); - } - - #[cfg(all(test, debug_assertions))] - pub fn test_encoding( - encoding: impl Encoding, - value: V, - ) { - test_encoding_inner::(encoding, &value); - test_encoding_inner::(encoding, &value); - } -} - -#[cfg(test)] -pub mod bench_prelude { - use super::test_prelude::*; - use crate::buffer::BufferTrait; - use crate::word_buffer::WordBuffer; - use rand::distributions::Standard; - use rand::prelude::*; - use test::black_box; - - pub use super::*; - pub use test::Bencher; - - pub fn dataset() -> Vec - where - Standard: Distribution, - { - let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default()); - (0..1000).map(|_| rng.gen()).collect() - } - - #[macro_export] - macro_rules! bench_encoding { - ($encoding:expr, $dataset:path) => { - #[bench] - fn encode(b: &mut Bencher) { - bench_encode(b, $encoding, $dataset()); - } - - #[bench] - fn decode(b: &mut Bencher) { - bench_decode(b, $encoding, $dataset()); - } - }; - } - pub use bench_encoding; - - pub fn bench_encode(b: &mut Bencher, encoding: impl Encoding, data: Vec) { - let mut buf = WordBuffer::with_capacity(16000); - let starting_cap = buf.capacity(); - - b.iter(|| { - let buf = black_box(&mut buf); - let data = black_box(data.as_slice()); - - let mut writer = buf.start_write(); - for v in data { - v.encode(encoding, &mut writer).unwrap(); - } - buf.finish_write(writer); - }); - - assert_eq!(buf.capacity(), starting_cap); - } - - pub fn bench_decode( - b: &mut Bencher, - encoding: impl Encoding, - data: Vec, - ) { - let mut buf = WordBuffer::default(); - - let mut writer = buf.start_write(); - for v in &data { - v.encode(encoding, &mut writer).unwrap(); - } - let bytes = buf.finish_write(writer).to_owned(); - - b.iter(|| { - let buf = black_box(&mut buf); - - let (mut reader, _) = buf.start_read(black_box(bytes.as_slice())); - for v in &data { - let decoded = T::decode(encoding, &mut reader).unwrap(); - assert_eq!(&decoded, v); - } - }) - } -} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..cd71127 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,47 @@ +#[cfg(debug_assertions)] +use std::borrow::Cow; +use std::fmt::{Debug, Display, Formatter}; + +/// Short version of `Err(error("..."))`. +pub fn err(msg: &'static str) -> Result { + Err(error(msg)) +} + +/// Creates an error with a message that might be displayed. +pub fn error(_msg: &'static str) -> Error { + #[cfg(debug_assertions)] + return Error(Cow::Borrowed(_msg)); + #[cfg(not(debug_assertions))] + Error(()) +} + +/// Creates an error from a `T:` [`Display`]. +#[cfg(feature = "serde")] +pub fn error_from_display(_t: impl Display) -> Error { + #[cfg(debug_assertions)] + return Error(Cow::Owned(_t.to_string())); + #[cfg(not(debug_assertions))] + Error(()) +} + +#[cfg(debug_assertions)] +type ErrorImpl = Cow<'static, str>; +#[cfg(not(debug_assertions))] +type ErrorImpl = (); + +/// Decoding / (De)serialization errors. +/// # Debug mode +/// In debug mode, the error contains a reason. +/// # Release mode +/// In release mode, the error is a zero-sized type for efficiency. +#[derive(Debug)] +pub struct Error(ErrorImpl); +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + #[cfg(debug_assertions)] + return f.write_str(&self.0); + #[cfg(not(debug_assertions))] + f.write_str("bitcode error") + } +} +impl std::error::Error for Error {} diff --git a/src/ext/arrayvec.rs b/src/ext/arrayvec.rs new file mode 100644 index 0000000..aca2042 --- /dev/null +++ b/src/ext/arrayvec.rs @@ -0,0 +1,220 @@ +use crate::coder::{decode_from_in_place, Decoder, Encoder, Result, View}; +use crate::derive::vec::{unsafe_wild_copy, VecDecoder, VecEncoder}; +use crate::derive::{Decode, Encode}; +use crate::error::err; +use crate::str::{StrDecoder, StrEncoder}; +use arrayvec::{ArrayString, ArrayVec}; +use std::mem::MaybeUninit; + +// TODO optimize ArrayVec impls and make ArrayString use them. +impl Encoder> for StrEncoder { + #[inline(always)] + fn encode(&mut self, t: &ArrayString) { + // Only lengths < 255 are fast to encode and avoid copying lots of memory for 1 byte strings. + // TODO miri doesn't like ArrayString::as_str().as_ptr(), replace with ArrayString::as_ptr() when available. + if N > 64 || cfg!(miri) { + self.encode(t.as_str()); + return; + } + + let s = t.as_str(); + self.0.lengths.encode_less_than_255(s.len()); + let primitives = self.0.elements.as_primitive().unwrap(); + primitives.reserve(N); // TODO Buffer::reserve impl additional * N so we can remove encode_vectored impl. + let dst = primitives.end_ptr(); + + // Safety: `s.as_ptr()` points to `N` valid bytes since it's referencing an ArrayString. + // `dst` has enough space for `[T; N]` because we've reserved `N`. + unsafe { + *(dst as *mut MaybeUninit<[u8; N]>) = *(s.as_ptr() as *const MaybeUninit<[u8; N]>); + primitives.set_end_ptr(dst.add(s.len())); + } + } + #[inline(never)] + fn encode_vectored<'a>(&mut self, i: impl Iterator> + Clone) { + // Only lengths < 255 are fast to encode and avoid copying lots of memory for 1 byte strings. + // TODO miri doesn't like ArrayString::as_str().as_ptr(), replace with ArrayString::as_ptr() when available. + if N > 64 || cfg!(miri) { + self.encode_vectored(i.map(|t| t.as_str())); + return; + } + + // This encode_vectored impl is same as encode impl, but pulls the reserve out of the loop. + let primitives = self.0.elements.as_primitive().unwrap(); + primitives.reserve(i.size_hint().1.unwrap() * N); + let mut dst = primitives.end_ptr(); + for t in i { + let s = t.as_str(); + self.0.lengths.encode_less_than_255(s.len()); + // Safety: `s.as_ptr()` points to `N` valid bytes since it's referencing an ArrayString. + // `dst` has enough space for `[T; N]` because we've reserved `size_hint * N`. + unsafe { + *(dst as *mut MaybeUninit<[u8; N]>) = *(s.as_ptr() as *const MaybeUninit<[u8; N]>); + dst = dst.add(s.len()); + } + } + primitives.set_end_ptr(dst); + } +} +impl Encode for ArrayString { + type Encoder = StrEncoder; +} + +// TODO replace with StrDecoder that optimizes calls to LengthDecoder::decode. +#[derive(Default)] +pub struct ArrayStringDecoder<'a, const N: usize>(StrDecoder<'a>); +impl<'a, const N: usize> View<'a> for ArrayStringDecoder<'a, N> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + self.0.populate(input, length)?; + // Safety: `length` was same length passed to populate. + if unsafe { self.0.lengths.any_greater_than::(length) } { + return err("invalid ArrayString"); + } + Ok(()) + } +} +impl<'a, const N: usize> Decoder<'a, ArrayString> for ArrayStringDecoder<'a, N> { + decode_from_in_place!(ArrayString); + #[inline(always)] + fn decode_in_place(&mut self, out: &mut MaybeUninit>) { + let s: &str = self.0.decode(); + let array_string = out.write(ArrayString::new()); + + // Avoid copying lots of memory for 1 byte strings. + // TODO miri doesn't like ArrayString::as_mut_str().as_mut_ptr(), replace with ArrayString::as_mut_str() when available. + if N > 64 || cfg!(miri) { + // Safety: We've ensured `self.lengths.max_len() <= N` in populate. + unsafe { array_string.try_push_str(s).unwrap_unchecked() }; + return; + } + // Empty s points to no valid bytes, so we can't unsafe_wild_copy. + if s.is_empty() { + return; + } + // Safety: We just checked n != 0 and ensured `self.lengths.max_len() <= N` in populate. + // Also, `dst` has room for `[u8; N]` since it's an ArrayString. + unsafe { + let src = s.as_ptr(); + let dst = array_string.as_mut_str().as_mut_ptr(); + let n = s.len(); + unsafe_wild_copy!([u8; N], src, dst, n); + array_string.set_len(s.len()); + } + } +} +impl<'a, const N: usize> Decode<'a> for ArrayString { + type Decoder = ArrayStringDecoder<'a, N>; +} + +// Helps optimize out some checks in `LengthEncoder::encode`. +#[inline(always)] +fn as_slice_assert_len(t: &ArrayVec) -> &[T] { + let s = t.as_slice(); + // Safety: ArrayVec has length <= N. TODO replace with LengthDecoder. + if s.len() > N { + unsafe { std::hint::unreachable_unchecked() }; + } + s +} + +impl Encoder> for VecEncoder { + #[inline(always)] + fn encode(&mut self, t: &ArrayVec) { + self.encode(as_slice_assert_len(t)) + } + #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator> + Clone) + where + ArrayVec: 'a, + { + self.encode_vectored(i.map(as_slice_assert_len)); + } +} +impl Encode for ArrayVec { + type Encoder = VecEncoder; +} + +pub struct ArrayVecDecoder<'a, T: Decode<'a>, const N: usize>(VecDecoder<'a, T>); +// Can't derive since it would bound T: Default. +impl<'a, T: Decode<'a>, const N: usize> Default for ArrayVecDecoder<'a, T, N> { + fn default() -> Self { + Self(Default::default()) + } +} +impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayVecDecoder<'a, T, N> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + self.0.populate(input, length)?; + // Safety: `length` was same length passed to populate. + if unsafe { self.0.lengths.any_greater_than::(length) } { + return err("invalid ArrayVec"); + } + Ok(()) + } +} +impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, ArrayVec> for ArrayVecDecoder<'a, T, N> { + decode_from_in_place!(ArrayVec); + #[inline(always)] + fn decode_in_place(&mut self, out: &mut MaybeUninit>) { + // Safety: We've ensured self.lengths.max_len() <= N in populate. + unsafe { + let av = out.write(ArrayVec::new()); + let n = self.0.lengths.decode(); + for i in 0..n { + self.0 + .elements + .decode_in_place(&mut *(av.as_mut_ptr().add(i) as *mut MaybeUninit)); + } + av.set_len(n); + } + } +} +impl<'a, T: Decode<'a>, const N: usize> Decode<'a> for ArrayVec { + type Decoder = ArrayVecDecoder<'a, T, N>; +} + +#[cfg(test)] +mod tests { + use crate::{decode, encode}; + use arrayvec::{ArrayString, ArrayVec}; + + // Smaller set of tests for ArrayString than ArrayVec they share VecEncoder/LengthDecoder. + #[test] + fn array_string() { + let mut v = ArrayString::<2>::default(); + v.push('0'); + v.push('1'); + let b = encode(&v); + assert!(decode::>(&b).is_err()); + assert_eq!(decode::>(&b).unwrap(), v); + assert_eq!(decode::>(&b).unwrap().as_str(), v.as_str()); + assert!(decode::>(&encode(&ArrayString::<0>::default())).is_ok()); + } + + #[test] + fn array_vec() { + let mut v = ArrayVec::::default(); + v.push(0); + v.push(1); + let b = encode(&v); + assert!(decode::>(&b).is_err()); + assert_eq!(decode::>(&b).unwrap(), v); + assert_eq!( + decode::>(&b).unwrap().as_slice(), + v.as_slice() + ); + assert_eq!( + decode::>(&b).unwrap().as_slice(), + v.as_slice() + ); + assert!(decode::>(&encode(&ArrayVec::::default())).is_ok()); + + // Make sure LengthDecoder::any_greater_than works on large lengths too. + let mut v = ArrayVec::::default(); + for i in 0..500 { + v.push(i as u8); + } + let b = encode(&v); + assert!(decode::>(&b).is_err()); + assert_eq!(decode::>(&b).unwrap(), v); + } +} diff --git a/src/ext/glam.rs b/src/ext/glam.rs new file mode 100644 index 0000000..b70fdb0 --- /dev/null +++ b/src/ext/glam.rs @@ -0,0 +1,42 @@ +use super::impl_struct; +use glam::*; + +trait Affine3AExt { + fn from_mat3a_translation(matrix3: Mat3A, translation: Vec3A) -> Self; +} +impl Affine3AExt for Affine3A { + fn from_mat3a_translation(matrix3: Mat3A, translation: Vec3A) -> Self { + Self { matrix3, translation } + } +} +impl_struct!(Affine2, from_mat2_translation, matrix2, Mat2, translation, Vec2); +impl_struct!(DAffine2, from_mat2_translation, matrix2, DMat2, translation, DVec2); +impl_struct!(Affine3A, from_mat3a_translation, matrix3, Mat3A, translation, Vec3A); +impl_struct!(DAffine3, from_mat3_translation, matrix3, DMat3, translation, DVec3); + +macro_rules! impl_vec { + ($t:ident, $new:ident, $e:ty, $($f:ident),+) => { + impl_struct!($t, $new, $($f, $e),+); + } +} +impl_vec!(Vec3A, new, f32, x, y, z); +impl_vec!(Mat3A, from_cols, Vec3A, x_axis, y_axis, z_axis); + +macro_rules! impl_glam { + ($e:ty, $v2:ident, $v3:ident, $v4:ident $(, $q:ident, $m2:ident, $m3:ident, $m4:ident)?) => { + impl_vec!($v2, new, $e, x, y); + impl_vec!($v3, new, $e, x, y, z); + impl_vec!($v4, new, $e, x, y, z, w); + $( + impl_vec!($q, from_xyzw, $e, x, y, z, w); + impl_vec!($m2, from_cols, $v2, x_axis, y_axis); + impl_vec!($m3, from_cols, $v3, x_axis, y_axis, z_axis); + impl_vec!($m4, from_cols, $v4, x_axis, y_axis, z_axis, w_axis); + )? + } +} +impl_glam!(f32, Vec2, Vec3, Vec4, Quat, Mat2, Mat3, Mat4); +impl_glam!(f64, DVec2, DVec3, DVec4, DQuat, DMat2, DMat3, DMat4); +impl_glam!(u32, UVec2, UVec3, UVec4); +impl_glam!(i32, IVec2, IVec3, IVec4); +impl_glam!(bool, BVec2, BVec3, BVec4); diff --git a/src/ext/mod.rs b/src/ext/mod.rs new file mode 100644 index 0000000..d1e080c --- /dev/null +++ b/src/ext/mod.rs @@ -0,0 +1,69 @@ +#[cfg(feature = "arrayvec")] +mod arrayvec; +#[cfg(feature = "glam")] +#[rustfmt::skip] // Makes impl_struct! calls way longer. +mod glam; + +#[allow(unused)] +macro_rules! impl_struct { + ($t:ident, $new:ident, $($f:ident, $ft:ty),+) => { + const _: () = { + #[derive(Default)] + pub struct StructEncoder { + $( + $f: <$ft as crate::Encode>::Encoder, + )+ + } + impl crate::coder::Encoder<$t> for StructEncoder { + #[inline(always)] + fn encode(&mut self, t: &$t) { + $( + self.$f.encode(&t.$f); + )+ + } + } + impl crate::coder::Buffer for StructEncoder { + fn collect_into(&mut self, out: &mut Vec) { + $( + self.$f.collect_into(out); + )+ + } + + fn reserve(&mut self, additional: std::num::NonZeroUsize) { + $( + self.$f.reserve(additional); + )+ + } + } + impl crate::Encode for $t { + type Encoder = StructEncoder; + } + + #[derive(Default)] + pub struct StructDecoder<'a> { + $( + $f: <$ft as crate::Decode<'a>>::Decoder, + )+ + } + impl<'a> crate::coder::View<'a> for StructDecoder<'a> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> crate::coder::Result<()> { + $( + self.$f.populate(input, length)?; + )+ + Ok(()) + } + } + impl<'a> crate::coder::Decoder<'a, $t> for StructDecoder<'a> { + #[inline(always)] + fn decode(&mut self) -> $t { + $t::$new($(self.$f.decode()),+) + } + } + impl<'a> crate::Decode<'a> for $t { + type Decoder = StructDecoder<'a>; + } + }; + } +} +#[allow(unused)] +pub(crate) use impl_struct; diff --git a/src/f32.rs b/src/f32.rs new file mode 100644 index 0000000..f056912 --- /dev/null +++ b/src/f32.rs @@ -0,0 +1,149 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::consume::consume_byte_arrays; +use crate::fast::{FastSlice, NextUnchecked, PushUnchecked, VecImpl}; +use std::mem::MaybeUninit; +use std::num::NonZeroUsize; + +#[derive(Debug, Default)] +pub struct F32Encoder(VecImpl); + +impl Encoder for F32Encoder { + #[inline(always)] + fn encode(&mut self, t: &f32) { + unsafe { self.0.push_unchecked(*t) }; + } +} + +/// [`bytemuck`] doesn't implement [`MaybeUninit`] casts. Slightly different from +/// [`bytemuck::cast_slice_mut`] in that it will truncate partial elements instead of panicking. +fn chunks_uninit(m: &mut [MaybeUninit]) -> &mut [MaybeUninit] { + use std::mem::{align_of, size_of}; + assert_eq!(align_of::(), align_of::()); + assert_eq!(0, size_of::() % size_of::()); + let divisor = size_of::() / size_of::(); + // Safety: `align_of == align_of` and `size_of()` is a multiple of `size_of()` + unsafe { + std::slice::from_raw_parts_mut(m.as_mut_ptr() as *mut MaybeUninit, m.len() / divisor) + } +} + +impl Buffer for F32Encoder { + fn collect_into(&mut self, out: &mut Vec) { + let floats = self.0.as_slice(); + let byte_len = std::mem::size_of_val(floats); + out.reserve(byte_len); + let uninit = &mut out.spare_capacity_mut()[..byte_len]; + + let (mantissa, sign_exp) = uninit.split_at_mut(floats.len() * 3); + let mantissa: &mut [MaybeUninit<[u8; 3]>] = chunks_uninit(mantissa); + + // TODO SIMD version with PSHUFB. + const CHUNK_SIZE: usize = 4; + let chunks_len = floats.len() / CHUNK_SIZE; + let chunks_floats = chunks_len * CHUNK_SIZE; + let chunks: &[[u32; CHUNK_SIZE]] = bytemuck::cast_slice(&floats[..chunks_floats]); + let mantissa_chunks: &mut [MaybeUninit<[[u8; 4]; 3]>] = chunks_uninit(mantissa); + let sign_exp_chunks: &mut [MaybeUninit<[u8; 4]>] = chunks_uninit(sign_exp); + + for ci in 0..chunks_len { + let [a, b, c, d] = chunks[ci]; + + let m0 = a & 0xFF_FF_FF | (b << 24); + let m1 = ((b >> 8) & 0xFF_FF) | (c << 16); + let m2 = (c >> 16) & 0xFF | (d << 8); + let mantissa_chunk = &mut mantissa_chunks[ci]; + mantissa_chunk.write([m0.to_le_bytes(), m1.to_le_bytes(), m2.to_le_bytes()]); + + let se = (a >> 24) | ((b >> 24) << 8) | ((c >> 24) << 16) | ((d >> 24) << 24); + let sign_exp_chunk = &mut sign_exp_chunks[ci]; + sign_exp_chunk.write(se.to_le_bytes()); + } + + for i in chunks_floats..floats.len() { + let [m @ .., se] = floats[i].to_le_bytes(); + mantissa[i].write(m); + sign_exp[i].write(se); + } + + // Safety: We just initialized these elements in the loops above. + unsafe { out.set_len(out.len() + byte_len) }; + self.0.clear(); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve(additional.get()) + } +} + +#[derive(Default)] +pub struct F32Decoder<'a> { + // While it is true that this contains 1 bit of the exp we still call it mantissa. + mantissa: FastSlice<'a, [u8; 3]>, + sign_exp: FastSlice<'a, u8>, +} + +impl<'a> View<'a> for F32Decoder<'a> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + let total: &[u8] = bytemuck::must_cast_slice(consume_byte_arrays::<4>(input, length)?); + let (mantissa, sign_exp) = total.split_at(length * 3); + let mantissa: &[[u8; 3]] = bytemuck::cast_slice(mantissa); + // Equivalent to `mantissa.into()` but satisfies miri when we read extra in decode. + self.mantissa = + unsafe { FastSlice::from_raw_parts(total.as_ptr() as *const [u8; 3], mantissa.len()) }; + self.sign_exp = sign_exp.into(); + Ok(()) + } +} + +impl<'a> Decoder<'a, f32> for F32Decoder<'a> { + #[inline(always)] + fn decode(&mut self) -> f32 { + let mantissa_ptr = unsafe { self.mantissa.next_unchecked_as_ptr() }; + + // Loading 4 bytes instead of 3 is 30% faster, so we read 1 extra byte after mantissa_ptr. + // Safety: The extra byte is within bounds because sign_exp comes after mantissa. + let mantissa_extended = unsafe { *(mantissa_ptr as *const [u8; 4]) }; + let mantissa = u32::from_le_bytes(mantissa_extended) & 0xFF_FF_FF; + + let sign_exp = unsafe { self.sign_exp.next_unchecked() }; + f32::from_bits(mantissa | ((sign_exp as u32) << 24)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + use rand_chacha::ChaCha20Rng; + + #[test] + fn test() { + for i in 1..16 { + let mut rng = ChaCha20Rng::from_seed(Default::default()); + let floats: Vec<_> = (0..i).map(|_| f32::from_bits(rng.gen())).collect(); + + let mut encoder = F32Encoder::default(); + encoder.reserve(NonZeroUsize::new(floats.len()).unwrap()); + for &f in &floats { + encoder.encode(&f); + } + let bytes = encoder.collect(); + + let mut decoder = F32Decoder::default(); + let mut slice = bytes.as_slice(); + decoder.populate(&mut slice, floats.len()).unwrap(); + assert!(slice.is_empty()); + for &f in &floats { + assert_eq!(f.to_bits(), decoder.decode().to_bits()); + } + } + } + + fn bench_data() -> Vec { + let mut rng = ChaCha20Rng::from_seed(Default::default()); + (0..crate::limit_bench_miri(1500001)) + .map(|_| rng.gen()) + .collect() + } + crate::bench_encode_decode!(f32_vec: Vec); +} diff --git a/src/fast.rs b/src/fast.rs new file mode 100644 index 0000000..7dc7bd2 --- /dev/null +++ b/src/fast.rs @@ -0,0 +1,522 @@ +use std::fmt::{Debug, Formatter}; +use std::marker::PhantomData; +use std::mem::MaybeUninit; + +pub type VecImpl = FastVec; +pub type SliceImpl<'a, T> = FastSlice<'a, T>; + +pub struct FastVec { + start: *mut T, // TODO NonNull/Unique? + end: *mut T, + capacity: usize, + _spooky: PhantomData>, +} + +impl Debug for FastVec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.as_slice().fmt(f) + } +} + +impl Default for FastVec { + fn default() -> Self { + Self::from(vec![]) + } +} + +impl Drop for FastVec { + fn drop(&mut self) { + unsafe { + drop(Vec::from(std::ptr::read(self))); + } + } +} + +impl From> for Vec { + fn from(fast: FastVec) -> Self { + let start = fast.start; + let length = fast.len(); + let capacity = fast.capacity; + std::mem::forget(fast); + unsafe { Vec::from_raw_parts(start, length, capacity) } + } +} + +impl From> for FastVec { + fn from(mut vec: Vec) -> Self { + let start = vec.as_mut_ptr(); + let end = unsafe { start.add(vec.len()) }; + let capacity = vec.capacity(); + std::mem::forget(vec); + Self { + start, + end, + capacity, + _spooky: Default::default(), + } + } +} + +impl FastVec { + pub fn as_slice(&self) -> &[T] { + unsafe { std::slice::from_raw_parts(self.start, self.len()) } + } + + pub fn as_mut_slice(&mut self) -> &mut [T] { + unsafe { std::slice::from_raw_parts_mut(self.start, self.len()) } + } + + pub fn clear(&mut self) { + self.mut_vec(Vec::clear); + } + + pub fn reserve(&mut self, additional: usize) { + // check copied from RawVec::grow_amortized + let len = self.len(); + if additional > self.capacity.wrapping_sub(len) { + #[cold] + #[inline(never)] + fn reserve_slow(me: &mut FastVec, additional: usize) { + me.mut_vec(|v| v.reserve(additional)); + } + reserve_slow(self, additional); + } + } + + pub fn resize(&mut self, new_len: usize, value: T) + where + T: Clone, + { + self.mut_vec(|v| v.resize(new_len, value)); + } + + /// Accesses the [`FastVec`] mutably as a [`Vec`]. TODO(unsound) panic in `f` causes double free. + fn mut_vec(&mut self, f: impl FnOnce(&mut Vec)) { + unsafe { + let copied = std::ptr::read(self as *mut FastVec); + let mut vec = Vec::from(copied); + f(&mut vec); + let copied = FastVec::from(vec); + std::ptr::write(self as *mut FastVec, copied); + } + } + + fn len(&self) -> usize { + (self.end as usize - self.start as usize) / std::mem::size_of::() // TODO sub_ptr. + } + + /// Get a pointer to write to without incrementing length. + #[inline(always)] + pub fn end_ptr(&mut self) -> *mut T { + debug_assert!(self.len() <= self.capacity); + self.end + } + + /// Set the end_ptr after mutating it. + #[inline(always)] + pub fn set_end_ptr(&mut self, end: *mut T) { + self.end = end; + debug_assert!(self.len() <= self.capacity); + } + + /// Increments length by 1. + /// + /// Safety: + /// + /// Element at [`Self::end_ptr()`] must have been initialized. + #[inline(always)] + pub unsafe fn increment_len(&mut self) { + self.end = self.end.add(1); + debug_assert!(self.len() <= self.capacity); + } +} + +pub trait PushUnchecked { + /// Like [`Vec::push`] but without the possibility of allocating. + /// Safety: len must be < capacity. + unsafe fn push_unchecked(&mut self, t: T); +} + +impl PushUnchecked for FastVec { + #[inline(always)] + unsafe fn push_unchecked(&mut self, t: T) { + debug_assert!(self.len() < self.capacity); + std::ptr::write(self.end, t); + self.end = self.end.add(1); + } +} + +impl PushUnchecked for Vec { + #[inline(always)] + unsafe fn push_unchecked(&mut self, t: T) { + let n = self.len(); + debug_assert!(n < self.capacity()); + let end = self.as_mut_ptr().add(n); + std::ptr::write(end, t); + self.set_len(n + 1) + } +} + +/// Like [`FastVec`] but borrows a [`MaybeUninit<[T; N]>`] instead of heap allocating. Only accepts +/// `T: Copy` because it doesn't drop elements. +pub struct FastArrayVec<'a, T: Copy, const N: usize> { + start: *mut T, + end: *mut T, + _spooky: PhantomData<&'a mut T>, +} + +impl<'a, T: Copy, const N: usize> FastArrayVec<'a, T, N> { + #[inline(always)] + pub fn new(uninit: &'a mut MaybeUninit<[T; N]>) -> Self { + let start = uninit.as_mut_ptr() as *mut T; + Self { + start, + end: start, + _spooky: PhantomData, + } + } + + #[inline(always)] + pub fn as_slice(&self) -> &[T] { + let len = (self.end as usize - self.start as usize) / std::mem::size_of::(); + unsafe { std::slice::from_raw_parts(self.start, len) } + } +} + +impl<'a, T: Copy, const N: usize> PushUnchecked for FastArrayVec<'a, T, N> { + #[inline(always)] + unsafe fn push_unchecked(&mut self, t: T) { + std::ptr::write(self.end, t); + self.end = self.end.add(1); + } +} + +#[derive(Clone)] +pub struct FastSlice<'a, T> { + ptr: *const T, + #[cfg(debug_assertions)] + len: usize, // TODO could store end ptr to allow Debug and as_slice. + _spooky: PhantomData<&'a T>, +} + +impl Debug for FastSlice<'_, T> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("FastSlice") // We don't have len so we can't debug elements. + } +} + +impl Default for FastSlice<'_, T> { + fn default() -> Self { + Self::from([].as_slice()) + } +} + +impl<'a, T> From<&'a [T]> for FastSlice<'a, T> { + fn from(slice: &'a [T]) -> Self { + Self { + ptr: slice.as_ptr(), + #[cfg(debug_assertions)] + len: slice.len(), + _spooky: PhantomData, + } + } +} + +impl<'a, T> FastSlice<'a, T> { + /// Safety: `ptr` and `len` must form a valid slice. + #[inline(always)] + pub unsafe fn from_raw_parts(ptr: *const T, len: usize) -> Self { + let _ = len; + Self { + ptr, + #[cfg(debug_assertions)] + len, + _spooky: PhantomData, + } + } + + /// Like [`NextUnchecked::next_unchecked`] but doesn't dereference the `T`. + #[inline(always)] + pub unsafe fn next_unchecked_as_ptr(&mut self) -> *const T { + #[cfg(debug_assertions)] + { + self.len = self.len.checked_sub(1).unwrap(); + } + let p = self.ptr; + self.ptr = self.ptr.add(1); + p + } + + #[inline(always)] + pub unsafe fn advance(&mut self, n: usize) { + #[cfg(debug_assertions)] + { + self.len = self.len.checked_sub(n).unwrap(); + } + self.ptr = self.ptr.add(n); + } + + #[inline(always)] + pub fn as_ptr(&self) -> *const T { + self.ptr + } +} + +pub trait NextUnchecked<'a, T: Copy> { + /// Gets the next item out of the slice and sets the slice to the remaining elements. + /// Safety: can only call len times. + unsafe fn next_unchecked(&mut self) -> T; + + /// Consumes `length` elements of the slice. + /// Safety: length must be in bounds. + unsafe fn chunk_unchecked(&mut self, length: usize) -> &'a [T]; +} + +impl<'a, T: Copy> NextUnchecked<'a, T> for FastSlice<'a, T> { + #[inline(always)] + unsafe fn next_unchecked(&mut self) -> T { + #[cfg(debug_assertions)] + { + self.len = self.len.checked_sub(1).unwrap(); + } + let t = *self.ptr; + self.ptr = self.ptr.add(1); + t + } + + #[inline(always)] + unsafe fn chunk_unchecked(&mut self, length: usize) -> &'a [T] { + #[cfg(debug_assertions)] + { + self.len = self.len.checked_sub(length).unwrap(); + } + let slice = std::slice::from_raw_parts(self.ptr, length); + self.ptr = self.ptr.add(length); + slice + } +} + +impl<'a, T: Copy> NextUnchecked<'a, T> for &'a [T] { + #[inline(always)] + unsafe fn next_unchecked(&mut self) -> T { + let p = *self.get_unchecked(0); + *self = self.get_unchecked(1..); + p + } + + #[inline(always)] + unsafe fn chunk_unchecked(&mut self, length: usize) -> &'a [T] { + let slice = self.get_unchecked(0..length); + *self = self.get_unchecked(length..); + slice + } +} + +/// Maybe owned [`FastSlice`]. Saves its allocation even if borrowing something. +#[derive(Debug, Default)] +pub struct CowSlice<'borrowed, T> { + slice: SliceImpl<'borrowed, T>, // Lifetime is min of 'borrowed and &'me self. + vec: Vec, +} +impl<'borrowed, T> CowSlice<'borrowed, T> { + /// Creates a [`CowSlice`] with an allocation of `vec`. None of `vec`'s elements are kept. + pub fn with_allocation(mut vec: Vec) -> Self { + vec.clear(); + Self { + slice: [].as_slice().into(), + vec, + } + } + + /// Converts a [`CowSlice`] into its internal allocation. The [`Vec`] is empty. + pub fn into_allocation(mut self) -> Vec { + self.vec.clear(); + self.vec + } + + /// References the inner [`SliceImpl`] as a `[T]`. + /// Safety: `len` must be equal to the slices original len. + #[must_use] + pub unsafe fn as_slice<'me>(&'me self, len: usize) -> &'me [T] + where + 'borrowed: 'me, + { + #[cfg(debug_assertions)] + assert_eq!(self.slice.len, len); + std::slice::from_raw_parts(self.slice.ptr, len) + } + + /// References the inner [`SliceImpl`]. + #[must_use] + #[inline(always)] + pub fn ref_slice<'me>(&'me self) -> &'me SliceImpl<'me, T> + where + 'borrowed: 'me, + { + // Safety: 'me is min of 'borrowed and &'me self because of `where 'borrowed: 'me`. + let slice: &'me SliceImpl<'me, T> = unsafe { std::mem::transmute(&self.slice) }; + slice + } + + /// Mutates the inner [`SliceImpl`]. + #[must_use] + #[inline(always)] + pub fn mut_slice<'me>(&'me mut self) -> &'me mut SliceImpl<'me, T> + where + 'borrowed: 'me, + { + // Safety: 'me is min of 'borrowed and &'me self because of `where 'borrowed: 'me`. + let slice: &'me mut SliceImpl<'me, T> = unsafe { std::mem::transmute(&mut self.slice) }; + slice + } + + /// Equivalent to `self.set_owned().extend_from_slice(slice)` but without copying. + pub fn set_borrowed(&mut self, slice: &'borrowed [T]) { + self.slice = slice.into(); + } + + /// Equivalent to [`Self::set_borrowed`] but takes a [`SliceImpl`] instead of a `&[T]`. + pub fn set_borrowed_slice_impl(&mut self, slice: SliceImpl<'borrowed, T>) { + self.slice = slice; + } + + /// Allows putting contents into a cleared `&mut Vec`. When `SetOwned` is dropped the + /// `CowSlice` will be updated to reference the new elements. + #[must_use] + pub fn set_owned(&mut self) -> SetOwned<'_, 'borrowed, T> { + // Clear self.slice before mutating self.vec, so we don't point to freed memory. + self.slice = [].as_slice().into(); + self.vec.clear(); + SetOwned(self) + } + + /// Mutates the owned [`Vec`]. + /// + /// **Panics** + /// + /// If self is not owned (set_owned hasn't been called). + pub fn mut_owned(&mut self, f: impl FnOnce(&mut Vec)) { + assert_eq!(self.slice.ptr, self.vec.as_ptr()); + // Clear self.slice before mutating self.vec, so we don't point to freed memory. + self.slice = [].as_slice().into(); + f(&mut self.vec); + // Safety: We clear `CowSlice.slice` whenever we mutate `CowSlice.vec`. + let slice: &'borrowed [T] = unsafe { std::mem::transmute(self.vec.as_slice()) }; + self.slice = slice.into(); + } +} + +pub struct SetOwned<'a, 'borrowed, T>(&'a mut CowSlice<'borrowed, T>); +impl<'borrowed, T> Drop for SetOwned<'_, 'borrowed, T> { + fn drop(&mut self) { + // Safety: We clear `CowSlice.slice` whenever we mutate `CowSlice.vec`. + let slice: &'borrowed [T] = unsafe { std::mem::transmute(self.0.vec.as_slice()) }; + self.0.slice = slice.into(); + } +} +impl<'a, T> std::ops::Deref for SetOwned<'a, '_, T> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0.vec + } +} +impl<'a, T> std::ops::DerefMut for SetOwned<'a, '_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0.vec + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test::{black_box, Bencher}; + + #[test] + fn test_as_slice() { + let mut vec = FastVec::default(); + vec.reserve(2); + unsafe { + vec.push_unchecked(1); + vec.push_unchecked(2); + } + assert_eq!(vec.as_slice(), [1, 2]); + } + + // TODO benchmark with u32 instead of just u8. + const N: usize = 1000; + + #[bench] + fn bench_next_unchecked(b: &mut Bencher) { + let src = vec![0u8; N]; + b.iter(|| { + let mut slice = src.as_slice(); + for _ in 0..black_box(N) { + unsafe { black_box(black_box(&mut slice).next_unchecked()) }; + } + }); + } + + #[bench] + fn bench_next_unchecked_fast(b: &mut Bencher) { + let src = vec![0u8; N]; + b.iter(|| { + let mut fast_slice = FastSlice::from(src.as_slice()); + for _ in 0..black_box(N) { + unsafe { black_box(black_box(&mut fast_slice).next_unchecked()) }; + } + }); + } + + #[bench] + fn bench_push_unchecked(b: &mut Bencher) { + let mut buffer = Vec::with_capacity(N); + b.iter(|| { + buffer.clear(); + let vec = black_box(&mut buffer); + for _ in 0..black_box(N) { + let v = black_box(&mut *vec); + unsafe { v.push_unchecked(black_box(0)) }; + } + }); + } + + #[bench] + fn bench_push_unchecked_fast(b: &mut Bencher) { + let mut buffer = Vec::with_capacity(N); + b.iter(|| { + buffer.clear(); + let mut vec = black_box(FastVec::from(std::mem::take(&mut buffer))); + for _ in 0..black_box(N) { + let v = black_box(&mut vec); + unsafe { v.push_unchecked(black_box(0)) }; + } + buffer = vec.into(); + }); + } + + #[bench] + fn bench_reserve(b: &mut Bencher) { + let mut buffer = Vec::::with_capacity(N); + b.iter(|| { + buffer.clear(); + let vec = black_box(&mut buffer); + for _ in 0..black_box(N) { + black_box(&mut *vec).reserve(1); + } + }); + } + + #[bench] + fn bench_reserve_fast(b: &mut Bencher) { + let mut buffer = Vec::::with_capacity(N); + b.iter(|| { + buffer.clear(); + let mut vec = black_box(FastVec::from(std::mem::take(&mut buffer))); + for _ in 0..black_box(N) { + black_box(&mut vec).reserve(1); + } + buffer = vec.into(); + }); + } +} diff --git a/src/guard.rs b/src/guard.rs deleted file mode 100644 index 5cf50e8..0000000 --- a/src/guard.rs +++ /dev/null @@ -1,47 +0,0 @@ -use crate::encoding::Encoding; -use crate::read::Read; -use crate::{Decode, Result, E}; - -pub const ZST_LIMIT: usize = 1 << 16; - -fn check_zst_len(len: usize) -> Result<()> { - if len > ZST_LIMIT { - Err(E::Invalid("too many zst").e()) - } else { - Ok(()) - } -} - -// Used by deserialize. Guards against Vec<()> with huge len taking forever. -#[inline] -#[cfg(any(test, feature = "serde"))] -pub fn guard_zst(len: usize) -> Result<()> { - if std::mem::size_of::() == 0 { - check_zst_len(len) - } else { - Ok(()) - } -} - -// Used by decode. Guards against allocating huge Vec without enough remaining bits to fill it. -// Also guards against Vec<()> with huge len taking forever. -#[inline] -pub fn guard_len(len: usize, encoding: impl Encoding, reader: &impl Read) -> Result<()> { - // In #[derive(Decode)] we report serde types as 1 bit min even though they might serialize - // to 0. We do this so we can have large vectors past the ZST_LIMIT. We assume that any type - // that will serialize to nothing in serde has no size. - if T::DECODE_MIN == 0 || std::mem::size_of::() == 0 { - check_zst_len(len) - } else { - // If we are using an encoding other than fixed DECODE_MIN is invalid. - let min_bits = if encoding.is_fixed() { - T::DECODE_MIN - } else { - 1 - }; - - // We ensure that we have the minimum required bits so decoding doesn't allocate unbounded memory. - let bits = len.saturating_mul(min_bits); - reader.reserve_bits(bits) - } -} diff --git a/src/histogram.rs b/src/histogram.rs new file mode 100644 index 0000000..de7b39a --- /dev/null +++ b/src/histogram.rs @@ -0,0 +1,95 @@ +pub fn histogram(bytes: &[u8]) -> [usize; 256] { + if bytes.len() < 100 { + histogram_simple(bytes) + } else { + histogram_parallel(bytes) + } +} + +fn histogram_simple(bytes: &[u8]) -> [usize; 256] { + let mut histogram = [0; 256]; + for &v in bytes { + histogram[v as usize] += 1; + } + histogram +} + +fn histogram_parallel(bytes: &[u8]) -> [usize; 256] { + // Summing multiple 32 bit histograms is faster than a 64 bit histogram. + let mut total = [0; 256]; + for bytes in bytes.chunks(u32::MAX as usize) { + for (i, &v) in histogram_parallel_u32(bytes).iter().enumerate() { + total[i] += v as usize; + } + } + total +} + +// Based on https://github.com/facebook/zstd/blob/1518570c62b95136b6a69714012957cae5487a9a/lib/compress/hist.c#L66 +fn histogram_parallel_u32(bytes: &[u8]) -> [u32; 256] { + let mut histograms = [[0; 256]; 4]; + + let (chunks, remainder) = bytes.split_at(bytes.len() / 16 * 16); + let chunks16: &[[[u8; 4]; 4]] = bytemuck::cast_slice(chunks); + for chunk16 in chunks16 { + for chunk4 in chunk16 { + let c = u32::from_ne_bytes(*chunk4); + histograms[0][c as u8 as usize] += 1; + histograms[1][(c >> 8) as u8 as usize] += 1; + histograms[2][(c >> 16) as u8 as usize] += 1; + histograms[3][(c >> 24) as usize] += 1; + } + } + for &v in remainder { + histograms[0][v as usize] += 1; + } + + let (dst, src) = histograms.split_at_mut(1); + let dst = &mut dst[0]; + for i in 0..256 { + for src in src.iter() { + dst[i] += src[i]; + } + } + *dst +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + use rand_chacha::ChaCha20Rng; + use test::{black_box, Bencher}; + + fn bench_data(n: usize) -> Vec { + let mut rng = ChaCha20Rng::from_seed(Default::default()); + std::iter::repeat_with(|| rng.gen_range(0..2)) + .take(crate::limit_bench_miri(n)) + .collect() + } + + fn bench_histogram_parallel(b: &mut Bencher, n: usize) { + let data = bench_data(n); + b.iter(|| histogram_parallel(black_box(&data))); + } + + fn bench_histogram_simple(b: &mut Bencher, n: usize) { + let data = bench_data(n); + b.iter(|| histogram_simple(black_box(&data))); + } + + macro_rules! bench { + ($name:ident, $($n:literal),+) => { + paste::paste! { + $( + #[bench] + fn [<$name _ $n>](b: &mut Bencher) { + $name(b, $n); + } + )+ + } + } + } + bench!(bench_histogram_parallel, 10, 100, 1000, 10000); + bench!(bench_histogram_simple, 10, 100, 1000, 10000); +} diff --git a/src/int.rs b/src/int.rs new file mode 100644 index 0000000..86584bb --- /dev/null +++ b/src/int.rs @@ -0,0 +1,143 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::error::err; +use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; +use crate::pack_ints::{pack_ints, unpack_ints, Int}; +use bytemuck::{CheckedBitPattern, NoUninit, Pod}; +use std::marker::PhantomData; +use std::num::NonZeroUsize; + +#[derive(Debug, Default)] +pub struct IntEncoder(VecImpl); + +/// Makes IntEncoder able to encode i32/f32/char. +impl Encoder

for IntEncoder { + #[inline(always)] + fn as_primitive(&mut self) -> Option<&mut VecImpl

> { + assert_eq!(std::mem::size_of::(), std::mem::size_of::

()); + // Safety: T and P are the same size, T is Pod, and we aren't reading P. + let vec: &mut VecImpl

= unsafe { std::mem::transmute(&mut self.0) }; + Some(vec) + } + + #[inline(always)] + fn encode(&mut self, p: &P) { + // TODO swap byte order if big endian. + let t = bytemuck::must_cast(*p); + unsafe { self.0.push_unchecked(t) }; + } +} + +impl Buffer for IntEncoder { + fn collect_into(&mut self, out: &mut Vec) { + pack_ints(self.0.as_mut_slice(), out); + self.0.clear(); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve(additional.get()) + } +} + +#[derive(Debug, Default)] +pub struct IntDecoder<'a, T: Int>(CowSlice<'a, T::Ule>); + +impl<'a, T: Int> IntDecoder<'a, T> { + // For CheckedIntDecoder. + fn borrowed_clone<'me: 'a>(&'me self) -> IntDecoder<'me, T> { + let mut cow = CowSlice::default(); + cow.set_borrowed_slice_impl(self.0.ref_slice().clone()); + Self(cow) + } +} + +impl<'a, T: Int> View<'a> for IntDecoder<'a, T> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + unpack_ints::(input, length, &mut self.0)?; + Ok(()) + } +} + +// Makes IntDecoder able to decode i32/f32 (but not char since it can fail). +impl<'a, T: Int, P: Pod> Decoder<'a, P> for IntDecoder<'a, T> { + #[inline(always)] + fn decode(&mut self) -> P { + let v = unsafe { self.0.mut_slice().next_unchecked() }; + // TODO swap byte order if big endian. + bytemuck::must_cast(v) + } +} + +/// For NonZeroU32, char, etc. +pub struct CheckedIntDecoder<'a, C, I: Int>(IntDecoder<'a, I>, PhantomData); + +// Can't bound C: Default since NonZeroU32/char don't implement it. +impl Default for CheckedIntDecoder<'_, C, I> { + fn default() -> Self { + Self(Default::default(), Default::default()) + } +} + +impl<'a, C: CheckedBitPattern, I: Int> View<'a> for CheckedIntDecoder<'a, C, I> +where + ::Bits: Pod, +{ + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + assert_eq!(std::mem::size_of::(), std::mem::size_of::()); + self.0.populate(input, length)?; + + let mut decoder = self.0.borrowed_clone(); + if (0..length).any(|_| !C::is_valid_bit_pattern(&decoder.decode())) { + return err("invalid bit pattern"); + } + Ok(()) + } +} + +impl<'a, C: CheckedBitPattern, I: Int> Decoder<'a, C> for CheckedIntDecoder<'a, C, I> +where + ::Bits: Pod, +{ + #[inline(always)] + fn decode(&mut self) -> C { + let i: I = self.0.decode(); + + // Safety: populate ensures: + // - C and I are of the same size. + // - The checked bit pattern of C is valid. + unsafe { std::mem::transmute_copy(&i) } + } +} + +#[cfg(test)] +mod tests { + use crate::{decode, encode}; + use std::num::NonZeroU32; + + #[test] + fn non_zero_u32() { + assert!(decode::(&encode(&0u32)).is_err()); + assert!(decode::(&encode(&1u32)).is_ok()); + } + + #[test] + fn char_() { + assert!(decode::(&encode(&u32::MAX)).is_err()); + assert!(decode::(&encode(&0u32)).is_ok()); + } + + fn bench_data() -> Vec { + crate::random_data(1000) + } + crate::bench_encode_decode!(u16_vec: Vec<_>); +} + +#[cfg(test)] +mod test2 { + fn bench_data() -> Vec> { + crate::random_data::(125) + .into_iter() + .map(|n| (0..n / 54).map(|_| n as u16 * 255).collect()) + .collect() + } + crate::bench_encode_decode!(u16_vecs: Vec>); +} diff --git a/src/length.rs b/src/length.rs new file mode 100644 index 0000000..2654f2d --- /dev/null +++ b/src/length.rs @@ -0,0 +1,268 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::consume::consume_byte_arrays; +use crate::error::{err, error}; +use crate::fast::{CowSlice, NextUnchecked, SliceImpl, VecImpl}; +use crate::pack::{pack_bytes, unpack_bytes}; +use std::num::NonZeroUsize; + +#[derive(Debug, Default)] +pub struct LengthEncoder { + small: VecImpl, + large: Vec, // Not a FastVec because capacity isn't known. +} + +impl Encoder for LengthEncoder { + #[inline(always)] + fn encode(&mut self, &v: &usize) { + unsafe { + let end_ptr = self.small.end_ptr(); + if v < 255 { + *end_ptr = v as u8; + } else { + #[inline(never)] + #[cold] // TODO cold or only inline(never)? + unsafe fn encode_slow(end_ptr: *mut u8, large: &mut Vec, v: usize) { + *end_ptr = 255; + + // Swap bytes if big endian, so we can cast large to little endian &[u8]. + #[cfg(target_endian = "little")] + let v = v as u64; + #[cfg(target_endian = "big")] + let v = (v as u64).swap_bytes(); + large.push(v); + } + encode_slow(end_ptr, &mut self.large, v); + } + self.small.increment_len(); + } + } +} + +pub trait Len { + fn len(&self) -> usize; +} + +impl Len for &[T] { + #[inline(always)] + fn len(&self) -> usize { + <[T]>::len(self) + } +} + +impl Len for &str { + #[inline(always)] + fn len(&self) -> usize { + str::len(self) + } +} + +impl LengthEncoder { + /// Encodes a length known to be < `255`. + #[cfg(feature = "arrayvec")] + #[inline(always)] + pub fn encode_less_than_255(&mut self, n: usize) { + use crate::fast::PushUnchecked; + debug_assert!(n < 255); + unsafe { self.small.push_unchecked(n as u8) }; + } + + /// Encodes lengths less than `N`. Have to reserve `N * i.size_hint().1 elements`. + /// Skips calling encode for T::len() == 0. Returns `true` if it failed due to a length over `N`. + #[inline(always)] + pub fn encode_vectored_max_len( + &mut self, + i: impl Iterator, + mut enocde: impl FnMut(T), + ) -> bool { + debug_assert!(N <= 64); + let mut ptr = self.small.end_ptr(); + for t in i { + let n = t.len(); + unsafe { + *ptr = n as u8; + ptr = ptr.add(1); + } + if n == 0 { + continue; + } + if n > N { + // Don't set end ptr (elements won't be saved). + return true; + } + enocde(t); + } + self.small.set_end_ptr(ptr); + false + } + + #[inline(always)] + pub fn encode_vectored_fallback( + &mut self, + i: impl Iterator, + mut reserve_and_encode_large: impl FnMut(T), + ) { + for v in i { + let n = v.len(); + self.encode(&n); + reserve_and_encode_large(v); + } + } +} + +impl Buffer for LengthEncoder { + fn collect_into(&mut self, out: &mut Vec) { + pack_bytes(self.small.as_mut_slice(), out); + self.small.clear(); + out.extend_from_slice(bytemuck::cast_slice(self.large.as_slice())); + self.large.clear(); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.small.reserve(additional.get()); // All lengths inhabit small, only large ones inhabit large. + } +} + +#[derive(Debug, Default)] +pub struct LengthDecoder<'a> { + small: CowSlice<'a, u8>, + large: SliceImpl<'a, [u8; 8]>, + sum: usize, +} + +impl<'a> LengthDecoder<'a> { + pub fn length(&self) -> usize { + self.sum + } + + // For decoding lengths multiple times (e.g. ArrayVec, utf8 validation). + pub fn borrowed_clone<'me: 'a>(&'me self) -> LengthDecoder<'me> { + let mut small = CowSlice::default(); + small.set_borrowed_slice_impl(self.small.ref_slice().clone()); + Self { + small, + large: self.large.clone(), + sum: self.sum, + } + } + + /// Returns if any of the decoded lengths are > `N`. + /// Safety: `length` must be the `length` passed to populate. + #[cfg_attr(not(feature = "arrayvec"), allow(unused))] + pub unsafe fn any_greater_than(&self, length: usize) -> bool { + if N < 255 { + // Fast path: don't need to scan large lengths since there shouldn't be any. + // A large length will have a 255 in small which will be greater than N. + self.small + .as_slice(length) + .iter() + .copied() + .max() + .unwrap_or(0) as usize + > N + } else { + let mut decoder = self.borrowed_clone(); + (0..length).any(|_| decoder.decode() > N) + } + } +} + +impl<'a> View<'a> for LengthDecoder<'a> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + unpack_bytes(input, length, &mut self.small)?; + let small = unsafe { self.small.as_slice(length) }; + + // Summing &[u8] can't overflow since that would require > 2^56 bytes of memory. + let mut sum: u64 = small.iter().map(|&v| v as u64).sum(); + + // Fast path for small lengths: If sum(small) < 255 every small < 255 so large_length is 0. + if sum < 255 { + self.sum = sum as usize; + return Ok(()); + } + + // Every 255 byte indicates a large is present. + let large_length = small.iter().filter(|&&v| v == 255).count(); + let large: &[[u8; 8]] = consume_byte_arrays(input, large_length)?; + self.large = large.into(); + + // Can't overflow since sum includes large_length many 255s. + sum -= large_length as u64 * 255; + + // Summing &[u64] can overflow, so we check it. + for &v in large { + let v = u64::from_le_bytes(v); + sum = sum.checked_add(v).ok_or_else(|| error("length overflow"))?; + } + if sum >= HUGE_LEN { + return err("length overflow"); // Lets us optimize decode with unreachable_unchecked. + } + self.sum = sum.try_into().map_err(|_| error("length > usize::MAX"))?; + Ok(()) + } +} + +// isize::MAX / (largest type we want to allocate without possibility of overflow) +const HUGE_LEN: u64 = 0x7FFFFFFF_FFFFFFFF / 4096; + +impl<'a> Decoder<'a, usize> for LengthDecoder<'a> { + #[inline(always)] + fn decode(&mut self) -> usize { + let length = unsafe { + let v = self.small.mut_slice().next_unchecked(); + + if v < 255 { + v as usize + } else { + #[cold] + unsafe fn cold(large: &mut SliceImpl<'_, [u8; 8]>) -> usize { + u64::from_le_bytes(large.next_unchecked()) as usize + } + cold(&mut self.large) + } + }; + + // Allows some checks in Vec::with_capacity to be removed if lto = true. + // Safety: sum < HUGE_LEN is checked in populate so all elements have to be < HUGE_LEN. + if length as u64 >= HUGE_LEN { + unsafe { std::hint::unreachable_unchecked() } + } + length + } +} + +#[cfg(test)] +mod tests { + use super::{LengthDecoder, LengthEncoder}; + use crate::coder::{Buffer, Decoder, Encoder, View}; + use std::num::NonZeroUsize; + + #[test] + fn test() { + let mut encoder = LengthEncoder::default(); + encoder.reserve(NonZeroUsize::new(3).unwrap()); + encoder.encode(&1); + encoder.encode(&255); + encoder.encode(&2); + let bytes = encoder.collect(); + + let mut decoder = LengthDecoder::default(); + decoder.populate(&mut bytes.as_slice(), 3).unwrap(); + assert_eq!(decoder.decode(), 1); + assert_eq!(decoder.decode(), 255); + assert_eq!(decoder.decode(), 2); + } + + #[cfg(target_pointer_width = "64")] // HUGE_LEN > u32::MAX + #[test] + fn huge_len() { + for (x, is_ok) in [(super::HUGE_LEN - 1, true), (super::HUGE_LEN, false)] { + let mut encoder = LengthEncoder::default(); + encoder.reserve(NonZeroUsize::new(1).unwrap()); + encoder.encode(&(x as usize)); + let bytes = encoder.collect(); + + let mut decoder = LengthDecoder::default(); + assert_eq!(decoder.populate(&mut bytes.as_slice(), 1).is_ok(), is_ok); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6918c36..4032d1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,280 +1,98 @@ -#![cfg_attr(test, feature(test))] +#![allow(clippy::items_after_test_module, clippy::blocks_in_if_conditions)] #![cfg_attr(doc, feature(doc_cfg))] -#![forbid(unsafe_code)] -#![allow(clippy::items_after_test_module)] - -//! Bitcode is a crate for encoding and decoding using a tinier -//! binary serialization strategy. You can easily go from having -//! an object in memory, quickly serialize it to bytes, and then -//! deserialize it back just as fast! -//! -//! The format is not necessarily stable between versions. If you want -//! a stable format, consider [bincode](https://docs.rs/bincode/latest/bincode/). -//! -//! ### Usage -//! -//! ```edition2021 -//! // The object that we will encode. -//! let target: Vec = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]; -//! -//! let encoded: Vec = bitcode::encode(&target).unwrap(); -//! let decoded: Vec = bitcode::decode(&encoded).unwrap(); -//! assert_eq!(target, decoded); -//! ``` -//! -//! ### Advanced Usage -//! -//! Bitcode has several hints that can be applied. -//! Hints may have an effect on the encoded length. -//! Most importantly hints will never cause errors if they don't hold true. -//! -//! ```edition2021 -//! // We mark enum variants that are more likely with a higher frequency. -//! // This allows bitcode to use shorter encodings for them. -//! #[derive(Copy, Clone, bitcode::Encode, bitcode::Decode)] -//! enum Fruit { -//! #[bitcode_hint(frequency = 10)] -//! Apple, -//! #[bitcode_hint(frequency = 5)] -//! Banana, -//! // Unspecified frequencies are 1. -//! Blueberry, -//! Lime, -//! Lychee, -//! Watermelon, -//! } -//! -//! // A cart full of 16 apples takes 2 bytes to encode (1 bit per Apple). -//! let apple_cart: usize = bitcode::encode(&[Fruit::Apple; 16]).unwrap().len(); -//! assert_eq!(apple_cart, 2); -//! -//! // A cart full of 16 bananas takes 4 bytes to encode (2 bits per Banana). -//! let banana_cart: usize = bitcode::encode(&[Fruit::Banana; 16]).unwrap().len(); -//! assert_eq!(banana_cart, 4); -//! -//! // A cart full of 16 blueberries takes 8 bytes to encode (4 bits per Blueberry). -//! let blueberry_cart: usize = bitcode::encode(&[Fruit::Blueberry; 16]).unwrap().len(); -//! assert_eq!(blueberry_cart, 8); -//! ``` -//! -//! ```edition2021 -//! // We expect most user ages to be in the interval [10, 100), so we specify that as the expected -//! // range. If we're right most of the time, users will take fewer bits to encode. -//! #[derive(bitcode::Encode, bitcode::Decode)] -//! struct User { -//! #[bitcode_hint(expected_range = "10..100")] -//! age: u32 -//! } -//! -//! // A user with an age inside the expected range takes up to a byte to encode. -//! let expected_age: usize = bitcode::encode(&User { age: 42 }).unwrap().len(); -//! assert_eq!(expected_age, 1); -//! -//! // A user with an age outside the expected range takes more than 4 bytes to encode. -//! let unexpected_age: usize = bitcode::encode(&User { age: 31415926 }).unwrap().len(); -//! assert!(unexpected_age > 4); -//! ``` -//! -//! ```edition2021 -//! // We expect that most posts won't have that many views or likes, but some can. By using gamma -//! // encoding, posts with fewer views/likes will take fewer bits to encode. -//! #[derive(bitcode::Encode, bitcode::Decode)] -//! #[bitcode_hint(gamma)] -//! struct Post { -//! views: u64, -//! likes: u64, -//! } -//! -//! // An average post just takes 1 byte to encode. -//! let average_post = bitcode::encode(&Post { -//! views: 4, -//! likes: 1, -//! }).unwrap().len(); -//! assert_eq!(average_post, 1); -//! -//! // A popular post takes 11 bytes to encode, luckily these posts are rare. -//! let popular_post = bitcode::encode(&Post { -//! views: 27182818, -//! likes: 161803, -//! }).unwrap().len(); -//! assert_eq!(popular_post, 11) -//! ``` - -// https://doc.rust-lang.org/beta/unstable-book/library-features/test.html -#[cfg(test)] -extern crate test; +#![cfg_attr(test, feature(test))] +#![doc = include_str!("../README.md")] // Fixes derive macro in tests/doc tests. #[cfg(test)] extern crate self as bitcode; +#[cfg(test)] +extern crate test; + +// Missing many calls to swap_bytes throughout the codebase. +#[cfg(target_endian = "big")] +compile_error!("big endian is not yet supported"); + +mod bool; +mod coder; +mod consume; +mod derive; +mod error; +mod ext; +mod f32; +mod fast; +mod histogram; +mod int; +mod length; +mod nightly; +mod pack; +mod pack_ints; +mod str; +mod u8_char; -pub use buffer::Buffer; -pub use code::{Decode, Encode}; -use std::fmt::{self, Display, Formatter}; +pub use crate::derive::*; +pub use crate::error::Error; #[cfg(feature = "derive")] +#[cfg_attr(doc, doc(cfg(feature = "derive")))] pub use bitcode_derive::{Decode, Encode}; -#[cfg(any(test, feature = "serde"))] -pub use crate::serde::{deserialize, serialize}; - -mod buffer; -mod code; -mod code_impls; -mod encoding; -mod guard; -mod nightly; -mod read; -mod register_buffer; -mod word; -mod word_buffer; -mod write; - -#[doc(hidden)] -pub mod __private; - -#[cfg(any(test, feature = "serde"))] +#[cfg(feature = "serde")] mod serde; +#[cfg(feature = "serde")] +pub use crate::serde::*; -#[cfg(all(test, not(miri)))] +#[cfg(test)] mod benches; #[cfg(test)] -mod bit_buffer; -#[cfg(all(test, debug_assertions))] -mod tests; +mod benches_borrowed; -/// Encodes a `T:` [`Encode`] into a [`Vec`]. -/// -/// Won't ever return `Err` unless using `#[bitcode(with_serde)]`. -/// -/// **Warning:** The format is subject to change between versions. -pub fn encode(t: &T) -> Result> -where - T: Encode, -{ - Ok(Buffer::new().encode(t)?.to_vec()) -} - -/// Decodes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Decode`]. -/// -/// **Warning:** The format is subject to change between versions. -pub fn decode(bytes: &[u8]) -> Result +#[cfg(test)] +fn random_data(n: usize) -> Vec where - T: Decode, + rand::distributions::Standard: rand::distributions::Distribution, { - Buffer::new().decode(bytes) + let n = limit_bench_miri(n); + use rand::prelude::*; + let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default()); + (0..n).map(|_| rng.gen()).collect() } - -impl Buffer { - /// Encodes a `T:` [`Encode`] into a [`&[u8]`][`prim@slice`]. Can reuse the buffer's - /// allocations. - /// - /// Won't ever return `Err` unless using `#[bitcode(with_serde)]`. - /// - /// Even if you call `to_vec` on the [`&[u8]`][`prim@slice`], it's still more efficient than - /// [`encode`]. - /// - /// **Warning:** The format is subject to change between versions. - pub fn encode(&mut self, t: &T) -> Result<&[u8]> - where - T: Encode, - { - code::encode_internal(&mut self.0, t) - } - - /// Decodes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Decode`]. Can reuse - /// the buffer's allocations. - /// - /// **Warning:** The format is subject to change between versions. - pub fn decode(&mut self, bytes: &[u8]) -> Result - where - T: Decode, - { - code::decode_internal(&mut self.0, bytes) - } -} - -/// Decoding / (De)serialization errors. -/// -/// # Debug mode -/// -/// In debug mode, the error contains a reason. -/// -/// # Release mode -/// -/// In release mode, the error is a zero-sized type for efficiency. -#[derive(Debug)] -#[cfg_attr(test, derive(PartialEq))] -pub struct Error(ErrorImpl); - -#[cfg(not(debug_assertions))] -type ErrorImpl = (); - -#[cfg(debug_assertions)] -type ErrorImpl = E; - -impl Error { - /// Replaces an invalid message. E.g. read_variant_index calls read_len but converts - /// `E::Invalid("length")` to `E::Invalid("variant index")`. - #[cfg(any(test, feature = "serde"))] - pub(crate) fn map_invalid(self, _s: &'static str) -> Self { - #[cfg(debug_assertions)] - return Self(match self.0 { - E::Invalid(_) => E::Invalid(_s), - _ => self.0, - }); - #[cfg(not(debug_assertions))] - self - } - - // Doesn't implement PartialEq because that would be part of the public api. - pub(crate) fn same(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -/// Inner error that can be converted to [`Error`] with [`E::e`]. -#[derive(Debug, PartialEq)] -pub(crate) enum E { - #[allow(unused)] // Only used by serde feature. - Custom(String), - Eof, - ExpectedEof, - Invalid(&'static str), - #[allow(unused)] // Only used by serde feature. - NotSupported(&'static str), -} - -impl E { - fn e(self) -> Error { - #[cfg(debug_assertions)] - return Error(self); - #[cfg(not(debug_assertions))] - Error(()) - } -} - -type Result = std::result::Result; - -impl Display for Error { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - #[cfg(debug_assertions)] - return Display::fmt(&self.0, f); - #[cfg(not(debug_assertions))] - f.write_str("bitcode error") +#[cfg(test)] +fn limit_bench_miri(n: usize) -> usize { + if cfg!(miri) { + (n / 100).max(10).min(1000) + } else { + n } } - -impl Display for E { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::Custom(s) => write!(f, "custom: {s}"), - Self::Eof => write!(f, "eof"), - Self::ExpectedEof => write!(f, "expected eof"), - Self::Invalid(s) => write!(f, "invalid {s}"), - Self::NotSupported(s) => write!(f, "{s} is not supported"), +#[cfg(test)] +macro_rules! bench_encode_decode { + ($($name:ident: $t:ty),+) => { + paste::paste! { + $( + #[bench] + fn [](b: &mut test::Bencher) { + let data: $t = bench_data(); + let mut buffer = crate::EncodeBuffer::<_>::default(); + b.iter(|| { + test::black_box(buffer.encode(test::black_box(&data))); + }) + } + + #[bench] + fn [](b: &mut test::Bencher) { + let data: $t = bench_data(); + let encoded = crate::encode(&data); + let mut buffer = crate::DecodeBuffer::<_>::default(); + b.iter(|| { + let decoded: $t = buffer.decode(test::black_box(&encoded)).unwrap(); + debug_assert_eq!(data, decoded); + decoded + }) + } + )+ } } } - -impl std::error::Error for Error {} +#[cfg(test)] +pub(crate) use bench_encode_decode; diff --git a/src/nightly.rs b/src/nightly.rs index 4ea600b..2215297 100644 --- a/src/nightly.rs +++ b/src/nightly.rs @@ -1,46 +1,18 @@ -// Replacements for nightly features used while developing the crate. - -use std::num::NonZeroU64; - -#[inline(always)] -pub const fn div_ceil(me: usize, rhs: usize) -> usize { - let d = me / rhs; - let r = me % rhs; - if r > 0 && rhs > 0 { - d + 1 - } else { - d - } -} - -#[inline(always)] -pub const fn ilog2_u64(me: u64) -> u32 { - if cfg!(debug_assertions) && me == 0 { - panic!("log2 on zero") - } - u64::BITS - 1 - me.leading_zeros() -} - -// Faster than ilog2_u64 on CPUs that have bsr but not lzcnt. -#[inline(always)] -pub const fn ilog2_non_zero_u64(me: NonZeroU64) -> u32 { - u64::BITS - 1 - me.leading_zeros() -} - -/// `::min` isn't const yet. -pub const fn min(a: usize, b: usize) -> usize { - if a < b { - a - } else { - b - } -} - -/// `::max` isn't const yet. -pub const fn max(a: usize, b: usize) -> usize { - if a > b { - a - } else { - b - } +/// `#![feature(int_roundings)]` was stabilized in 1.73, but we want to avoid MSRV that high. +macro_rules! impl_div_ceil { + ($name:ident, $t:ty) => { + #[inline(always)] + #[allow(unused_comparisons)] // < 0 checks not required for unsigned + pub const fn $name(lhs: $t, rhs: $t) -> $t { + let d = lhs / rhs; + let r = lhs % rhs; + if (r > 0 && rhs > 0) || (r < 0 && rhs < 0) { + d + 1 + } else { + d + } + } + }; } +impl_div_ceil!(div_ceil_u8, u8); +impl_div_ceil!(div_ceil_usize, usize); diff --git a/src/pack.rs b/src/pack.rs new file mode 100644 index 0000000..c57af6c --- /dev/null +++ b/src/pack.rs @@ -0,0 +1,714 @@ +use crate::coder::Result; +use crate::consume::{consume_byte, consume_byte_arrays, consume_bytes}; +use crate::error::err; +use crate::fast::CowSlice; + +/// Possible states per byte in descending order. Each packed byte will use `log2(states)` bits. +#[repr(u8)] +#[derive(Copy, Clone, PartialEq, PartialOrd)] +enum Packing { + _256 = 0, + _16, + _6, + _4, + _3, + _2, +} + +impl Packing { + fn new(max: u8) -> Self { + match max { + // We could encode max 0 as nothing, but that could allocate unbounded memory when decoding. + 0..=1 => Self::_2, + 2 => Self::_3, + 3 => Self::_4, + 4..=5 => Self::_6, + 6..=15 => Self::_16, + _ => Self::_256, + } + } + + fn write(self, out: &mut Vec, offset_by_min: bool) { + // Encoded in such a way such that 0 is `Self::_256` and higher numbers are smaller packing. + // Also makes `Self::_256` with offset_by_min = true is unrepresentable. + out.push(self as u8 * 2 - offset_by_min as u8); + } + + fn read(input: &mut &[u8]) -> Result<(Self, bool)> { + let v = consume_byte(input)?; + let p_u8 = crate::nightly::div_ceil_u8(v, 2); + let offset_by_min = v & 1 != 0; + let p = match p_u8 { + 0 => Self::_256, + 1 => Self::_16, + 2 => Self::_6, + 3 => Self::_4, + 4 => Self::_3, + 5 => Self::_2, + _ => return invalid_packing(), + }; + debug_assert_eq!(p as u8, p_u8); + Ok((p, offset_by_min)) + } +} + +pub(crate) fn invalid_packing() -> Result { + err("invalid packing") +} + +/// Packs 8 bools per byte. +pub fn pack_bools(bools: &[bool], out: &mut Vec) { + pack_arithmetic::<2>(bytemuck::cast_slice(bools), out); +} + +/// Unpacks 8 bools per byte. `out` will be overwritten with the bools. +pub fn unpack_bools(input: &mut &[u8], length: usize, out: &mut CowSlice) -> Result<()> { + // TODO could borrow if length == 1. + let mut set_owned = out.set_owned(); + let out: &mut Vec = &mut set_owned; + // Safety: u8 and bool have same size/align and `out` will only contain bytes that are 0 or 1. + let out: &mut Vec = unsafe { std::mem::transmute(out) }; + unpack_arithmetic::<2>(input, length, out) +} + +/// Packs multiple bytes into single bytes and writes them to `out`. This only works if +/// `max - min < 16`, otherwise this just copies `bytes` to `out`. +/// +/// These particular tradeoffs were selected so input bytes don't span multiple output bytes to +/// avoid confusing bytewise compression algorithms (e.g. Deflate). +/// +/// Mutates `bytes` to avoid copying them. The remaining `bytes` should be considered garbage. +pub fn pack_bytes(bytes: &mut [u8], out: &mut Vec) { + // Pass through bytes.len() <= 1. + match bytes { + [] => return, + [v] => { + out.push(*v); + return; + } + _ => (), + } + + let mut min = 255; + let mut max = 0; + for &v in bytes.iter() { + min = min.min(v); + max = max.max(v); + } + + // If subtracting min from all bytes results in a better packing do it, otherwise don't bother. + let p = Packing::new(max); + let p2 = Packing::new(max - min); + let p = if p2 > p && bytes.len() > 5 { + for b in bytes.iter_mut() { + *b -= min; + } + p2.write(out, true); + out.push(min); + p2 + } else { + p.write(out, false); + p + }; + + match p { + Packing::_256 => out.extend_from_slice(bytes), + Packing::_16 => pack_arithmetic::<16>(bytes, out), + Packing::_6 => pack_arithmetic::<6>(bytes, out), + Packing::_4 => pack_arithmetic::<4>(bytes, out), + Packing::_3 => pack_arithmetic::<3>(bytes, out), + Packing::_2 => pack_arithmetic::<2>(bytes, out), + } +} + +/// Opposite of `pack_bytes`. Needs to know the `length` in bytes. `out` is overwritten with the bytes. +pub fn unpack_bytes<'a>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, u8>, +) -> Result<()> { + // Pass through length <= 1. + match length { + 0 => return Ok(()), + 1 => { + out.set_borrowed(consume_bytes(input, 1)?); + return Ok(()); + } + _ => (), + } + + let (p, offset_by_min) = Packing::read(input)?; + let min = offset_by_min.then(|| consume_byte(input)).transpose()?; + + if p == Packing::_256 { + debug_assert!(min.is_none()); // Packing::_256 with min should be unrepresentable. + out.set_borrowed(consume_bytes(input, length)?); + return Ok(()); + } + + let mut set_owned = out.set_owned(); + let out = &mut *set_owned; + match p { + Packing::_16 => unpack_arithmetic::<16>(input, length, out)?, + Packing::_6 => unpack_arithmetic::<6>(input, length, out)?, + Packing::_4 => unpack_arithmetic::<4>(input, length, out)?, + Packing::_3 => unpack_arithmetic::<3>(input, length, out)?, + Packing::_2 => unpack_arithmetic::<2>(input, length, out)?, + _ => unreachable!(), + } + if let Some(min) = min { + for v in out { + // TODO validate min such that overflow is impossible and numbers like 0 aren't valid. + *v = v.wrapping_add(min); + } + } + Ok(()) +} + +/// Like `pack_bytes` but all values are less than `N` so it can avoid encoding the packing. +pub fn pack_bytes_less_than(bytes: &[u8], out: &mut Vec) { + debug_assert!(bytes.iter().all(|&b| (b as usize) < N)); + match Packing::new(N.saturating_sub(1) as u8) { + Packing::_256 => out.extend_from_slice(bytes), + Packing::_16 => pack_arithmetic::<16>(bytes, out), + Packing::_6 => pack_arithmetic::<6>(bytes, out), + Packing::_4 => pack_arithmetic::<4>(bytes, out), + Packing::_3 => pack_arithmetic::<3>(bytes, out), + Packing::_2 => pack_arithmetic::<2>(bytes, out), + } +} + +/// Like `unpack_bytes` but all values are less than `N` so it can avoid encoding the packing. +/// Bytes returned by this function are guaranteed less than `N`. +/// +/// If `HISTOGRAM` is set to `N` it also returns a histogram of the output bytes. This is because +/// the histogram can be calculated much faster when operating on the packed bytes. +/// +/// If `HISTOGRAM` is set to `0` it only checks variants < `N` and doesn't calculate a histogram. +pub fn unpack_bytes_less_than<'a, const N: usize, const HISTOGRAM: usize>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, u8>, +) -> Result<[usize; HISTOGRAM]> { + assert!(HISTOGRAM == N || HISTOGRAM == 0); + + /// Checks that `unpacked` bytes are less than `N`. All of `unpacked` is assumed to be < `FACTOR`. + /// `HISTOGRAM` must be 0. + fn check_less_than( + unpacked: &[u8], + ) -> Result<[usize; HISTOGRAM]> { + assert!(FACTOR >= N); + debug_assert!(unpacked.iter().all(|&v| (v as usize) < FACTOR)); + if FACTOR > N && unpacked.iter().copied().max().unwrap_or(0) as usize >= N { + return invalid_packing(); + } + Ok(std::array::from_fn(|_| unreachable!("HISTOGRAM not 0"))) + } + + /// Returns `Ok(histogram)` if buckets after `OUT` are 0. + fn check_histogram( + histogram: [usize; IN], + ) -> Result<[usize; OUT]> { + let (histogram, remaining) = histogram.split_at(OUT); + if remaining.iter().copied().sum::() != 0 { + return invalid_packing(); + } + Ok(*<&[usize; OUT]>::try_from(histogram).unwrap()) + } + + let p = Packing::new(N.saturating_sub(1) as u8); + if p == Packing::_256 { + let bytes = consume_bytes(input, length)?; + out.set_borrowed(bytes); + return if HISTOGRAM == 0 { + check_less_than::(bytes) + } else { + check_histogram(crate::histogram::histogram(bytes)) + }; + } + + /// `FACTOR_POW_DIVISOR == (FACTOR as usize).pow(factor_to_divisor::() as u32)` but as a constant. + fn unpack_arithmetic_less_than< + const N: usize, + const HISTOGRAM: usize, + const FACTOR: usize, + const FACTOR_POW_DIVISOR: usize, + >( + input: &mut &[u8], + length: usize, + out: &mut Vec, + ) -> Result<[usize; HISTOGRAM]> { + assert!(HISTOGRAM == N || HISTOGRAM == 0); + assert!(FACTOR >= 2 && FACTOR >= N); + let divisor = factor_to_divisor::(); + assert_eq!(FACTOR.pow(divisor as u32), FACTOR_POW_DIVISOR); + + let original_input = *input; + unpack_arithmetic::(input, length, out)?; + if HISTOGRAM == 0 { + check_less_than::(out) + } else { + let floor = length / divisor; + let ceil = crate::nightly::div_ceil_usize(length, divisor); + let whole = &original_input[..floor]; + + // Can only `partial_with_garbage % FACTOR` partial_length times as the rest are undefined garbage. + let partial_length = length - floor * divisor; + let partial_with_garbage = original_input[floor..ceil].first().copied(); + + // POPCNT is much faster than histogram. + let histogram = if FACTOR == 2 { + assert_eq!(N, 2); + assert_eq!(divisor, 8); + let mut one_count = 0; + let mut whole = whole; + while let Ok(chunk) = consume_byte_arrays(&mut whole, 1) { + one_count += u64::from_ne_bytes(chunk[0]).count_ones() as usize; + } + for &byte in whole { + one_count += byte.count_ones() as usize; + } + if let Some(partial_with_garbage) = partial_with_garbage { + // Set undefined garbage bits to zero. + let partial = partial_with_garbage << (divisor - partial_length); + one_count += partial.count_ones() as usize; + } + Ok(std::array::from_fn(|i| match i { + 0 => length - one_count, + 1 => one_count, + _ => unreachable!(), + })) + } else { + check_histogram(if whole.len() < 100 { + // Simple path: histogram of unpacked bytes. + let mut histogram = [0; FACTOR]; + for &v in out.iter() { + // Safety: unpack_arithmetic:: returns bytes < FACTOR. + unsafe { *histogram.get_unchecked_mut(v as usize) += 1 }; + } + histogram + } else { + // High throughput path: histogram of packed bytes (one time cost of ~100ns). + let packed_histogram = check_histogram::<256, FACTOR_POW_DIVISOR>( + crate::histogram::histogram(whole), + )?; + let mut histogram: [_; FACTOR] = unpack_histogram(&packed_histogram); + if let Some(mut partial_with_garbage) = partial_with_garbage { + // .min(divisor) does nothing, it's only improve code gen. + for _ in 0..partial_length.min(divisor) { + histogram[partial_with_garbage as usize % FACTOR] += 1; + partial_with_garbage /= FACTOR as u8; + } + } + histogram + }) + }; + if let Ok(h) = histogram { + debug_assert_eq!( + h, + check_histogram(crate::histogram::histogram(out)).unwrap() + ); + } + histogram + } + } + + let mut set_owned = out.set_owned(); + let out = &mut *set_owned; + match p { + Packing::_16 => unpack_arithmetic_less_than::(input, length, out), + Packing::_6 => unpack_arithmetic_less_than::(input, length, out), + Packing::_4 => unpack_arithmetic_less_than::(input, length, out), + Packing::_3 => unpack_arithmetic_less_than::(input, length, out), + Packing::_2 => unpack_arithmetic_less_than::(input, length, out), + _ => unreachable!(), + } +} + +#[inline(never)] +fn unpack_histogram( + packed_histogram: &[usize; FACTOR_POW_DIVISOR], +) -> [usize; FACTOR] { + let divisor = factor_to_divisor::(); + assert_eq!(FACTOR.pow(divisor as u32), FACTOR_POW_DIVISOR); + std::array::from_fn(|i| { + let mut sum = 0; + for level in 0..divisor { + let width = FACTOR.pow(level as u32); + let runs = FACTOR_POW_DIVISOR / (width * FACTOR); + for run in 0..runs { + let run_start = run * (width * FACTOR) + i * width; + let section = &packed_histogram[run_start..run_start + width]; + sum += section.iter().copied().sum::(); + } + } + sum + }) +} + +#[inline(always)] +fn factor_to_divisor() -> usize { + match FACTOR { + 2 => 8, + 3 => 5, + 4 => 4, + 6 => 3, + 16 => 2, + _ => unreachable!(), + } +} + +/// Packs multiple bytes into one. All the bytes must be < `FACTOR`. +/// Factors 2,4,16 are bit packing. Factors 3,6 are arithmetic coding. +fn pack_arithmetic(bytes: &[u8], out: &mut Vec) { + debug_assert!(bytes.iter().all(|&v| v < FACTOR as u8)); + let divisor = factor_to_divisor::(); + + let floor = bytes.len() / divisor; + let ceil = (bytes.len() + (divisor - 1)) / divisor; + + out.reserve(ceil); + let packed = &mut out.spare_capacity_mut()[..ceil]; + + for i in 0..floor { + unsafe { + packed.get_unchecked_mut(i).write( + if FACTOR == 2 && cfg!(all(target_feature = "bmi2", not(miri))) { + // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). + let chunk = (bytes.as_ptr() as *const u8 as *const [u8; 8]).add(i); + let chunk = u64::from_le_bytes(*chunk); + std::arch::x86_64::_pext_u64(chunk, 0x0101010101010101) as u8 + } else { + let mut acc = 0; + for byte_index in 0..divisor { + let byte = *bytes.get_unchecked(i * divisor + byte_index); + acc += byte * (FACTOR as u8).pow(byte_index as u32); + } + acc + }, + ); + } + } + if floor < ceil { + let mut acc = 0; + for &v in bytes[floor * divisor..].iter().rev() { + acc *= FACTOR as u8; + acc += v; + } + packed[floor].write(acc); + } + // Safety: `ceil` elements after len were initialized by loops above. + unsafe { out.set_len(out.len() + ceil) }; +} + +/// Opposite of `pack_arithmetic`. `out` will be overwritten with the unpacked bytes. +fn unpack_arithmetic( + input: &mut &[u8], + unpacked_len: usize, + out: &mut Vec, +) -> Result<()> { + let divisor = factor_to_divisor::(); + + // TODO STRICT: check that packed.all(|&b| b < FACTOR.powi(divisor)). + let floor = unpacked_len / divisor; + let ceil = crate::nightly::div_ceil_usize(unpacked_len, divisor); + let packed = consume_bytes(input, ceil)?; + + debug_assert!(out.is_empty()); + out.reserve(unpacked_len); + let unpacked = &mut out.spare_capacity_mut()[..unpacked_len]; + + for i in 0..floor { + unsafe { + let mut packed = *packed.get_unchecked(i); + if FACTOR == 2 && cfg!(all(target_feature = "bmi2", not(miri))) { + // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). + let chunk = std::arch::x86_64::_pdep_u64(packed as u64, 0x0101010101010101); + *(unpacked.as_mut_ptr() as *mut [u8; 8]).add(i) = chunk.to_le_bytes(); + } else { + for byte in unpacked.get_unchecked_mut(i * divisor..i * divisor + divisor) { + byte.write(packed % FACTOR as u8); + packed /= FACTOR as u8; + } + } + } + } + if floor < ceil { + let mut packed = packed[floor]; + for byte in unpacked[floor * divisor..].iter_mut() { + byte.write(packed % FACTOR as u8); + packed /= FACTOR as u8; + } + } + // Safety: `unpacked_len` elements were initialized by the loops above. + unsafe { out.set_len(unpacked_len) }; + Ok(()) +} + +#[cfg(test)] +mod tests { + use paste::paste; + use test::{black_box, Bencher}; + + #[test] + fn test_pack_bytes() { + fn pack_bytes(bytes: &[u8]) -> Vec { + let mut out = vec![]; + super::pack_bytes(&mut bytes.to_owned(), &mut out); + out + } + + assert!(pack_bytes(&[1, 2, 3, 4, 5, 6, 7]).len() < 7); + assert!(pack_bytes(&[201, 202, 203, 204, 205, 206, 207]).len() < 7); + + for max in 0..255u8 { + for sub in [1, 2, 3, 4, 5, 15, 255] { + let min = max.saturating_sub(sub); + let original = [min, min, min, min, min, min, min, max]; + let packed = pack_bytes(&original); + + let mut slice = packed.as_slice(); + let mut out = crate::fast::CowSlice::default(); + super::unpack_bytes(&mut slice, original.len(), &mut out).unwrap(); + assert!(slice.is_empty()); + assert_eq!(original, unsafe { out.as_slice(original.len()) }); + } + } + } + + fn pack_arithmetic(bytes: &[u8]) -> Vec { + let mut out = vec![]; + super::pack_arithmetic::(bytes, &mut out); + out + } + + #[test] + fn test_pack_arithmetic() { + assert_eq!(pack_arithmetic::<2>(&[1, 0, 1, 0]), [0b0101]); + assert_eq!( + pack_arithmetic::<2>(&[1, 0, 1, 0, 1, 0, 1, 0]), + [0b01010101] + ); + assert_eq!( + pack_arithmetic::<2>(&[1, 0, 1, 0, 1, 0, 1, 0, 1]), + [0b01010101, 0b1] + ); + + assert_eq!(pack_arithmetic::<3>(&[0]), [0]); + assert_eq!(pack_arithmetic::<3>(&[0, 1]), [0 + 1 * 3]); + assert_eq!(pack_arithmetic::<3>(&[0, 1, 2]), [0 + 1 * 3 + 2 * 3 * 3]); + assert_eq!( + pack_arithmetic::<3>(&[2, 0, 0, 0, 0, 0, 1, 2]), + [2, 0 + 1 * 3 + 2 * 3 * 3] + ); + + assert_eq!(pack_arithmetic::<4>(&[1, 0]), [0b0001]); + assert_eq!(pack_arithmetic::<4>(&[1, 0, 1, 0]), [0b00010001]); + assert_eq!( + pack_arithmetic::<4>(&[1, 0, 1, 0, 1, 0]), + [0b00010001, 0b0001] + ); + + assert_eq!(pack_arithmetic::<6>(&[0]), [0]); + assert_eq!(pack_arithmetic::<6>(&[0, 1]), [0 + 1 * 6]); + assert_eq!(pack_arithmetic::<6>(&[0, 1, 2]), [0 + 1 * 6 + 2 * 6 * 6]); + assert_eq!( + pack_arithmetic::<6>(&[2, 0, 0, 0, 1, 2]), + [2, 0 + 1 * 6 + 2 * 6 * 6] + ); + + assert_eq!(pack_arithmetic::<16>(&[1]), [0b0001]); + assert_eq!(pack_arithmetic::<16>(&[1, 0]), [0b00000001]); + assert_eq!(pack_arithmetic::<16>(&[1, 0, 1]), [0b00000001, 0b0001]); + } + + #[test] + fn test_unpack_arithmetic() { + fn test(bytes: &[u8]) { + let packed = pack_arithmetic::(bytes); + + let mut input = packed.as_slice(); + let mut bytes2 = vec![]; + super::unpack_arithmetic::(&mut input, bytes.len(), &mut bytes2).unwrap(); + assert!(input.is_empty()); + assert_eq!(bytes, bytes2); + } + + test::<2>(&[1, 0, 1, 0]); + test::<2>(&[1, 0, 1, 0, 1, 0, 1, 0]); + test::<2>(&[1, 0, 1, 0, 1, 0, 1, 0, 1]); + + test::<3>(&[0]); + test::<3>(&[0, 1]); + test::<3>(&[0, 1, 2]); + test::<3>(&[2, 0, 0, 0, 0, 0, 1, 2]); + + test::<4>(&[1, 0]); + test::<4>(&[1, 0, 1, 0]); + test::<4>(&[1, 0, 1, 0, 1, 0]); + + test::<6>(&[0]); + test::<6>(&[0, 1]); + test::<6>(&[0, 1, 2]); + test::<6>(&[2, 0, 0, 0, 1, 2]); + + test::<16>(&[1]); + test::<16>(&[1, 0]); + test::<16>(&[1, 0, 1]); + } + + fn bench_pack_arithmetic(b: &mut Bencher) { + let bytes = vec![0; 1000]; + let mut out = Vec::with_capacity(bytes.len()); + b.iter(|| { + out.clear(); + super::pack_arithmetic::(&bytes, black_box(&mut out)); + }); + } + + fn bench_unpack_arithmetic(b: &mut Bencher) { + let unpacked_len = 1000; + let packed = pack_arithmetic::(&vec![0; unpacked_len]); + let mut out = Vec::with_capacity(unpacked_len); + + b.iter(|| { + let mut input = packed.as_slice(); + let input = black_box(&mut input); + let unpacked_len = black_box(unpacked_len); + out.clear(); + super::unpack_arithmetic::(input, unpacked_len, black_box(&mut out)).unwrap(); + }); + } + + macro_rules! bench_n { + ($bench:ident, $($n:literal),+) => { + paste! { + $( + #[bench] + fn [<$bench $n>](b: &mut Bencher) { + $bench::<$n>(b); + } + )+ + } + } + } + bench_n!(bench_pack_arithmetic, 2, 3, 4, 6, 16); + bench_n!(bench_unpack_arithmetic, 2, 3, 4, 6, 16); + + fn test_pack_bytes_less_than_n() { + for n in [1, 11, 97, 991, 10007].into_iter().flat_map(|n_prime| { + let divisor = if FACTOR == 256 { + 1 + } else { + super::factor_to_divisor::() + }; + let n_factor = crate::nightly::div_ceil_usize(n_prime, divisor) * divisor; + [n_factor, n_prime] + }) { + let bytes: Vec<_> = crate::random_data(n) + .into_iter() + .map(|v: usize| (v % N as usize) as u8) + .collect(); + let n = bytes.len(); // random_data shrinks n on miri. + + println!("n {n}, N {N}, FACTOR {FACTOR}"); + if N != FACTOR { + let mut bytes = bytes.clone(); + bytes[n - 1] = (FACTOR - 1) as u8; // Make least 1 byte is out of bounds. + let mut packed = vec![]; + super::pack_bytes_less_than::(&bytes, &mut packed); + + assert!(super::unpack_bytes_less_than::( + &mut packed.as_slice(), + bytes.len(), + &mut crate::fast::CowSlice::default() + ) + .is_err()); + assert!(super::unpack_bytes_less_than::( + &mut packed.as_slice(), + bytes.len(), + &mut crate::fast::CowSlice::default() + ) + .is_err()); + } + + let mut packed = vec![]; + super::pack_bytes_less_than::(&bytes, &mut packed); + + let mut input = packed.as_slice(); + let mut unpacked = crate::fast::CowSlice::default(); + super::unpack_bytes_less_than::(&mut input, bytes.len(), &mut unpacked).unwrap(); + assert!(input.is_empty()); + assert_eq!(unsafe { unpacked.as_slice(bytes.len()) }, bytes); + + let mut input = packed.as_slice(); + let mut unpacked = crate::fast::CowSlice::default(); + let histogram = + super::unpack_bytes_less_than::(&mut input, bytes.len(), &mut unpacked) + .unwrap(); + assert!(input.is_empty()); + assert_eq!(unsafe { unpacked.as_slice(bytes.len()) }, bytes); + assert_eq!( + histogram.as_slice(), + &crate::histogram::histogram(&bytes)[..N] + ); + } + } + + macro_rules! test_pack_bytes_less_than_n { + ($($n:literal => $factor:literal),+) => { + $( + paste::paste! { + #[test] + fn []() { + test_pack_bytes_less_than_n::<$n, $factor>(); + } + } + )+ + } + } + // Test factors and +/- 1 to catch off by 1 errors. + test_pack_bytes_less_than_n!(2 => 2, 3 => 3, 4 => 4, 5 => 6, 6 => 6, 7 => 16); + test_pack_bytes_less_than_n!(15 => 16, 16 => 16, 17 => 256, 255 => 256, 256 => 256); + + macro_rules! bench_unpack_histogram { + ($($f:literal => $fpd:literal),+) => { + $( + paste::paste! { + #[bench] + fn [](b: &mut Bencher) { + b.iter(|| { + super::unpack_histogram::<$f, $fpd>(black_box(&[0; $fpd])) + }); + } + } + )+ + } + } + bench_unpack_histogram!(3 => 243, 4 => 256, 6 => 216, 16 => 256); + + macro_rules! bench_unpack_bytes_less_than { + ($($n:literal),+) => { + $( + paste::paste! { + #[bench] + fn [](b: &mut Bencher) { + let mut out = crate::fast::CowSlice::default(); + b.iter(|| { + super::unpack_bytes_less_than::<$n, 0>(black_box(&mut [0].as_slice()), black_box(1), black_box(&mut out)).unwrap(); + }); + } + + #[bench] + fn [](b: &mut Bencher) { + let mut out = crate::fast::CowSlice::default(); + b.iter(|| { + super::unpack_bytes_less_than::<$n, $n>(black_box(&mut [0].as_slice()), black_box(1), black_box(&mut out)).unwrap(); + }); + } + } + )+ + } + } + bench_unpack_bytes_less_than!(2, 3, 4, 6, 16, 256); +} diff --git a/src/pack_ints.rs b/src/pack_ints.rs new file mode 100644 index 0000000..41b11bd --- /dev/null +++ b/src/pack_ints.rs @@ -0,0 +1,461 @@ +use crate::coder::Result; +use crate::consume::{consume_byte, consume_byte_arrays}; +use crate::fast::CowSlice; +use crate::pack::{invalid_packing, pack_bytes, unpack_bytes}; +use bytemuck::Pod; + +/// Possible integer sizes in descending order. +/// TODO consider nonstandard sizes like 24. +#[repr(u8)] +#[derive(Copy, Clone, PartialEq, PartialOrd)] +enum Packing { + _128 = 0, + _64, + _32, + _16, + _8, +} + +impl Packing { + fn new(max: T) -> Self { + let max: u128 = max.into(); + #[allow(clippy::match_overlapping_arm)] // Just make sure not to reorder them. + match max { + ..=0xFF => Self::_8, + ..=0xFF_FF => Self::_16, + ..=0xFF_FF_FF_FF => Self::_32, + ..=0xFF_FF_FF_FF_FF_FF_FF_FF => Self::_64, + _ => Self::_128, + } + } + + fn write(self, out: &mut Vec, offset_by_min: bool) { + // Encoded in such a way such that 0 is no packing and higher numbers are smaller packing. + // Also makes no packing with offset_by_min = true is unrepresentable. + out.push((self as u8 - Self::new(T::MAX) as u8) * 2 - offset_by_min as u8); + } + + fn read(input: &mut &[u8]) -> Result<(Self, bool)> { + let v = consume_byte(input)?; + let p_u8 = crate::nightly::div_ceil_u8(v, 2) + Self::new(T::MAX) as u8; + let offset_by_min = v & 1 != 0; + let p = match p_u8 { + 0 => Self::_128, + 1 => Self::_64, + 2 => Self::_32, + 3 => Self::_16, + 4 => Self::_8, + _ => return invalid_packing(), + }; + debug_assert_eq!(p as u8, p_u8); + Ok((p, offset_by_min)) + } +} + +// Default bound makes #[derive(Default)] on IntEncoder/IntDecoder work. +pub trait Int: + Copy + Default + Into + Ord + Pod + Sized + std::ops::Sub + std::ops::SubAssign +{ + type Ule: Pod + Default; // Unaligned little endian. + const MIN: Self; + const MAX: Self; + fn read(input: &mut &[u8]) -> Result; + fn write(v: Self, out: &mut Vec); + fn wrapping_add(lhs: Self::Ule, rhs: Self::Ule) -> Self::Ule; + #[cfg(test)] + fn from_unaligned(unaligned: Self::Ule) -> Self; + fn pack128(v: &[Self], out: &mut Vec); + fn pack64(v: &[Self], out: &mut Vec); + fn pack32(v: &[Self], out: &mut Vec); + fn pack16(v: &[Self], out: &mut Vec); + fn pack8(v: &mut [Self], out: &mut Vec); + fn unpack128<'a>(v: &'a [[u8; 16]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; + fn unpack64<'a>(v: &'a [[u8; 8]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; + fn unpack32<'a>(v: &'a [[u8; 4]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; + fn unpack16<'a>(v: &'a [[u8; 2]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; + fn unpack8<'a>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, Self::Ule>, + ) -> Result<()>; +} + +macro_rules! impl_simple { + () => { + type Ule = [u8; std::mem::size_of::()]; + const MIN: Self = Self::MIN; + const MAX: Self = Self::MAX; + fn read(input: &mut &[u8]) -> Result { + Ok(consume_byte_arrays(input, 1)?[0]) + } + fn write(v: Self, out: &mut Vec) { + out.extend_from_slice(&v.to_le_bytes()); + } + fn wrapping_add(lhs: Self::Ule, rhs: Self::Ule) -> Self::Ule { + Self::from_le_bytes(lhs) + .wrapping_add(Self::from_le_bytes(rhs)) + .to_le_bytes() + } + #[cfg(test)] + fn from_unaligned(unaligned: Self::Ule) -> Self { + Self::from_le_bytes(unaligned) + } + }; +} +macro_rules! impl_unreachable { + ($t:ty, $pack:ident, $unpack:ident) => { + fn $pack(_: &[Self], _: &mut Vec) { + unimplemented!(); + } + fn $unpack<'a>(_: &'a [<$t as Int>::Ule], _: &mut CowSlice<'a, Self::Ule>) -> Result<()> { + invalid_packing() + } + }; +} +macro_rules! impl_self { + ($pack:ident, $unpack:ident) => { + fn $pack(v: &[Self], out: &mut Vec) { + out.extend_from_slice(bytemuck::cast_slice(&v)) // TODO big endian swap bytes. + } + fn $unpack<'a>(v: &'a [Self::Ule], out: &mut CowSlice<'a, Self::Ule>) -> Result<()> { + out.set_borrowed(v); + Ok(()) + } + }; +} +macro_rules! impl_smaller { + ($t:ty, $pack:ident, $unpack:ident) => { + fn $pack(v: &[Self], out: &mut Vec) { + out.extend(v.iter().flat_map(|&v| (v as $t).to_le_bytes())) + } + fn $unpack<'a>(v: &'a [<$t as Int>::Ule], out: &mut CowSlice<'a, Self::Ule>) -> Result<()> { + let mut set_owned = out.set_owned(); + set_owned.extend( + v.iter() + .map(|&v| (<$t>::from_le_bytes(v) as Self).to_le_bytes()), + ); + Ok(()) + } + }; +} + +// Scratch space to bridge gap between pack_ints and pack_bytes. +// In theory, we could avoid this intermediate step, but it would result in a lot of generated code. +fn with_scratch(f: impl FnOnce(&mut Vec) -> T) -> T { + thread_local! { + static SCRATCH: std::cell::RefCell> = Default::default(); + } + SCRATCH.with(|s| { + let s = &mut s.borrow_mut(); + s.clear(); + f(s) + }) +} +macro_rules! impl_u8 { + () => { + fn pack8(v: &mut [Self], out: &mut Vec) { + with_scratch(|bytes| { + bytes.extend(v.iter().map(|&v| v as u8)); + pack_bytes(bytes, out); + }) + } + fn unpack8(input: &mut &[u8], length: usize, out: &mut CowSlice) -> Result<()> { + with_scratch(|allocation| { + // unpack_bytes might not result in a copy, but if it does we want to avoid an allocation. + let mut bytes = CowSlice::with_allocation(std::mem::take(allocation)); + unpack_bytes(input, length, &mut bytes)?; + // Safety: unpack_bytes ensures bytes has length of `length`. + let slice = unsafe { bytes.as_slice(length) }; + out.set_owned() + .extend(slice.iter().map(|&v| (v as Self).to_le_bytes())); + *allocation = bytes.into_allocation(); + Ok(()) + }) + } + }; +} + +impl Int for u128 { + impl_simple!(); + impl_self!(pack128, unpack128); + impl_smaller!(u64, pack64, unpack64); + impl_smaller!(u32, pack32, unpack32); + impl_smaller!(u16, pack16, unpack16); + impl_u8!(); +} +impl Int for u64 { + impl_simple!(); + impl_unreachable!(u128, pack128, unpack128); + impl_self!(pack64, unpack64); + impl_smaller!(u32, pack32, unpack32); + impl_smaller!(u16, pack16, unpack16); + impl_u8!(); +} +impl Int for u32 { + impl_simple!(); + impl_unreachable!(u128, pack128, unpack128); + impl_unreachable!(u64, pack64, unpack64); + impl_self!(pack32, unpack32); + impl_smaller!(u16, pack16, unpack16); + impl_u8!(); +} +impl Int for u16 { + impl_simple!(); + impl_unreachable!(u128, pack128, unpack128); + impl_unreachable!(u64, pack64, unpack64); + impl_unreachable!(u32, pack32, unpack32); + impl_self!(pack16, unpack16); + impl_u8!(); +} +impl Int for u8 { + impl_simple!(); + impl_unreachable!(u128, pack128, unpack128); + impl_unreachable!(u64, pack64, unpack64); + impl_unreachable!(u32, pack32, unpack32); + impl_unreachable!(u16, pack16, unpack16); + // Doesn't use impl_u8!() because it would copy unnecessary. + fn pack8(v: &mut [Self], out: &mut Vec) { + pack_bytes(v, out); + } + fn unpack8(input: &mut &[u8], length: usize, out: &mut CowSlice<[u8; 1]>) -> Result<()> { + // Safety: [u8; 1] and u8 are the same from the perspective of CowSlice. + let out: &mut CowSlice = unsafe { std::mem::transmute(out) }; + unpack_bytes(input, length, out) + } +} + +fn minmax(v: &[T]) -> (T, T) { + let mut min = T::MAX; + let mut max = T::MIN; + for &v in v.iter() { + min = min.min(v); + max = max.max(v); + } + (min, max) +} + +/// Like [`pack_bytes`] but for larger integers. +pub fn pack_ints(ints: &mut [T], out: &mut Vec) { + // Passes through u8s and length <= 1 since they can't be compressed. + let p = if std::mem::size_of::() == 1 || ints.len() <= 1 { + Packing::new(T::MAX) + } else { + // Take a small sample to avoid wastefully scanning the whole slice. + let (sample, remaining) = ints.split_at(ints.len().min(16)); + let (min, max) = minmax(sample); + + // Only have to check packing(max - min) since it's always as good as just packing(max). + let none = Packing::new(T::MAX); + if Packing::new(max - min) == none { + none.write::(out, false); + none + } else { + let (remaining_min, remaining_max) = minmax(remaining); + let min = min.min(remaining_min); + let max = max.max(remaining_max); + + // If subtracting min from all ints results in a better packing do it, otherwise don't bother. + // TODO ensure packing never expands data unnecessarily. + let p = Packing::new(max); + let p2 = Packing::new(max - min); + if p2 > p && ints.len() > 5 { + for b in ints.iter_mut() { + *b -= min; + } + p2.write::(out, true); + T::write(min, out); + p2 + } else { + p.write::(out, false); + p + } + } + }; + + match p { + Packing::_128 => T::pack128(ints, out), + Packing::_64 => T::pack64(ints, out), + Packing::_32 => T::pack32(ints, out), + Packing::_16 => T::pack16(ints, out), + Packing::_8 => T::pack8(ints, out), + } +} + +/// Opposite of [`pack_ints`]. Unpacks into `T::Ule` aka unaligned little endian. +pub fn unpack_ints<'a, T: Int>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, T::Ule>, +) -> Result<()> { + // Passes through u8s and length <= 1 since they can't be compressed. + let (p, min) = if std::mem::size_of::() == 1 || length <= 1 { + (Packing::new(T::MAX), None) + } else { + let (p, offset_by_min) = Packing::read::(input)?; + (p, offset_by_min.then(|| T::read(input)).transpose()?) + }; + + match p { + Packing::_128 => T::unpack128(consume_byte_arrays(input, length)?, out), + Packing::_64 => T::unpack64(consume_byte_arrays(input, length)?, out), + Packing::_32 => T::unpack32(consume_byte_arrays(input, length)?, out), + Packing::_16 => T::unpack16(consume_byte_arrays(input, length)?, out), + Packing::_8 => T::unpack8(input, length, out), + }?; + if let Some(min) = min { + // Has to be owned to have min. + out.mut_owned(|out| { + for v in out { + *v = T::wrapping_add(*v, min); + } + }) + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::{pack_ints, unpack_ints, CowSlice, Int}; + use std::fmt::Debug; + use test::{black_box, Bencher}; + + fn t(ints: &[T]) -> Vec { + let mut out = vec![]; + pack_ints(&mut ints.to_owned(), &mut out); + + let mut slice = out.as_slice(); + let mut unpacked = CowSlice::default(); + let length = ints.len(); + unpack_ints::(&mut slice, length, &mut unpacked).unwrap(); + let unpacked = unsafe { unpacked.as_slice(length) }; + let unpacked: Vec<_> = unpacked.iter().copied().map(T::from_unaligned).collect(); + assert_eq!(unpacked, ints); + assert!(slice.is_empty()); + + let packing = out[0]; + let size = 100.0 * out.len() as f32 / std::mem::size_of_val(ints) as f32; + println!("{packing} {size:>5.1}%"); + out + } + + #[rustfmt::skip] + macro_rules! test { + ($name:ident, $t:ty) => { + #[test] + fn $name() { + type T = $t; + for increment in [0, 1, u8::MAX as u128 + 1, u16::MAX as u128 + 1, u32::MAX as u128 + 1, u64::MAX as u128 + 1] { + let Ok(increment) = T::try_from(increment) else { + continue; + }; + + for max in [0, u8::MAX as u128, u16::MAX as u128, u32::MAX as u128, u64::MAX as u128, u128::MAX as u128] { + let Ok(start) = T::try_from(max / 2) else { + continue; + }; + let s = format!("{start} {increment}"); + print!("{s:<25} => "); + t::(&std::array::from_fn::<_, 100, _>(|i| { + start + i as T * increment + })); + } + } + } + }; + } + test!(test_u008, u8); + test!(test_u016, u16); + test!(test_u032, u32); + test!(test_u064, u64); + test!(test_u128, u128); + + fn bench_pack_ints(b: &mut Bencher, src: &[T]) { + let mut ints = src.to_vec(); + let mut out = Vec::with_capacity(std::mem::size_of_val(src) + 10); + let starting_cap = out.capacity(); + b.iter(|| { + ints.copy_from_slice(&src); + out.clear(); + pack_ints(black_box(&mut ints), black_box(&mut out)); + }); + assert_eq!(out.capacity(), starting_cap); + } + + fn bench_unpack_ints(b: &mut Bencher, src: &[T]) { + let mut packed = vec![]; + pack_ints(&mut src.to_vec(), &mut packed); + let mut out = CowSlice::with_allocation(Vec::::with_capacity(src.len())); + b.iter(|| { + let length = src.len(); + unpack_ints::( + black_box(&mut packed.as_slice()), + length, + black_box(&mut out), + ) + .unwrap(); + debug_assert_eq!( + unsafe { out.as_slice(length) } + .iter() + .copied() + .map(T::from_unaligned) + .collect::>(), + src + ); + }); + } + + macro_rules! bench { + ($name:ident, $t:ident) => { + paste::paste! { + #[bench] + fn [](b: &mut Bencher) { + bench_pack_ints::<$t>(b, &[0; 1000]); + } + + #[bench] + fn [](b: &mut Bencher) { + bench_pack_ints::<$t>(b, &[$t::MAX; 1000]); + } + + #[bench] + fn [](b: &mut Bencher) { + bench_pack_ints::<$t>(b, &crate::random_data(1000)); + } + + #[bench] + fn [](b: &mut Bencher) { + let src = vec![$t::MIN; 1000]; + let mut ints = src.clone(); + let mut out: Vec = Vec::with_capacity(std::mem::size_of_val(&ints) + 10); + b.iter(|| { + ints.copy_from_slice(&src); + let input = black_box(&mut ints); + out.clear(); + let out = black_box(&mut out); + out.extend_from_slice(bytemuck::cast_slice(&input)); + }); + } + + #[bench] + fn [](b: &mut Bencher) { + bench_unpack_ints::<$t>(b, &[0; 1000]); + } + + #[bench] + fn [](b: &mut Bencher) { + bench_unpack_ints::<$t>(b, &[$t::MAX; 1000]); + } + + #[bench] + fn [](b: &mut Bencher) { + bench_unpack_ints::<$t>(b, &crate::random_data(1000)); + } + } + }; + } + bench!(u008, u8); + bench!(u016, u16); + bench!(u032, u32); + bench!(u064, u64); + bench!(u128, u128); +} diff --git a/src/read.rs b/src/read.rs deleted file mode 100644 index 8df0c08..0000000 --- a/src/read.rs +++ /dev/null @@ -1,183 +0,0 @@ -use crate::encoding::ByteEncoding; -use crate::word::Word; -use crate::Result; -use std::num::NonZeroUsize; - -/// Abstracts over reading bits from a buffer. -pub trait Read { - /// Advances any amount of bits. Must never fail. - fn advance(&mut self, bits: usize); - /// Peeks 64 bits without reading them. Bits after EOF are zeroed. - fn peek_bits(&mut self) -> Result; - // Reads 1 bit. - fn read_bit(&mut self) -> Result; - /// Reads up to 64 bits. `bits` must be in range `1..=64`. - fn read_bits(&mut self, bits: usize) -> Result; - /// Reads `len` bytes. - fn read_bytes(&mut self, len: NonZeroUsize) -> Result<&[u8]>; - /// Reads `len` with a [`ByteEncoding`]. - fn read_encoded_bytes(&mut self, len: NonZeroUsize) -> Result<&[u8]>; - /// Ensures that at least `bits` remain. Never underreports remaining bits. - fn reserve_bits(&self, bits: usize) -> Result<()>; -} - -#[cfg(all(test, not(miri)))] -mod tests { - use crate::bit_buffer::BitBuffer; - use crate::buffer::BufferTrait; - use crate::nightly::div_ceil; - use crate::read::Read; - use crate::word_buffer::WordBuffer; - use paste::paste; - use std::num::NonZeroUsize; - use test::{black_box, Bencher}; - - fn bench_start_read(b: &mut Bencher) { - let bytes = vec![123u8; 6659]; - let mut buf = T::default(); - - b.iter(|| { - black_box(buf.start_read(black_box(bytes.as_slice()))); - }); - } - - // How many times each benchmark calls the function. - const TIMES: usize = 1000; - - fn bench_read_bit(b: &mut Bencher) { - let bytes = vec![123u8; div_ceil(TIMES, u8::BITS as usize)]; - let mut buf = T::default(); - let _ = buf.start_read(&bytes); - - b.iter(|| { - let buf = black_box(&mut buf); - let (mut reader, _) = buf.start_read(black_box(&bytes)); - for _ in 0..black_box(TIMES) { - black_box(reader.read_bit().unwrap()); - } - }); - } - - fn bench_read_bits(b: &mut Bencher, bits: usize) { - let bytes = vec![123u8; div_ceil(bits * TIMES, u8::BITS as usize)]; - let mut buf = T::default(); - let _ = buf.start_read(&bytes); - - b.iter(|| { - let buf = black_box(&mut buf); - let (mut reader, _) = buf.start_read(black_box(&bytes)); - for _ in 0..black_box(TIMES) { - black_box(reader.read_bits(bits).unwrap()); - } - }); - } - - fn bench_read_bytes(b: &mut Bencher, byte_count: usize) { - let bytes = vec![123u8; byte_count * TIMES + 1]; - let mut buf = T::default(); - let _ = buf.start_read(&bytes); - - let byte_count = NonZeroUsize::new(byte_count).expect("must be >= 0"); - b.iter(|| { - let buf = black_box(&mut buf); - let (mut reader, _) = buf.start_read(black_box(&bytes)); - reader.read_bit().unwrap(); // Make read_bytes unaligned. - for _ in 0..black_box(TIMES) { - black_box(reader.read_bytes(byte_count).unwrap()); - } - }); - } - - fn random_lengths(min: NonZeroUsize, max: NonZeroUsize) -> Vec { - use rand::prelude::*; - let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default()); - - (0..TIMES) - .map(|_| NonZeroUsize::new(rng.gen_range(min.get()..=max.get())).unwrap()) - .collect() - } - - fn bench_read_bytes_range(b: &mut Bencher, min: usize, max: usize) { - let min = NonZeroUsize::new(min).expect("must be >= 0"); - let max = NonZeroUsize::new(max).expect("must be >= 0"); - - let lens = random_lengths(min, max); - let total_len: usize = lens.iter().map(|l| l.get()).sum(); - let bytes = vec![123u8; total_len + 1]; - - let mut buf = T::default(); - let _ = buf.start_read(&bytes); - - b.iter(|| { - let buf = black_box(&mut buf); - let (mut reader, _) = buf.start_read(black_box(&bytes)); - reader.read_bit().unwrap(); // Make read_bytes unaligned. - for &len in black_box(lens.as_slice()) { - black_box(reader.read_bytes(len).unwrap()); - } - }); - } - - macro_rules! bench_read_bits { - ($name:ident, $T:ty, $n:literal) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_read_bits::<$T>(b, $n); - } - } - }; - } - - macro_rules! bench_read_bytes { - ($name:ident, $T:ty, $n:literal) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_read_bytes::<$T>(b, $n); - } - } - }; - } - - macro_rules! bench_read_bytes_range { - ($name:ident, $T:ty, $min:literal, $max:literal) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_read_bytes_range::<$T>(b, $min, $max); - } - } - }; - } - - macro_rules! bench_read { - ($name:ident, $T:ty) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_start_read::<$T>(b); - } - - #[bench] - fn [](b: &mut Bencher) { - bench_read_bit::<$T>(b); - } - } - - bench_read_bits!($name, $T, 5); - bench_read_bits!($name, $T, 41); - bench_read_bytes!($name, $T, 1); - bench_read_bytes!($name, $T, 10); - bench_read_bytes!($name, $T, 100); - bench_read_bytes!($name, $T, 1000); - bench_read_bytes!($name, $T, 10000); - - bench_read_bytes_range!($name, $T, 1, 8); - bench_read_bytes_range!($name, $T, 1, 16); - }; - } - - bench_read!(bit_buffer, BitBuffer); - bench_read!(word_buffer, WordBuffer); -} diff --git a/src/register_buffer.rs b/src/register_buffer.rs deleted file mode 100644 index e346d52..0000000 --- a/src/register_buffer.rs +++ /dev/null @@ -1,157 +0,0 @@ -use crate::encoding::ByteEncoding; -use crate::read::Read; -use crate::word::*; -use crate::write::Write; -use crate::Result; -use std::num::NonZeroUsize; - -/// A writer that can only hold 64 bits, but only uses registers instead of load/store. -pub struct RegisterWriter<'a, W: Write> { - pub writer: &'a mut W, - pub inner: Register, -} - -impl<'a, W: Write> RegisterWriter<'a, W> { - pub fn new(writer: &'a mut W) -> Self { - Self { - writer, - inner: Register::EMPTY, - } - } -} - -impl<'a, W: Write> RegisterWriter<'a, W> { - /// Writes the contents of the buffer to `writer` and clears the buffer. - #[inline(always)] - pub fn flush(&mut self) { - debug_assert!( - self.inner.index <= 64, - "too many bits written to RegisterBuffer" - ); - self.writer.write_bits(self.inner.value, self.inner.index); - self.inner = Register::EMPTY; - } -} - -/// A reader that can only hold 64 bits, but only uses registers instead of loads. -pub struct RegisterReader<'a, R: Read> { - pub reader: &'a mut R, - pub inner: Register, -} - -// The purpose of this drop impl is to advance the reader if we encounter an error to check for EOF. -// Since all errors are equal when debug_assertions is off, we don't care if the error is EOF or not. -#[cfg(debug_assertions)] -impl<'a, R: Read> Drop for RegisterReader<'a, R> { - fn drop(&mut self) { - self.advance_reader(); - } -} - -impl<'a, R: Read> RegisterReader<'a, R> { - pub fn new(reader: &'a mut R) -> Self { - Self { - reader, - inner: Register::EMPTY, - } - } - - /// Only advances the reader. Doesn't refill the buffer. - #[inline(always)] - pub fn advance_reader(&mut self) { - debug_assert!( - self.inner.index <= 64, - "too many bits read from RegisterBuffer" - ); - self.reader.advance(self.inner.index); - self.inner = Register::EMPTY; - } - - /// Advances the reader and refills the buffer. - #[inline(always)] - pub fn refill(&mut self) -> Result<()> { - self.advance_reader(); - self.inner.value = self.reader.peek_bits()?; - self.inner.index = 0; - Ok(()) - } -} - -/// The inner part of [`RegisterWriter`] or [`RegisterReader`]. Allows recursive types to compile -/// because their reader's type doesn't depend on their input reader's type. -#[derive(Copy, Clone)] -pub struct Register { - value: Word, - index: usize, -} - -impl Register { - const EMPTY: Self = Self { value: 0, index: 0 }; -} - -impl Write for Register { - type Revert = (); - fn get_revert(&mut self) -> Self::Revert { - unimplemented!() - } - fn revert(&mut self, _: Self::Revert) { - unimplemented!() - } - - #[inline(always)] - fn write_bit(&mut self, v: bool) { - self.write_bits(v as Word, 1); - } - - #[inline(always)] - fn write_bits(&mut self, word: Word, bits: usize) { - self.value |= word << self.index; - self.index += bits; - } - - fn write_bytes(&mut self, _: &[u8]) { - unimplemented!() - } - - fn write_encoded_bytes(&mut self, _: &[u8]) -> bool { - unimplemented!() - } -} - -impl Read for Register { - #[inline(always)] - fn advance(&mut self, bits: usize) { - self.index += bits; - } - - #[inline(always)] - fn peek_bits(&mut self) -> Result { - debug_assert!(self.index < 64); - let v = self.value >> self.index; - Ok(v) - } - - #[inline(always)] - fn read_bit(&mut self) -> Result { - Ok(self.read_bits(1)? != 0) - } - - #[inline(always)] - fn read_bits(&mut self, bits: usize) -> Result { - let v = self.peek_bits()? & (Word::MAX >> (WORD_BITS - bits)); - self.advance(bits); - Ok(v) - } - - fn read_bytes(&mut self, _: NonZeroUsize) -> Result<&[u8]> { - unimplemented!() - } - - fn read_encoded_bytes(&mut self, _: NonZeroUsize) -> Result<&[u8]> { - unimplemented!() - } - - fn reserve_bits(&self, _: usize) -> Result<()> { - unimplemented!() - } -} diff --git a/src/serde/de.rs b/src/serde/de.rs index f8e4bfb..23913c2 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -1,244 +1,424 @@ -use crate::buffer::BufferTrait; -use crate::encoding::{Encoding, Fixed, Gamma}; -use crate::guard::guard_zst; -use crate::read::Read; -use crate::{Decode, Error, Result, E}; +use crate::bool::BoolDecoder; +use crate::coder::{Decoder, Result, View}; +use crate::consume::expect_eof; +use crate::error::{err, error, Error}; +use crate::int::IntDecoder; +use crate::length::LengthDecoder; +use crate::serde::guard::guard_zst; +use crate::serde::variant::VariantDecoder; +use crate::serde::{default_box_slice, get_mut_or_resize, type_changed}; +use crate::str::StrDecoder; use serde::de::{ - DeserializeOwned, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, - VariantAccess, Visitor, + DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess, Visitor, }; -use serde::Deserializer; -use std::num::NonZeroUsize; - -pub fn deserialize_internal( - buffer: &mut B, - bytes: &[u8], -) -> Result { - let (mut reader, context) = buffer.start_read(bytes); - let decode_result = deserialize_compat(Fixed, &mut reader); - B::finish_read_with_result(reader, context, decode_result) +use serde::{Deserialize, Deserializer}; + +// Redefine Result from crate::coder::Result to std::result::Result since the former isn't public. +mod inner { + use super::*; + use std::result::Result; + + /// Deserializes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Deserialize`]. + /// + /// **Warning:** The format is incompatible with [`encode`][`crate::encode`] and subject to + /// change between versions. + #[cfg_attr(doc, doc(cfg(feature = "serde")))] + pub fn deserialize<'de, T: Deserialize<'de>>(mut bytes: &'de [u8]) -> Result { + let mut decoder = SerdeDecoder::Unspecified2 { length: 1 }; + let t = T::deserialize(DecoderWrapper { + decoder: &mut decoder, + input: &mut bytes, + })?; + expect_eof(bytes)?; + Ok(t) + } } - -pub fn deserialize_compat( - encoding: impl Encoding, - reader: &mut impl Read, -) -> Result { - T::deserialize(BitcodeDeserializer { encoding, reader }) +pub use inner::deserialize; + +#[derive(Debug)] +enum SerdeDecoder<'a> { + Bool(BoolDecoder<'a>), + Enum(Box<(VariantDecoder<'a>, Vec>)>), // (variants, values) TODO only 1 allocation? + Map(Box<(LengthDecoder<'a>, (SerdeDecoder<'a>, SerdeDecoder<'a>))>), // (lengths, (keys, values)) + Seq(Box<(LengthDecoder<'a>, SerdeDecoder<'a>)>), // (lengths, values) + Str(StrDecoder<'a>), + Tuple(Box<[SerdeDecoder<'a>]>), // [field0, field1, ..] + U8(IntDecoder<'a, u8>), + U16(IntDecoder<'a, u16>), + U32(IntDecoder<'a, u32>), + U64(IntDecoder<'a, u64>), + U128(IntDecoder<'a, u128>), + Unspecified, + Unspecified2 { length: usize }, } -struct BitcodeDeserializer<'a, C, R> { - encoding: C, - reader: &'a mut R, +impl Default for SerdeDecoder<'_> { + fn default() -> Self { + Self::Unspecified + } } -macro_rules! reborrow { - ($e:expr) => { - BitcodeDeserializer { - encoding: $e.encoding, - reader: &mut *$e.reader, +impl<'a> View<'a> for SerdeDecoder<'a> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + match self { + Self::Bool(d) => d.populate(input, length), + Self::Enum(d) => { + d.0.populate(input, length)?; + if let Some(max_variant_index) = d.0.max_variant_index() { + get_mut_or_resize(&mut d.1, max_variant_index as usize); + d.1.iter_mut() + .enumerate() + .try_for_each(|(i, variant)| variant.populate(input, d.0.length(i as u8))) + } else { + Ok(()) + } + } + Self::Map(d) => { + d.0.populate(input, length)?; + let length = d.0.length(); + d.1 .0.populate(input, length)?; + d.1 .1.populate(input, length) + } + Self::Seq(d) => { + d.0.populate(input, length)?; + let length = d.0.length(); + d.1.populate(input, length) + } + Self::Str(d) => d.populate(input, length), + Self::Tuple(d) => d.iter_mut().try_for_each(|d| d.populate(input, length)), + Self::U8(d) => d.populate(input, length), + Self::U16(d) => d.populate(input, length), + Self::U32(d) => d.populate(input, length), + Self::U64(d) => d.populate(input, length), + Self::U128(d) => d.populate(input, length), + Self::Unspecified => { + *self = Self::Unspecified2 { length }; + Ok(()) + } + Self::Unspecified2 { .. } => unreachable!(), // TODO } } } -impl BitcodeDeserializer<'_, C, R> { - fn read_len(self) -> Result { - usize::decode(Gamma, self.reader) - } +struct DecoderWrapper<'a, 'de> { + decoder: &'a mut SerdeDecoder<'de>, + input: &'a mut &'de [u8], +} - fn read_variant_index(self) -> Result { - u32::decode(Gamma, self.reader).map_err(|e| e.map_invalid("variant index")) - } +macro_rules! specify { + ($self:ident, $variant:ident) => { + match &mut *$self.decoder { + &mut SerdeDecoder::Unspecified2 { length } => { + #[cold] + fn cold(me: &mut DecoderWrapper, length: usize) -> Result<()> { + *me.decoder = SerdeDecoder::$variant(Default::default()); + me.decoder.populate(me.input, length) + } + cold(&mut $self, length)?; + } + _ => (), + } + }; } macro_rules! impl_de { - ($name:ident, $visit:ident) => { - fn $name(self, visitor: V) -> Result + ($deserialize:ident, $visit:ident, $t:ty, $variant:ident) => { + #[inline(always)] + fn $deserialize(mut self, v: V) -> Result where V: Visitor<'de>, { - visitor.$visit(Decode::decode(self.encoding, self.reader)?) + v.$visit({ + specify!(self, $variant); + match &mut *self.decoder { + SerdeDecoder::$variant(d) => d.decode(), + _ => return type_changed(), + } + }) } }; } -impl<'de, C: Encoding, R: Read> Deserializer<'de> for BitcodeDeserializer<'_, C, R> { +impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { type Error = Error; - fn deserialize_any(self, _visitor: V) -> Result + fn deserialize_any(self, _: V) -> Result where V: Visitor<'de>, { - Err(E::NotSupported("deserialize_any").e()) + err("deserialize_any is not supported") } - impl_de!(deserialize_bool, visit_bool); - impl_de!(deserialize_i8, visit_i8); - impl_de!(deserialize_i16, visit_i16); - impl_de!(deserialize_i32, visit_i32); - impl_de!(deserialize_i64, visit_i64); - impl_de!(deserialize_i128, visit_i128); - impl_de!(deserialize_u8, visit_u8); - impl_de!(deserialize_u16, visit_u16); - impl_de!(deserialize_u32, visit_u32); - impl_de!(deserialize_u64, visit_u64); - impl_de!(deserialize_u128, visit_u128); - impl_de!(deserialize_f32, visit_f32); - impl_de!(deserialize_f64, visit_f64); - impl_de!(deserialize_char, visit_char); - impl_de!(deserialize_string, visit_string); - - #[inline(always)] // Makes #[bitcode(with_serde)] ArrayString much faster. - fn deserialize_str(self, visitor: V) -> Result + // Use native decoders. + impl_de!(deserialize_bool, visit_bool, bool, Bool); + impl_de!(deserialize_u8, visit_u8, u8, U8); + impl_de!(deserialize_u16, visit_u16, u16, U16); + impl_de!(deserialize_u32, visit_u32, u32, U32); + impl_de!(deserialize_u64, visit_u64, u64, U64); + impl_de!(deserialize_u128, visit_u128, u128, U128); + impl_de!(deserialize_str, visit_borrowed_str, &str, Str); + + // IntDecoder works on signed integers/floats (but not chars). + impl_de!(deserialize_i8, visit_i8, i8, U8); + impl_de!(deserialize_i16, visit_i16, i16, U16); + impl_de!(deserialize_i32, visit_i32, i32, U32); + impl_de!(deserialize_i64, visit_i64, i64, U64); + impl_de!(deserialize_i128, visit_i128, i128, U128); + impl_de!(deserialize_f32, visit_f32, f32, U32); + impl_de!(deserialize_f64, visit_f64, f64, U64); + + #[inline(always)] + fn deserialize_char(self, v: V) -> Result where V: Visitor<'de>, { - visitor.visit_str(self.encoding.read_str(self.reader)?) + v.visit_char(char::from_u32(u32::deserialize(self)?).ok_or_else(|| error("invalid char"))?) } - fn deserialize_bytes(self, visitor: V) -> Result + fn deserialize_string(self, v: V) -> Result where V: Visitor<'de>, { - let len = reborrow!(self).read_len()?; - let bytes = if let Some(len) = NonZeroUsize::new(len) { - self.reader.read_bytes(len)? - } else { - &[] - }; - - visitor.visit_bytes(bytes) + self.deserialize_str(v) } - fn deserialize_byte_buf(self, visitor: V) -> Result + #[inline(always)] + fn deserialize_bytes(self, v: V) -> Result where V: Visitor<'de>, { - self.deserialize_bytes(visitor) + self.deserialize_byte_buf(v) // TODO avoid allocation. } - fn deserialize_option(self, visitor: V) -> Result + fn deserialize_byte_buf(self, v: V) -> Result where V: Visitor<'de>, { - if self.reader.read_bit()? { - visitor.visit_some(self) - } else { - visitor.visit_none() - } + v.visit_byte_buf(>::deserialize(self)?) } - fn deserialize_unit(self, visitor: V) -> Result + #[inline(always)] + fn deserialize_option(mut self, v: V) -> Result where V: Visitor<'de>, { - visitor.visit_unit() + specify!(self, Enum); + let (decoder, variant_index) = match &mut *self.decoder { + SerdeDecoder::Enum(b) => { + let variant_index = b.0.decode(); + (&mut b.1[variant_index as usize], variant_index) + } + _ => return type_changed(), + }; + match variant_index { + 0 => v.visit_none(), + 1 => v.visit_some(DecoderWrapper { + decoder, + input: &mut *self.input, + }), + _ => err("invalid option"), + } } - fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result + #[inline(always)] + fn deserialize_unit(self, v: V) -> Result where V: Visitor<'de>, { - visitor.visit_unit() + v.visit_unit() } - fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + #[inline(always)] + fn deserialize_unit_struct(self, _: &'static str, v: V) -> Result where V: Visitor<'de>, { - visitor.visit_newtype_struct(self) + v.visit_unit() } - fn deserialize_seq(self, visitor: V) -> Result + #[inline(always)] + fn deserialize_newtype_struct(self, _: &'static str, v: V) -> Result where V: Visitor<'de>, { - let len = reborrow!(self).read_len()?; - self.deserialize_tuple(len, visitor) + v.visit_newtype_struct(self) } - // based on https://github.com/bincode-org/bincode/blob/c44b5e364e7084cdbabf9f94b63a3c7f32b8fb68/src/de/mod.rs#L293-L330 - fn deserialize_tuple(self, len: usize, visitor: V) -> Result + fn deserialize_seq(mut self, v: V) -> Result where V: Visitor<'de>, { - struct Access<'a, E, R> { - deserializer: BitcodeDeserializer<'a, E, R>, + specify!(self, Seq); + let (decoder, len) = match &mut *self.decoder { + SerdeDecoder::Seq(b) => { + let len = b.0.decode(); + (&mut b.1, len) + } + _ => return type_changed(), + }; + + struct Access<'a, 'de> { + wrapper: DecoderWrapper<'a, 'de>, len: usize, } - - impl<'de, C: Encoding, R: Read> SeqAccess<'de> for Access<'_, C, R> { + impl<'de> SeqAccess<'de> for Access<'_, 'de> { type Error = Error; + #[inline(always)] fn next_element_seed(&mut self, seed: T) -> Result> where T: DeserializeSeed<'de>, { guard_zst::(self.len)?; - if self.len > 0 { - self.len -= 1; - let value = DeserializeSeed::deserialize(seed, reborrow!(self.deserializer))?; - Ok(Some(value)) - } else { - Ok(None) - } + self.len + .checked_sub(1) + .map(|len| { + self.len = len; + DeserializeSeed::deserialize( + seed, + DecoderWrapper { + decoder: &mut *self.wrapper.decoder, + input: &mut *self.wrapper.input, + }, + ) + }) + .transpose() } fn size_hint(&self) -> Option { Some(self.len) } } - - visitor.visit_seq(Access { - deserializer: self, + v.visit_seq(Access { + wrapper: DecoderWrapper { + decoder, + input: self.input, + }, len, }) } - fn deserialize_tuple_struct( - self, - _name: &'static str, - len: usize, - visitor: V, - ) -> Result + #[inline(always)] + fn deserialize_tuple(mut self, tuple_len: usize, v: V) -> Result where V: Visitor<'de>, { - self.deserialize_tuple(len, visitor) + if let &mut SerdeDecoder::Unspecified2 { length } = &mut *self.decoder { + #[cold] + fn cold(me: &mut DecoderWrapper, length: usize, tuple_len: usize) -> Result<()> { + *me.decoder = SerdeDecoder::Tuple(default_box_slice(tuple_len)); + me.decoder.populate(me.input, length) + } + cold(&mut self, length, tuple_len)?; + } + let decoders = match &mut *self.decoder { + SerdeDecoder::Tuple(d) => &mut **d, + _ => return type_changed(), + }; + assert_eq!(decoders.len(), tuple_len); // Removes multiple bounds checks. + + struct Access<'a, 'de> { + decoders: &'a mut [SerdeDecoder<'de>], + input: &'a mut &'de [u8], + index: usize, + } + impl<'de> SeqAccess<'de> for Access<'_, 'de> { + type Error = Error; + + #[inline(always)] + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de>, + { + guard_zst::(self.decoders.len())?; + self.decoders + .get_mut(self.index) + .map(|decoder| { + self.index += 1; + DeserializeSeed::deserialize( + seed, + DecoderWrapper { + decoder, + input: &mut *self.input, + }, + ) + }) + .transpose() + } + + fn size_hint(&self) -> Option { + Some(self.decoders.len()) + } + } + v.visit_seq(Access { + decoders, + input: &mut *self.input, + index: 0, + }) + } + + #[inline(always)] + fn deserialize_tuple_struct(self, _: &'static str, len: usize, v: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_tuple(len, v) } - // based on https://github.com/bincode-org/bincode/blob/c44b5e364e7084cdbabf9f94b63a3c7f32b8fb68/src/de/mod.rs#L353-L400 - fn deserialize_map(self, visitor: V) -> Result + fn deserialize_map(mut self, v: V) -> Result where V: Visitor<'de>, { - struct Access<'a, E, R> { - deserializer: BitcodeDeserializer<'a, E, R>, + specify!(self, Map); + let (decoders, len) = match &mut *self.decoder { + SerdeDecoder::Map(b) => { + let len = b.0.decode(); + (&mut b.1, len) + } + _ => return type_changed(), + }; + struct Access<'a, 'de> { + decoders: &'a mut (SerdeDecoder<'de>, SerdeDecoder<'de>), + input: &'a mut &'de [u8], len: usize, } - impl<'de, C: Encoding, R: Read> MapAccess<'de> for Access<'_, C, R> { + impl<'de> MapAccess<'de> for Access<'_, 'de> { type Error = Error; + #[inline(always)] fn next_key_seed(&mut self, seed: K) -> Result> where K: DeserializeSeed<'de>, { guard_zst::(self.len)?; - if self.len > 0 { - self.len -= 1; - let key = DeserializeSeed::deserialize(seed, reborrow!(self.deserializer))?; - Ok(Some(key)) - } else { - Ok(None) - } + self.len + .checked_sub(1) + .map(|len| { + self.len = len; + DeserializeSeed::deserialize( + seed, + DecoderWrapper { + decoder: &mut self.decoders.0, + input: &mut *self.input, + }, + ) + }) + .transpose() } + #[inline(always)] fn next_value_seed(&mut self, seed: V) -> Result where V: DeserializeSeed<'de>, { - let value = DeserializeSeed::deserialize(seed, reborrow!(self.deserializer))?; - Ok(value) + DeserializeSeed::deserialize( + seed, + DecoderWrapper { + decoder: &mut self.decoders.1, + input: &mut *self.input, + }, + ) } fn size_hint(&self) -> Option { @@ -246,64 +426,51 @@ impl<'de, C: Encoding, R: Read> Deserializer<'de> for BitcodeDeserializer<'_, C, } } - let len = reborrow!(self).read_len()?; - visitor.visit_map(Access { - deserializer: self, + v.visit_map(Access { + decoders, + input: self.input, len, }) } + #[inline(always)] fn deserialize_struct( self, - _name: &'static str, + _: &'static str, fields: &'static [&'static str], - visitor: V, + v: V, ) -> Result where V: Visitor<'de>, { - self.deserialize_tuple(fields.len(), visitor) + self.deserialize_tuple(fields.len(), v) } - // based on https://github.com/bincode-org/bincode/blob/c44b5e364e7084cdbabf9f94b63a3c7f32b8fb68/src/de/mod.rs#L263-L291 + #[inline(always)] fn deserialize_enum( self, - _name: &'static str, - _variants: &'static [&'static str], - visitor: V, + _: &'static str, + _: &'static [&'static str], + v: V, ) -> Result where V: Visitor<'de>, { - impl<'a, 'de, C: Encoding, R: Read> EnumAccess<'de> for BitcodeDeserializer<'a, C, R> { - type Error = Error; - type Variant = BitcodeDeserializer<'a, C, R>; - - fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> - where - V: DeserializeSeed<'de>, - { - let idx = reborrow!(self).read_variant_index()?; - let val: Result<_> = seed.deserialize(idx.into_deserializer()); - Ok((val?, reborrow!(self))) - } - } - - visitor.visit_enum(self) + v.visit_enum(self) } - fn deserialize_identifier(self, _visitor: V) -> Result + fn deserialize_identifier(self, _: V) -> Result where V: Visitor<'de>, { - Err(E::NotSupported("deserialize_identifier").e()) + err("deserialize_identifier is not supported") } - fn deserialize_ignored_any(self, _visitor: V) -> Result + fn deserialize_ignored_any(self, _: V) -> Result where V: Visitor<'de>, { - Err(E::NotSupported("deserialize_ignored_any").e()) + err("deserialize_ignored_any is not supported") } fn is_human_readable(&self) -> bool { @@ -311,8 +478,34 @@ impl<'de, C: Encoding, R: Read> Deserializer<'de> for BitcodeDeserializer<'_, C, } } -// based on https://github.com/bincode-org/bincode/blob/c44b5e364e7084cdbabf9f94b63a3c7f32b8fb68/src/de/mod.rs#L461-L492 -impl<'de, C: Encoding, R: Read> VariantAccess<'de> for BitcodeDeserializer<'_, C, R> { +impl<'a, 'de> EnumAccess<'de> for DecoderWrapper<'a, 'de> { + type Error = Error; + type Variant = DecoderWrapper<'a, 'de>; + + fn variant_seed(mut self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: DeserializeSeed<'de>, + { + specify!(self, Enum); + let (decoder, variant_index) = match &mut *self.decoder { + SerdeDecoder::Enum(b) => { + let variant_index = b.0.decode(); + (&mut b.1[variant_index as usize], variant_index as u32) + } + _ => return type_changed(), + }; + let val: Result<_> = seed.deserialize(variant_index.into_deserializer()); + Ok(( + val?, + DecoderWrapper { + decoder, + input: &mut *self.input, + }, + )) + } +} + +impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> { type Error = Error; fn unit_variant(self) -> Result<()> { @@ -323,20 +516,76 @@ impl<'de, C: Encoding, R: Read> VariantAccess<'de> for BitcodeDeserializer<'_, C where T: DeserializeSeed<'de>, { - DeserializeSeed::deserialize(seed, self) + seed.deserialize(self) } - fn tuple_variant(self, len: usize, visitor: V) -> Result + fn tuple_variant(self, len: usize, v: V) -> Result where V: Visitor<'de>, { - Deserializer::deserialize_tuple(self, len, visitor) + self.deserialize_tuple(len, v) } - fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result + fn struct_variant(self, fields: &'static [&'static str], v: V) -> Result where V: Visitor<'de>, { - Deserializer::deserialize_tuple(self, fields.len(), visitor) + self.deserialize_tuple(fields.len(), v) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + #[test] + fn deserialize() { + macro_rules! test { + ($v:expr, $t:ty) => { + let ser = crate::serialize::<$t>(&$v).unwrap(); + println!("{:<24} {ser:?}", stringify!($t)); + assert_eq!($v, crate::deserialize::<$t>(&ser).unwrap()); + }; + } + // Primitives + test!(5, u8); + test!(5, u16); + test!(5, u32); + test!(5, u64); + test!(5, u128); + test!(5, i8); + test!(5, i16); + test!(5, i32); + test!(5, i64); + test!(5, i128); + test!(true, bool); + test!('a', char); + + // Enums + test!(Some(true), Option); + test!(Ok(true), Result); + test!(vec![Ok(true), Err(2)], Vec>); + test!(vec![Err(1), Ok(false)], Vec>); + + // Maps + let mut map = BTreeMap::new(); + map.insert(1u8, 11u8); + map.insert(2u8, 22u8); + test!(map, BTreeMap); + + // Sequences + test!("abc".to_owned(), String); + test!(vec![1u8, 2u8, 3u8], Vec); + test!( + vec!["abc".to_owned(), "def".to_owned(), "ghi".to_owned()], + Vec + ); + + // Tuples + test!((1u8, 2u8, 3u8), (u8, u8, u8)); + test!([1u8, 2u8, 3u8], [u8; 3]); + + // Complex. + test!(vec![(None, 3), (Some(4), 5)], Vec<(Option, u8)>); } } diff --git a/src/serde/guard.rs b/src/serde/guard.rs new file mode 100644 index 0000000..73fa747 --- /dev/null +++ b/src/serde/guard.rs @@ -0,0 +1,22 @@ +use crate::coder::Result; +use crate::error::err; + +pub const ZST_LIMIT: usize = 1 << 16; + +fn check_zst_len(len: usize) -> Result<()> { + if len > ZST_LIMIT { + err("too many zero sized types") + } else { + Ok(()) + } +} + +// Used by deserialize. Guards against Vec<()> with huge len taking forever. +#[inline] +pub fn guard_zst(len: usize) -> Result<()> { + if std::mem::size_of::() == 0 { + check_zst_len(len) + } else { + Ok(()) + } +} diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 196745c..577cec5 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -1,81 +1,52 @@ -use crate::{Buffer, Error, Result}; -use serde::{de::DeserializeOwned, Serialize}; +use crate::error::{err, error_from_display, Error}; use std::fmt::Display; -pub(crate) mod de; -pub(crate) mod ser; +mod de; +mod guard; +mod ser; +mod variant; -/// Serializes a `T:` [`Serialize`] into a [`Vec`]. -/// -/// **Warning:** The format is incompatible with [`decode`][`crate::decode`] and subject to change between versions. -#[cfg_attr(doc, doc(cfg(feature = "serde")))] -pub fn serialize(t: &T) -> Result> -where - T: Serialize, -{ - Ok(Buffer::new().serialize(t)?.to_vec()) -} +pub use de::*; +pub use ser::*; -/// Deserializes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Deserialize`][`serde::Deserialize`]. -/// -/// **Warning:** The format is incompatible with [`encode`][`crate::encode`] and subject to change between versions. -#[cfg_attr(doc, doc(cfg(feature = "serde")))] -pub fn deserialize(bytes: &[u8]) -> Result -where - T: DeserializeOwned, -{ - Buffer::new().deserialize(bytes) +fn type_changed() -> Result { + err("type changed") } -impl Buffer { - /// Serializes a `T:` [`Serialize`] into a [`&[u8]`][`prim@slice`]. Can reuse the buffer's - /// allocations. - /// - /// Even if you call `to_vec` on the [`&[u8]`][`prim@slice`], it's still more efficient than - /// [`serialize`]. - /// - /// **Warning:** The format is incompatible with [`decode`][`Buffer::decode`] and subject to change between versions. - #[cfg_attr(doc, doc(cfg(feature = "serde")))] - pub fn serialize(&mut self, t: &T) -> Result<&[u8]> - where - T: Serialize, - { - ser::serialize_internal(&mut self.0, t) - } +fn default_box_slice(len: usize) -> Box<[T]> { + let mut vec = vec![]; + vec.resize_with(len, Default::default); + vec.into() +} - /// Deserializes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Deserialize`][`serde::Deserialize`]. Can reuse - /// the buffer's allocations. - /// - /// **Warning:** The format is incompatible with [`encode`][`Buffer::encode`] and subject to change between versions. - #[cfg_attr(doc, doc(cfg(feature = "serde")))] - pub fn deserialize(&mut self, bytes: &[u8]) -> Result - where - T: DeserializeOwned, - { - de::deserialize_internal(&mut self.0, bytes) +#[inline(always)] +fn get_mut_or_resize(vec: &mut Vec, index: usize) -> &mut T { + if index >= vec.len() { + #[cold] + #[inline(never)] + fn cold(vec: &mut Vec, index: usize) { + vec.resize_with(index + 1, Default::default) + } + cold(vec, index); } + // Safety we've just resized `vec.len()` to be > than `index`. + unsafe { vec.get_unchecked_mut(index) } } impl serde::ser::Error for Error { - fn custom(_msg: T) -> Self + fn custom(t: T) -> Self where T: Display, { - #[cfg(debug_assertions)] - return Self(crate::E::Custom(_msg.to_string())); - #[cfg(not(debug_assertions))] - Self(()) + error_from_display(t) } } impl serde::de::Error for Error { - fn custom(_msg: T) -> Self + fn custom(t: T) -> Self where T: Display, { - #[cfg(debug_assertions)] - return Self(crate::E::Custom(_msg.to_string())); - #[cfg(not(debug_assertions))] - Self(()) + error_from_display(t) } } diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 7516173..2f3aaf9 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -1,109 +1,289 @@ -use crate::buffer::BufferTrait; -use crate::encoding::{Encoding, Fixed, Gamma}; -use crate::write::Write; -use crate::{Encode, Error, Result, E}; +use crate::bool::BoolEncoder; +use crate::coder::{Buffer, Encoder, Result}; +use crate::error::{err, error, Error}; +use crate::int::IntEncoder; +use crate::length::LengthEncoder; +use crate::serde::variant::VariantEncoder; +use crate::serde::{default_box_slice, get_mut_or_resize, type_changed}; +use crate::str::StrEncoder; use serde::ser::{ SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, SerializeTupleStruct, SerializeTupleVariant, }; use serde::{Serialize, Serializer}; +use std::num::NonZeroUsize; -pub fn serialize_internal<'a>( - buffer: &'a mut impl BufferTrait, - t: &(impl Serialize + ?Sized), -) -> Result<&'a [u8]> { - let mut writer = buffer.start_write(); - serialize_compat(t, Fixed, &mut writer)?; - Ok(buffer.finish_write(writer)) -} +// Redefine Result from crate::coder::Result to std::result::Result since the former isn't public. +mod inner { + use super::*; + use std::result::Result; + + /// Serializes a `T:` [`Serialize`] into a [`Vec`]. + /// + /// **Warning:** The format is incompatible with [`decode`][`crate::decode`] and subject to + /// change between versions. + #[cfg_attr(doc, doc(cfg(feature = "serde")))] + pub fn serialize(t: &T) -> Result, Error> { + let mut lazy = LazyEncoder::Unspecified { + reserved: NonZeroUsize::new(1), + }; + let mut index_alloc = 0; + t.serialize(EncoderWrapper { + lazy: &mut lazy, + index_alloc: &mut index_alloc, + })?; -pub fn serialize_compat( - t: &(impl Serialize + ?Sized), - encoding: impl Encoding, - writer: &mut impl Write, -) -> Result<()> { - t.serialize(BitcodeSerializer { encoding, writer }) + // If we just wrote out the buffers in field order we wouldn't be able to deserialize them + // since we might learn their types from serde in a different order. + // + // Consider the value: `[(vec![], 0u8), (vec![true], 1u8)]` + // We don't know that the Vec contains bool until we've already deserialized 0u8. + // Serde only tells us what is in sequences that aren't empty. + // + // Therefore, we have to reorder the buffers to match the order serde told us about them. + let mut buffers = default_box_slice(index_alloc); + lazy.reorder(&mut buffers); + + let mut bytes = vec![]; + for buffer in Vec::from(buffers).into_iter().flatten() { + buffer.collect_into(&mut bytes); + } + Ok(bytes) + } } +pub use inner::serialize; -struct BitcodeSerializer<'a, C, W> { - encoding: C, - writer: &'a mut W, +#[derive(Debug)] +enum SpecifiedEncoder { + Bool(BoolEncoder), + Enum(Box<(VariantEncoder, Vec)>), // (variants, values) TODO only 1 allocation? + Map(Box<(LengthEncoder, (LazyEncoder, LazyEncoder))>), // (lengths, (keys, values)) + Seq(Box<(LengthEncoder, LazyEncoder)>), // (lengths, values) + Str(StrEncoder), + Tuple(Box<[LazyEncoder]>), // [field0, field1, ..] + U8(IntEncoder), + U16(IntEncoder), + U32(IntEncoder), + U64(IntEncoder), + U128(IntEncoder), } -macro_rules! reborrow { - ($e:expr) => { - BitcodeSerializer { - encoding: $e.encoding, - writer: &mut *$e.writer, +impl SpecifiedEncoder { + fn reserve(&mut self, additional: NonZeroUsize) { + match self { + Self::Bool(v) => v.reserve(additional), + Self::Enum(v) => { + v.0.reserve(additional); + // We don't know the variants of the enums, so we can't reserve more. + } + Self::Map(v) => { + v.0.reserve(additional); + // We don't know the lengths of the maps, so we can't reserve more. + } + Self::Seq(v) => { + v.0.reserve(additional); + // We don't know the lengths of the sequences, so we can't reserve more. + } + Self::Str(v) => { + v.reserve(additional); + } + Self::Tuple(v) => v.iter_mut().for_each(|v| v.reserve_fast(additional.get())), + Self::U8(v) => v.reserve(additional), + Self::U16(v) => v.reserve(additional), + Self::U32(v) => v.reserve(additional), + Self::U64(v) => v.reserve(additional), + Self::U128(v) => v.reserve(additional), } } } -impl BitcodeSerializer<'_, C, W> { - fn write_len(self, len: usize) -> Result<()> { - len.encode(Gamma, self.writer) +#[derive(Debug)] +enum LazyEncoder { + Unspecified { + reserved: Option, + }, + Specified { + specified: SpecifiedEncoder, + index: usize, + }, +} + +impl Default for LazyEncoder { + fn default() -> Self { + Self::Unspecified { reserved: None } + } +} + +impl LazyEncoder { + fn reorder<'a>(&'a mut self, buffers: &mut [Option<&'a mut dyn Buffer>]) { + match self { + Self::Specified { specified, index } => { + buffers[*index] = Some(match specified { + SpecifiedEncoder::Bool(v) => v, + SpecifiedEncoder::Enum(v) => { + v.1.iter_mut().for_each(|v| v.reorder(buffers)); + &mut v.0 + } + SpecifiedEncoder::Map(v) => { + v.1 .0.reorder(buffers); + v.1 .1.reorder(buffers); + &mut v.0 + } + SpecifiedEncoder::Seq(v) => { + v.1.reorder(buffers); + &mut v.0 + } + SpecifiedEncoder::Str(v) => v, + SpecifiedEncoder::Tuple(v) => { + v.iter_mut().for_each(|v| v.reorder(buffers)); + return; // Has no buffer. + } + SpecifiedEncoder::U8(v) => v, + SpecifiedEncoder::U16(v) => v, + SpecifiedEncoder::U32(v) => v, + SpecifiedEncoder::U64(v) => v, + SpecifiedEncoder::U128(v) => v, + }) + } + Self::Unspecified { .. } => (), + } + } + + /// OLD COMMENT: + /// Only reserves if the type is unspecified to save time. Speeds up large 1 time collections + /// without slowing down many small collections too much. Takes a `usize` instead of a + /// [`NonZeroUsize`] to avoid branching on len. + /// + /// Can't be reserve_fast anymore with push_within_capacity. + fn reserve_fast(&mut self, len: usize) { + match self { + Self::Specified { specified, .. } => { + if let Some(len) = NonZeroUsize::new(len) { + specified.reserve(len); + } + } + Self::Unspecified { reserved } => *reserved = NonZeroUsize::new(len), + } } +} - fn write_variant_index(self, variant_index: u32) -> Result<()> { - variant_index.encode(Gamma, self.writer) +macro_rules! specify { + ($wrapper:ident, $variant:ident) => {{ + let lazy = &mut *$wrapper.lazy; + match lazy { + LazyEncoder::Unspecified { reserved } => { + let reserved = *reserved; + #[cold] + fn cold<'a>( + me: &'a mut LazyEncoder, + index_alloc: &mut usize, + reserved: Option, + ) -> &'a mut SpecifiedEncoder { + let mut specified = SpecifiedEncoder::$variant(Default::default()); + if let Some(reserved) = reserved { + specified.reserve(reserved); + } + *me = LazyEncoder::Specified { + specified, + index: std::mem::replace(index_alloc, *index_alloc + 1), + }; + // TODO might be slower to put in cold fn. + if let LazyEncoder::Specified { specified, .. } = me { + specified + } else { + unreachable!(); + } + } + cold(lazy, &mut *$wrapper.index_alloc, reserved) + } + LazyEncoder::Specified { specified, .. } => specified, + } + }}; +} + +struct EncoderWrapper<'a> { + lazy: &'a mut LazyEncoder, + index_alloc: &'a mut usize, +} + +impl<'a> EncoderWrapper<'a> { + fn serialize_enum(self, variant_index: u32) -> Result> { + let variant_index = variant_index + .try_into() + .map_err(|_| error("enums with more than 256 variants are unsupported"))?; + match specify!(self, Enum) { + SpecifiedEncoder::Enum(b) => { + b.0.encode(&variant_index); + let lazy = get_mut_or_resize(&mut b.1, variant_index as usize); + lazy.reserve_fast(1); // TODO use push instead. + Ok(Self { + lazy, + index_alloc: self.index_alloc, + }) + } + _ => type_changed(), + } } } macro_rules! impl_ser { - ($name:ident, $a:ty) => { - #[inline(always)] - fn $name(self, v: $a) -> Result { - v.encode(self.encoding, self.writer) + ($name:ident, $t:ty, $variant:ident) => { + fn $name(self, v: $t) -> Result<()> { + match specify!(self, $variant) { + SpecifiedEncoder::$variant(b) => b.encode(&v), + _ => return type_changed(), + } + Ok(()) } }; } -impl Serializer for BitcodeSerializer<'_, C, W> { +impl<'a> Serializer for EncoderWrapper<'a> { type Ok = (); type Error = Error; - type SerializeSeq = Self; - type SerializeTuple = Self; - type SerializeTupleStruct = Self; - type SerializeTupleVariant = Self; - type SerializeMap = Self; - type SerializeStruct = Self; - type SerializeStructVariant = Self; - - impl_ser!(serialize_bool, bool); - impl_ser!(serialize_i8, i8); - impl_ser!(serialize_i16, i16); - impl_ser!(serialize_i32, i32); - impl_ser!(serialize_i64, i64); - impl_ser!(serialize_i128, i128); - impl_ser!(serialize_u8, u8); - impl_ser!(serialize_u16, u16); - impl_ser!(serialize_u32, u32); - impl_ser!(serialize_u64, u64); - impl_ser!(serialize_u128, u128); - impl_ser!(serialize_f32, f32); - impl_ser!(serialize_f64, f64); - impl_ser!(serialize_char, char); - impl_ser!(serialize_str, &str); + type SerializeSeq = EncoderWrapper<'a>; + type SerializeTuple = TupleSerializer<'a>; + type SerializeTupleStruct = TupleSerializer<'a>; + type SerializeTupleVariant = TupleSerializer<'a>; + type SerializeMap = MapSerializer<'a>; + type SerializeStruct = TupleSerializer<'a>; + type SerializeStructVariant = TupleSerializer<'a>; + + // Use native encoders. + impl_ser!(serialize_bool, bool, Bool); + impl_ser!(serialize_u8, u8, U8); + impl_ser!(serialize_u16, u16, U16); + impl_ser!(serialize_u32, u32, U32); + impl_ser!(serialize_u64, u64, U64); + impl_ser!(serialize_u128, u128, U128); + impl_ser!(serialize_str, &str, Str); + + // IntEncoder works on signed integers/floats/char. + impl_ser!(serialize_i8, i8, U8); + impl_ser!(serialize_i16, i16, U16); + impl_ser!(serialize_i32, i32, U32); + impl_ser!(serialize_i64, i64, U64); + impl_ser!(serialize_i128, i128, U128); + impl_ser!(serialize_f32, f32, U32); + impl_ser!(serialize_f64, f64, U64); + impl_ser!(serialize_char, char, U32); fn serialize_bytes(self, v: &[u8]) -> Result { - reborrow!(self).write_len(v.len())?; - self.writer.write_bytes(v); - Ok(()) + v.serialize(self) } #[inline(always)] fn serialize_none(self) -> Result { - self.writer.write_false(); + self.serialize_enum(0)?; Ok(()) } #[inline(always)] - fn serialize_some(self, value: &T) -> Result + fn serialize_some(self, v: &T) -> Result where T: Serialize, { - self.writer.write_bit(true); - value.serialize(self) + v.serialize(self.serialize_enum(1)?) } fn serialize_unit(self) -> Result { @@ -120,7 +300,8 @@ impl Serializer for BitcodeSerializer<'_, C, W> { variant_index: u32, _variant: &'static str, ) -> Result { - self.write_variant_index(variant_index) + self.serialize_enum(variant_index)?; + Ok(()) } fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result @@ -140,29 +321,73 @@ impl Serializer for BitcodeSerializer<'_, C, W> { where T: Serialize, { - reborrow!(self).write_variant_index(variant_index)?; - value.serialize(self) + value.serialize(self.serialize_enum(variant_index)?) } #[inline(always)] fn serialize_seq(self, len: Option) -> Result { let len = len.expect("sequence must have len"); - reborrow!(self).write_len(len)?; - Ok(self) + match specify!(self, Seq) { + SpecifiedEncoder::Seq(b) => { + b.0.encode(&len); + b.1.reserve_fast(len); + Ok(Self { + lazy: &mut b.1, + index_alloc: self.index_alloc, + }) + } + _ => type_changed(), + } } #[inline(always)] - fn serialize_tuple(self, _len: usize) -> Result { - Ok(self) + fn serialize_tuple(self, len: usize) -> Result { + let lazy = &mut *self.lazy; + let specified = match lazy { + &mut LazyEncoder::Unspecified { reserved } => { + #[cold] + fn cold( + me: &mut LazyEncoder, + reserved: Option, + len: usize, + ) -> &mut SpecifiedEncoder { + let mut specified = SpecifiedEncoder::Tuple(default_box_slice(len)); + if let Some(reserved) = reserved { + specified.reserve(reserved); + } + *me = LazyEncoder::Specified { + specified, + index: usize::MAX, // We never use this. + }; + // TODO might be slower to put in cold fn. + let LazyEncoder::Specified { specified: encoder, .. } = me else { + unreachable!(); + }; + encoder + } + cold(lazy, reserved, len) + } + LazyEncoder::Specified { specified, .. } => specified, + }; + match specified { + SpecifiedEncoder::Tuple(encoders) => { + assert_eq!(encoders.len(), len); // Removes multiple bounds checks. + Ok(TupleSerializer { + encoders, + index_alloc: self.index_alloc, + }) + } + _ => type_changed(), + } } #[inline(always)] fn serialize_tuple_struct( self, _name: &'static str, - _len: usize, + len: usize, ) -> Result { - Ok(self) + self.serialize_tuple(len) } #[inline(always)] @@ -171,22 +396,31 @@ impl Serializer for BitcodeSerializer<'_, C, W> { _name: &'static str, variant_index: u32, _variant: &'static str, - _len: usize, + len: usize, ) -> Result { - reborrow!(self).write_variant_index(variant_index)?; - Ok(self) + self.serialize_enum(variant_index)?.serialize_tuple(len) } #[inline(always)] fn serialize_map(self, len: Option) -> Result { let len = len.expect("sequence must have len"); - reborrow!(self).write_len(len)?; - Ok(self) + match specify!(self, Map) { + SpecifiedEncoder::Map(b) => { + b.0.encode(&len); + b.1 .0.reserve_fast(len); + b.1 .1.reserve_fast(len); + Ok(MapSerializer { + encoders: &mut b.1, + index_alloc: self.index_alloc, + }) + } + _ => type_changed(), + } } #[inline(always)] - fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - Ok(self) + fn serialize_struct(self, _name: &'static str, len: usize) -> Result { + self.serialize_tuple(len) } #[inline(always)] @@ -195,10 +429,9 @@ impl Serializer for BitcodeSerializer<'_, C, W> { _name: &'static str, variant_index: u32, _variant: &'static str, - _len: usize, + len: usize, ) -> Result { - reborrow!(self).write_variant_index(variant_index)?; - Ok(self) + self.serialize_enum(variant_index)?.serialize_tuple(len) } #[inline(always)] @@ -217,34 +450,62 @@ macro_rules! ok_error_end { }; } -macro_rules! impl_seq { +impl SerializeSeq for EncoderWrapper<'_> { + ok_error_end!(); + fn serialize_element(&mut self, value: &T) -> Result<()> { + value.serialize(EncoderWrapper { + lazy: &mut *self.lazy, + index_alloc: &mut *self.index_alloc, + }) + } +} + +struct TupleSerializer<'a> { + encoders: &'a mut [LazyEncoder], // [field0, field1, ..] + index_alloc: &'a mut usize, +} + +macro_rules! impl_tuple { ($tr:ty, $fun:ident) => { - impl $tr for BitcodeSerializer<'_, C, W> { + impl $tr for TupleSerializer<'_> { ok_error_end!(); fn $fun(&mut self, value: &T) -> Result<()> { - value.serialize(reborrow!(self)) + let (lazy, remaining) = std::mem::take(&mut self.encoders) + .split_first_mut() + .expect("length mismatch"); + self.encoders = remaining; + value.serialize(EncoderWrapper { + lazy, + index_alloc: &mut *self.index_alloc, + }) } } }; } -impl_seq!(SerializeSeq, serialize_element); -impl_seq!(SerializeTuple, serialize_element); -impl_seq!(SerializeTupleStruct, serialize_field); -impl_seq!(SerializeTupleVariant, serialize_field); +impl_tuple!(SerializeTuple, serialize_element); +impl_tuple!(SerializeTupleStruct, serialize_field); +impl_tuple!(SerializeTupleVariant, serialize_field); macro_rules! impl_struct { ($tr:ty) => { - impl $tr for BitcodeSerializer<'_, C, W> { + impl $tr for TupleSerializer<'_> { ok_error_end!(); fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<()> where T: Serialize, { - value.serialize(reborrow!(self)) + let (lazy, remaining) = std::mem::take(&mut self.encoders) + .split_first_mut() + .expect("length mismatch"); + self.encoders = remaining; + value.serialize(EncoderWrapper { + lazy, + index_alloc: &mut *self.index_alloc, + }) } fn skip_field(&mut self, _key: &'static str) -> Result<()> { - Err(E::NotSupported("skip_field").e()) + err("skip field is not supported") } } }; @@ -252,19 +513,55 @@ macro_rules! impl_struct { impl_struct!(SerializeStruct); impl_struct!(SerializeStructVariant); -impl SerializeMap for BitcodeSerializer<'_, C, W> { +struct MapSerializer<'a> { + encoders: &'a mut (LazyEncoder, LazyEncoder), // (keys, values) + index_alloc: &'a mut usize, +} + +impl SerializeMap for MapSerializer<'_> { ok_error_end!(); fn serialize_key(&mut self, key: &T) -> Result<()> where T: Serialize, { - key.serialize(reborrow!(self)) + key.serialize(EncoderWrapper { + lazy: &mut self.encoders.0, + index_alloc: &mut *self.index_alloc, + }) } fn serialize_value(&mut self, value: &T) -> Result<()> where T: Serialize, { - value.serialize(reborrow!(self)) + value.serialize(EncoderWrapper { + lazy: &mut self.encoders.1, + index_alloc: &mut *self.index_alloc, + }) + } +} + +#[cfg(test)] +mod tests { + #[test] + fn enum_256_variants() { + enum Enum { + A, + B, + } + impl serde::Serialize for Enum { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let variant_index = match self { + Self::A => 255, + Self::B => 256, + }; + serializer.serialize_unit_variant("", variant_index, "") + } + } + assert!(crate::serialize(&Enum::A).is_ok()); + assert!(crate::serialize(&Enum::B).is_err()); } } diff --git a/src/serde/variant.rs b/src/serde/variant.rs new file mode 100644 index 0000000..3b4336e --- /dev/null +++ b/src/serde/variant.rs @@ -0,0 +1,71 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; +use crate::pack::{pack_bytes, unpack_bytes}; +use std::marker::PhantomData; +use std::num::NonZeroUsize; + +#[derive(Debug, Default)] +pub struct VariantEncoder { + data: VecImpl, +} + +impl Encoder for VariantEncoder { + #[inline(always)] + fn encode(&mut self, v: &u8) { + unsafe { self.data.push_unchecked(*v) }; + } +} + +impl Buffer for VariantEncoder { + fn collect_into(&mut self, out: &mut Vec) { + pack_bytes(self.data.as_mut_slice(), out); + self.data.clear(); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.data.reserve(additional.get()) + } +} + +#[derive(Debug, Default)] +pub struct VariantDecoder<'a> { + variants: CowSlice<'a, u8>, + histogram: Vec, + spooky: PhantomData<&'a ()>, +} + +impl VariantDecoder<'_> { + pub fn length(&self, variant_index: u8) -> usize { + self.histogram[variant_index as usize] + } + + /// Returns the max variant index if there were any variants. + pub fn max_variant_index(&self) -> Option { + self.histogram.len().checked_sub(1).map(|v| v as u8) + } +} + +impl<'a> View<'a> for VariantDecoder<'a> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + unpack_bytes(input, length, &mut self.variants)?; + // Safety: unpack_bytes just initialized self.variants with length of `length`. + let variants = unsafe { self.variants.as_slice(length) }; + + let histogram = crate::histogram::histogram(variants); + let len = histogram + .iter() + .copied() + .rposition(|v| v != 0) + .map(|i| i + 1) + .unwrap_or(0); + self.histogram.clear(); + self.histogram.extend_from_slice(&histogram[..len]); + Ok(()) + } +} + +impl<'a> Decoder<'a, u8> for VariantDecoder<'a> { + fn decode(&mut self) -> u8 { + unsafe { self.variants.mut_slice().next_unchecked() } + } +} diff --git a/src/str.rs b/src/str.rs new file mode 100644 index 0000000..f786d7a --- /dev/null +++ b/src/str.rs @@ -0,0 +1,216 @@ +use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::consume::consume_bytes; +use crate::derive::vec::VecEncoder; +use crate::error::err; +use crate::fast::{NextUnchecked, SliceImpl}; +use crate::length::LengthDecoder; +use crate::u8_char::U8Char; +use std::num::NonZeroUsize; +use std::str::{from_utf8, from_utf8_unchecked}; + +#[derive(Debug, Default)] +pub struct StrEncoder(pub(crate) VecEncoder); // pub(crate) for arrayvec.rs + +#[inline(always)] +fn str_as_u8_chars(s: &str) -> &[U8Char] { + bytemuck::must_cast_slice(s.as_bytes()) +} + +impl Buffer for StrEncoder { + fn collect_into(&mut self, out: &mut Vec) { + self.0.collect_into(out); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve(additional); + } +} + +impl Encoder for StrEncoder { + #[inline(always)] + fn encode(&mut self, t: &str) { + self.0.encode(str_as_u8_chars(t)); + } + + #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) { + self.0.encode_vectored(i.map(str_as_u8_chars)) + } +} + +// TODO find a way to remove this shim. +impl<'b> Encoder<&'b str> for StrEncoder { + #[inline(always)] + fn encode(&mut self, t: &&str) { + self.encode(*t); + } + + #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) + where + &'b str: 'a, + { + self.encode_vectored(i.copied()) + } +} + +impl Encoder for StrEncoder { + #[inline(always)] + fn encode(&mut self, t: &String) { + self.encode(t.as_str()); + } + + #[inline(always)] + fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) + where + String: 'a, + { + self.encode_vectored(i.map(String::as_str)) + } +} + +// Doesn't use VecDecoder because can't decode &[u8]. +#[derive(Debug, Default)] +pub struct StrDecoder<'a> { + // pub(crate) for arrayvec::ArrayString. + pub(crate) lengths: LengthDecoder<'a>, + strings: SliceImpl<'a, u8>, +} + +impl<'a> View<'a> for StrDecoder<'a> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { + // TODO take NonZeroUsize length in View::populate. + let Some(length) = NonZeroUsize::new(length) else { + return Ok(()); + }; + self.lengths.populate(input, length.get())?; + let bytes = consume_bytes(input, self.lengths.length())?; + + // Fast path: If bytes are ASCII then they're valid UTF-8 and no char boundary can be invalid. + // TODO(optimization): + // - Worst case when bytes doesn't fit in CPU cache, this will load bytes 3 times from RAM. + // - We should subdivide it into chunks in that case. + if is_ascii_simd(bytes) + || from_utf8(bytes).is_ok_and(|s| { + // Check that gaps between individual strings are on char boundaries in larger string. + // Indices 0 and s.len() are not checked since s: &str guarantees them. + let mut length_decoder = self.lengths.borrowed_clone(); + let mut end = 0; + for _ in 0..length.get() - 1 { + end += length_decoder.decode(); + // TODO(optimization) is_char_boundary has unnecessary checks. + if !s.is_char_boundary(end) { + return false; + } + } + true + }) + { + self.strings = bytes.into(); + Ok(()) + } else { + err("invalid utf8") + } + } +} + +impl<'a> Decoder<'a, &'a str> for StrDecoder<'a> { + #[inline(always)] + fn decode(&mut self) -> &'a str { + let bytes = unsafe { self.strings.chunk_unchecked(self.lengths.decode()) }; + debug_assert!(from_utf8(bytes).is_ok()); + + // Safety: `bytes` is valid UTF-8 because populate checked that `self.strings` is valid UTF-8 + // and that every sub string starts and ends on char boundaries. + unsafe { from_utf8_unchecked(bytes) } + } +} + +impl<'a> Decoder<'a, String> for StrDecoder<'a> { + #[inline(always)] + fn decode(&mut self) -> String { + let v: &str = self.decode(); + v.to_owned() + } +} + +/// Tests 128 bytes a time instead of `<[u8]>::is_ascii` which only tests 8. +/// 390% faster on 8KB, 27% faster on 1GB (RAM bottleneck). +fn is_ascii_simd(v: &[u8]) -> bool { + const CHUNK: usize = 128; + let chunks_exact = v.chunks_exact(CHUNK); + let remainder = chunks_exact.remainder(); + for chunk in chunks_exact { + let mut any = false; + for &v in chunk { + any |= v & 0x80 != 0 + } + if any { + debug_assert!(!chunk.is_ascii()); + return false; + } + } + debug_assert!(v[..v.len() - remainder.len()].is_ascii()); + remainder.is_ascii() +} + +#[cfg(test)] +mod tests { + use super::is_ascii_simd; + use crate::u8_char::U8Char; + use crate::{decode, encode}; + use test::{black_box, Bencher}; + + #[test] + fn utf8_validation() { + // Check from_utf8: + assert!(decode::<&str>(&encode(&vec![U8Char(255u8)])).is_err()); + assert_eq!(decode::<&str>(&encode("\0")).unwrap(), "\0"); + assert_eq!(decode::<&str>(&encode(&"☺".to_owned())).unwrap(), "☺"); + + let c = "☺"; + let full = super::str_as_u8_chars(c); + let start = &full[..1]; + let end = &full[1..]; + + // Check is_char_boundary: + assert!(decode::<[&str; 2]>(&encode(&[start.to_vec(), end.to_vec()])).is_err()); + assert_eq!(decode::<[&str; 2]>(&encode(&[c, c])).unwrap(), [c, c]); + } + + #[test] + fn test_is_ascii_simd() { + assert!(is_ascii_simd(&[0x7F; 128])); + assert!(!is_ascii_simd(&[0xFF; 128])); + } + + #[bench] + fn bench_is_ascii(b: &mut Bencher) { + b.iter(|| black_box(&[0; 8192]).is_ascii()) + } + + #[bench] + fn bench_is_ascii_simd(b: &mut Bencher) { + b.iter(|| is_ascii_simd(black_box(&[0; 8192]))) + } + + type S = &'static str; + fn bench_data() -> (S, S, S, S, S, S, S) { + ("a", "b", "c", "d", "e", "f", "g") + } + crate::bench_encode_decode!(str_tuple: (&str, &str, &str, &str, &str, &str, &str)); +} + +#[cfg(test)] +mod tests2 { + fn bench_data() -> Vec { + crate::random_data::(40000) + .into_iter() + .map(|n| { + let n = (8 + n / 32) as usize; + " ".repeat(n) + }) + .collect() + } + crate::bench_encode_decode!(str_vec: Vec); +} diff --git a/src/tests.rs b/src/tests.rs deleted file mode 100644 index 7594fc7..0000000 --- a/src/tests.rs +++ /dev/null @@ -1,560 +0,0 @@ -use crate::code::{decode_internal, encode_internal}; -use crate::serde::de::deserialize_internal; -use crate::serde::ser::serialize_internal; -use crate::word_buffer::WordBuffer; -use crate::{Buffer, Decode, Encode, E}; -use paste::paste; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::ffi::{CStr, CString}; -use std::fmt::Debug; - -#[cfg(not(miri))] -use crate::bit_buffer::BitBuffer; - -#[test] -fn test_buffer_with_capacity() { - assert_eq!(Buffer::with_capacity(0).capacity(), 0); - - let mut buf = Buffer::with_capacity(1016); - let enough_cap = buf.capacity(); - let bytes = buf.serialize(&"a".repeat(997 + 16)).unwrap(); - assert_eq!(bytes.len(), enough_cap); - assert_eq!(buf.capacity(), enough_cap); - - let mut buf = Buffer::with_capacity(1016); - let small_cap = buf.capacity(); - let bytes = buf.serialize(&"a".repeat(997 + 19)).unwrap(); - assert_ne!(bytes.len(), small_cap); - assert_ne!(buf.capacity(), small_cap); -} - -macro_rules! impl_the_same { - ($ser_trait:ident, $de_trait:ident, $ser:ident, $de:ident) => { - paste! { - fn [< the_same_ $ser>] ( - t: T, - buf: &mut Buffer, - ) { - let serialized = { - let a = [<$ser _internal>](&mut WordBuffer::default(), &t) - .unwrap() - .to_vec(); - let b = buf.$ser(&t).unwrap().to_vec(); - assert_eq!(a, b); - - #[cfg(not(miri))] - { - let c = [<$ser _internal>](&mut BitBuffer::default(), &t) - .unwrap() - .to_vec(); - assert_eq!(a, c); - } - a - }; - - let a: T = - [<$de _internal>](&mut WordBuffer::default(), &serialized).expect("WordBuffer error"); - let b: T = buf - .$de(&serialized) - .expect("Buffer::deserialize error"); - - assert_eq!(t, a); - assert_eq!(t, b); - - #[cfg(not(miri))] - { - let c: T = - [<$de _internal>](&mut BitBuffer::default(), &serialized).expect("BitBuffer error"); - assert_eq!(t, c); - } - - let mut bytes = serialized.clone(); - bytes.push(0); - #[cfg(not(miri))] - assert_eq!( - [<$de _internal>]::(&mut Default::default(), &bytes), - Err(E::ExpectedEof.e()) - ); - assert_eq!( - [<$de _internal>]::(&mut Default::default(), &bytes), - Err(E::ExpectedEof.e()) - ); - - let mut bytes = serialized.clone(); - if bytes.pop().is_some() { - #[cfg(not(miri))] - assert_eq!( - [<$de _internal>]::(&mut Default::default(), &bytes), - Err(E::Eof.e()) - ); - assert_eq!( - [<$de _internal>]::(&mut Default::default(), &bytes), - Err(E::Eof.e()) - ); - } - } - } - } -} - -impl_the_same!(Serialize, DeserializeOwned, serialize, deserialize); -impl_the_same!(Encode, Decode, encode, decode); - -fn the_same_once( - t: T, -) { - let mut buf = Buffer::new(); - the_same_serialize(t.clone(), &mut buf); - the_same_encode(t, &mut buf); -} - -fn the_same(t: T) { - the_same_once(t.clone()); - - let mut buf = Buffer::new(); - - #[cfg(miri)] - const END: usize = 2; - #[cfg(not(miri))] - const END: usize = 65; - for i in 0..END { - let input = vec![t.clone(); i]; - the_same_serialize(input.clone(), &mut buf); - the_same_encode(input, &mut buf); - } -} - -#[test] -fn fuzz1() { - let bytes = &[64]; - assert!(crate::decode::>(bytes).is_err()); - assert!(crate::serde::deserialize::>(bytes).is_err()); -} - -#[test] -fn fuzz2() { - let bytes = &[0, 0, 0, 1]; - assert!(crate::decode::>(bytes).is_err()); - assert!(crate::serde::deserialize::>(bytes).is_err()); -} - -#[test] -fn fuzz3() { - use bitvec::prelude::*; - - #[rustfmt::skip] - let bits = bitvec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - let mut bits2 = BitVec::::new(); - bits2.extend_from_bitslice(&bits); - let bytes = bits2.as_raw_slice(); - - assert!(crate::decode::>(bytes).is_err()); - assert!(crate::serde::deserialize::>(bytes).is_err()); -} - -#[test] -fn test_reddit() { - #[derive(Serialize)] - #[allow(dead_code)] - enum Variant { - Three = 3, - Zero = 0, - Two = 2, - One = 1, - } - - assert_eq!(crate::serde::serialize(&Variant::Three).unwrap().len(), 1); -} - -#[test] -fn test_zst_vec() { - for i in (0..100).step_by(9) { - the_same(vec![(); i]); - } -} - -#[test] -fn test_long_string() { - the_same("abcde".repeat(25)) -} - -#[test] -fn test_array_string() { - use arrayvec::ArrayString; - - // Serialize one field has with serde. - #[derive(Clone, Debug, PartialEq, Encode, Decode, Serialize, Deserialize)] - struct MyStruct1 { - #[bitcode_hint(ascii_lowercase)] - #[bitcode(with_serde)] - inner: ArrayString, - #[bitcode_hint(gamma)] - foo: i32, - } - - for i in 0..=20 { - let short = MyStruct1 { - inner: ArrayString::<20>::from(&"a".repeat(i)).unwrap(), - foo: 5, - }; - the_same_once(short); - } - - // not ascii_lowercase - let short = MyStruct1 { - inner: ArrayString::<5>::from(&"A".repeat(5)).unwrap(), - foo: 5, - }; - the_same_once(short); - - // Serialize whole struct with serde. - #[derive(Clone, Debug, PartialEq, Encode, Decode, Serialize, Deserialize)] - #[bitcode_hint(ascii)] - #[bitcode(with_serde)] - struct MyStruct2 { - inner: ArrayString, - } - - let long = MyStruct2 { - inner: ArrayString::<150>::from(&"abcde".repeat(30)).unwrap(), - }; - the_same_once(long); - - // Serialize whole variant with serde. - #[derive(Clone, Debug, PartialEq, Encode, Decode, Serialize, Deserialize)] - enum MyEnum { - #[bitcode(with_serde)] - A(ArrayString), - } - - let medium = MyEnum::A(ArrayString::<25>::from(&"abcde".repeat(5)).unwrap()); - the_same_once(medium); -} - -#[test] -#[cfg_attr(miri, ignore)] -fn test_zst() { - use crate::guard::ZST_LIMIT; - fn is_ok(v: Vec) -> bool { - let ser = crate::serialize(&v).unwrap(); - let a = crate::deserialize::>(&ser).is_ok(); - let b = crate::decode::>(&ser).is_ok(); - assert_eq!(a, b); - b - } - assert!(is_ok(vec![0u8; ZST_LIMIT])); - assert!(is_ok(vec![0u8; ZST_LIMIT])); - assert!(!is_ok(vec![(); ZST_LIMIT + 1])); - assert!(is_ok(vec![0u8; ZST_LIMIT + 1])); -} - -#[test] -fn test_chars() { - #[cfg(not(miri))] - const STEP: usize = char::MAX as usize / 1000; - - #[cfg(miri)] - const STEP: usize = char::MAX as usize / 100; - - let chars = (0..=char::MAX as u32) - .step_by(STEP) - .filter_map(char::from_u32) - .collect::>(); - the_same_once(chars); -} - -#[test] -fn test_char1() { - let c = char::from_u32(11141).unwrap(); - the_same(c) -} - -#[test] -fn test_expected_range() { - #[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Encode, Decode)] - struct LargeU64(#[bitcode_hint(expected_range = "10..1000000000")] u64); - - let mut i = 0; - loop { - the_same_once(LargeU64(i)); - if let Some(new) = i.checked_add(1).and_then(|i| i.checked_mul(2)) { - i = new; - } else { - break; - } - } -} - -#[test] -fn test_weird_tuple() { - let value = (1u8, Option::<()>::None); - println!( - "{} {:?}", - <(u8, Option<()>)>::DECODE_MIN, - crate::encode(&value).unwrap() - ); - the_same(value); -} - -#[test] -fn test_gamma_bytes() { - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - #[bitcode_hint(gamma)] - struct Bytes(Vec); - - the_same_once(Bytes(vec![0u8; 20])); - the_same_once(Bytes(vec![255u8; 20])); -} - -#[test] -fn test_name_conflict() { - mod decode { - #[allow(unused_imports)] - use musli::Decode; - - #[derive(bitcode::Decode)] - struct Struct { - #[allow(unused)] - field: u64, - } - } - - mod encode { - #[allow(unused_imports)] - use musli::Encode; - - #[derive(bitcode::Encode)] - struct Struct { - #[allow(unused)] - field: u64, - } - } -} - -#[test] -fn test_c_string() { - the_same_once(CString::new(vec![]).unwrap()); - the_same_once(CString::new((1..=255).collect::>()).unwrap()); - - let bytes = vec![1, 2, 3, 255, 0]; - let c_str = CStr::from_bytes_with_nul(&bytes).unwrap(); - let encoded = crate::encode(c_str).unwrap(); - let decoded = crate::decode::(&encoded).unwrap(); - assert_eq!(decoded.as_c_str(), c_str) -} - -#[test] -fn test_numbers_extra() { - macro_rules! test { - ($t:ident) => { - the_same(5 as $t); - the_same($t::MAX - 5); - the_same($t::MAX); - }; - } - - test!(u64); - test!(u128); - - macro_rules! test_signed { - ($t:ident) => { - test!($t); - the_same(-5 as $t); - the_same($t::MIN); - the_same($t::MIN + 5); - }; - } - - test_signed!(i64); - test_signed!(i128); -} - -// Everything below this comment was derived from bincode: -// https://github.com/bincode-org/bincode/blob/v1.x/tests/test.rs - -#[test] -fn test_numbers() { - // unsigned positive - the_same(5u8); - the_same(5u16); - the_same(5u32); - the_same(5u64); - the_same(5usize); - // signed positive - the_same(5i8); - the_same(5i16); - the_same(5i32); - the_same(5i64); - the_same(5isize); - // signed negative - the_same(-5i8); - the_same(-5i16); - the_same(-5i32); - the_same(-5i64); - the_same(-5isize); - // floating - the_same(-100f32); - the_same(0f32); - the_same(5f32); - the_same(-100f64); - the_same(5f64); -} - -#[test] -fn test_string() { - the_same("".to_string()); - the_same("a".to_string()); -} - -#[test] -fn test_tuple() { - the_same((1isize,)); - the_same((1isize, 2isize, 3isize)); - the_same((1isize, "foo".to_string(), ())); -} - -#[test] -fn test_basic_struct() { - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - struct Easy { - x: isize, - s: String, - y: usize, - } - the_same(Easy { - x: -4, - s: "foo".to_string(), - y: 10, - }); -} - -#[test] -fn test_nested_struct() { - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - struct Easy { - x: isize, - s: String, - y: usize, - } - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - struct Nest { - f: Easy, - b: usize, - s: Easy, - } - - the_same(Nest { - f: Easy { - x: -1, - s: "foo".to_string(), - y: 20, - }, - b: 100, - s: Easy { - x: -100, - s: "bar".to_string(), - y: 20, - }, - }); -} - -#[test] -fn test_struct_newtype() { - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - struct NewtypeStr(usize); - - the_same(NewtypeStr(5)); -} - -#[test] -fn test_struct_tuple() { - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - struct TubStr(usize, String, f32); - - the_same(TubStr(5, "hello".to_string(), 3.2)); -} - -#[test] -fn test_option() { - the_same(Some(5usize)); - the_same(Some("foo bar".to_string())); - the_same(None::); -} - -#[test] -fn test_enum() { - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - enum TestEnum { - NoArg, - OneArg(usize), - Args(usize, usize), - AnotherNoArg, - StructLike { x: usize, y: f32 }, - } - the_same(TestEnum::NoArg); - the_same(TestEnum::OneArg(4)); - the_same(TestEnum::Args(4, 5)); - the_same(TestEnum::AnotherNoArg); - the_same(TestEnum::StructLike { - x: 4, - y: std::f32::consts::PI, - }); - the_same(vec![ - TestEnum::NoArg, - TestEnum::OneArg(5), - TestEnum::AnotherNoArg, - TestEnum::StructLike { x: 4, y: 1.4 }, - ]); -} - -#[test] -fn test_vec() { - let v: Vec = vec![]; - the_same(v); - the_same(vec![1u64]); - the_same(vec![1u64, 2, 3, 4, 5, 6]); -} - -#[test] -fn test_map() { - let mut m = HashMap::new(); - m.insert(4u64, "foo".to_string()); - m.insert(0u64, "bar".to_string()); - the_same(m); -} - -#[test] -fn test_bool() { - the_same(true); - the_same(false); -} - -#[test] -fn test_unicode() { - the_same("å".to_string()); - the_same("aåååååååa".to_string()); -} - -#[test] -fn test_fixed_size_array() { - the_same([24u32; 32]); - the_same([1u64, 2, 3, 4, 5, 6, 7, 8]); - the_same([0u8; 19]); -} - -#[test] -fn expected_range_bug() { - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - pub struct UVec2 { - x: u16, - y: u16, - } - - #[derive(Encode, Decode, Serialize, Deserialize, PartialEq, Debug, Clone)] - pub struct Wrapper(#[bitcode_hint(expected_range = "0..31")] UVec2); - - let val = Wrapper(UVec2 { x: 500, y: 512 }); - the_same(val); -} diff --git a/src/u8_char.rs b/src/u8_char.rs new file mode 100644 index 0000000..d768b5d --- /dev/null +++ b/src/u8_char.rs @@ -0,0 +1,43 @@ +use crate::coder::{Buffer, Encoder}; +use crate::derive::Encode; +use crate::fast::{PushUnchecked, VecImpl}; +use std::num::NonZeroUsize; + +/// Represents a single byte of a string, unlike u8 which represents an integer. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct U8Char(pub u8); + +// Could derive with bytemuck/derive. +unsafe impl bytemuck::Zeroable for U8Char {} +unsafe impl bytemuck::Pod for U8Char {} + +impl Encode for U8Char { + type Encoder = U8CharEncoder; +} + +#[derive(Debug, Default)] +pub struct U8CharEncoder(VecImpl); + +impl Encoder for U8CharEncoder { + #[inline(always)] + fn as_primitive(&mut self) -> Option<&mut VecImpl> { + Some(&mut self.0) + } + + #[inline(always)] + fn encode(&mut self, &v: &U8Char) { + unsafe { self.0.push_unchecked(v) } + } +} + +impl Buffer for U8CharEncoder { + fn collect_into(&mut self, out: &mut Vec) { + out.extend_from_slice(bytemuck::must_cast_slice(self.0.as_slice())); + self.0.clear(); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.0.reserve(additional.get()) + } +} diff --git a/src/word.rs b/src/word.rs deleted file mode 100644 index 88c954f..0000000 --- a/src/word.rs +++ /dev/null @@ -1,5 +0,0 @@ -/// How much data is copied in write_bits/read_bits. -/// Can't be changed to another size without significant code changes. -pub type Word = u64; -pub const WORD_BITS: usize = Word::BITS as usize; -pub const WORD_BYTES: usize = std::mem::size_of::(); diff --git a/src/word_buffer.rs b/src/word_buffer.rs deleted file mode 100644 index 8d7478a..0000000 --- a/src/word_buffer.rs +++ /dev/null @@ -1,595 +0,0 @@ -use crate::buffer::BufferTrait; -use crate::encoding::ByteEncoding; -use crate::nightly::div_ceil; -use crate::read::Read; -use crate::word::*; -use crate::write::Write; -use crate::{Result, E}; -use from_bytes_or_zeroed::FromBytesOrZeroed; -use std::array; -use std::num::NonZeroUsize; - -/// A fast [`Buffer`] that operates on [`Word`]s. -#[derive(Debug, Default)] -pub struct WordBuffer { - allocation: Allocation, - read_bytes_buf: Box<[Word]>, -} - -#[derive(Debug, Default)] -struct Allocation { - allocation: Vec, - written_words: usize, -} - -impl Allocation { - fn as_mut_slice(&mut self) -> &mut [Word] { - self.allocation.as_mut_slice() - } - - fn take_box(&mut self) -> Box<[Word]> { - let vec = std::mem::take(&mut self.allocation); - let mut box_ = if vec.capacity() == vec.len() { - vec - } else { - // Must have been created by start_read. We need len and capacity to be equal to make - // into_boxed_slice zero cost. If we zeroed up to capacity we could have a situation - // where reading/writing to same buffer causes the whole capacity to be zeroed each - // write (even if only a small portion of the buffer is used). - vec![] - } - .into_boxed_slice(); - - // Zero all the words that we could have written to. - let written_words = self.written_words.min(box_.len()); - box_[0..written_words].fill(0); - self.written_words = 0; - debug_assert!(box_.iter().all(|&w| w == 0)); - - box_ - } - - fn replace_box(&mut self, box_: Box<[Word]>, written_words: usize) { - self.allocation = box_.into(); - self.written_words = written_words; - } - - fn make_vec(&mut self) -> &mut Vec { - self.written_words = usize::MAX; - &mut self.allocation - } -} - -pub struct WordContext { - input_bytes: usize, -} - -impl WordBuffer { - /// Extra [`Word`]s appended to the end of the input to make deserialization faster. - /// 1 for peek_reserved_bits and another for read_zeros (which calls peek_reserved_bits). - const READ_PADDING: usize = 2; -} - -impl BufferTrait for WordBuffer { - type Writer = WordWriter; - type Reader<'a> = WordReader<'a>; - type Context = WordContext; - - fn capacity(&self) -> usize { - // Subtract the padding of 1 (added by alloc_index_plus_one). - self.allocation.allocation.capacity().saturating_sub(1) * WORD_BYTES - } - - fn with_capacity(cap: usize) -> Self { - let mut me = Self::default(); - if cap == 0 { - return me; - } - let mut writer = Self::Writer::default(); - - // Convert len to index by subtracting 1. - Self::Writer::alloc_index_plus_one(&mut writer.words, div_ceil(cap, WORD_BYTES) - 1); - me.allocation.replace_box(writer.words, 0); - me - } - - fn start_write(&mut self) -> Self::Writer { - let words = self.allocation.take_box(); - Self::Writer { words, index: 0 } - } - - fn finish_write(&mut self, mut writer: Self::Writer) -> &[u8] { - // write_zeros doesn't allocate, but it moves index so we allocate up to index at the end. - let index = writer.index / WORD_BITS; - if index >= writer.words.len() { - // TODO could allocate exact amount instead of regular growth strategy. - Self::Writer::alloc_index_plus_one(&mut writer.words, index); - } - - let Self::Writer { words, index } = writer; - let written_words = div_ceil(index, WORD_BITS); - - self.allocation.replace_box(words, written_words); - let written_words = &mut self.allocation.as_mut_slice()[..written_words]; - - // Swap bytes in each word (that was written to) if big endian. - if cfg!(target_endian = "big") { - written_words.iter_mut().for_each(|w| *w = w.swap_bytes()); - } - - let written_bytes = div_ceil(index, u8::BITS as usize); - &bytemuck::cast_slice(written_words)[..written_bytes] - } - - fn start_read<'a>(&'a mut self, bytes: &'a [u8]) -> (Self::Reader<'a>, Self::Context) { - let words = self.allocation.make_vec(); - words.clear(); - - // u8s rounded up to u64s plus 1 u64 padding. - let capacity = div_ceil(bytes.len(), WORD_BYTES) + Self::READ_PADDING; - words.reserve_exact(capacity); - - // Fast hot loop (would be nicer with array_chunks, but that requires nightly). - let chunks = bytes.chunks_exact(WORD_BYTES); - let remainder = chunks.remainder(); - words.extend(chunks.map(|chunk| { - let chunk: &[u8; 8] = chunk.try_into().unwrap(); - Word::from_le_bytes(*chunk) - })); - - // Remaining bytes. - if !remainder.is_empty() { - words.push(u64::from_le_bytes(array::from_fn(|i| { - remainder.get(i).copied().unwrap_or_default() - }))); - } - - // Padding so peek_reserved_bits doesn't ever go out of bounds. - words.extend([0; Self::READ_PADDING]); - debug_assert_eq!(words.len(), capacity); - - let reader = WordReader { - inner: WordReaderInner { words, index: 0 }, - read_bytes_buf: &mut self.read_bytes_buf, - }; - let context = WordContext { - input_bytes: bytes.len(), - }; - (reader, context) - } - - fn finish_read(reader: Self::Reader<'_>, context: Self::Context) -> Result<()> { - let read = reader.inner.index; - let bytes_read = div_ceil(read, u8::BITS as usize); - let index = read / WORD_BITS; - let bits_written = read % WORD_BITS; - - if bits_written != 0 && reader.inner.words[index] & !((1 << bits_written) - 1) != 0 { - return Err(E::ExpectedEof.e()); - } - - use std::cmp::Ordering::*; - match bytes_read.cmp(&context.input_bytes) { - Less => Err(E::ExpectedEof.e()), - Equal => Ok(()), - Greater => { - // It is possible that we read more bytes than we have (bytes are rounded up to words). - // We don't check this while deserializing to avoid degrading performance. - Err(E::Eof.e()) - } - } - } -} - -#[derive(Default)] -pub struct WordWriter { - words: Box<[Word]>, - index: usize, -} - -impl WordWriter { - /// Allocates at least `words` of zeroed memory. - fn alloc(words: &mut Box<[Word]>, len: usize) { - let new_cap = len.next_power_of_two().max(16); - - // TODO find a way to use Allocator::grow_zeroed safely (new bytemuck api?). - let new = bytemuck::allocation::zeroed_slice_box(new_cap); - - let previous = std::mem::replace(words, new); - words[..previous.len()].copy_from_slice(&previous); - } - - // Allocates up to an `index + 1` in words if a bounds check fails. - // Returns a mutable array of [index, index + 1] to avoid bounds checks near hot code. - #[cold] - fn alloc_index_plus_one(words: &mut Box<[Word]>, index: usize) -> &mut [Word; 2] { - let end = index + 2; - Self::alloc(words, end); - (&mut words[index..end]).try_into().unwrap() - } - - /// Ensures that space for `bytes` is allocated.\ - #[inline(always)] - fn reserve_write_bytes(&mut self, bytes: usize) { - let index = self.index / WORD_BITS + bytes / WORD_BYTES + 1; - if index >= self.words.len() { - Self::alloc_index_plus_one(&mut self.words, index); - } - } - - #[inline(always)] - fn write_bits_inner( - &mut self, - word: Word, - bits: usize, - out_of_bounds: fn(&mut Box<[Word]>, usize) -> &mut [Word; 2], - ) { - debug_assert!(bits <= WORD_BITS); - if bits != WORD_BITS { - debug_assert_eq!(word, word & ((1 << bits) - 1)); - } - - let bit_index = self.index; - self.index += bits; - - let index = bit_index / WORD_BITS; - let bit_remainder = bit_index % WORD_BITS; - - // Only requires 1 branch in hot path. - let slice = if let Some(w) = self.words.get_mut(index..index + 2) { - w.try_into().unwrap() - } else { - out_of_bounds(&mut self.words, index) - }; - slice[0] |= word << bit_remainder; - slice[1] = (word >> 1) >> (WORD_BITS - bit_remainder - 1); - } - - #[inline(always)] - fn write_reserved_bits(&mut self, word: Word, bits: usize) { - self.write_bits_inner(word, bits, |_, _| unreachable!()); - } - - fn write_reserved_words(&mut self, src: &[Word]) { - debug_assert!(!src.is_empty()); - - let bit_start = self.index; - let bit_end = self.index + src.len() * WORD_BITS; - self.index = bit_end; - - let start = bit_start / WORD_BITS; - let end = div_ceil(bit_end, WORD_BITS); - - let shl = bit_start % WORD_BITS; - let shr = WORD_BITS - shl; - - if shl == 0 { - self.words[start..end].copy_from_slice(src) - } else { - let after_start = start + 1; - let before_end = end - 1; - - let dst = &mut self.words[after_start..before_end]; - - // Do bounds check outside loop. Makes compiler go brrr - assert!(dst.len() < src.len()); - - for (i, w) in dst.iter_mut().enumerate() { - let a = src[i]; - let b = src[i + 1]; - debug_assert_eq!(*w, 0); - *w = (a >> shr) | (b << shl) - } - - self.words[start] |= src[0] << shl; - debug_assert_eq!(self.words[before_end], 0); - self.words[before_end] = *src.last().unwrap() >> shr - } - } -} - -impl Write for WordWriter { - type Revert = usize; - #[inline(always)] - fn get_revert(&mut self) -> Self::Revert { - self.index - } - #[cold] - fn revert(&mut self, revert: Self::Revert) { - // min with self.words.len() since if writing zeros, the words might not have been allocated. - let start = div_ceil(revert, WORD_BITS).min(self.words.len()); - let end = (div_ceil(self.index, WORD_BITS)).min(self.words.len()); - - // Zero whole words. - self.words[start..end].fill(0); - - // Zero remaining bits. Might not have been allocated if writing zeros. - let i = revert / WORD_BITS; - if i < self.words.len() { - let keep_up_to = revert % WORD_BITS; - self.words[i] &= (1 << keep_up_to) - 1; - } - self.index = revert; - } - - #[inline(always)] - fn write_bit(&mut self, v: bool) { - let bit_index = self.index; - self.index += 1; - - let index = bit_index / WORD_BITS; - let bit_remainder = bit_index % WORD_BITS; - - *if let Some(w) = self.words.get_mut(index) { - w - } else { - &mut Self::alloc_index_plus_one(&mut self.words, index)[0] - } |= (v as Word) << bit_remainder; - } - - #[inline(always)] - fn write_bits(&mut self, word: Word, bits: usize) { - self.write_bits_inner(word, bits, Self::alloc_index_plus_one); - } - - #[inline(always)] - fn write_bytes(&mut self, bytes: &[u8]) { - #[inline(always)] - fn write_0_to_8_bytes(me: &mut WordWriter, bytes: &[u8]) { - debug_assert!(bytes.len() <= 8); - me.write_reserved_bits( - u64::from_le_bytes_or_zeroed(bytes), - bytes.len() * u8::BITS as usize, - ); - } - - // Slower for small inputs. Doesn't work on big endian since it bytemucks u64 to bytes. - #[inline(never)] - fn write_many_bytes(me: &mut WordWriter, bytes: &[u8]) { - assert!(!cfg!(target_endian = "big")); - - // TODO look into align_to specification to see if any special cases are required. - let (a, b, c) = bytemuck::pod_align_to::(bytes); - write_0_to_8_bytes(me, a); - me.write_reserved_words(b); - write_0_to_8_bytes(me, c); - } - - if bytes.is_empty() { - return; - } - - self.reserve_write_bytes(bytes.len()); - - // Fast case for short bytes. Both methods are about the same speed at 75 bytes. - // write_many_bytes doesn't work on big endian. - if bytes.len() < 75 || cfg!(target_endian = "big") { - let mut bytes = bytes; - while bytes.len() > 8 { - let b8: &[u8; 8] = bytes[0..8].try_into().unwrap(); - self.write_reserved_bits(Word::from_le_bytes(*b8), WORD_BITS); - bytes = &bytes[8..] - } - write_0_to_8_bytes(self, bytes); - } else { - write_many_bytes(self, bytes) - } - } - - #[inline(always)] - fn write_encoded_bytes(&mut self, mut bytes: &[u8]) -> bool { - // TODO could reserve bytes.len() * C::BITS_PER_BYTE. - - while bytes.len() > 8 { - let (bytes8, remaining) = bytes.split_at(8); - let bytes8: &[u8; 8] = bytes8.try_into().unwrap(); - bytes = remaining; - - let word = Word::from_le_bytes(*bytes8); - if !C::validate(word, WORD_BYTES) { - return false; - } - self.write_bits(C::pack(word), WORD_BYTES * C::BITS_PER_BYTE); - } - - let word = Word::from_le_bytes_or_zeroed(bytes); - if !C::validate(word, bytes.len()) { - return false; - } - self.write_bits(C::pack(word), bytes.len() * C::BITS_PER_BYTE); - true - } - - #[inline(always)] - fn write_zeros(&mut self, bits: usize) { - debug_assert!(bits <= WORD_BITS); - self.index += bits; - } -} - -struct WordReaderInner<'a> { - words: &'a [Word], - index: usize, -} - -impl WordReaderInner<'_> { - #[inline(always)] - fn peek_reserved_bits(&self, bits: usize) -> Word { - debug_assert!((1..=WORD_BITS).contains(&bits)); - let bit_index = self.index; - - let index = bit_index / WORD_BITS; - let bit_remainder = bit_index % WORD_BITS; - - let a = self.words[index] >> bit_remainder; - let b = (self.words[index + 1] << 1) << (WORD_BITS - 1 - bit_remainder); - - // Clear bits at end (don't need to do in ser because bits at end are zeroed). - let extra_bits = WORD_BITS - bits; - ((a | b) << extra_bits) >> extra_bits - } - - #[inline(always)] - fn read_reserved_bits(&mut self, bits: usize) -> Word { - let v = self.peek_reserved_bits(bits); - self.index += bits; - v - } - - /// Faster [`Read::reserve_bits`] that can elide bounds checks for `bits` in range `1..=64`. - #[inline(always)] - fn reserve_1_to_64_bits(&self, bits: usize) -> Result<()> { - debug_assert!((1..=WORD_BITS).contains(&bits)); - - let read = self.index / WORD_BITS; - let len = self.words.len(); - if read + 1 >= len { - // TODO hint as unlikely. - Err(E::Eof.e()) - } else { - Ok(()) - } - } -} - -pub struct WordReader<'a> { - inner: WordReaderInner<'a>, - read_bytes_buf: &'a mut Box<[Word]>, -} - -impl<'a> Read for WordReader<'a> { - #[inline(always)] - fn advance(&mut self, bits: usize) { - self.inner.index += bits; - } - - #[inline(always)] - fn peek_bits(&mut self) -> Result { - self.inner.reserve_1_to_64_bits(64)?; - Ok(self.inner.peek_reserved_bits(64)) - } - - #[inline(always)] - fn read_bit(&mut self) -> Result { - self.inner.reserve_1_to_64_bits(1)?; - - let bit_index = self.inner.index; - self.inner.index += 1; - - let index = bit_index / WORD_BITS; - let bit_remainder = bit_index % WORD_BITS; - - Ok((self.inner.words[index] & (1 << bit_remainder)) != 0) - } - - #[inline(always)] - fn read_bits(&mut self, bits: usize) -> Result { - self.inner.reserve_1_to_64_bits(bits)?; - Ok(self.inner.read_reserved_bits(bits)) - } - - #[inline(never)] - fn read_bytes(&mut self, len: NonZeroUsize) -> Result<&[u8]> { - // We read the `[u8]` as `[Word]` and then truncate it. - let len = len.get(); - let words_len = (len - 1) / WORD_BYTES + 1; - let src_len = words_len + 1; - - let start = self.inner.index / WORD_BITS; - let src = if let Some(src) = self.inner.words.get(start..start + src_len) { - src - } else { - return Err(E::Eof.e()); - }; - - // Only allocate after src is reserved to prevent memory exhaustion attacks. - let buf = &mut *self.read_bytes_buf; - let dst = if let Some(slice) = buf.get_mut(..words_len) { - slice - } else { - alloc_read_bytes_buf(buf, words_len); - &mut buf[..words_len] - }; - - // If offset is 0 we would shl by 64 which is invalid so we just copy the slice. If shl by - // 64 resulted in 0 we wouldn't need this special case. - let offset = self.inner.index % WORD_BITS; - if offset == 0 { - let src = &src[..words_len]; - dst.copy_from_slice(src); - } else { - let shl = WORD_BITS - offset; - let shr = offset; - - for (i, w) in dst.iter_mut().enumerate() { - *w = (src[i] >> shr) | (src[i + 1] << shl); - } - } - self.inner.index += len * u8::BITS as usize; - - // Swap bytes in each word (that was written to) if big endian and bytemuck to bytes. - if cfg!(target_endian = "big") { - dst.iter_mut().for_each(|w| *w = w.swap_bytes()); - } - Ok(&bytemuck::cast_slice(self.read_bytes_buf)[..len]) - } - - #[inline(always)] - fn read_encoded_bytes(&mut self, len: NonZeroUsize) -> Result<&[u8]> { - let len = len.get(); - let whole_words_len = (len - 1) / WORD_BYTES; - let word_len = whole_words_len + 1; - - // Only allocate after reserved to prevent memory exhaustion attacks. - let read = self.inner.index / WORD_BITS + 2 + whole_words_len * C::BITS_PER_BYTE / 8; - if read >= self.inner.words.len() { - return Err(E::Eof.e()); - } - - let buf = &mut *self.read_bytes_buf; - let words = if let Some(slice) = buf.get_mut(..word_len) { - slice - } else { - alloc_read_bytes_buf(buf, word_len); - &mut buf[..word_len] - }; - - let whole_words = &mut words[..whole_words_len]; - for w in whole_words { - *w = C::unpack(self.inner.peek_reserved_bits(WORD_BITS)); - self.inner.index += WORD_BYTES * C::BITS_PER_BYTE; - } - - let remaining_bytes = len - whole_words_len * WORD_BYTES; - debug_assert!((1..=8).contains(&remaining_bytes)); - *words.last_mut().unwrap() = C::unpack(self.inner.peek_reserved_bits(WORD_BITS)); - self.inner.index += remaining_bytes * C::BITS_PER_BYTE; - - // Swap bytes in each word (that was written to) if big endian and bytemuck to bytes. - if cfg!(target_endian = "big") { - words.iter_mut().for_each(|w| *w = w.swap_bytes()); - } - Ok(&bytemuck::cast_slice(self.read_bytes_buf)[..len]) - } - - #[inline(always)] - fn reserve_bits(&self, bits: usize) -> Result<()> { - // TODO could make this overestimate remaining bits by a small amount to simplify logic. - let whole_words_len = bits / WORD_BITS; - let words_len = whole_words_len + 1; - - let read = self.inner.index / WORD_BITS + words_len; - if read >= self.inner.words.len() { - // TODO hint as unlikely. - Err(E::Eof.e()) - } else { - Ok(()) - } - } -} - -#[cold] -fn alloc_read_bytes_buf(buf: &mut Box<[Word]>, len: usize) { - let new_cap = len.next_power_of_two().max(16); - *buf = bytemuck::allocation::zeroed_slice_box(new_cap); -} diff --git a/src/write.rs b/src/write.rs deleted file mode 100644 index b3a13d2..0000000 --- a/src/write.rs +++ /dev/null @@ -1,186 +0,0 @@ -use crate::encoding::ByteEncoding; -use crate::word::Word; - -/// Abstracts over writing bits to a buffer. -pub trait Write { - /// Reverts allow reverting writes. Doing so can be expensive so make sure it only happens in - /// the cold path. This is useful for writing/validating data at the same time. - type Revert; - fn get_revert(&mut self) -> Self::Revert; - fn revert(&mut self, revert: Self::Revert); - - /// Writes a bit. If `v` is always `false` use [`Self::write_false`]. - fn write_bit(&mut self, v: bool); - /// Writes up to 64 bits. The index of `word`'s most significant 1 must be < `bits`. - /// `bits` must be in range `0..=64`. - fn write_bits(&mut self, word: Word, bits: usize); - /// Writes `bytes`. - fn write_bytes(&mut self, bytes: &[u8]); - /// Writes `bytes` with a [`ByteEncoding`]. Returns if the bytes are valid according to - /// [`ByteEncoding::validate`]. - fn write_encoded_bytes(&mut self, bytes: &[u8]) -> bool; - /// Writes `false`. Might be faster than `writer.write_bit(false)`. - #[inline(always)] - fn write_false(&mut self) { - self.write_zeros(1); - } - /// Writes up to 64 zero bits. Might be faster than `writer.write_bits(0, bits)`. - fn write_zeros(&mut self, bits: usize) { - self.write_bits(0, bits); - } -} - -#[cfg(all(test, not(miri)))] -mod tests { - use super::*; - use crate::bit_buffer::BitBuffer; - use crate::buffer::BufferTrait; - use crate::word_buffer::WordBuffer; - use paste::paste; - use test::{black_box, Bencher}; - - // How many times each benchmark calls the function. - const TIMES: usize = 1000; - - #[bench] - fn bench_vec(b: &mut Bencher) { - let mut vec = vec![]; - b.iter(|| { - let vec = black_box(&mut vec); - vec.clear(); - for _ in 0..TIMES { - vec.push(black_box(0b10101u8)) - } - black_box(vec); - }); - } - - fn bench_write_bit(b: &mut Bencher) { - let mut buf = T::default(); - b.iter(|| { - let buf = black_box(&mut buf); - let mut writer = buf.start_write(); - for _ in 0..TIMES { - writer.write_bit(black_box(true)) - } - buf.finish_write(writer); - }); - } - - fn bench_write_bytes(b: &mut Bencher, bytes: usize) { - let v = vec![123u8; bytes]; - let mut buf = T::default(); - b.iter(|| { - let buf = black_box(&mut buf); - let mut writer = buf.start_write(); - writer.write_bit(true); // Make write_bytes unaligned. - for _ in 0..TIMES { - writer.write_bytes(black_box(v.as_slice())) - } - buf.finish_write(writer); - }); - } - - fn bench_write_bytes_range(b: &mut Bencher, min: usize, max: usize) { - use rand::prelude::*; - - let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default()); - let v: Vec> = (0..TIMES) - .map(|_| (0..rng.gen_range(min..=max)).map(|i| i as u8).collect()) - .collect(); - - let mut buf = T::default(); - b.iter(|| { - let buf = black_box(&mut buf); - let mut writer = buf.start_write(); - writer.write_bit(true); // Make write_bytes unaligned. - for v in black_box(v.as_slice()) { - writer.write_bytes(v) - } - buf.finish_write(writer); - }); - } - - fn bench_write_bits(b: &mut Bencher, bits: usize) { - let v = Word::MAX >> (Word::BITS as usize - bits); - let mut buf = T::default(); - b.iter(|| { - let buf = black_box(&mut buf); - let mut writer = buf.start_write(); - for _ in 0..TIMES { - writer.write_bits(black_box(v), black_box(bits)) - } - buf.finish_write(writer); - }); - } - - #[bench] - fn bench_word_buffer_write_false(b: &mut Bencher) { - let mut buf = WordBuffer::default(); - b.iter(|| { - let buf = black_box(&mut buf); - let mut writer = buf.start_write(); - for _ in 0..TIMES { - writer.write_false() - } - buf.finish_write(writer); - }); - } - - macro_rules! bench_write_bits { - ($name:ident, $T:ty, $n:literal) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_write_bits::<$T>(b, $n); - } - } - }; - } - - macro_rules! bench_write_bytes { - ($name:ident, $T:ty, $n:literal) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_write_bytes::<$T>(b, $n); - } - } - }; - } - - macro_rules! bench_write_bytes_range { - ($name:ident, $T:ty, $min:literal, $max:literal) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_write_bytes_range::<$T>(b, $min, $max); - } - } - }; - } - - macro_rules! bench_write { - ($name:ident, $T:ty) => { - paste! { - #[bench] - fn [](b: &mut Bencher) { - bench_write_bit::<$T>(b); - } - } - - bench_write_bits!($name, $T, 5); - bench_write_bits!($name, $T, 41); - bench_write_bytes!($name, $T, 1); - bench_write_bytes!($name, $T, 10); - bench_write_bytes!($name, $T, 100); - bench_write_bytes!($name, $T, 1000); - - bench_write_bytes_range!($name, $T, 0, 8); - bench_write_bytes_range!($name, $T, 0, 16); - }; - } - - bench_write!(bit_buffer, BitBuffer); - bench_write!(word_buffer, WordBuffer); -} From 3b00907f8d1342da69818999e1d539db36c9492e Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Fri, 23 Feb 2024 12:18:46 -0800 Subject: [PATCH 02/45] Remove derive macro decode impl and use decode_in_place in array/Option. --- bitcode_derive/src/decode.rs | 22 +++++++++------------- src/coder.rs | 23 +++++++++-------------- src/derive/array.rs | 5 ----- src/derive/impls.rs | 7 ------- src/derive/option.rs | 7 ++++--- src/derive/vec.rs | 2 -- src/ext/arrayvec.rs | 4 +--- src/ext/mod.rs | 1 + 8 files changed, 24 insertions(+), 47 deletions(-) diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index 7b5bdac..622e2fd 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -27,11 +27,11 @@ enum Item { } impl Item { - const ALL: [Self; 5] = [ + const ALL: [Self; 4] = [ Self::Type, Self::Default, Self::Populate, - Self::Decode, + // No Self::Decode since it's only used for enum variants, not top level struct/enum. Self::DecodeInPlace, ]; const COUNT: usize = Self::ALL.len(); @@ -58,6 +58,7 @@ impl Item { Self::Populate => quote! { self.#global_field_name.populate(input, __length)?; }, + // Only used by enum variants. Self::Decode => quote! { let #field_name = self.#global_field_name.decode(); }, @@ -167,7 +168,8 @@ impl Item { #inners } } - Self::Decode | Self::DecodeInPlace => { + Self::Decode => unimplemented!(), + Self::DecodeInPlace => { if never { return quote! { // Safety: View::populate will error on length != 0 so decode won't be called. @@ -264,7 +266,7 @@ impl Item { struct Output([TokenStream; Item::COUNT]); impl Output { - fn make_ghost(mut self) -> Self { + fn haunt(mut self) -> Self { let type_ = &mut self.0[Item::Type as usize]; if type_.is_empty() { let de = de_lifetime(); @@ -320,7 +322,7 @@ pub fn derive_impl(mut input: DeriveInput) -> Result { } Data::Union(u) => err(&u.union_token, "unions are not supported")?, }) - .make_ghost(); + .haunt(); bounds.apply_to_generics(&mut generics); let input_generics = generics.clone(); @@ -347,7 +349,7 @@ pub fn derive_impl(mut input: DeriveInput) -> Result { // Push de_param after bounding 'de: 'a. let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone())); - generics.params.push(de_param.clone()); // TODO bound to other lifetimes. + generics.params.push(de_param.clone()); generics .make_where_clause() .predicates @@ -361,8 +363,7 @@ pub fn derive_impl(mut input: DeriveInput) -> Result { generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it. let (decoder_impl_generics, decoder_generics, decoder_where_clause) = generics.split_for_impl(); - let Output([type_body, default_body, populate_body, decode_body, decode_in_place_body]) = - output; + let Output([type_body, default_body, populate_body, decode_in_place_body]) = output; let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site()); let decoder_ty = quote! { #decoder_ident #decoder_generics }; let private = private(); @@ -395,11 +396,6 @@ pub fn derive_impl(mut input: DeriveInput) -> Result { } impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause { - #[cfg_attr(not(debug_assertions), inline(always))] - fn decode(&mut self) -> #input_ty { - #decode_body - } - #[cfg_attr(not(debug_assertions), inline(always))] fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) { #decode_in_place_body diff --git a/src/coder.rs b/src/coder.rs index 885534c..2bcdd43 100644 --- a/src/coder.rs +++ b/src/coder.rs @@ -53,6 +53,8 @@ pub trait View<'a> { fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()>; } +/// One of [`Decoder::decode`] and [`Decoder::decode_in_place`] must be implemented or calling +/// either one will result in infinite recursion and a stack overflow. pub trait Decoder<'a, T>: View<'a> + Default { /// Returns a `Some(ptr)` to the current element if it can be decoded by copying. #[inline(always)] @@ -69,30 +71,23 @@ pub trait Decoder<'a, T>: View<'a> + Default { } /// Decodes a single value. Can't error since `View::populate` has already validated the input. - fn decode(&mut self) -> T; + /// Prefer decode for primitives (since it's simpler) and decode_in_place for array/struct/tuple. + #[inline(always)] + fn decode(&mut self) -> T { + let mut out = MaybeUninit::uninit(); + self.decode_in_place(&mut out); + unsafe { out.assume_init() } + } /// [`Self::decode`] without redundant copies. Only downside is panics will leak the value. /// The only panics out of our control are Hash/Ord/PartialEq for BinaryHeap/BTreeMap/HashMap. /// E.g. if a user PartialEq panics we will leak some memory which is an acceptable tradeoff. - /// TODO make this required and add default impl for Self::decode. #[inline(always)] fn decode_in_place(&mut self, out: &mut MaybeUninit) { out.write(self.decode()); } } -macro_rules! decode_from_in_place { - ($t:ty) => { - #[inline(always)] - fn decode(&mut self) -> $t { - let mut out = std::mem::MaybeUninit::uninit(); - self.decode_in_place(&mut out); - unsafe { out.assume_init() } - } - }; -} -pub(crate) use decode_from_in_place; - #[doc(hidden)] #[macro_export] macro_rules! __private_uninit_field { diff --git a/src/derive/array.rs b/src/derive/array.rs index eccdb9c..40a1b50 100644 --- a/src/derive/array.rs +++ b/src/derive/array.rs @@ -56,11 +56,6 @@ impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayDecoder<'a, T, N> { } impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, [T; N]> for ArrayDecoder<'a, T, N> { - #[inline(always)] - fn decode(&mut self) -> [T; N] { - std::array::from_fn(|_| self.0.decode()) - } - #[inline(always)] fn decode_in_place(&mut self, out: &mut MaybeUninit<[T; N]>) { // Safety: Equivalent to nightly MaybeUninit::transpose. diff --git a/src/derive/impls.rs b/src/derive/impls.rs index 98e731d..d36e903 100644 --- a/src/derive/impls.rs +++ b/src/derive/impls.rs @@ -239,13 +239,6 @@ macro_rules! impl_tuples { } impl<'a, $($name: Decode<'a>,)*> Decoder<'a, ($($name,)*)> for TupleDecoder<'a, $($name,)*> { - #[inline(always)] - fn decode(&mut self) -> ($($name,)*) { - ( - $(self.$n.decode(),)* - ) - } - #[inline(always)] fn decode_in_place(&mut self, out: &mut MaybeUninit<($($name,)*)>) { $( diff --git a/src/derive/option.rs b/src/derive/option.rs index 368749a..0e1e374 100644 --- a/src/derive/option.rs +++ b/src/derive/option.rs @@ -2,6 +2,7 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View, MAX_VECTORED_CHUNK}; use crate::derive::variant::{VariantDecoder, VariantEncoder}; use crate::derive::{Decode, Encode}; use crate::fast::{FastArrayVec, PushUnchecked}; +use std::mem::MaybeUninit; use std::num::NonZeroUsize; #[derive(Debug)] @@ -110,11 +111,11 @@ impl<'a, T: Decode<'a>> View<'a> for OptionDecoder<'a, T> { impl<'a, T: Decode<'a>> Decoder<'a, Option> for OptionDecoder<'a, T> { #[inline(always)] - fn decode(&mut self) -> Option { + fn decode_in_place(&mut self, out: &mut MaybeUninit>) { if self.variants.decode() != 0 { - Some(self.some.decode()) + out.write(Some(self.some.decode())); } else { - None + out.write(None); } } } diff --git a/src/derive/vec.rs b/src/derive/vec.rs index 0f773c0..45d619d 100644 --- a/src/derive/vec.rs +++ b/src/derive/vec.rs @@ -290,8 +290,6 @@ impl Encoder> for VecEncoder { } } impl<'a, T: Decode<'a>> Decoder<'a, Vec> for VecDecoder<'a, T> { - crate::coder::decode_from_in_place!(Vec); - #[inline(always)] fn decode_in_place(&mut self, out: &mut MaybeUninit>) { let length = self.lengths.decode(); diff --git a/src/ext/arrayvec.rs b/src/ext/arrayvec.rs index aca2042..e7a2b2a 100644 --- a/src/ext/arrayvec.rs +++ b/src/ext/arrayvec.rs @@ -1,4 +1,4 @@ -use crate::coder::{decode_from_in_place, Decoder, Encoder, Result, View}; +use crate::coder::{Decoder, Encoder, Result, View}; use crate::derive::vec::{unsafe_wild_copy, VecDecoder, VecEncoder}; use crate::derive::{Decode, Encode}; use crate::error::err; @@ -74,7 +74,6 @@ impl<'a, const N: usize> View<'a> for ArrayStringDecoder<'a, N> { } } impl<'a, const N: usize> Decoder<'a, ArrayString> for ArrayStringDecoder<'a, N> { - decode_from_in_place!(ArrayString); #[inline(always)] fn decode_in_place(&mut self, out: &mut MaybeUninit>) { let s: &str = self.0.decode(); @@ -152,7 +151,6 @@ impl<'a, T: Decode<'a>, const N: usize> View<'a> for ArrayVecDecoder<'a, T, N> { } } impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, ArrayVec> for ArrayVecDecoder<'a, T, N> { - decode_from_in_place!(ArrayVec); #[inline(always)] fn decode_in_place(&mut self, out: &mut MaybeUninit>) { // Safety: We've ensured self.lengths.max_len() <= N in populate. diff --git a/src/ext/mod.rs b/src/ext/mod.rs index d1e080c..254621f 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -54,6 +54,7 @@ macro_rules! impl_struct { } } impl<'a> crate::coder::Decoder<'a, $t> for StructDecoder<'a> { + // TODO use decode_in_place instead. #[inline(always)] fn decode(&mut self) -> $t { $t::$new($(self.$f.decode()),+) From f16569429de177adaae5ae6b14801c08d3ced688 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Fri, 23 Feb 2024 12:38:41 -0800 Subject: [PATCH 03/45] Fix typo. --- src/coder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/coder.rs b/src/coder.rs index 2bcdd43..f492a06 100644 --- a/src/coder.rs +++ b/src/coder.rs @@ -64,7 +64,7 @@ pub trait Decoder<'a, T>: View<'a> + Default { /// Assuming [`Self::as_primitive_ptr`] returns `Some(ptr)`, this advances `ptr` by `n`. /// # Safety - /// All advances and decodes must not pass `Self::populate(_, length)`. + /// All advances and decodes must not pass `self.populate(_, length)`. unsafe fn as_primitive_advance(&mut self, n: usize) { let _ = n; unreachable!(); From 2148646b9e59237bd987e83e03f47b1e43b2ec5c Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Fri, 23 Feb 2024 13:50:28 -0800 Subject: [PATCH 04/45] Implement Result. --- fuzz/fuzz_targets/fuzz.rs | 3 +- src/derive/impls.rs | 7 +++ src/derive/mod.rs | 1 + src/derive/option.rs | 2 +- src/derive/result.rs | 103 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 src/derive/result.rs diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index 2f91781..2a936bc 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -74,7 +74,8 @@ fuzz_target!(|data: &[u8]| { Vec<$typ>, HashMap, ArrayVec<$typ, 0>, - ArrayVec<$typ, 5> + ArrayVec<$typ, 5>, + Result<$typ, u32> ); } #[allow(unused)] diff --git a/src/derive/impls.rs b/src/derive/impls.rs index d36e903..d35c440 100644 --- a/src/derive/impls.rs +++ b/src/derive/impls.rs @@ -4,6 +4,7 @@ use crate::derive::array::{ArrayDecoder, ArrayEncoder}; use crate::derive::empty::EmptyCoder; use crate::derive::map::{MapDecoder, MapEncoder}; use crate::derive::option::{OptionDecoder, OptionEncoder}; +use crate::derive::result::{ResultDecoder, ResultEncoder}; use crate::derive::smart_ptr::{DerefEncoder, FromDecoder}; use crate::derive::vec::{VecDecoder, VecEncoder}; use crate::derive::{Decode, Encode}; @@ -160,6 +161,12 @@ impl<'a, K: Decode<'a> + Eq + Hash, V: Decode<'a>, S: BuildHasher + Default> Dec type Decoder = MapDecoder<'a, K, V>; } +impl Encode for std::result::Result { + type Encoder = ResultEncoder; +} +impl<'a, T: Decode<'a>, E: Decode<'a>> Decode<'a> for std::result::Result { + type Decoder = ResultDecoder<'a, T, E>; +} impl Encode for PhantomData { type Encoder = EmptyCoder; } diff --git a/src/derive/mod.rs b/src/derive/mod.rs index d5167bd..b04289a 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -8,6 +8,7 @@ mod empty; mod impls; mod map; mod option; +mod result; mod smart_ptr; mod variant; pub(crate) mod vec; diff --git a/src/derive/option.rs b/src/derive/option.rs index 0e1e374..4546f28 100644 --- a/src/derive/option.rs +++ b/src/derive/option.rs @@ -38,7 +38,7 @@ impl Encoder> for OptionEncoder { // Types with many vectorized encoders benefit from a &[&T] since encode_vectorized is still // faster even with the extra indirection. TODO vectored encoder count >= 8 instead of size_of. if std::mem::size_of::() >= 64 { - let mut uninit = std::mem::MaybeUninit::uninit(); + let mut uninit = MaybeUninit::uninit(); let mut refs = FastArrayVec::<_, MAX_VECTORED_CHUNK>::new(&mut uninit); for t in i { diff --git a/src/derive/result.rs b/src/derive/result.rs new file mode 100644 index 0000000..842acaa --- /dev/null +++ b/src/derive/result.rs @@ -0,0 +1,103 @@ +use crate::coder::{Buffer, Decoder, Encoder, View}; +use crate::derive::variant::{VariantDecoder, VariantEncoder}; +use crate::derive::{Decode, Encode}; +use crate::error::Error; +use std::mem::MaybeUninit; +use std::num::NonZeroUsize; + +#[derive(Debug)] +pub struct ResultEncoder { + variants: VariantEncoder<2>, + ok: T::Encoder, + err: E::Encoder, +} + +// Can't derive since it would bound T + E: Default. +impl Default for ResultEncoder { + fn default() -> Self { + Self { + variants: Default::default(), + ok: Default::default(), + err: Default::default(), + } + } +} + +impl Encoder> for ResultEncoder { + #[inline(always)] + fn encode(&mut self, t: &Result) { + self.variants.encode(&(t.is_err() as u8)); + match t { + Ok(t) => { + self.ok.reserve(NonZeroUsize::new(1).unwrap()); + self.ok.encode(t); + } + Err(t) => { + self.err.reserve(NonZeroUsize::new(1).unwrap()); + self.err.encode(t); + } + } + } + // TODO implement encode_vectored if we can avoid lots of code duplication with OptionEncoder. +} + +impl Buffer for ResultEncoder { + fn collect_into(&mut self, out: &mut Vec) { + self.variants.collect_into(out); + self.ok.collect_into(out); + self.err.collect_into(out); + } + + fn reserve(&mut self, additional: NonZeroUsize) { + self.variants.reserve(additional); + // We don't know how many are Ok or Err, so we can't reserve more. + } +} + +#[derive(Debug)] +pub struct ResultDecoder<'a, T: Decode<'a>, E: Decode<'a>> { + variants: VariantDecoder<'a, 2, false>, + ok: T::Decoder, + err: E::Decoder, +} + +// Can't derive since it would bound T: Default. +impl<'a, T: Decode<'a>, E: Decode<'a>> Default for ResultDecoder<'a, T, E> { + fn default() -> Self { + Self { + variants: Default::default(), + ok: Default::default(), + err: Default::default(), + } + } +} + +impl<'a, T: Decode<'a>, E: Decode<'a>> View<'a> for ResultDecoder<'a, T, E> { + fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<(), Error> { + self.variants.populate(input, length)?; + self.ok.populate(input, self.variants.length(0))?; + self.err.populate(input, self.variants.length(1)) + } +} + +impl<'a, T: Decode<'a>, E: Decode<'a>> Decoder<'a, Result> for ResultDecoder<'a, T, E> { + #[inline(always)] + fn decode_in_place(&mut self, out: &mut MaybeUninit>) { + if self.variants.decode() == 0 { + out.write(Ok(self.ok.decode())); + } else { + out.write(Err(self.err.decode())); + } + } +} + +#[cfg(test)] +mod tests { + fn bench_data() -> Vec> { + crate::random_data::<(bool, u32, u8)>(1000) + .into_iter() + .map(|(is_ok, ok, err)| if is_ok { Ok(ok) } else { Err(err) }) + .collect() + } + crate::bench_encode_decode!(result_vec: Vec<_>); +} From eb811da47f2d48e88191827ff8ac0f0c59bdfb6f Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Fri, 23 Feb 2024 14:52:15 -0800 Subject: [PATCH 05/45] Optimize OptionEncoder::encode_vectored. --- src/coder.rs | 6 +++++- src/derive/option.rs | 34 +++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/coder.rs b/src/coder.rs index f492a06..0c1f3c8 100644 --- a/src/coder.rs +++ b/src/coder.rs @@ -37,7 +37,11 @@ pub trait Encoder: Buffer + Default { fn encode(&mut self, t: &T); /// Calls [`Self::encode`] once for every item in `i`. Only use this with **FAST** iterators. - // #[inline(always)] + /// # Safety + /// `i` must have an accurate `i.size_hint().1.unwrap()` that != 0 and is <= [`MAX_VECTORED_CHUNK`]. + /// Currently, the non-map iterators that uphold these requirements are: + /// - vec.rs + /// - option.rs fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) where T: 'a, diff --git a/src/derive/option.rs b/src/derive/option.rs index 4546f28..7365a9e 100644 --- a/src/derive/option.rs +++ b/src/derive/option.rs @@ -44,7 +44,7 @@ impl Encoder> for OptionEncoder { for t in i { self.variants.encode(&(t.is_some() as u8)); if let Some(t) = t { - // Safety: Even if all `Some` won't write more than MAX_VECTORED_CHUNK elements. + // Safety: encode_vectored guarantees less than `MAX_VECTORED_CHUNK` items. unsafe { refs.push_unchecked(t) }; } } @@ -56,19 +56,18 @@ impl Encoder> for OptionEncoder { self.some.reserve(some_count); self.some.encode_vectored(refs.iter().copied()); } else { - let mut some_count = 0; - for t in i.clone() { - let is_some = t.is_some() as u8; - some_count += is_some as usize; - self.variants.encode(&is_some); - } + // Safety: encode_vectored guarantees `i.size_hint().1.unwrap() != 0`. + let size_hint = + unsafe { NonZeroUsize::new(i.size_hint().1.unwrap()).unwrap_unchecked() }; + // size_of::() is small, so we can just assume all elements are Some. + // This will waste a maximum of `MAX_VECTORED_CHUNK * size_of::()` bytes. + self.some.reserve(size_hint); - let Some(some_sum) = NonZeroUsize::new(some_count) else { - return; - }; - self.some.reserve(some_sum); - for t in i.flatten() { - self.some.encode(t); + for option in i { + self.variants.encode(&(option.is_some() as u8)); + if let Some(t) = option { + self.some.encode(t); + } } } } @@ -128,3 +127,12 @@ mod tests { } crate::bench_encode_decode!(option_vec: Vec<_>); } + +#[cfg(test)] +mod tests2 { + #[rustfmt::skip] + fn bench_data() -> Vec> { + crate::random_data(1000) + } + crate::bench_encode_decode!(option_u16_vec: Vec<_>); +} From e9f63b5783b33c6ff4706eb83882c6907bda2ec0 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Fri, 23 Feb 2024 15:18:21 -0800 Subject: [PATCH 06/45] Fix doc warning on comment. --- src/coder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/coder.rs b/src/coder.rs index 0c1f3c8..79f7f18 100644 --- a/src/coder.rs +++ b/src/coder.rs @@ -38,7 +38,7 @@ pub trait Encoder: Buffer + Default { /// Calls [`Self::encode`] once for every item in `i`. Only use this with **FAST** iterators. /// # Safety - /// `i` must have an accurate `i.size_hint().1.unwrap()` that != 0 and is <= [`MAX_VECTORED_CHUNK`]. + /// `i` must have an accurate `i.size_hint().1.unwrap()` that != 0 and is <= `MAX_VECTORED_CHUNK`. /// Currently, the non-map iterators that uphold these requirements are: /// - vec.rs /// - option.rs From 12430a5ffdec15e4d596d4dcaef4a08173341a0b Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sun, 25 Feb 2024 12:24:59 -0800 Subject: [PATCH 07/45] Compile on big endian platforms. --- src/int.rs | 4 +-- src/length.rs | 14 ++++----- src/lib.rs | 4 --- src/pack.rs | 24 +++++++++++----- src/pack_ints.rs | 75 +++++++++++++++++++++++++++--------------------- 5 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/int.rs b/src/int.rs index 86584bb..31e0bb2 100644 --- a/src/int.rs +++ b/src/int.rs @@ -21,7 +21,6 @@ impl Encoder

for IntEncoder { #[inline(always)] fn encode(&mut self, p: &P) { - // TODO swap byte order if big endian. let t = bytemuck::must_cast(*p); unsafe { self.0.push_unchecked(t) }; } @@ -39,7 +38,7 @@ impl Buffer for IntEncoder { } #[derive(Debug, Default)] -pub struct IntDecoder<'a, T: Int>(CowSlice<'a, T::Ule>); +pub struct IntDecoder<'a, T: Int>(CowSlice<'a, T::Une>); impl<'a, T: Int> IntDecoder<'a, T> { // For CheckedIntDecoder. @@ -62,7 +61,6 @@ impl<'a, T: Int, P: Pod> Decoder<'a, P> for IntDecoder<'a, T> { #[inline(always)] fn decode(&mut self) -> P { let v = unsafe { self.0.mut_slice().next_unchecked() }; - // TODO swap byte order if big endian. bytemuck::must_cast(v) } } diff --git a/src/length.rs b/src/length.rs index 2654f2d..3c9e4a4 100644 --- a/src/length.rs +++ b/src/length.rs @@ -8,7 +8,7 @@ use std::num::NonZeroUsize; #[derive(Debug, Default)] pub struct LengthEncoder { small: VecImpl, - large: Vec, // Not a FastVec because capacity isn't known. + large: Vec, // TODO IntEncoder (handles endian and uses smaller integers). } impl Encoder for LengthEncoder { @@ -19,16 +19,16 @@ impl Encoder for LengthEncoder { if v < 255 { *end_ptr = v as u8; } else { + #[cold] #[inline(never)] - #[cold] // TODO cold or only inline(never)? unsafe fn encode_slow(end_ptr: *mut u8, large: &mut Vec, v: usize) { *end_ptr = 255; // Swap bytes if big endian, so we can cast large to little endian &[u8]. - #[cfg(target_endian = "little")] - let v = v as u64; - #[cfg(target_endian = "big")] - let v = (v as u64).swap_bytes(); + let mut v = v as u64; + if cfg!(target_endian = "big") { + v = v.swap_bytes(); + } large.push(v); } encode_slow(end_ptr, &mut self.large, v); @@ -125,7 +125,7 @@ impl Buffer for LengthEncoder { #[derive(Debug, Default)] pub struct LengthDecoder<'a> { small: CowSlice<'a, u8>, - large: SliceImpl<'a, [u8; 8]>, + large: SliceImpl<'a, [u8; 8]>, // TODO IntDecoder. sum: usize, } diff --git a/src/lib.rs b/src/lib.rs index 4032d1a..aa0675c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,10 +9,6 @@ extern crate self as bitcode; #[cfg(test)] extern crate test; -// Missing many calls to swap_bytes throughout the codebase. -#[cfg(target_endian = "big")] -compile_error!("big endian is not yet supported"); - mod bool; mod coder; mod consume; diff --git a/src/pack.rs b/src/pack.rs index c57af6c..d381bf4 100644 --- a/src/pack.rs +++ b/src/pack.rs @@ -374,10 +374,15 @@ fn pack_arithmetic(bytes: &[u8], out: &mut Vec) { unsafe { packed.get_unchecked_mut(i).write( if FACTOR == 2 && cfg!(all(target_feature = "bmi2", not(miri))) { - // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). - let chunk = (bytes.as_ptr() as *const u8 as *const [u8; 8]).add(i); - let chunk = u64::from_le_bytes(*chunk); - std::arch::x86_64::_pext_u64(chunk, 0x0101010101010101) as u8 + #[cfg(not(target_feature = "bmi2"))] + unreachable!(); + #[cfg(target_feature = "bmi2")] + { + // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). + let chunk = (bytes.as_ptr() as *const u8 as *const [u8; 8]).add(i); + let chunk = u64::from_le_bytes(*chunk); + std::arch::x86_64::_pext_u64(chunk, 0x0101010101010101) as u8 + } } else { let mut acc = 0; for byte_index in 0..divisor { @@ -422,9 +427,14 @@ fn unpack_arithmetic( unsafe { let mut packed = *packed.get_unchecked(i); if FACTOR == 2 && cfg!(all(target_feature = "bmi2", not(miri))) { - // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). - let chunk = std::arch::x86_64::_pdep_u64(packed as u64, 0x0101010101010101); - *(unpacked.as_mut_ptr() as *mut [u8; 8]).add(i) = chunk.to_le_bytes(); + #[cfg(not(target_feature = "bmi2"))] + unreachable!(); + #[cfg(target_feature = "bmi2")] + { + // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). + let chunk = std::arch::x86_64::_pdep_u64(packed as u64, 0x0101010101010101); + *(unpacked.as_mut_ptr() as *mut [u8; 8]).add(i) = chunk.to_le_bytes(); + } } else { for byte in unpacked.get_unchecked_mut(i * divisor..i * divisor + divisor) { byte.write(packed % FACTOR as u8); diff --git a/src/pack_ints.rs b/src/pack_ints.rs index 41b11bd..92017af 100644 --- a/src/pack_ints.rs +++ b/src/pack_ints.rs @@ -56,49 +56,48 @@ impl Packing { pub trait Int: Copy + Default + Into + Ord + Pod + Sized + std::ops::Sub + std::ops::SubAssign { - type Ule: Pod + Default; // Unaligned little endian. + // Unaligned native endian. TODO could be aligned on big endian since we always have to copy. + type Une: Pod + Default; const MIN: Self; const MAX: Self; - fn read(input: &mut &[u8]) -> Result; + fn read(input: &mut &[u8]) -> Result; fn write(v: Self, out: &mut Vec); - fn wrapping_add(lhs: Self::Ule, rhs: Self::Ule) -> Self::Ule; + fn wrapping_add(self, rhs: Self::Une) -> Self::Une; #[cfg(test)] - fn from_unaligned(unaligned: Self::Ule) -> Self; + fn from_unaligned(unaligned: Self::Une) -> Self; fn pack128(v: &[Self], out: &mut Vec); fn pack64(v: &[Self], out: &mut Vec); fn pack32(v: &[Self], out: &mut Vec); fn pack16(v: &[Self], out: &mut Vec); fn pack8(v: &mut [Self], out: &mut Vec); - fn unpack128<'a>(v: &'a [[u8; 16]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; - fn unpack64<'a>(v: &'a [[u8; 8]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; - fn unpack32<'a>(v: &'a [[u8; 4]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; - fn unpack16<'a>(v: &'a [[u8; 2]], out: &mut CowSlice<'a, Self::Ule>) -> Result<()>; + fn unpack128<'a>(v: &'a [[u8; 16]], out: &mut CowSlice<'a, Self::Une>) -> Result<()>; + fn unpack64<'a>(v: &'a [[u8; 8]], out: &mut CowSlice<'a, Self::Une>) -> Result<()>; + fn unpack32<'a>(v: &'a [[u8; 4]], out: &mut CowSlice<'a, Self::Une>) -> Result<()>; + fn unpack16<'a>(v: &'a [[u8; 2]], out: &mut CowSlice<'a, Self::Une>) -> Result<()>; fn unpack8<'a>( input: &mut &'a [u8], length: usize, - out: &mut CowSlice<'a, Self::Ule>, + out: &mut CowSlice<'a, Self::Une>, ) -> Result<()>; } macro_rules! impl_simple { () => { - type Ule = [u8; std::mem::size_of::()]; + type Une = [u8; std::mem::size_of::()]; const MIN: Self = Self::MIN; const MAX: Self = Self::MAX; - fn read(input: &mut &[u8]) -> Result { - Ok(consume_byte_arrays(input, 1)?[0]) + fn read(input: &mut &[u8]) -> Result { + Ok(Self::from_le_bytes(consume_byte_arrays(input, 1)?[0])) } fn write(v: Self, out: &mut Vec) { out.extend_from_slice(&v.to_le_bytes()); } - fn wrapping_add(lhs: Self::Ule, rhs: Self::Ule) -> Self::Ule { - Self::from_le_bytes(lhs) - .wrapping_add(Self::from_le_bytes(rhs)) - .to_le_bytes() + fn wrapping_add(self, rhs: Self::Une) -> Self::Une { + self.wrapping_add(Self::from_ne_bytes(rhs)).to_ne_bytes() } #[cfg(test)] - fn from_unaligned(unaligned: Self::Ule) -> Self { - Self::from_le_bytes(unaligned) + fn from_unaligned(unaligned: Self::Une) -> Self { + Self::from_ne_bytes(unaligned) } }; } @@ -107,7 +106,7 @@ macro_rules! impl_unreachable { fn $pack(_: &[Self], _: &mut Vec) { unimplemented!(); } - fn $unpack<'a>(_: &'a [<$t as Int>::Ule], _: &mut CowSlice<'a, Self::Ule>) -> Result<()> { + fn $unpack<'a>(_: &'a [<$t as Int>::Une], _: &mut CowSlice<'a, Self::Une>) -> Result<()> { invalid_packing() } }; @@ -115,10 +114,21 @@ macro_rules! impl_unreachable { macro_rules! impl_self { ($pack:ident, $unpack:ident) => { fn $pack(v: &[Self], out: &mut Vec) { - out.extend_from_slice(bytemuck::cast_slice(&v)) // TODO big endian swap bytes. + // If we're little endian we can copy directly because we encode in little endian. + if cfg!(target_endian = "little") { + out.extend_from_slice(bytemuck::cast_slice(&v)); + } else { + out.extend(v.iter().flat_map(|&v| v.to_le_bytes())); + } } - fn $unpack<'a>(v: &'a [Self::Ule], out: &mut CowSlice<'a, Self::Ule>) -> Result<()> { - out.set_borrowed(v); + fn $unpack<'a>(v: &'a [Self::Une], out: &mut CowSlice<'a, Self::Une>) -> Result<()> { + // If we're little endian we can borrow the input since we encode in little endian. + if cfg!(target_endian = "little") { + out.set_borrowed(v); + } else { + out.set_owned() + .extend(v.iter().map(|&v| Self::from_le_bytes(v).to_ne_bytes())); + } Ok(()) } }; @@ -128,11 +138,10 @@ macro_rules! impl_smaller { fn $pack(v: &[Self], out: &mut Vec) { out.extend(v.iter().flat_map(|&v| (v as $t).to_le_bytes())) } - fn $unpack<'a>(v: &'a [<$t as Int>::Ule], out: &mut CowSlice<'a, Self::Ule>) -> Result<()> { - let mut set_owned = out.set_owned(); - set_owned.extend( + fn $unpack<'a>(v: &'a [<$t as Int>::Une], out: &mut CowSlice<'a, Self::Une>) -> Result<()> { + out.set_owned().extend( v.iter() - .map(|&v| (<$t>::from_le_bytes(v) as Self).to_le_bytes()), + .map(|&v| (<$t>::from_le_bytes(v) as Self).to_ne_bytes()), ); Ok(()) } @@ -159,7 +168,7 @@ macro_rules! impl_u8 { pack_bytes(bytes, out); }) } - fn unpack8(input: &mut &[u8], length: usize, out: &mut CowSlice) -> Result<()> { + fn unpack8(input: &mut &[u8], length: usize, out: &mut CowSlice) -> Result<()> { with_scratch(|allocation| { // unpack_bytes might not result in a copy, but if it does we want to avoid an allocation. let mut bytes = CowSlice::with_allocation(std::mem::take(allocation)); @@ -167,7 +176,7 @@ macro_rules! impl_u8 { // Safety: unpack_bytes ensures bytes has length of `length`. let slice = unsafe { bytes.as_slice(length) }; out.set_owned() - .extend(slice.iter().map(|&v| (v as Self).to_le_bytes())); + .extend(slice.iter().map(|&v| (v as Self).to_ne_bytes())); *allocation = bytes.into_allocation(); Ok(()) }) @@ -234,7 +243,7 @@ fn minmax(v: &[T]) -> (T, T) { (min, max) } -/// Like [`pack_bytes`] but for larger integers. +/// Like [`pack_bytes`] but for larger integers. Handles endian conversion. pub fn pack_ints(ints: &mut [T], out: &mut Vec) { // Passes through u8s and length <= 1 since they can't be compressed. let p = if std::mem::size_of::() == 1 || ints.len() <= 1 { @@ -281,11 +290,11 @@ pub fn pack_ints(ints: &mut [T], out: &mut Vec) { } } -/// Opposite of [`pack_ints`]. Unpacks into `T::Ule` aka unaligned little endian. +/// Opposite of [`pack_ints`]. Unpacks into `T::Une` aka unaligned native endian. pub fn unpack_ints<'a, T: Int>( input: &mut &'a [u8], length: usize, - out: &mut CowSlice<'a, T::Ule>, + out: &mut CowSlice<'a, T::Une>, ) -> Result<()> { // Passes through u8s and length <= 1 since they can't be compressed. let (p, min) = if std::mem::size_of::() == 1 || length <= 1 { @@ -306,7 +315,7 @@ pub fn unpack_ints<'a, T: Int>( // Has to be owned to have min. out.mut_owned(|out| { for v in out { - *v = T::wrapping_add(*v, min); + *v = min.wrapping_add(*v); } }) } @@ -384,7 +393,7 @@ mod tests { fn bench_unpack_ints(b: &mut Bencher, src: &[T]) { let mut packed = vec![]; pack_ints(&mut src.to_vec(), &mut packed); - let mut out = CowSlice::with_allocation(Vec::::with_capacity(src.len())); + let mut out = CowSlice::with_allocation(Vec::::with_capacity(src.len())); b.iter(|| { let length = src.len(); unpack_ints::( From 8c2e5580d4e3e4880079f83abb5c96dacf0c3f38 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sun, 25 Feb 2024 12:38:49 -0800 Subject: [PATCH 08/45] Fix compile on 32 bit x86. --- src/pack.rs | 50 +++++++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/pack.rs b/src/pack.rs index d381bf4..e24161a 100644 --- a/src/pack.rs +++ b/src/pack.rs @@ -358,6 +358,12 @@ fn factor_to_divisor() -> usize { } } +const BMI2: bool = cfg!(all( + target_arch = "x86_64", + target_feature = "bmi2", + not(miri) +)); + /// Packs multiple bytes into one. All the bytes must be < `FACTOR`. /// Factors 2,4,16 are bit packing. Factors 3,6 are arithmetic coding. fn pack_arithmetic(bytes: &[u8], out: &mut Vec) { @@ -372,26 +378,24 @@ fn pack_arithmetic(bytes: &[u8], out: &mut Vec) { for i in 0..floor { unsafe { - packed.get_unchecked_mut(i).write( - if FACTOR == 2 && cfg!(all(target_feature = "bmi2", not(miri))) { - #[cfg(not(target_feature = "bmi2"))] - unreachable!(); - #[cfg(target_feature = "bmi2")] - { - // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). - let chunk = (bytes.as_ptr() as *const u8 as *const [u8; 8]).add(i); - let chunk = u64::from_le_bytes(*chunk); - std::arch::x86_64::_pext_u64(chunk, 0x0101010101010101) as u8 - } - } else { - let mut acc = 0; - for byte_index in 0..divisor { - let byte = *bytes.get_unchecked(i * divisor + byte_index); - acc += byte * (FACTOR as u8).pow(byte_index as u32); - } - acc - }, - ); + packed.get_unchecked_mut(i).write(if FACTOR == 2 && BMI2 { + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] + unreachable!(); + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] + { + // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). + let chunk = (bytes.as_ptr() as *const u8 as *const [u8; 8]).add(i); + let chunk = u64::from_le_bytes(*chunk); + std::arch::x86_64::_pext_u64(chunk, 0x0101010101010101) as u8 + } + } else { + let mut acc = 0; + for byte_index in 0..divisor { + let byte = *bytes.get_unchecked(i * divisor + byte_index); + acc += byte * (FACTOR as u8).pow(byte_index as u32); + } + acc + }); } } if floor < ceil { @@ -426,10 +430,10 @@ fn unpack_arithmetic( for i in 0..floor { unsafe { let mut packed = *packed.get_unchecked(i); - if FACTOR == 2 && cfg!(all(target_feature = "bmi2", not(miri))) { - #[cfg(not(target_feature = "bmi2"))] + if FACTOR == 2 && BMI2 { + #[cfg(not(all(target_arch = "x86_64", target_feature = "bmi2")))] unreachable!(); - #[cfg(target_feature = "bmi2")] + #[cfg(all(target_arch = "x86_64", target_feature = "bmi2"))] { // Could use on any pow2 FACTOR, but only 2 is faster (target-cpu=native). let chunk = std::arch::x86_64::_pdep_u64(packed as u64, 0x0101010101010101); From 5bdc22ba943d0ba8de092a763327b8167656611f Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sun, 25 Feb 2024 16:50:31 -0800 Subject: [PATCH 09/45] Implement usize, use IntEncoder in LengthEncoder. --- fuzz/fuzz_targets/fuzz.rs | 4 +- src/derive/impls.rs | 8 +- src/derive/mod.rs | 2 + src/error.rs | 1 + src/fast.rs | 7 +- src/int.rs | 4 +- src/length.rs | 40 ++++---- src/pack_ints.rs | 198 +++++++++++++++++++++++++++++++------- 8 files changed, 195 insertions(+), 69 deletions(-) diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index 2a936bc..f4285fd 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -160,8 +160,8 @@ fuzz_target!(|data: &[u8]| { i64, u128, i128, - // usize, - // isize, + usize, + isize, BitsEqualF32, BitsEqualF64, Vec, diff --git a/src/derive/impls.rs b/src/derive/impls.rs index d35c440..d19e289 100644 --- a/src/derive/impls.rs +++ b/src/derive/impls.rs @@ -44,8 +44,8 @@ macro_rules! impl_int { )+ } } -impl_int!(u8 => u8, u16 => u16, u32 => u32, u64 => u64, u128 => u128); -impl_int!(i8 => u8, i16 => u16, i32 => u32, i64 => u64, i128 => u128); +impl_int!(u8 => u8, u16 => u16, u32 => u32, u64 => u64, u128 => u128, usize => usize); +impl_int!(i8 => u8, i16 => u16, i32 => u32, i64 => u64, i128 => u128, isize => usize); impl_int!(f64 => u64); // Totally an int... macro_rules! impl_checked_int { @@ -60,8 +60,8 @@ macro_rules! impl_checked_int { )+ } } -impl_checked_int!(NonZeroU8 => u8, NonZeroU16 => u16, NonZeroU32 => u32, NonZeroU64 => u64, NonZeroU128 => u128); -impl_checked_int!(NonZeroI8 => u8, NonZeroI16 => u16, NonZeroI32 => u32, NonZeroI64 => u64, NonZeroI128 => u128); +impl_checked_int!(NonZeroU8 => u8, NonZeroU16 => u16, NonZeroU32 => u32, NonZeroU64 => u64, NonZeroU128 => u128, NonZeroUsize => usize); +impl_checked_int!(NonZeroI8 => u8, NonZeroI16 => u16, NonZeroI32 => u32, NonZeroI64 => u64, NonZeroI128 => u128, NonZeroIsize => usize); impl_checked_int!(char => u32); macro_rules! impl_t { diff --git a/src/derive/mod.rs b/src/derive/mod.rs index b04289a..787c807 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -159,6 +159,8 @@ mod tests { test!(("abc", "123"), (&str, &str)); test!(Vec::>::new(), Vec>); test!(vec![None, Some(1), None], Vec>); + test!((0usize, 1isize), (usize, isize)); + test!(vec![true; 255], Vec); } #[derive(Encode, Decode)] diff --git a/src/error.rs b/src/error.rs index cd71127..d87c47a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -35,6 +35,7 @@ type ErrorImpl = (); /// # Release mode /// In release mode, the error is a zero-sized type for efficiency. #[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] pub struct Error(ErrorImpl); impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { diff --git a/src/fast.rs b/src/fast.rs index 7dc7bd2..4637439 100644 --- a/src/fast.rs +++ b/src/fast.rs @@ -395,14 +395,15 @@ impl<'borrowed, T> CowSlice<'borrowed, T> { /// **Panics** /// /// If self is not owned (set_owned hasn't been called). - pub fn mut_owned(&mut self, f: impl FnOnce(&mut Vec)) { - assert_eq!(self.slice.ptr, self.vec.as_ptr()); + pub fn mut_owned(&mut self, f: impl FnOnce(&mut Vec) -> R) -> R { + assert!(std::ptr::eq(self.slice.ptr, self.vec.as_ptr()), "not owned"); // Clear self.slice before mutating self.vec, so we don't point to freed memory. self.slice = [].as_slice().into(); - f(&mut self.vec); + let ret = f(&mut self.vec); // Safety: We clear `CowSlice.slice` whenever we mutate `CowSlice.vec`. let slice: &'borrowed [T] = unsafe { std::mem::transmute(self.vec.as_slice()) }; self.slice = slice.into(); + ret } } diff --git a/src/int.rs b/src/int.rs index 31e0bb2..2df9373 100644 --- a/src/int.rs +++ b/src/int.rs @@ -41,8 +41,8 @@ impl Buffer for IntEncoder { pub struct IntDecoder<'a, T: Int>(CowSlice<'a, T::Une>); impl<'a, T: Int> IntDecoder<'a, T> { - // For CheckedIntDecoder. - fn borrowed_clone<'me: 'a>(&'me self) -> IntDecoder<'me, T> { + // For CheckedIntDecoder/LengthDecoder. + pub(crate) fn borrowed_clone<'me: 'a>(&'me self) -> IntDecoder<'me, T> { let mut cow = CowSlice::default(); cow.set_borrowed_slice_impl(self.0.ref_slice().clone()); Self(cow) diff --git a/src/length.rs b/src/length.rs index 3c9e4a4..f82de51 100644 --- a/src/length.rs +++ b/src/length.rs @@ -1,14 +1,14 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View}; -use crate::consume::consume_byte_arrays; use crate::error::{err, error}; -use crate::fast::{CowSlice, NextUnchecked, SliceImpl, VecImpl}; +use crate::fast::{CowSlice, NextUnchecked, VecImpl}; +use crate::int::{IntDecoder, IntEncoder}; use crate::pack::{pack_bytes, unpack_bytes}; use std::num::NonZeroUsize; #[derive(Debug, Default)] pub struct LengthEncoder { small: VecImpl, - large: Vec, // TODO IntEncoder (handles endian and uses smaller integers). + large: IntEncoder, } impl Encoder for LengthEncoder { @@ -21,15 +21,10 @@ impl Encoder for LengthEncoder { } else { #[cold] #[inline(never)] - unsafe fn encode_slow(end_ptr: *mut u8, large: &mut Vec, v: usize) { + unsafe fn encode_slow(end_ptr: *mut u8, large: &mut IntEncoder, v: usize) { *end_ptr = 255; - - // Swap bytes if big endian, so we can cast large to little endian &[u8]. - let mut v = v as u64; - if cfg!(target_endian = "big") { - v = v.swap_bytes(); - } - large.push(v); + large.reserve(NonZeroUsize::new(1).unwrap()); + large.encode(&v); } encode_slow(end_ptr, &mut self.large, v); } @@ -113,8 +108,7 @@ impl Buffer for LengthEncoder { fn collect_into(&mut self, out: &mut Vec) { pack_bytes(self.small.as_mut_slice(), out); self.small.clear(); - out.extend_from_slice(bytemuck::cast_slice(self.large.as_slice())); - self.large.clear(); + self.large.collect_into(out); } fn reserve(&mut self, additional: NonZeroUsize) { @@ -125,7 +119,7 @@ impl Buffer for LengthEncoder { #[derive(Debug, Default)] pub struct LengthDecoder<'a> { small: CowSlice<'a, u8>, - large: SliceImpl<'a, [u8; 8]>, // TODO IntDecoder. + large: IntDecoder<'a, usize>, sum: usize, } @@ -140,7 +134,7 @@ impl<'a> LengthDecoder<'a> { small.set_borrowed_slice_impl(self.small.ref_slice().clone()); Self { small, - large: self.large.clone(), + large: self.large.borrowed_clone(), sum: self.sum, } } @@ -182,16 +176,18 @@ impl<'a> View<'a> for LengthDecoder<'a> { // Every 255 byte indicates a large is present. let large_length = small.iter().filter(|&&v| v == 255).count(); - let large: &[[u8; 8]] = consume_byte_arrays(input, large_length)?; - self.large = large.into(); + self.large.populate(input, large_length)?; // Can't overflow since sum includes large_length many 255s. sum -= large_length as u64 * 255; // Summing &[u64] can overflow, so we check it. - for &v in large { - let v = u64::from_le_bytes(v); - sum = sum.checked_add(v).ok_or_else(|| error("length overflow"))?; + let mut decoder = self.large.borrowed_clone(); + for _ in 0..large_length { + let v: usize = decoder.decode(); + sum = sum + .checked_add(v as u64) + .ok_or_else(|| error("length overflow"))?; } if sum >= HUGE_LEN { return err("length overflow"); // Lets us optimize decode with unreachable_unchecked. @@ -214,8 +210,8 @@ impl<'a> Decoder<'a, usize> for LengthDecoder<'a> { v as usize } else { #[cold] - unsafe fn cold(large: &mut SliceImpl<'_, [u8; 8]>) -> usize { - u64::from_le_bytes(large.next_unchecked()) as usize + unsafe fn cold(large: &mut IntDecoder<'_, usize>) -> usize { + large.decode() } cold(&mut self.large) } diff --git a/src/pack_ints.rs b/src/pack_ints.rs index 92017af..d4fb4ec 100644 --- a/src/pack_ints.rs +++ b/src/pack_ints.rs @@ -1,7 +1,9 @@ use crate::coder::Result; use crate::consume::{consume_byte, consume_byte_arrays}; +use crate::error::error; use crate::fast::CowSlice; use crate::pack::{invalid_packing, pack_bytes, unpack_bytes}; +use crate::Error; use bytemuck::Pod; /// Possible integer sizes in descending order. @@ -18,7 +20,7 @@ enum Packing { impl Packing { fn new(max: T) -> Self { - let max: u128 = max.into(); + let max: u128 = max.try_into().unwrap_or_else(|_| unreachable!()); // From isn't implemented for u128. #[allow(clippy::match_overlapping_arm)] // Just make sure not to reorder them. match max { ..=0xFF => Self::_8, @@ -29,20 +31,43 @@ impl Packing { } } + fn no_packing() -> Self { + // usize must encode like u64. + if T::IS_USIZE { + Self::new(u64::MAX) + } else { + Self::new(T::MAX) + } + } + fn write(self, out: &mut Vec, offset_by_min: bool) { // Encoded in such a way such that 0 is no packing and higher numbers are smaller packing. // Also makes no packing with offset_by_min = true is unrepresentable. - out.push((self as u8 - Self::new(T::MAX) as u8) * 2 - offset_by_min as u8); + out.push((self as u8 - Self::no_packing::() as u8) * 2 - offset_by_min as u8); } fn read(input: &mut &[u8]) -> Result<(Self, bool)> { let v = consume_byte(input)?; - let p_u8 = crate::nightly::div_ceil_u8(v, 2) + Self::new(T::MAX) as u8; + let p_u8 = crate::nightly::div_ceil_u8(v, 2) + Self::no_packing::() as u8; let offset_by_min = v & 1 != 0; let p = match p_u8 { 0 => Self::_128, - 1 => Self::_64, - 2 => Self::_32, + 1 => { + if T::IS_USIZE && cfg!(target_pointer_width = "32") { + return Err(usize_too_big()); + } else { + Self::_64 + } + } + 2 => { + if offset_by_min && T::IS_USIZE && cfg!(target_pointer_width = "32") { + // Offsetting u32 would result in u64. If we didn't have this check the + // mut_owned() call would panic (since on 32 bit usize borrows u32). + return Err(usize_too_big()); + } else { + Self::_32 + } + } 3 => Self::_16, 4 => Self::_8, _ => return invalid_packing(), @@ -52,10 +77,16 @@ impl Packing { } } +pub(crate) fn usize_too_big() -> Error { + error("encountered a usize greater than u32::MAX on a 32 bit platform") +} + // Default bound makes #[derive(Default)] on IntEncoder/IntDecoder work. pub trait Int: - Copy + Default + Into + Ord + Pod + Sized + std::ops::Sub + std::ops::SubAssign + Copy + Default + TryInto + Ord + Pod + Sized + std::ops::Sub { + // usize must encode like u64, so it needs a special case. + const IS_USIZE: bool = false; // Unaligned native endian. TODO could be aligned on big endian since we always have to copy. type Une: Pod + Default; const MIN: Self; @@ -63,7 +94,6 @@ pub trait Int: fn read(input: &mut &[u8]) -> Result; fn write(v: Self, out: &mut Vec); fn wrapping_add(self, rhs: Self::Une) -> Self::Une; - #[cfg(test)] fn from_unaligned(unaligned: Self::Une) -> Self; fn pack128(v: &[Self], out: &mut Vec); fn pack64(v: &[Self], out: &mut Vec); @@ -87,15 +117,24 @@ macro_rules! impl_simple { const MIN: Self = Self::MIN; const MAX: Self = Self::MAX; fn read(input: &mut &[u8]) -> Result { - Ok(Self::from_le_bytes(consume_byte_arrays(input, 1)?[0])) + if Self::IS_USIZE { + u64::from_le_bytes(consume_byte_arrays(input, 1)?[0]) + .try_into() + .map_err(|_| usize_too_big()) + } else { + Ok(Self::from_le_bytes(consume_byte_arrays(input, 1)?[0])) + } } fn write(v: Self, out: &mut Vec) { - out.extend_from_slice(&v.to_le_bytes()); + if Self::IS_USIZE { + out.extend_from_slice(&(v as u64).to_le_bytes()); + } else { + out.extend_from_slice(&v.to_le_bytes()); + } } fn wrapping_add(self, rhs: Self::Une) -> Self::Une { self.wrapping_add(Self::from_ne_bytes(rhs)).to_ne_bytes() } - #[cfg(test)] fn from_unaligned(unaligned: Self::Une) -> Self { Self::from_ne_bytes(unaligned) } @@ -104,10 +143,10 @@ macro_rules! impl_simple { macro_rules! impl_unreachable { ($t:ty, $pack:ident, $unpack:ident) => { fn $pack(_: &[Self], _: &mut Vec) { - unimplemented!(); + unreachable!(); // Packings that increase size won't be chosen. } fn $unpack<'a>(_: &'a [<$t as Int>::Une], _: &mut CowSlice<'a, Self::Une>) -> Result<()> { - invalid_packing() + unreachable!(); // Packings that increase size are unrepresentable. } }; } @@ -116,7 +155,7 @@ macro_rules! impl_self { fn $pack(v: &[Self], out: &mut Vec) { // If we're little endian we can copy directly because we encode in little endian. if cfg!(target_endian = "little") { - out.extend_from_slice(bytemuck::cast_slice(&v)); + out.extend_from_slice(bytemuck::must_cast_slice(&v)); } else { out.extend(v.iter().flat_map(|&v| v.to_le_bytes())); } @@ -184,6 +223,24 @@ macro_rules! impl_u8 { }; } +impl Int for usize { + const IS_USIZE: bool = true; + impl_simple!(); + impl_unreachable!(u128, pack128, unpack128); + + #[cfg(target_pointer_width = "64")] + impl_self!(pack64, unpack64); + #[cfg(target_pointer_width = "64")] + impl_smaller!(u32, pack32, unpack32); + + #[cfg(target_pointer_width = "32")] + impl_unreachable!(u64, pack64, unpack64); + #[cfg(target_pointer_width = "32")] + impl_self!(pack32, unpack32); + + impl_smaller!(u16, pack16, unpack16); + impl_u8!(); +} impl Int for u128 { impl_simple!(); impl_self!(pack128, unpack128); @@ -243,10 +300,21 @@ fn minmax(v: &[T]) -> (T, T) { (min, max) } +fn skip_packing(length: usize) -> bool { + // Be careful using size_of:: since usize can be 4 or 8. + if std::mem::size_of::() == 1 { + return true; // u8s can't be packed by pack_ints (only pack_bytes). + } + if length == 0 { + return true; // Can't pack 0 ints. + } + // Packing a single u16 is pointless (takes at least 2 bytes). + std::mem::size_of::() == 2 && length == 1 +} + /// Like [`pack_bytes`] but for larger integers. Handles endian conversion. pub fn pack_ints(ints: &mut [T], out: &mut Vec) { - // Passes through u8s and length <= 1 since they can't be compressed. - let p = if std::mem::size_of::() == 1 || ints.len() <= 1 { + let p = if skip_packing::(ints.len()) { Packing::new(T::MAX) } else { // Take a small sample to avoid wastefully scanning the whole slice. @@ -269,7 +337,7 @@ pub fn pack_ints(ints: &mut [T], out: &mut Vec) { let p2 = Packing::new(max - min); if p2 > p && ints.len() > 5 { for b in ints.iter_mut() { - *b -= min; + *b = *b - min; } p2.write::(out, true); T::write(min, out); @@ -296,8 +364,7 @@ pub fn unpack_ints<'a, T: Int>( length: usize, out: &mut CowSlice<'a, T::Une>, ) -> Result<()> { - // Passes through u8s and length <= 1 since they can't be compressed. - let (p, min) = if std::mem::size_of::() == 1 || length <= 1 { + let (p, min) = if skip_packing::(length) { (Packing::new(T::MAX), None) } else { let (p, offset_by_min) = Packing::read::(input)?; @@ -314,32 +381,90 @@ pub fn unpack_ints<'a, T: Int>( if let Some(min) = min { // Has to be owned to have min. out.mut_owned(|out| { - for v in out { + for v in out.iter_mut() { *v = min.wrapping_add(*v); } + // If a + b < b overflow occurred. + let overflow = || out.iter().any(|v| T::from_unaligned(*v) < min); + + // We only care about overflow if it changes results on 32 bit and 64 bit: + // 1 + u32::MAX as usize overflows on 32 bit but works on 64 bit. + if !T::IS_USIZE || cfg!(target_pointer_width = "64") { + return Ok(()); + } + + // Fast path, overflow is impossible if max(a) + b doesn't overflow. + let max_before_offset = match p { + Packing::_8 => u8::MAX as u128, + Packing::_16 => u16::MAX as u128, + _ => unreachable!(), // _32, _64, _128 won't be returned from Packing::read::() with offset_by_min == true. + }; + let min = min.try_into().unwrap_or_else(|_| unreachable!()); + if max_before_offset + min <= usize::MAX as u128 { + debug_assert!(!overflow()); + return Ok(()); + } + if overflow() { + return Err(usize_too_big()); + } + Ok(()) }) + } else { + Ok(()) } - Ok(()) } #[cfg(test)] mod tests { - use super::{pack_ints, unpack_ints, CowSlice, Int}; + use super::{usize_too_big, CowSlice, Int, Result}; use std::fmt::Debug; use test::{black_box, Bencher}; - fn t(ints: &[T]) -> Vec { + pub fn pack_ints(ints: &[T]) -> Vec { let mut out = vec![]; - pack_ints(&mut ints.to_owned(), &mut out); - - let mut slice = out.as_slice(); - let mut unpacked = CowSlice::default(); - let length = ints.len(); - unpack_ints::(&mut slice, length, &mut unpacked).unwrap(); - let unpacked = unsafe { unpacked.as_slice(length) }; - let unpacked: Vec<_> = unpacked.iter().copied().map(T::from_unaligned).collect(); + super::pack_ints(&mut ints.to_vec(), &mut out); + assert_eq!(ints, unpack_ints(&out, ints.len()).unwrap()); + out + } + pub fn unpack_ints(mut packed: &[u8], length: usize) -> Result> { + let mut out = CowSlice::default(); + super::unpack_ints::(&mut packed, length, &mut out)?; + assert!(packed.is_empty()); + let unpacked = unsafe { out.as_slice(length) }; + Ok(unpacked.iter().copied().map(T::from_unaligned).collect()) + } + const COUNTING: [usize; 8] = [0usize, 1, 2, 3, 4, 5, 6, 7]; + + #[test] + fn test_usize_eq_u64() { + let a = COUNTING; + let b = a.map(|v| v as u64); + assert_eq!(pack_ints(&a), pack_ints(&b)); + let a = COUNTING.map(|v| v + 1000); + let b = a.map(|a| a as u64); + assert_eq!(pack_ints(&a), pack_ints(&b)); + } + + #[test] + fn test_usize_too_big() { + for scale in [1, 1 << 8, 1 << 16, 1 << 32] { + println!("scale {scale}"); + let a = COUNTING.map(|v| v as u64 * scale + u32::MAX as u64); + let packed = pack_ints(&a); + let b = unpack_ints::(&packed, a.len()); + if cfg!(target_pointer_width = "64") { + let b = b.unwrap(); + assert_eq!(a, std::array::from_fn(|i| b[i] as u64)); + } else { + assert_eq!(b.unwrap_err(), usize_too_big()); + } + } + } + + fn t(ints: &[T]) -> Vec { + let out = pack_ints(&mut ints.to_owned()); + let unpacked = unpack_ints::(&out, ints.len()).unwrap(); assert_eq!(unpacked, ints); - assert!(slice.is_empty()); let packing = out[0]; let size = 100.0 * out.len() as f32 / std::mem::size_of_val(ints) as f32; @@ -377,6 +502,7 @@ mod tests { test!(test_u032, u32); test!(test_u064, u64); test!(test_u128, u128); + test!(test_usize, usize); fn bench_pack_ints(b: &mut Bencher, src: &[T]) { let mut ints = src.to_vec(); @@ -385,18 +511,17 @@ mod tests { b.iter(|| { ints.copy_from_slice(&src); out.clear(); - pack_ints(black_box(&mut ints), black_box(&mut out)); + super::pack_ints(black_box(&mut ints), black_box(&mut out)); }); assert_eq!(out.capacity(), starting_cap); } fn bench_unpack_ints(b: &mut Bencher, src: &[T]) { - let mut packed = vec![]; - pack_ints(&mut src.to_vec(), &mut packed); + let packed = pack_ints(&mut src.to_vec()); let mut out = CowSlice::with_allocation(Vec::::with_capacity(src.len())); b.iter(|| { let length = src.len(); - unpack_ints::( + super::unpack_ints::( black_box(&mut packed.as_slice()), length, black_box(&mut out), @@ -441,7 +566,7 @@ mod tests { let input = black_box(&mut ints); out.clear(); let out = black_box(&mut out); - out.extend_from_slice(bytemuck::cast_slice(&input)); + out.extend_from_slice(bytemuck::must_cast_slice(&input)); }); } @@ -467,4 +592,5 @@ mod tests { bench!(u032, u32); bench!(u064, u64); bench!(u128, u128); + bench!(usize, usize); } From 2e2e842c6d56912e66a73d241de19302098d9504 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Mon, 26 Feb 2024 12:11:57 -0800 Subject: [PATCH 10/45] Optimize FastVec::reserve by storing capacity ptr instead of usize. --- src/fast.rs | 81 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/src/fast.rs b/src/fast.rs index 4637439..ff0833b 100644 --- a/src/fast.rs +++ b/src/fast.rs @@ -5,10 +5,11 @@ use std::mem::MaybeUninit; pub type VecImpl = FastVec; pub type SliceImpl<'a, T> = FastSlice<'a, T>; +/// Implementation of [`Vec`] that optimizes push_unchecked at the cost of as_slice being slower. pub struct FastVec { - start: *mut T, // TODO NonNull/Unique? - end: *mut T, - capacity: usize, + start: *mut T, // vec.as_mut_ptr() + end: *mut T, // vec.as_mut_ptr().add(vec.len()) + capacity: *mut T, // vec.as_mut_ptr().add(vec.capacity()) _spooky: PhantomData>, } @@ -32,11 +33,18 @@ impl Drop for FastVec { } } +/// Replacement for `feature = "ptr_sub_ptr"` which isn't yet stable. +#[inline(always)] +fn sub_ptr(ptr: *mut T, origin: *mut T) -> usize { + // unsafe { ptr.sub_ptr(origin) } + (ptr as usize - origin as usize) / std::mem::size_of::() +} + impl From> for Vec { fn from(fast: FastVec) -> Self { let start = fast.start; let length = fast.len(); - let capacity = fast.capacity; + let capacity = sub_ptr(fast.capacity, fast.start); std::mem::forget(fast); unsafe { Vec::from_raw_parts(start, length, capacity) } } @@ -44,9 +52,10 @@ impl From> for Vec { impl From> for FastVec { fn from(mut vec: Vec) -> Self { + assert_ne!(std::mem::size_of::(), 0); let start = vec.as_mut_ptr(); let end = unsafe { start.add(vec.len()) }; - let capacity = vec.capacity(); + let capacity = unsafe { start.add(vec.capacity()) }; std::mem::forget(vec); Self { start, @@ -58,6 +67,10 @@ impl From> for FastVec { } impl FastVec { + fn len(&self) -> usize { + sub_ptr(self.end, self.start) + } + pub fn as_slice(&self) -> &[T] { unsafe { std::slice::from_raw_parts(self.start, self.len()) } } @@ -71,9 +84,7 @@ impl FastVec { } pub fn reserve(&mut self, additional: usize) { - // check copied from RawVec::grow_amortized - let len = self.len(); - if additional > self.capacity.wrapping_sub(len) { + if additional > sub_ptr(self.capacity, self.end) { #[cold] #[inline(never)] fn reserve_slow(me: &mut FastVec, additional: usize) { @@ -101,14 +112,10 @@ impl FastVec { } } - fn len(&self) -> usize { - (self.end as usize - self.start as usize) / std::mem::size_of::() // TODO sub_ptr. - } - /// Get a pointer to write to without incrementing length. #[inline(always)] pub fn end_ptr(&mut self) -> *mut T { - debug_assert!(self.len() <= self.capacity); + debug_assert!(self.end <= self.capacity); self.end } @@ -116,7 +123,7 @@ impl FastVec { #[inline(always)] pub fn set_end_ptr(&mut self, end: *mut T) { self.end = end; - debug_assert!(self.len() <= self.capacity); + debug_assert!(self.end <= self.capacity); } /// Increments length by 1. @@ -127,7 +134,7 @@ impl FastVec { #[inline(always)] pub unsafe fn increment_len(&mut self) { self.end = self.end.add(1); - debug_assert!(self.len() <= self.capacity); + debug_assert!(self.end <= self.capacity); } } @@ -140,7 +147,7 @@ pub trait PushUnchecked { impl PushUnchecked for FastVec { #[inline(always)] unsafe fn push_unchecked(&mut self, t: T) { - debug_assert!(self.len() < self.capacity); + debug_assert!(self.end < self.capacity); std::ptr::write(self.end, t); self.end = self.end.add(1); } @@ -444,12 +451,12 @@ mod tests { assert_eq!(vec.as_slice(), [1, 2]); } - // TODO benchmark with u32 instead of just u8. const N: usize = 1000; + type VecT = Vec; #[bench] fn bench_next_unchecked(b: &mut Bencher) { - let src = vec![0u8; N]; + let src: VecT = vec![0; N]; b.iter(|| { let mut slice = src.as_slice(); for _ in 0..black_box(N) { @@ -460,7 +467,7 @@ mod tests { #[bench] fn bench_next_unchecked_fast(b: &mut Bencher) { - let src = vec![0u8; N]; + let src: VecT = vec![0; N]; b.iter(|| { let mut fast_slice = FastSlice::from(src.as_slice()); for _ in 0..black_box(N) { @@ -469,9 +476,37 @@ mod tests { }); } + #[bench] + fn bench_push(b: &mut Bencher) { + let mut buffer = VecT::with_capacity(N); + b.iter(|| { + buffer.clear(); + let vec = black_box(&mut buffer); + for _ in 0..black_box(N) { + let v = black_box(&mut *vec); + v.push(black_box(0)); + } + }); + } + + #[bench] + fn bench_push_fast(b: &mut Bencher) { + let mut buffer = VecT::with_capacity(N); + b.iter(|| { + buffer.clear(); + let mut vec = black_box(FastVec::from(std::mem::take(&mut buffer))); + for _ in 0..black_box(N) { + let v = black_box(&mut vec); + v.reserve(1); + unsafe { v.push_unchecked(black_box(0)) }; + } + buffer = vec.into(); + }); + } + #[bench] fn bench_push_unchecked(b: &mut Bencher) { - let mut buffer = Vec::with_capacity(N); + let mut buffer = VecT::with_capacity(N); b.iter(|| { buffer.clear(); let vec = black_box(&mut buffer); @@ -484,7 +519,7 @@ mod tests { #[bench] fn bench_push_unchecked_fast(b: &mut Bencher) { - let mut buffer = Vec::with_capacity(N); + let mut buffer = VecT::with_capacity(N); b.iter(|| { buffer.clear(); let mut vec = black_box(FastVec::from(std::mem::take(&mut buffer))); @@ -498,7 +533,7 @@ mod tests { #[bench] fn bench_reserve(b: &mut Bencher) { - let mut buffer = Vec::::with_capacity(N); + let mut buffer = VecT::with_capacity(N); b.iter(|| { buffer.clear(); let vec = black_box(&mut buffer); @@ -510,7 +545,7 @@ mod tests { #[bench] fn bench_reserve_fast(b: &mut Bencher) { - let mut buffer = Vec::::with_capacity(N); + let mut buffer = VecT::with_capacity(N); b.iter(|| { buffer.clear(); let mut vec = black_box(FastVec::from(std::mem::take(&mut buffer))); From 70882caed17c0a2df9634ae62921030c3198abfc Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Mon, 26 Feb 2024 12:42:38 -0800 Subject: [PATCH 11/45] Don't pack <= 2 bytes since it would take 2 or 3 bytes. --- src/derive/mod.rs | 2 ++ src/pack.rs | 26 ++++++++++---------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/derive/mod.rs b/src/derive/mod.rs index 787c807..d4ab2e2 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -161,6 +161,8 @@ mod tests { test!(vec![None, Some(1), None], Vec>); test!((0usize, 1isize), (usize, isize)); test!(vec![true; 255], Vec); + test!([0u8, 1u8], [u8; 2]); + test!([0u8, 1u8, 2u8], [u8; 3]); } #[derive(Encode, Decode)] diff --git a/src/pack.rs b/src/pack.rs index e24161a..1425080 100644 --- a/src/pack.rs +++ b/src/pack.rs @@ -71,6 +71,10 @@ pub fn unpack_bools(input: &mut &[u8], length: usize, out: &mut CowSlice) unpack_arithmetic::<2>(input, length, out) } +fn skip_packing(length: usize) -> bool { + length <= 2 // Packing takes at least 2 bytes, so it can only expand <= 2 bytes. +} + /// Packs multiple bytes into single bytes and writes them to `out`. This only works if /// `max - min < 16`, otherwise this just copies `bytes` to `out`. /// @@ -79,14 +83,9 @@ pub fn unpack_bools(input: &mut &[u8], length: usize, out: &mut CowSlice) /// /// Mutates `bytes` to avoid copying them. The remaining `bytes` should be considered garbage. pub fn pack_bytes(bytes: &mut [u8], out: &mut Vec) { - // Pass through bytes.len() <= 1. - match bytes { - [] => return, - [v] => { - out.push(*v); - return; - } - _ => (), + if skip_packing(bytes.len()) { + out.extend_from_slice(bytes); + return; } let mut min = 255; @@ -127,14 +126,9 @@ pub fn unpack_bytes<'a>( length: usize, out: &mut CowSlice<'a, u8>, ) -> Result<()> { - // Pass through length <= 1. - match length { - 0 => return Ok(()), - 1 => { - out.set_borrowed(consume_bytes(input, 1)?); - return Ok(()); - } - _ => (), + if skip_packing(length) { + out.set_borrowed(consume_bytes(input, length)?); + return Ok(()); } let (p, offset_by_min) = Packing::read(input)?; From f0ef93469f40dee302b8a7649f66c4039a2b59ca Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Mon, 26 Feb 2024 13:46:24 -0800 Subject: [PATCH 12/45] Clarify doc. --- src/serde/de.rs | 2 +- src/serde/ser.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index 23913c2..29718e3 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -21,7 +21,7 @@ mod inner { /// Deserializes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Deserialize`]. /// /// **Warning:** The format is incompatible with [`encode`][`crate::encode`] and subject to - /// change between versions. + /// change between major versions. #[cfg_attr(doc, doc(cfg(feature = "serde")))] pub fn deserialize<'de, T: Deserialize<'de>>(mut bytes: &'de [u8]) -> Result { let mut decoder = SerdeDecoder::Unspecified2 { length: 1 }; diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 2f3aaf9..25b0da1 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -21,7 +21,7 @@ mod inner { /// Serializes a `T:` [`Serialize`] into a [`Vec`]. /// /// **Warning:** The format is incompatible with [`decode`][`crate::decode`] and subject to - /// change between versions. + /// change between major versions. #[cfg_attr(doc, doc(cfg(feature = "serde")))] pub fn serialize(t: &T) -> Result, Error> { let mut lazy = LazyEncoder::Unspecified { From 88fe307bb16b24827b3d3139d4f96e634a0f21ff Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Tue, 27 Feb 2024 12:41:36 -0800 Subject: [PATCH 13/45] Update lz4_flex and use unsafe encode/decode for more accurate benchmarks. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0a9639d..23e2d54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ serde = { version = "1.0", optional = true } arrayvec = { version = "0.7", features = ["serde"] } bincode = "1.3.3" flate2 = "1.0.28" -lz4_flex = "0.10.0" +lz4_flex = { version = "0.11.2", default-features = false } paste = "1.0.14" rand = "0.8.5" rand_chacha = "0.3.1" From a740a0ca9a91a59e03b9262fdf8b7e656b83fdc4 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Tue, 27 Feb 2024 12:43:54 -0800 Subject: [PATCH 14/45] Fix comment. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 23e2d54..60fe026 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,6 @@ zstd = "0.13.0" derive = [ "bitcode_derive" ] default = [ "derive" ] -# TODO halfs speed of benches_borrowed::bench_splitcode_decode +# TODO halfs speed of benches_borrowed::bench_bitcode_decode #[profile.bench] #lto = true From 2bfd00377b0d369ca9d283d4dbd581329683f910 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Tue, 27 Feb 2024 13:04:08 -0800 Subject: [PATCH 15/45] Add back github workflow (without miri big-endian for now). --- .github/workflows/build.yml | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 .github/workflows/build.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..df9d6c7 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,46 @@ +name: Build + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + # Nightly toolchain must ship the `rust-std` component for + # `i686-unknown-linux-gnu` and `mips64-unknown-linux-gnuabi64`. + # In practice, `rust-std` almost always ships for + # `i686-unknown-linux-gnu` so we just need to check this page for a + # compatible nightly: + # https://rust-lang.github.io/rustup-components-history/mips64-unknown-linux-gnuabi64.html + toolchain: nightly-2023-07-04 + override: true + components: rustfmt, miri + - name: Lint + run: cargo fmt --check + - name: Test (debug) + run: cargo test + - name: Install i686 and GCC multilib + run: rustup target add i686-unknown-linux-gnu && sudo apt update && sudo apt install -y gcc-multilib + - name: Test (32-bit) + run: cargo test --target i686-unknown-linux-gnu + - name: Setup Miri + run: cargo miri setup + - name: Test (miri) + run: cargo miri test + - name: Setup Miri (big-endian) + run: rustup target add mips64-unknown-linux-gnuabi64 && cargo miri setup --target mips64-unknown-linux-gnuabi64 +# TODO miri big-endian (zstd doesn't compile) +# - name: Test (miri big-endian) +# run: cargo miri test --target mips64-unknown-linux-gnuabi64 From 5b2864d56fcc789144ec9064240a8b14d8aa237d Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Tue, 27 Feb 2024 13:08:58 -0800 Subject: [PATCH 16/45] Use older nightly due to fmt changing. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index df9d6c7..dad80cf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,7 +24,7 @@ jobs: # `i686-unknown-linux-gnu` so we just need to check this page for a # compatible nightly: # https://rust-lang.github.io/rustup-components-history/mips64-unknown-linux-gnuabi64.html - toolchain: nightly-2023-07-04 + toolchain: nightly-2023-04-25 override: true components: rustfmt, miri - name: Lint From b13efde315a5b3853bf500d60c52a81a6a11097c Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 28 Feb 2024 13:57:44 -0800 Subject: [PATCH 17/45] Pack signed integers, use F32Encoder in serialize. --- src/derive/impls.rs | 25 +-- src/derive/mod.rs | 7 +- src/f32.rs | 2 +- src/pack.rs | 118 ++++++++++---- src/pack_ints.rs | 368 ++++++++++++++++++++++++++++---------------- src/serde/de.rs | 16 +- src/serde/ser.rs | 38 +++-- 7 files changed, 385 insertions(+), 189 deletions(-) diff --git a/src/derive/impls.rs b/src/derive/impls.rs index d19e289..a431ef3 100644 --- a/src/derive/impls.rs +++ b/src/derive/impls.rs @@ -32,21 +32,26 @@ impl_both!(f32, F32Encoder, F32Decoder); impl_both!(String, StrEncoder, StrDecoder); macro_rules! impl_int { - ($($a:ty => $b:ty),+) => { + ($($t:ty),+) => { $( - impl Encode for $a { - type Encoder = IntEncoder<$b>; + impl Encode for $t { + type Encoder = IntEncoder<$t>; } - - impl<'a> Decode<'a> for $a { - type Decoder = IntDecoder<'a, $b>; + impl<'a> Decode<'a> for $t { + type Decoder = IntDecoder<'a, $t>; } )+ } } -impl_int!(u8 => u8, u16 => u16, u32 => u32, u64 => u64, u128 => u128, usize => usize); -impl_int!(i8 => u8, i16 => u16, i32 => u32, i64 => u64, i128 => u128, isize => usize); -impl_int!(f64 => u64); // Totally an int... +impl_int!(u8, u16, u32, u64, u128, usize); +impl_int!(i8, i16, i32, i64, i128, isize); +// TODO F64Encoder (once F32Encoder is sufficiently optimized). +impl Encode for f64 { + type Encoder = IntEncoder; +} +impl<'a> Decode<'a> for f64 { + type Decoder = IntDecoder<'a, u64>; +} macro_rules! impl_checked_int { ($($a:ty => $b:ty),+) => { @@ -61,7 +66,7 @@ macro_rules! impl_checked_int { } } impl_checked_int!(NonZeroU8 => u8, NonZeroU16 => u16, NonZeroU32 => u32, NonZeroU64 => u64, NonZeroU128 => u128, NonZeroUsize => usize); -impl_checked_int!(NonZeroI8 => u8, NonZeroI16 => u16, NonZeroI32 => u32, NonZeroI64 => u64, NonZeroI128 => u128, NonZeroIsize => usize); +impl_checked_int!(NonZeroI8 => i8, NonZeroI16 => i16, NonZeroI32 => i32, NonZeroI64 => i64, NonZeroI128 => i128, NonZeroIsize => isize); impl_checked_int!(char => u32); macro_rules! impl_t { diff --git a/src/derive/mod.rs b/src/derive/mod.rs index d4ab2e2..4f1439a 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -159,10 +159,11 @@ mod tests { test!(("abc", "123"), (&str, &str)); test!(Vec::>::new(), Vec>); test!(vec![None, Some(1), None], Vec>); - test!((0usize, 1isize), (usize, isize)); + test!((0, 1), (usize, isize)); test!(vec![true; 255], Vec); - test!([0u8, 1u8], [u8; 2]); - test!([0u8, 1u8, 2u8], [u8; 3]); + test!([0, 1], [u8; 2]); + test!([0, 1, 2], [u8; 3]); + test!([0, -1, 0, -1, 0, -1, 0], [i8; 7]); } #[derive(Encode, Decode)] diff --git a/src/f32.rs b/src/f32.rs index f056912..782edb9 100644 --- a/src/f32.rs +++ b/src/f32.rs @@ -75,7 +75,7 @@ impl Buffer for F32Encoder { } } -#[derive(Default)] +#[derive(Debug, Default)] pub struct F32Decoder<'a> { // While it is true that this contains 1 bit of the exp we still call it mantissa. mantissa: FastSlice<'a, [u8; 3]>, diff --git a/src/pack.rs b/src/pack.rs index 1425080..9e0848c 100644 --- a/src/pack.rs +++ b/src/pack.rs @@ -2,6 +2,7 @@ use crate::coder::Result; use crate::consume::{consume_byte, consume_byte_arrays, consume_bytes}; use crate::error::err; use crate::fast::CowSlice; +use crate::pack_ints::SizedInt; /// Possible states per byte in descending order. Each packed byte will use `log2(states)` bits. #[repr(u8)] @@ -75,6 +76,10 @@ fn skip_packing(length: usize) -> bool { length <= 2 // Packing takes at least 2 bytes, so it can only expand <= 2 bytes. } +pub trait Byte: SizedInt {} +impl Byte for u8 {} +impl Byte for i8 {} + /// Packs multiple bytes into single bytes and writes them to `out`. This only works if /// `max - min < 16`, otherwise this just copies `bytes` to `out`. /// @@ -82,32 +87,47 @@ fn skip_packing(length: usize) -> bool { /// avoid confusing bytewise compression algorithms (e.g. Deflate). /// /// Mutates `bytes` to avoid copying them. The remaining `bytes` should be considered garbage. -pub fn pack_bytes(bytes: &mut [u8], out: &mut Vec) { +pub fn pack_bytes(bytes: &mut [T], out: &mut Vec) { if skip_packing(bytes.len()) { - out.extend_from_slice(bytes); + out.extend_from_slice(bytemuck::must_cast_slice(bytes)); return; } + let (min, max) = crate::pack_ints::minmax(bytes); - let mut min = 255; - let mut max = 0; - for &v in bytes.iter() { - min = min.min(v); - max = max.max(v); - } + // i8 packs as u8 if positive. + let basic_packing = if min >= T::default() { + Packing::new(bytemuck::must_cast(max)) + } else { + Packing::_256 // Any negative i8 as u8 is > 15 and can't be packed without offset_packing. + }; + // u8::wrapping_sub == i8::wrapping_sub, so we can use u8s from here onward. + let min: u8 = bytemuck::must_cast(min); + let max: u8 = bytemuck::must_cast(max); + let bytes: &mut [u8] = bytemuck::must_cast_slice_mut(bytes); + pack_bytes_unsigned(bytes, out, basic_packing, min, max); +} + +/// [`pack_bytes`] but after i8s have been cast to u8s. +fn pack_bytes_unsigned( + bytes: &mut [u8], + out: &mut Vec, + basic_packing: Packing, + min: u8, + max: u8, +) { // If subtracting min from all bytes results in a better packing do it, otherwise don't bother. - let p = Packing::new(max); - let p2 = Packing::new(max - min); - let p = if p2 > p && bytes.len() > 5 { + let offset_packing = Packing::new(max.wrapping_sub(min)); + let p = if offset_packing > basic_packing && bytes.len() > 5 { for b in bytes.iter_mut() { - *b -= min; + *b = b.wrapping_sub(min); } - p2.write(out, true); + offset_packing.write(out, true); out.push(min); - p2 + offset_packing } else { - p.write(out, false); - p + basic_packing.write(out, false); + basic_packing }; match p { @@ -121,7 +141,18 @@ pub fn pack_bytes(bytes: &mut [u8], out: &mut Vec) { } /// Opposite of `pack_bytes`. Needs to know the `length` in bytes. `out` is overwritten with the bytes. -pub fn unpack_bytes<'a>( +pub fn unpack_bytes<'a, T: Byte>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, T>, +) -> Result<()> { + // Safety: T is u8 or i8 which have same size/align and are Copy. + let out: &mut CowSlice<'a, u8> = unsafe { std::mem::transmute(out) }; + unpack_bytes_unsigned(input, length, out) +} + +/// [`unpack_bytes`] but after i8s have been cast to u8s. +fn unpack_bytes_unsigned<'a>( input: &mut &'a [u8], length: usize, out: &mut CowSlice<'a, u8>, @@ -152,7 +183,6 @@ pub fn unpack_bytes<'a>( } if let Some(min) = min { for v in out { - // TODO validate min such that overflow is impossible and numbers like 0 aren't valid. *v = v.wrapping_add(min); } } @@ -458,28 +488,52 @@ mod tests { use paste::paste; use test::{black_box, Bencher}; - #[test] - fn test_pack_bytes() { - fn pack_bytes(bytes: &[u8]) -> Vec { - let mut out = vec![]; - super::pack_bytes(&mut bytes.to_owned(), &mut out); - out - } + fn pack_bytes(bytes: &[T]) -> Vec { + let mut out = vec![]; + super::pack_bytes(&mut bytes.to_owned(), &mut out); + out + } + + fn unpack_bytes(mut packed: &[u8], length: usize) -> Vec { + let mut out = crate::fast::CowSlice::default(); + super::unpack_bytes(&mut packed, length, &mut out).unwrap(); + assert!(packed.is_empty()); + unsafe { out.as_slice(length).to_vec() } + } - assert!(pack_bytes(&[1, 2, 3, 4, 5, 6, 7]).len() < 7); - assert!(pack_bytes(&[201, 202, 203, 204, 205, 206, 207]).len() < 7); + #[test] + fn test_pack_bytes_u8() { + assert_eq!(pack_bytes(&[1u8, 2, 3, 4, 5, 6, 7]).len(), 5); + assert_eq!(pack_bytes(&[201u8, 202, 203, 204, 205, 206, 207]).len(), 6); for max in 0..255u8 { for sub in [1, 2, 3, 4, 5, 15, 255] { let min = max.saturating_sub(sub); let original = [min, min, min, min, min, min, min, max]; let packed = pack_bytes(&original); + let unpacked = unpack_bytes(&packed, original.len()); + assert_eq!(original.as_slice(), unpacked.as_slice()); + } + } + } - let mut slice = packed.as_slice(); - let mut out = crate::fast::CowSlice::default(); - super::unpack_bytes(&mut slice, original.len(), &mut out).unwrap(); - assert!(slice.is_empty()); - assert_eq!(original, unsafe { out.as_slice(original.len()) }); + #[test] + fn test_pack_bytes_i8() { + assert_eq!(pack_bytes(&[1i8, 2, 3, 4, 5, 6, 7]).len(), 5); + assert_eq!(pack_bytes(&[-1i8, -2, -3, -4, -5, -6, -7]).len(), 6); + assert_eq!(pack_bytes(&[-3i8, -2, -1, 0, 1, 2, 3]).len(), 6); + assert_eq!( + pack_bytes(&[0i8, -1, 0, -1, 0, -1, 0]), + [9, (-1i8) as u8, 0b1010101] + ); + + for max in i8::MIN..i8::MAX { + for sub in [1, 2, 3, 4, 5, 15, 127] { + let min = max.saturating_sub(sub); + let original = [min, min, min, min, min, min, min, max]; + let packed = pack_bytes(&original); + let unpacked = unpack_bytes(&packed, original.len()); + assert_eq!(original.as_slice(), unpacked.as_slice()); } } } diff --git a/src/pack_ints.rs b/src/pack_ints.rs index d4fb4ec..422b256 100644 --- a/src/pack_ints.rs +++ b/src/pack_ints.rs @@ -19,7 +19,7 @@ enum Packing { } impl Packing { - fn new(max: T) -> Self { + fn new(max: T) -> Self { let max: u128 = max.try_into().unwrap_or_else(|_| unreachable!()); // From isn't implemented for u128. #[allow(clippy::match_overlapping_arm)] // Just make sure not to reorder them. match max { @@ -31,43 +31,20 @@ impl Packing { } } - fn no_packing() -> Self { - // usize must encode like u64. - if T::IS_USIZE { - Self::new(u64::MAX) - } else { - Self::new(T::MAX) - } - } - - fn write(self, out: &mut Vec, offset_by_min: bool) { + fn write(self, out: &mut Vec, offset_by_min: bool) { // Encoded in such a way such that 0 is no packing and higher numbers are smaller packing. // Also makes no packing with offset_by_min = true is unrepresentable. - out.push((self as u8 - Self::no_packing::() as u8) * 2 - offset_by_min as u8); + out.push((self as u8 - Self::new(T::MAX) as u8) * 2 - offset_by_min as u8); } - fn read(input: &mut &[u8]) -> Result<(Self, bool)> { + fn read(input: &mut &[u8]) -> Result<(Self, bool)> { let v = consume_byte(input)?; - let p_u8 = crate::nightly::div_ceil_u8(v, 2) + Self::no_packing::() as u8; + let p_u8 = crate::nightly::div_ceil_u8(v, 2) + Self::new(T::MAX) as u8; let offset_by_min = v & 1 != 0; let p = match p_u8 { 0 => Self::_128, - 1 => { - if T::IS_USIZE && cfg!(target_pointer_width = "32") { - return Err(usize_too_big()); - } else { - Self::_64 - } - } - 2 => { - if offset_by_min && T::IS_USIZE && cfg!(target_pointer_width = "32") { - // Offsetting u32 would result in u64. If we didn't have this check the - // mut_owned() call would panic (since on 32 bit usize borrows u32). - return Err(usize_too_big()); - } else { - Self::_32 - } - } + 1 => Self::_64, + 2 => Self::_32, 3 => Self::_16, 4 => Self::_8, _ => return invalid_packing(), @@ -77,24 +54,102 @@ impl Packing { } } -pub(crate) fn usize_too_big() -> Error { - error("encountered a usize greater than u32::MAX on a 32 bit platform") +fn usize_too_big() -> Error { + error("encountered a isize/usize with more than 32 bits on a 32 bit platform") } -// Default bound makes #[derive(Default)] on IntEncoder/IntDecoder work. -pub trait Int: - Copy + Default + TryInto + Ord + Pod + Sized + std::ops::Sub -{ - // usize must encode like u64, so it needs a special case. - const IS_USIZE: bool = false; +pub trait Int: Copy + std::fmt::Debug + Default + Ord + Pod + Sized { // Unaligned native endian. TODO could be aligned on big endian since we always have to copy. type Une: Pod + Default; + type Int: SizedInt; + fn from_unaligned(unaligned: Self::Une) -> Self { + bytemuck::must_cast(unaligned) + } + fn to_unaligned(self) -> Self::Une { + bytemuck::must_cast(self) + } + fn with_input(ints: &mut [Self], f: impl FnOnce(&mut [Self::Int])); + fn with_output<'a>( + out: &mut CowSlice<'a, Self::Une>, + length: usize, + f: impl FnOnce(&mut CowSlice<'a, ::Une>) -> Result<()>, + ) -> Result<()>; +} +macro_rules! impl_usize_and_isize { + ($($isize:ident => $i64:ident),+) => { + $( + impl Int for $isize { + type Une = [u8; std::mem::size_of::()]; + type Int = $i64; + fn with_input(ints: &mut [Self], f: impl FnOnce(&mut [Self::Int])) { + if cfg!(target_pointer_width = "64") { + f(bytemuck::cast_slice_mut(ints)) + } else { + // 32 bit isize to i64 requires conversion. TODO reuse allocation. + let mut ints: Vec<$i64> = ints.iter().map(|&v| v as $i64).collect(); + f(&mut ints); + } + } + fn with_output<'a>(out: &mut CowSlice<'a, Self::Une>, length: usize, f: impl FnOnce(&mut CowSlice<'a, ::Une>) -> Result<()>) -> Result<()> { + if cfg!(target_pointer_width = "64") { + // Safety: isize::Une == i64::Une on 64 bit. + f(unsafe { std::mem::transmute(out) }) + } else { + // i64 to 32 bit isize on requires checked conversion. TODO reuse allocations. + let mut out_i64 = CowSlice::default(); + f(&mut out_i64)?; + let out_i64 = unsafe { out_i64.as_slice(length) }; + let out_isize: Result> = out_i64.iter().map(|&v| $i64::from_unaligned(v).try_into().map(Self::to_unaligned).map_err(|_| usize_too_big())).collect(); + *out.set_owned() = out_isize?; + Ok(()) + } + } + } + )+ + } +} +impl_usize_and_isize!(usize => u64, isize => i64); + +/// An [`Int`] that has a fixed size independent of platform (not usize). +pub trait SizedInt: Int { + type Unsigned: SizedUInt; const MIN: Self; const MAX: Self; + fn to_unsigned(self) -> Self::Unsigned { + bytemuck::must_cast(self) + } +} + +macro_rules! impl_int { + ($($int:ident => $uint:ident),+) => { + $( + impl Int for $int { + type Une = [u8; std::mem::size_of::()]; + type Int = Self; + fn with_input(ints: &mut [Self], f: impl FnOnce(&mut [Self::Int])) { + f(ints) + } + fn with_output<'a>(out: &mut CowSlice<'a, Self::Une>, _: usize, f: impl FnOnce(&mut CowSlice<'a, ::Une>) -> Result<()>) -> Result<()> { + f(out) + } + } + impl SizedInt for $int { + type Unsigned = $uint; + const MIN: Self = Self::MIN; + const MAX: Self = Self::MAX; + } + )+ + } +} +impl_int!(u8 => u8, u16 => u16, u32 => u32, u64 => u64, u128 => u128); +impl_int!(i8 => u8, i16 => u16, i32 => u32, i64 => u64, i128 => u128); + +/// A [`SizedInt`] that is unsigned. +pub trait SizedUInt: SizedInt + TryInto { fn read(input: &mut &[u8]) -> Result; fn write(v: Self, out: &mut Vec); fn wrapping_add(self, rhs: Self::Une) -> Self::Une; - fn from_unaligned(unaligned: Self::Une) -> Self; + fn wrapping_sub(self, rhs: Self) -> Self; fn pack128(v: &[Self], out: &mut Vec); fn pack64(v: &[Self], out: &mut Vec); fn pack32(v: &[Self], out: &mut Vec); @@ -113,30 +168,17 @@ pub trait Int: macro_rules! impl_simple { () => { - type Une = [u8; std::mem::size_of::()]; - const MIN: Self = Self::MIN; - const MAX: Self = Self::MAX; fn read(input: &mut &[u8]) -> Result { - if Self::IS_USIZE { - u64::from_le_bytes(consume_byte_arrays(input, 1)?[0]) - .try_into() - .map_err(|_| usize_too_big()) - } else { - Ok(Self::from_le_bytes(consume_byte_arrays(input, 1)?[0])) - } + Ok(Self::from_le_bytes(consume_byte_arrays(input, 1)?[0])) } fn write(v: Self, out: &mut Vec) { - if Self::IS_USIZE { - out.extend_from_slice(&(v as u64).to_le_bytes()); - } else { - out.extend_from_slice(&v.to_le_bytes()); - } + out.extend_from_slice(&v.to_le_bytes()); } fn wrapping_add(self, rhs: Self::Une) -> Self::Une { self.wrapping_add(Self::from_ne_bytes(rhs)).to_ne_bytes() } - fn from_unaligned(unaligned: Self::Une) -> Self { - Self::from_ne_bytes(unaligned) + fn wrapping_sub(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) } }; } @@ -223,25 +265,7 @@ macro_rules! impl_u8 { }; } -impl Int for usize { - const IS_USIZE: bool = true; - impl_simple!(); - impl_unreachable!(u128, pack128, unpack128); - - #[cfg(target_pointer_width = "64")] - impl_self!(pack64, unpack64); - #[cfg(target_pointer_width = "64")] - impl_smaller!(u32, pack32, unpack32); - - #[cfg(target_pointer_width = "32")] - impl_unreachable!(u64, pack64, unpack64); - #[cfg(target_pointer_width = "32")] - impl_self!(pack32, unpack32); - - impl_smaller!(u16, pack16, unpack16); - impl_u8!(); -} -impl Int for u128 { +impl SizedUInt for u128 { impl_simple!(); impl_self!(pack128, unpack128); impl_smaller!(u64, pack64, unpack64); @@ -249,7 +273,7 @@ impl Int for u128 { impl_smaller!(u16, pack16, unpack16); impl_u8!(); } -impl Int for u64 { +impl SizedUInt for u64 { impl_simple!(); impl_unreachable!(u128, pack128, unpack128); impl_self!(pack64, unpack64); @@ -257,7 +281,7 @@ impl Int for u64 { impl_smaller!(u16, pack16, unpack16); impl_u8!(); } -impl Int for u32 { +impl SizedUInt for u32 { impl_simple!(); impl_unreachable!(u128, pack128, unpack128); impl_unreachable!(u64, pack64, unpack64); @@ -265,7 +289,7 @@ impl Int for u32 { impl_smaller!(u16, pack16, unpack16); impl_u8!(); } -impl Int for u16 { +impl SizedUInt for u16 { impl_simple!(); impl_unreachable!(u128, pack128, unpack128); impl_unreachable!(u64, pack64, unpack64); @@ -273,7 +297,7 @@ impl Int for u16 { impl_self!(pack16, unpack16); impl_u8!(); } -impl Int for u8 { +impl SizedUInt for u8 { impl_simple!(); impl_unreachable!(u128, pack128, unpack128); impl_unreachable!(u64, pack64, unpack64); @@ -290,7 +314,7 @@ impl Int for u8 { } } -fn minmax(v: &[T]) -> (T, T) { +pub fn minmax(v: &[T]) -> (T, T) { let mut min = T::MAX; let mut max = T::MIN; for &v in v.iter() { @@ -300,7 +324,7 @@ fn minmax(v: &[T]) -> (T, T) { (min, max) } -fn skip_packing(length: usize) -> bool { +fn skip_packing(length: usize) -> bool { // Be careful using size_of:: since usize can be 4 or 8. if std::mem::size_of::() == 1 { return true; // u8s can't be packed by pack_ints (only pack_bytes). @@ -314,39 +338,74 @@ fn skip_packing(length: usize) -> bool { /// Like [`pack_bytes`] but for larger integers. Handles endian conversion. pub fn pack_ints(ints: &mut [T], out: &mut Vec) { - let p = if skip_packing::(ints.len()) { - Packing::new(T::MAX) + T::with_input(ints, |ints| pack_ints_sized(ints, out)); +} + +/// [`pack_ints`] but after isize has been converted to i64. +fn pack_ints_sized(ints: &mut [T], out: &mut Vec) { + // Handle i8 right away since pack_bytes needs to know that it's signed. + // If we didn't have this special case [0i8, -1, 0, -1, 0, -1] couldn't be packed. + // Doesn't affect larger signed ints because they're made positive before pack_bytes:: is called. + if std::mem::size_of::() == 1 && T::MIN < T::default() { + let ints: &mut [i8] = bytemuck::must_cast_slice_mut(ints); + pack_bytes(ints, out); + return; + }; + + let (basic_packing, min_max) = if skip_packing::(ints.len()) { + (Packing::new(T::Unsigned::MAX), None) } else { // Take a small sample to avoid wastefully scanning the whole slice. let (sample, remaining) = ints.split_at(ints.len().min(16)); let (min, max) = minmax(sample); - // Only have to check packing(max - min) since it's always as good as just packing(max). - let none = Packing::new(T::MAX); - if Packing::new(max - min) == none { - none.write::(out, false); - none + // Only have to check packing(max - min) since it's always as good as packing(max). + let none = Packing::new(T::Unsigned::MAX); + if Packing::new(max.to_unsigned().wrapping_sub(min.to_unsigned())) == none { + none.write::(out, false); + (none, None) } else { let (remaining_min, remaining_max) = minmax(remaining); let min = min.min(remaining_min); let max = max.max(remaining_max); - // If subtracting min from all ints results in a better packing do it, otherwise don't bother. - // TODO ensure packing never expands data unnecessarily. - let p = Packing::new(max); - let p2 = Packing::new(max - min); - if p2 > p && ints.len() > 5 { - for b in ints.iter_mut() { - *b = *b - min; - } - p2.write::(out, true); - T::write(min, out); - p2 + // Signed ints pack as unsigned ints if positive. + let basic_packing = if min >= T::default() { + Packing::new(max.to_unsigned()) } else { - p.write::(out, false); - p + none // Any negative can't be packed without offset_packing. + }; + (basic_packing, Some((min, max))) + } + }; + let ints = bytemuck::must_cast_slice_mut(ints); + let min_max = min_max.map(|(min, max)| (min.to_unsigned(), max.to_unsigned())); + pack_ints_sized_unsigned::(ints, out, basic_packing, min_max); +} + +/// [`pack_ints_sized`] but after signed integers have been cast to unsigned. +fn pack_ints_sized_unsigned( + ints: &mut [T], + out: &mut Vec, + basic_packing: Packing, + min_max: Option<(T, T)>, +) { + let p = if let Some((min, max)) = min_max { + // If subtracting min from all ints results in a better packing do it, otherwise don't bother. + let offset_packing = Packing::new(max.wrapping_sub(min)); + if offset_packing > basic_packing && ints.len() > 5 { + for b in ints.iter_mut() { + *b = b.wrapping_sub(min); } + offset_packing.write::(out, true); + T::write(min, out); + offset_packing + } else { + basic_packing.write::(out, false); + basic_packing } + } else { + basic_packing }; match p { @@ -363,6 +422,28 @@ pub fn unpack_ints<'a, T: Int>( input: &mut &'a [u8], length: usize, out: &mut CowSlice<'a, T::Une>, +) -> Result<()> { + T::with_output(out, length, |out| { + unpack_ints_sized::(input, length, out) + }) +} + +/// [`unpack_ints`] but after isize has been converted to i64. +fn unpack_ints_sized<'a, T: SizedInt>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, T::Une>, +) -> Result<()> { + // Safety: T::Une and T::Unsigned::Une are the same type. + let out: &mut CowSlice<'a, _> = unsafe { std::mem::transmute(out) }; + unpack_ints_sized_unsigned::(input, length, out) +} + +/// [`unpack_ints_sized`] but after signed integers have been cast to unsigned. +fn unpack_ints_sized_unsigned<'a, T: SizedUInt>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, T::Une>, ) -> Result<()> { let (p, min) = if skip_packing::(length) { (Packing::new(T::MAX), None) @@ -384,34 +465,9 @@ pub fn unpack_ints<'a, T: Int>( for v in out.iter_mut() { *v = min.wrapping_add(*v); } - // If a + b < b overflow occurred. - let overflow = || out.iter().any(|v| T::from_unaligned(*v) < min); - - // We only care about overflow if it changes results on 32 bit and 64 bit: - // 1 + u32::MAX as usize overflows on 32 bit but works on 64 bit. - if !T::IS_USIZE || cfg!(target_pointer_width = "64") { - return Ok(()); - } - - // Fast path, overflow is impossible if max(a) + b doesn't overflow. - let max_before_offset = match p { - Packing::_8 => u8::MAX as u128, - Packing::_16 => u16::MAX as u128, - _ => unreachable!(), // _32, _64, _128 won't be returned from Packing::read::() with offset_by_min == true. - }; - let min = min.try_into().unwrap_or_else(|_| unreachable!()); - if max_before_offset + min <= usize::MAX as u128 { - debug_assert!(!overflow()); - return Ok(()); - } - if overflow() { - return Err(usize_too_big()); - } - Ok(()) }) - } else { - Ok(()) } + Ok(()) } #[cfg(test)] @@ -448,7 +504,6 @@ mod tests { #[test] fn test_usize_too_big() { for scale in [1, 1 << 8, 1 << 16, 1 << 32] { - println!("scale {scale}"); let a = COUNTING.map(|v| v as u64 * scale + u32::MAX as u64); let packed = pack_ints(&a); let b = unpack_ints::(&packed, a.len()); @@ -461,7 +516,38 @@ mod tests { } } - fn t(ints: &[T]) -> Vec { + #[test] + fn test_isize_too_big() { + for scale in [1, 1 << 8, 1 << 16, 1 << 32] { + let a = COUNTING.map(|v| v as i64 * scale + i32::MAX as i64); + let packed = pack_ints(&a); + let b = unpack_ints::(&packed, a.len()); + if cfg!(target_pointer_width = "64") { + let b = b.unwrap(); + assert_eq!(a, std::array::from_fn(|i| b[i] as i64)); + } else { + assert_eq!(b.unwrap_err(), usize_too_big()); + } + } + } + + #[test] + fn test_i8_special_case() { + assert_eq!( + pack_ints(&[0i8, -1, 0, -1, 0, -1, 0]), + [9, (-1i8) as u8, 0b1010101] + ); + } + + #[test] + fn test_isize_sign_extension() { + assert_eq!( + pack_ints(&[0isize, -1, 0, -1, 0, -1, 0]), + [5, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 10, 0b1010101] + ); + } + + fn test_inner(ints: &[T]) -> Vec { let out = pack_ints(&mut ints.to_owned()); let unpacked = unpack_ints::(&out, ints.len()).unwrap(); assert_eq!(unpacked, ints); @@ -483,13 +569,25 @@ mod tests { continue; }; - for max in [0, u8::MAX as u128, u16::MAX as u128, u32::MAX as u128, u64::MAX as u128, u128::MAX as u128] { - let Ok(start) = T::try_from(max / 2) else { + for max in [ + i128::MIN, i64::MIN as i128, i32::MIN as i128, i16::MIN as i128, i8::MIN as i128, -1, + 0, i8::MAX as i128, i16::MAX as i128, i32::MAX as i128, i64::MAX as i128, i128::MAX + ] { + if max == T::MAX as i128 { + continue; + } + let Ok(start) = T::try_from(max) else { continue; }; let s = format!("{start} {increment}"); + if increment == 1 { + print!("{s:<19} mod 2 => "); + test_inner::(&std::array::from_fn::<_, 100, _>(|i| { + start + (i as T % 2) * increment + })); + } print!("{s:<25} => "); - t::(&std::array::from_fn::<_, 100, _>(|i| { + test_inner::(&std::array::from_fn::<_, 100, _>(|i| { start + i as T * increment })); } @@ -503,6 +601,12 @@ mod tests { test!(test_u064, u64); test!(test_u128, u128); test!(test_usize, usize); + test!(test_i008, i8); + test!(test_i016, i16); + test!(test_i032, i32); + test!(test_i064, i64); + test!(test_i128, i128); + test!(test_isize, isize); fn bench_pack_ints(b: &mut Bencher, src: &[T]) { let mut ints = src.to_vec(); diff --git a/src/serde/de.rs b/src/serde/de.rs index 29718e3..9e29b82 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -2,6 +2,7 @@ use crate::bool::BoolDecoder; use crate::coder::{Decoder, Result, View}; use crate::consume::expect_eof; use crate::error::{err, error, Error}; +use crate::f32::F32Decoder; use crate::int::IntDecoder; use crate::length::LengthDecoder; use crate::serde::guard::guard_zst; @@ -39,6 +40,8 @@ pub use inner::deserialize; enum SerdeDecoder<'a> { Bool(BoolDecoder<'a>), Enum(Box<(VariantDecoder<'a>, Vec>)>), // (variants, values) TODO only 1 allocation? + F32(F32Decoder<'a>), + // We don't need signed integer decoders here because unsigned ones work the same. Map(Box<(LengthDecoder<'a>, (SerdeDecoder<'a>, SerdeDecoder<'a>))>), // (lengths, (keys, values)) Seq(Box<(LengthDecoder<'a>, SerdeDecoder<'a>)>), // (lengths, values) Str(StrDecoder<'a>), @@ -73,6 +76,7 @@ impl<'a> View<'a> for SerdeDecoder<'a> { Ok(()) } } + Self::F32(d) => d.populate(input, length), Self::Map(d) => { d.0.populate(input, length)?; let length = d.0.length(); @@ -151,6 +155,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { // Use native decoders. impl_de!(deserialize_bool, visit_bool, bool, Bool); + impl_de!(deserialize_f32, visit_f32, f32, F32); impl_de!(deserialize_u8, visit_u8, u8, U8); impl_de!(deserialize_u16, visit_u16, u16, U16); impl_de!(deserialize_u32, visit_u32, u32, U32); @@ -158,13 +163,12 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { impl_de!(deserialize_u128, visit_u128, u128, U128); impl_de!(deserialize_str, visit_borrowed_str, &str, Str); - // IntDecoder works on signed integers/floats (but not chars). + // IntDecoder works on signed integers/f64 (but not chars). impl_de!(deserialize_i8, visit_i8, i8, U8); impl_de!(deserialize_i16, visit_i16, i16, U16); impl_de!(deserialize_i32, visit_i32, i32, U32); impl_de!(deserialize_i64, visit_i64, i64, U64); impl_de!(deserialize_i128, visit_i128, i128, U128); - impl_de!(deserialize_f32, visit_f32, f32, U32); impl_de!(deserialize_f64, visit_f64, f64, U64); #[inline(always)] @@ -576,6 +580,14 @@ mod tests { // Sequences test!("abc".to_owned(), String); test!(vec![1u8, 2u8, 3u8], Vec); + // Make sure signed integers are being packed properly (output should end in 85). + test!(vec![0, -1, 0, -1, 0, -1, 0], Vec); + test!(vec![0, -1, 0, -1, 0, -1, 0], Vec); + test!(vec![0, -1, 0, -1, 0, -1, 0], Vec); + test!(vec![0, -1, 0, -1, 0, -1, 0], Vec); + test!(vec![0, -1, 0, -1, 0, -1, 0], Vec); + // Make sure f32 sign_exp is grouped (output should end in 4x 63). + test!(vec![1.0; 4], Vec); test!( vec!["abc".to_owned(), "def".to_owned(), "ghi".to_owned()], Vec diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 25b0da1..2950f3a 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -1,6 +1,7 @@ use crate::bool::BoolEncoder; use crate::coder::{Buffer, Encoder, Result}; use crate::error::{err, error, Error}; +use crate::f32::F32Encoder; use crate::int::IntEncoder; use crate::length::LengthEncoder; use crate::serde::variant::VariantEncoder; @@ -57,8 +58,15 @@ pub use inner::serialize; enum SpecifiedEncoder { Bool(BoolEncoder), Enum(Box<(VariantEncoder, Vec)>), // (variants, values) TODO only 1 allocation? + F32(F32Encoder), + // Serialize needs separate signed integer encoders to be able to pack [0, -1, 0, -1, 0, -1]. + I8(IntEncoder), + I16(IntEncoder), + I32(IntEncoder), + I64(IntEncoder), + I128(IntEncoder), Map(Box<(LengthEncoder, (LazyEncoder, LazyEncoder))>), // (lengths, (keys, values)) - Seq(Box<(LengthEncoder, LazyEncoder)>), // (lengths, values) + Seq(Box<(LengthEncoder, LazyEncoder)>), // (lengths, values) Str(StrEncoder), Tuple(Box<[LazyEncoder]>), // [field0, field1, ..] U8(IntEncoder), @@ -76,6 +84,12 @@ impl SpecifiedEncoder { v.0.reserve(additional); // We don't know the variants of the enums, so we can't reserve more. } + Self::F32(v) => v.reserve(additional), + Self::I8(v) => v.reserve(additional), + Self::I16(v) => v.reserve(additional), + Self::I32(v) => v.reserve(additional), + Self::I64(v) => v.reserve(additional), + Self::I128(v) => v.reserve(additional), Self::Map(v) => { v.0.reserve(additional); // We don't know the lengths of the maps, so we can't reserve more. @@ -124,6 +138,12 @@ impl LazyEncoder { v.1.iter_mut().for_each(|v| v.reorder(buffers)); &mut v.0 } + SpecifiedEncoder::F32(v) => v, + SpecifiedEncoder::I8(v) => v, + SpecifiedEncoder::I16(v) => v, + SpecifiedEncoder::I32(v) => v, + SpecifiedEncoder::I64(v) => v, + SpecifiedEncoder::I128(v) => v, SpecifiedEncoder::Map(v) => { v.1 .0.reorder(buffers); v.1 .1.reorder(buffers); @@ -251,20 +271,20 @@ impl<'a> Serializer for EncoderWrapper<'a> { // Use native encoders. impl_ser!(serialize_bool, bool, Bool); + impl_ser!(serialize_f32, f32, F32); + impl_ser!(serialize_i8, i8, I8); + impl_ser!(serialize_i16, i16, I16); + impl_ser!(serialize_i32, i32, I32); + impl_ser!(serialize_i64, i64, I64); + impl_ser!(serialize_i128, i128, I128); + impl_ser!(serialize_str, &str, Str); impl_ser!(serialize_u8, u8, U8); impl_ser!(serialize_u16, u16, U16); impl_ser!(serialize_u32, u32, U32); impl_ser!(serialize_u64, u64, U64); impl_ser!(serialize_u128, u128, U128); - impl_ser!(serialize_str, &str, Str); - // IntEncoder works on signed integers/floats/char. - impl_ser!(serialize_i8, i8, U8); - impl_ser!(serialize_i16, i16, U16); - impl_ser!(serialize_i32, i32, U32); - impl_ser!(serialize_i64, i64, U64); - impl_ser!(serialize_i128, i128, U128); - impl_ser!(serialize_f32, f32, U32); + // IntEncoder works on f64/char. impl_ser!(serialize_f64, f64, U64); impl_ser!(serialize_char, char, U32); From 29e4f71dbe848b50f15e9f4e09073d4cf9ec8240 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 28 Feb 2024 14:04:41 -0800 Subject: [PATCH 18/45] cargo test --all-features, cargo check --no-default-features --- .github/workflows/build.yml | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index dad80cf..d946109 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,16 +29,20 @@ jobs: components: rustfmt, miri - name: Lint run: cargo fmt --check - - name: Test (debug) + - name: Check (no-default-features) + run: cargo check --no-default-features + - name: Test run: cargo test + - name: Test (all-features) + run: cargo test --all-features - name: Install i686 and GCC multilib run: rustup target add i686-unknown-linux-gnu && sudo apt update && sudo apt install -y gcc-multilib - - name: Test (32-bit) - run: cargo test --target i686-unknown-linux-gnu + - name: Test (32-bit all-features) + run: cargo test --target i686-unknown-linux-gnu --all-features - name: Setup Miri run: cargo miri setup - - name: Test (miri) - run: cargo miri test + - name: Test (miri all-features) + run: cargo miri test --all-features - name: Setup Miri (big-endian) run: rustup target add mips64-unknown-linux-gnuabi64 && cargo miri setup --target mips64-unknown-linux-gnuabi64 # TODO miri big-endian (zstd doesn't compile) From e86c94e4e491995a60d351ce74a928c36173939d Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 28 Feb 2024 14:29:55 -0800 Subject: [PATCH 19/45] Fix comment. --- src/ext/arrayvec.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ext/arrayvec.rs b/src/ext/arrayvec.rs index e7a2b2a..a2ad2ef 100644 --- a/src/ext/arrayvec.rs +++ b/src/ext/arrayvec.rs @@ -80,7 +80,7 @@ impl<'a, const N: usize> Decoder<'a, ArrayString> for ArrayStringDecoder<'a, let array_string = out.write(ArrayString::new()); // Avoid copying lots of memory for 1 byte strings. - // TODO miri doesn't like ArrayString::as_mut_str().as_mut_ptr(), replace with ArrayString::as_mut_str() when available. + // TODO miri doesn't like ArrayString::as_mut_str().as_mut_ptr(), replace with ArrayString::as_mut_ptr() when available. if N > 64 || cfg!(miri) { // Safety: We've ensured `self.lengths.max_len() <= N` in populate. unsafe { array_string.try_push_str(s).unwrap_unchecked() }; From 863bd1485cd3e9c07dc4d1a1b26b87cf31056c97 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 28 Feb 2024 14:36:56 -0800 Subject: [PATCH 20/45] Release 0.6.0-alpha.1 --- Cargo.toml | 5 ++--- bitcode_derive/Cargo.toml | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 60fe026..9b6dcbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,17 +6,16 @@ members = [ [package] name = "bitcode" authors = [ "Cai Bear", "Finn Bear" ] -version = "0.6.0" +version = "0.6.0-alpha.1" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode" description = "bitcode is a bitwise binary serializer" exclude = ["fuzz/"] -publish = false # TODO remove when ready (also remove in bitcode_derive). [dependencies] arrayvec = { version = "0.7", default-features = false, optional = true } -bitcode_derive = { version = "0.6.0", path = "./bitcode_derive", optional = true } +bitcode_derive = { version = "0.6.0-alpha.1", path = "./bitcode_derive", optional = true } bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] } glam = { version = "0.22", default-features = false, features = [ "std" ], optional = true } serde = { version = "1.0", optional = true } diff --git a/bitcode_derive/Cargo.toml b/bitcode_derive/Cargo.toml index 3bb0d19..0c0a414 100644 --- a/bitcode_derive/Cargo.toml +++ b/bitcode_derive/Cargo.toml @@ -1,12 +1,11 @@ [package] name = "bitcode_derive" authors = [ "Cai Bear", "Finn Bear" ] -version = "0.6.0" +version = "0.6.0-alpha.1" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode/" description = "Implementation of #[derive(Encode, Decode)] for bitcode" -publish = false # TODO remove when ready [lib] proc-macro = true From e9f17740c408a7b0767754c3f38851c5dcc75973 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Thu, 29 Feb 2024 15:54:11 -0800 Subject: [PATCH 21/45] Fix #18 and release new version. --- Cargo.toml | 7 +++++-- src/lib.rs | 3 +-- src/serde/de.rs | 1 - src/serde/ser.rs | 1 - 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9b6dcbe..0efbfde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ [package] name = "bitcode" authors = [ "Cai Bear", "Finn Bear" ] -version = "0.6.0-alpha.1" +version = "0.6.0-alpha.2" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode" @@ -21,7 +21,7 @@ glam = { version = "0.22", default-features = false, features = [ "std" ], optio serde = { version = "1.0", optional = true } [dev-dependencies] -arrayvec = { version = "0.7", features = ["serde"] } +arrayvec = { version = "0.7", features = [ "serde" ] } bincode = "1.3.3" flate2 = "1.0.28" lz4_flex = { version = "0.11.2", default-features = false } @@ -35,6 +35,9 @@ zstd = "0.13.0" derive = [ "bitcode_derive" ] default = [ "derive" ] +[package.metadata.docs.rs] +features = [ "derive", "serde" ] + # TODO halfs speed of benches_borrowed::bench_bitcode_decode #[profile.bench] #lto = true diff --git a/src/lib.rs b/src/lib.rs index aa0675c..b1de80d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ #![allow(clippy::items_after_test_module, clippy::blocks_in_if_conditions)] -#![cfg_attr(doc, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(test, feature(test))] #![doc = include_str!("../README.md")] @@ -30,7 +30,6 @@ pub use crate::derive::*; pub use crate::error::Error; #[cfg(feature = "derive")] -#[cfg_attr(doc, doc(cfg(feature = "derive")))] pub use bitcode_derive::{Decode, Encode}; #[cfg(feature = "serde")] diff --git a/src/serde/de.rs b/src/serde/de.rs index 9e29b82..41ec728 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -23,7 +23,6 @@ mod inner { /// /// **Warning:** The format is incompatible with [`encode`][`crate::encode`] and subject to /// change between major versions. - #[cfg_attr(doc, doc(cfg(feature = "serde")))] pub fn deserialize<'de, T: Deserialize<'de>>(mut bytes: &'de [u8]) -> Result { let mut decoder = SerdeDecoder::Unspecified2 { length: 1 }; let t = T::deserialize(DecoderWrapper { diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 2950f3a..c893f96 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -23,7 +23,6 @@ mod inner { /// /// **Warning:** The format is incompatible with [`decode`][`crate::decode`] and subject to /// change between major versions. - #[cfg_attr(doc, doc(cfg(feature = "serde")))] pub fn serialize(t: &T) -> Result, Error> { let mut lazy = LazyEncoder::Unspecified { reserved: NonZeroUsize::new(1), From b4e4c95807450e8f2dbe1e9701e16bdefaf4318b Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sat, 9 Mar 2024 16:47:31 -0800 Subject: [PATCH 22/45] Optimize serialize by 40%, document unsound code (to be fixed). --- src/serde/ser.rs | 254 +++++++++++++++++++++++++---------------------- 1 file changed, 137 insertions(+), 117 deletions(-) diff --git a/src/serde/ser.rs b/src/serde/ser.rs index c893f96..3e2a4df 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -5,7 +5,7 @@ use crate::f32::F32Encoder; use crate::int::IntEncoder; use crate::length::LengthEncoder; use crate::serde::variant::VariantEncoder; -use crate::serde::{default_box_slice, get_mut_or_resize, type_changed}; +use crate::serde::{default_box_slice, get_mut_or_resize}; use crate::str::StrEncoder; use serde::ser::{ SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, @@ -174,6 +174,7 @@ impl LazyEncoder { /// [`NonZeroUsize`] to avoid branching on len. /// /// Can't be reserve_fast anymore with push_within_capacity. + #[inline(always)] fn reserve_fast(&mut self, len: usize) { match self { Self::Specified { specified, .. } => { @@ -190,33 +191,39 @@ macro_rules! specify { ($wrapper:ident, $variant:ident) => {{ let lazy = &mut *$wrapper.lazy; match lazy { - LazyEncoder::Unspecified { reserved } => { - let reserved = *reserved; + // Check if we're already the correct encoder. This results in 1 branch in the hot path. + LazyEncoder::Specified { specified: SpecifiedEncoder::$variant(_), .. } => (), + _ => { + // Either create the correct encoder if unspecified or panic if we already have an + // encoder since it must be a different type. #[cold] fn cold<'a>( me: &'a mut LazyEncoder, index_alloc: &mut usize, - reserved: Option, - ) -> &'a mut SpecifiedEncoder { - let mut specified = SpecifiedEncoder::$variant(Default::default()); - if let Some(reserved) = reserved { - specified.reserve(reserved); - } + ) { + let &mut LazyEncoder::Unspecified { reserved } = me else { + panic!("type changed"); + }; *me = LazyEncoder::Specified { - specified, + specified: SpecifiedEncoder::$variant(Default::default()), index: std::mem::replace(index_alloc, *index_alloc + 1), }; - // TODO might be slower to put in cold fn. - if let LazyEncoder::Specified { specified, .. } = me { - specified - } else { + let LazyEncoder::Specified { specified, .. } = me else { unreachable!(); + }; + if let Some(reserved) = reserved { + specified.reserve(reserved); } } - cold(lazy, &mut *$wrapper.index_alloc, reserved) + cold(lazy, &mut *$wrapper.index_alloc); } - LazyEncoder::Specified { specified, .. } => specified, } + let LazyEncoder::Specified { specified: SpecifiedEncoder::$variant(b), .. } = lazy else { + // Safety: `cold` gets called when lazy isn't the correct encoder. `cold` either diverges + // or sets lazy to the correct encoder. + unsafe { std::hint::unreachable_unchecked() }; + }; + b }}; } @@ -226,32 +233,27 @@ struct EncoderWrapper<'a> { } impl<'a> EncoderWrapper<'a> { + #[inline(always)] fn serialize_enum(self, variant_index: u32) -> Result> { let variant_index = variant_index .try_into() .map_err(|_| error("enums with more than 256 variants are unsupported"))?; - match specify!(self, Enum) { - SpecifiedEncoder::Enum(b) => { - b.0.encode(&variant_index); - let lazy = get_mut_or_resize(&mut b.1, variant_index as usize); - lazy.reserve_fast(1); // TODO use push instead. - Ok(Self { - lazy, - index_alloc: self.index_alloc, - }) - } - _ => type_changed(), - } + let b = specify!(self, Enum); + b.0.encode(&variant_index); + let lazy = get_mut_or_resize(&mut b.1, variant_index as usize); + lazy.reserve_fast(1); // TODO use push instead. + Ok(Self { + lazy, + index_alloc: self.index_alloc, + }) } } macro_rules! impl_ser { ($name:ident, $t:ty, $variant:ident) => { + // TODO #[inline(always)] makes benchmark slower because collect_seq isn't inlined. fn $name(self, v: $t) -> Result<()> { - match specify!(self, $variant) { - SpecifiedEncoder::$variant(b) => b.encode(&v), - _ => return type_changed(), - } + specify!(self, $variant).encode(&v); Ok(()) } }; @@ -287,6 +289,7 @@ impl<'a> Serializer for EncoderWrapper<'a> { impl_ser!(serialize_f64, f64, U64); impl_ser!(serialize_char, char, U32); + #[inline(always)] fn serialize_bytes(self, v: &[u8]) -> Result { v.serialize(self) } @@ -305,14 +308,17 @@ impl<'a> Serializer for EncoderWrapper<'a> { v.serialize(self.serialize_enum(1)?) } + #[inline(always)] fn serialize_unit(self) -> Result { Ok(()) } + #[inline(always)] fn serialize_unit_struct(self, _name: &'static str) -> Result { Ok(()) } + #[inline(always)] fn serialize_unit_variant( self, _name: &'static str, @@ -323,6 +329,7 @@ impl<'a> Serializer for EncoderWrapper<'a> { Ok(()) } + #[inline(always)] fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result where T: Serialize, @@ -330,6 +337,7 @@ impl<'a> Serializer for EncoderWrapper<'a> { value.serialize(self) } + #[inline(always)] fn serialize_newtype_variant( self, _name: &'static str, @@ -346,58 +354,56 @@ impl<'a> Serializer for EncoderWrapper<'a> { #[inline(always)] fn serialize_seq(self, len: Option) -> Result { let len = len.expect("sequence must have len"); - match specify!(self, Seq) { - SpecifiedEncoder::Seq(b) => { - b.0.encode(&len); - b.1.reserve_fast(len); - Ok(Self { - lazy: &mut b.1, - index_alloc: self.index_alloc, - }) - } - _ => type_changed(), - } + let b = specify!(self, Seq); + b.0.encode(&len); + b.1.reserve_fast(len); + Ok(Self { + lazy: &mut b.1, + index_alloc: self.index_alloc, + }) } #[inline(always)] fn serialize_tuple(self, len: usize) -> Result { + // Copy of specify! macro that takes an additional len parameter to cold. let lazy = &mut *self.lazy; - let specified = match lazy { - &mut LazyEncoder::Unspecified { reserved } => { + match lazy { + LazyEncoder::Specified { + specified: SpecifiedEncoder::Tuple(_), + .. + } => (), + _ => { #[cold] - fn cold( - me: &mut LazyEncoder, - reserved: Option, - len: usize, - ) -> &mut SpecifiedEncoder { - let mut specified = SpecifiedEncoder::Tuple(default_box_slice(len)); - if let Some(reserved) = reserved { - specified.reserve(reserved); - } + fn cold(me: &mut LazyEncoder, len: usize) { + let &mut LazyEncoder::Unspecified { reserved } = me else { + panic!("type changed"); + }; *me = LazyEncoder::Specified { - specified, - index: usize::MAX, // We never use this. + specified: SpecifiedEncoder::Tuple(default_box_slice(len)), + index: usize::MAX, // We never use index for SpecifiedEncoder::Tuple. }; - // TODO might be slower to put in cold fn. - let LazyEncoder::Specified { specified: encoder, .. } = me else { + let LazyEncoder::Specified { specified, .. } = me else { unreachable!(); }; - encoder + if let Some(reserved) = reserved { + specified.reserve(reserved); + } } - cold(lazy, reserved, len) + cold(lazy, len); } - LazyEncoder::Specified { specified, .. } => specified, }; - match specified { - SpecifiedEncoder::Tuple(encoders) => { - assert_eq!(encoders.len(), len); // Removes multiple bounds checks. - Ok(TupleSerializer { - encoders, - index_alloc: self.index_alloc, - }) - } - _ => type_changed(), - } + let LazyEncoder::Specified { + specified: SpecifiedEncoder::Tuple(encoders), + .. + } = lazy else { + // Safety: see specify! macro which this is based on. + unsafe { std::hint::unreachable_unchecked() }; + }; + assert!(encoders.len() == len, "type changed"); // Removes multiple bounds checks. + Ok(TupleSerializer { + encoders, + index_alloc: self.index_alloc, + }) } #[inline(always)] @@ -423,18 +429,14 @@ impl<'a> Serializer for EncoderWrapper<'a> { #[inline(always)] fn serialize_map(self, len: Option) -> Result { let len = len.expect("sequence must have len"); - match specify!(self, Map) { - SpecifiedEncoder::Map(b) => { - b.0.encode(&len); - b.1 .0.reserve_fast(len); - b.1 .1.reserve_fast(len); - Ok(MapSerializer { - encoders: &mut b.1, - index_alloc: self.index_alloc, - }) - } - _ => type_changed(), - } + let b = specify!(self, Map); + b.0.encode(&len); + b.1 .0.reserve_fast(len); + b.1 .1.reserve_fast(len); + Ok(MapSerializer { + encoders: &mut b.1, + index_alloc: self.index_alloc, + }) } #[inline(always)] @@ -471,6 +473,8 @@ macro_rules! ok_error_end { impl SerializeSeq for EncoderWrapper<'_> { ok_error_end!(); + // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. + #[inline(always)] fn serialize_element(&mut self, value: &T) -> Result<()> { value.serialize(EncoderWrapper { lazy: &mut *self.lazy, @@ -485,10 +489,11 @@ struct TupleSerializer<'a> { } macro_rules! impl_tuple { - ($tr:ty, $fun:ident) => { + ($tr:ty, $fun:ident $(, $key:ident)?) => { impl $tr for TupleSerializer<'_> { ok_error_end!(); - fn $fun(&mut self, value: &T) -> Result<()> { + #[inline(always)] + fn $fun(&mut self, $($key: &'static str,)? value: &T) -> Result<()> { let (lazy, remaining) = std::mem::take(&mut self.encoders) .split_first_mut() .expect("length mismatch"); @@ -498,39 +503,20 @@ macro_rules! impl_tuple { index_alloc: &mut *self.index_alloc, }) } + + $( + fn skip_field(&mut self, $key: &'static str) -> Result<()> { + err("skip field is not supported") + } + )? } }; } impl_tuple!(SerializeTuple, serialize_element); impl_tuple!(SerializeTupleStruct, serialize_field); impl_tuple!(SerializeTupleVariant, serialize_field); - -macro_rules! impl_struct { - ($tr:ty) => { - impl $tr for TupleSerializer<'_> { - ok_error_end!(); - fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<()> - where - T: Serialize, - { - let (lazy, remaining) = std::mem::take(&mut self.encoders) - .split_first_mut() - .expect("length mismatch"); - self.encoders = remaining; - value.serialize(EncoderWrapper { - lazy, - index_alloc: &mut *self.index_alloc, - }) - } - - fn skip_field(&mut self, _key: &'static str) -> Result<()> { - err("skip field is not supported") - } - } - }; -} -impl_struct!(SerializeStruct); -impl_struct!(SerializeStructVariant); +impl_tuple!(SerializeStruct, serialize_field, _key); +impl_tuple!(SerializeStructVariant, serialize_field, _key); struct MapSerializer<'a> { encoders: &'a mut (LazyEncoder, LazyEncoder), // (keys, values) @@ -539,6 +525,8 @@ struct MapSerializer<'a> { impl SerializeMap for MapSerializer<'_> { ok_error_end!(); + // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. + #[inline(always)] fn serialize_key(&mut self, key: &T) -> Result<()> where T: Serialize, @@ -549,6 +537,8 @@ impl SerializeMap for MapSerializer<'_> { }) } + // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. + #[inline(always)] fn serialize_value(&mut self, value: &T) -> Result<()> where T: Serialize, @@ -562,17 +552,17 @@ impl SerializeMap for MapSerializer<'_> { #[cfg(test)] mod tests { + use serde::ser::SerializeTuple; + use serde::{Serialize, Serializer}; + #[test] fn enum_256_variants() { enum Enum { A, B, } - impl serde::Serialize for Enum { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { + impl Serialize for Enum { + fn serialize(&self, serializer: S) -> Result { let variant_index = match self { Self::A => 255, Self::B => 256, @@ -583,4 +573,34 @@ mod tests { assert!(crate::serialize(&Enum::A).is_ok()); assert!(crate::serialize(&Enum::B).is_err()); } + + #[test] + #[should_panic(expected = "type changed")] + fn test_type_changed() { + struct BoolOrU8(bool); + impl Serialize for BoolOrU8 { + fn serialize(&self, serializer: S) -> Result { + if self.0 { + serializer.serialize_bool(false) + } else { + serializer.serialize_u8(1) + } + } + } + let _ = crate::serialize(&vec![BoolOrU8(false), BoolOrU8(true)]); + } + + #[test] + #[should_panic(expected = "type changed")] + fn test_tuple_len_changed() { + struct TupleN(usize); + impl Serialize for TupleN { + fn serialize(&self, serializer: S) -> Result { + let mut tuple = serializer.serialize_tuple(self.0)?; + (0..self.0).try_for_each(|_| tuple.serialize_element(&false))?; + tuple.end() + } + } + let _ = crate::serialize(&vec![TupleN(1), TupleN(2)]); + } } From 656069597c9e671a2f8befecd25c16298eb71588 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sat, 9 Mar 2024 20:07:07 -0800 Subject: [PATCH 23/45] Optimize deserialize by 30%, document unsound code (to be fixed). --- src/serde/de.rs | 222 +++++++++++++++++++++++++---------------------- src/serde/mod.rs | 10 ++- src/serde/ser.rs | 12 +-- 3 files changed, 131 insertions(+), 113 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index 41ec728..c25429f 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -24,7 +24,7 @@ mod inner { /// **Warning:** The format is incompatible with [`encode`][`crate::encode`] and subject to /// change between major versions. pub fn deserialize<'de, T: Deserialize<'de>>(mut bytes: &'de [u8]) -> Result { - let mut decoder = SerdeDecoder::Unspecified2 { length: 1 }; + let mut decoder = SerdeDecoder::Unspecified { length: 1 }; let t = T::deserialize(DecoderWrapper { decoder: &mut decoder, input: &mut bytes, @@ -50,13 +50,13 @@ enum SerdeDecoder<'a> { U32(IntDecoder<'a, u32>), U64(IntDecoder<'a, u64>), U128(IntDecoder<'a, u128>), - Unspecified, - Unspecified2 { length: usize }, + Unpopulated, + Unspecified { length: usize }, } impl Default for SerdeDecoder<'_> { fn default() -> Self { - Self::Unspecified + Self::Unpopulated } } @@ -94,11 +94,11 @@ impl<'a> View<'a> for SerdeDecoder<'a> { Self::U32(d) => d.populate(input, length), Self::U64(d) => d.populate(input, length), Self::U128(d) => d.populate(input, length), - Self::Unspecified => { - *self = Self::Unspecified2 { length }; + Self::Unpopulated => { + *self = Self::Unspecified { length }; Ok(()) } - Self::Unspecified2 { .. } => unreachable!(), // TODO + Self::Unspecified { .. } => unreachable!(), } } } @@ -110,16 +110,29 @@ struct DecoderWrapper<'a, 'de> { macro_rules! specify { ($self:ident, $variant:ident) => { - match &mut *$self.decoder { - &mut SerdeDecoder::Unspecified2 { length } => { - #[cold] - fn cold(me: &mut DecoderWrapper, length: usize) -> Result<()> { - *me.decoder = SerdeDecoder::$variant(Default::default()); - me.decoder.populate(me.input, length) + { + match &mut $self.decoder { + // Check if it's already the correct decoder. This results in 1 branch in the hot path. + SerdeDecoder::$variant(_) => (), + _ => { + // Either create the correct decoder if unspecified or diverge via panic/error. + #[cold] + fn cold<'de>(decoder: &mut SerdeDecoder<'de>, input: &mut &'de[u8]) -> Result<()> { + let &mut SerdeDecoder::Unspecified { length } = decoder else { + type_changed!(); + }; + *decoder = SerdeDecoder::$variant(Default::default()); + decoder.populate(input, length) + } + cold(&mut *$self.decoder, &mut *$self.input)?; } - cold(&mut $self, length)?; } - _ => (), + let SerdeDecoder::$variant(d) = &mut *$self.decoder else { + // Safety: `cold` gets called when decoder isn't the correct decoder. `cold` either + // errors or sets lazy to the correct decoder. + unsafe { std::hint::unreachable_unchecked() }; + }; + d } }; } @@ -131,13 +144,7 @@ macro_rules! impl_de { where V: Visitor<'de>, { - v.$visit({ - specify!(self, $variant); - match &mut *self.decoder { - SerdeDecoder::$variant(d) => d.decode(), - _ => return type_changed(), - } - }) + v.$visit(specify!(self, $variant).decode()) } }; } @@ -178,6 +185,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { v.visit_char(char::from_u32(u32::deserialize(self)?).ok_or_else(|| error("invalid char"))?) } + #[inline(always)] fn deserialize_string(self, v: V) -> Result where V: Visitor<'de>, @@ -193,6 +201,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { self.deserialize_byte_buf(v) // TODO avoid allocation. } + #[inline(always)] fn deserialize_byte_buf(self, v: V) -> Result where V: Visitor<'de>, @@ -205,14 +214,11 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where V: Visitor<'de>, { - specify!(self, Enum); - let (decoder, variant_index) = match &mut *self.decoder { - SerdeDecoder::Enum(b) => { - let variant_index = b.0.decode(); - (&mut b.1[variant_index as usize], variant_index) - } - _ => return type_changed(), - }; + let (variant_decoder, decoders) = &mut **specify!(self, Enum); + let variant_index = variant_decoder.decode(); + // Safety: populate guarantees `variant_decoder.max_variant_index() < decoders.len()`. + let decoder = unsafe { decoders.get_unchecked_mut(variant_index as usize) }; + match variant_index { 0 => v.visit_none(), 1 => v.visit_some(DecoderWrapper { @@ -251,14 +257,8 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where V: Visitor<'de>, { - specify!(self, Seq); - let (decoder, len) = match &mut *self.decoder { - SerdeDecoder::Seq(b) => { - let len = b.0.decode(); - (&mut b.1, len) - } - _ => return type_changed(), - }; + let (length_decoder, decoder) = &mut **specify!(self, Seq); + let len = length_decoder.decode(); struct Access<'a, 'de> { wrapper: DecoderWrapper<'a, 'de>, @@ -273,21 +273,21 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { T: DeserializeSeed<'de>, { guard_zst::(self.len)?; - self.len - .checked_sub(1) - .map(|len| { - self.len = len; - DeserializeSeed::deserialize( - seed, - DecoderWrapper { - decoder: &mut *self.wrapper.decoder, - input: &mut *self.wrapper.input, - }, - ) - }) - .transpose() + if self.len != 0 { + self.len -= 1; + Ok(Some(DeserializeSeed::deserialize( + seed, + DecoderWrapper { + decoder: &mut *self.wrapper.decoder, + input: &mut *self.wrapper.input, + }, + )?)) + } else { + Ok(None) + } } + #[inline(always)] fn size_hint(&self) -> Option { Some(self.len) } @@ -302,23 +302,36 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { } #[inline(always)] - fn deserialize_tuple(mut self, tuple_len: usize, v: V) -> Result + fn deserialize_tuple(mut self, len: usize, v: V) -> Result where V: Visitor<'de>, { - if let &mut SerdeDecoder::Unspecified2 { length } = &mut *self.decoder { - #[cold] - fn cold(me: &mut DecoderWrapper, length: usize, tuple_len: usize) -> Result<()> { - *me.decoder = SerdeDecoder::Tuple(default_box_slice(tuple_len)); - me.decoder.populate(me.input, length) + // Copy of specify! macro that takes an additional len parameter to cold. + match &mut self.decoder { + SerdeDecoder::Tuple(_) => (), + _ => { + #[cold] + fn cold<'de>( + decoder: &mut SerdeDecoder<'de>, + input: &mut &'de [u8], + len: usize, + ) -> Result<()> { + let &mut SerdeDecoder::Unspecified { length } = decoder else { + type_changed!(); + }; + *decoder = SerdeDecoder::Tuple(default_box_slice(len)); + decoder.populate(input, length) + } + cold(&mut *self.decoder, &mut *self.input, len)?; } - cold(&mut self, length, tuple_len)?; } - let decoders = match &mut *self.decoder { - SerdeDecoder::Tuple(d) => &mut **d, - _ => return type_changed(), + let SerdeDecoder::Tuple(decoders) = &mut *self.decoder else { + // Safety: see specify! macro which this is based on. + unsafe { std::hint::unreachable_unchecked() }; }; - assert_eq!(decoders.len(), tuple_len); // Removes multiple bounds checks. + if decoders.len() != len { + type_changed!(); // Removes multiple bounds checks. + } struct Access<'a, 'de> { decoders: &'a mut [SerdeDecoder<'de>], @@ -333,22 +346,21 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where T: DeserializeSeed<'de>, { - guard_zst::(self.decoders.len())?; - self.decoders - .get_mut(self.index) - .map(|decoder| { - self.index += 1; - DeserializeSeed::deserialize( - seed, - DecoderWrapper { - decoder, - input: &mut *self.input, - }, - ) - }) - .transpose() + if let Some(decoder) = self.decoders.get_mut(self.index) { + self.index += 1; + Ok(Some(DeserializeSeed::deserialize( + seed, + DecoderWrapper { + decoder, + input: &mut *self.input, + }, + )?)) + } else { + Ok(None) + } } + #[inline(always)] fn size_hint(&self) -> Option { Some(self.decoders.len()) } @@ -372,14 +384,9 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where V: Visitor<'de>, { - specify!(self, Map); - let (decoders, len) = match &mut *self.decoder { - SerdeDecoder::Map(b) => { - let len = b.0.decode(); - (&mut b.1, len) - } - _ => return type_changed(), - }; + let (length_decoder, decoders) = &mut **specify!(self, Map); + let len = length_decoder.decode(); + struct Access<'a, 'de> { decoders: &'a mut (SerdeDecoder<'de>, SerdeDecoder<'de>), input: &'a mut &'de [u8], @@ -395,21 +402,21 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { K: DeserializeSeed<'de>, { guard_zst::(self.len)?; - self.len - .checked_sub(1) - .map(|len| { - self.len = len; - DeserializeSeed::deserialize( - seed, - DecoderWrapper { - decoder: &mut self.decoders.0, - input: &mut *self.input, - }, - ) - }) - .transpose() + if self.len != 0 { + self.len -= 1; + Ok(Some(DeserializeSeed::deserialize( + seed, + DecoderWrapper { + decoder: &mut self.decoders.0, + input: &mut *self.input, + }, + )?)) + } else { + Ok(None) + } } + // TODO(unsound): could be called more than len times by buggy safe code and go out of bounds. #[inline(always)] fn next_value_seed(&mut self, seed: V) -> Result where @@ -424,6 +431,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { ) } + #[inline(always)] fn size_hint(&self) -> Option { Some(self.len) } @@ -476,6 +484,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { err("deserialize_ignored_any is not supported") } + #[inline(always)] fn is_human_readable(&self) -> bool { false } @@ -485,18 +494,17 @@ impl<'a, 'de> EnumAccess<'de> for DecoderWrapper<'a, 'de> { type Error = Error; type Variant = DecoderWrapper<'a, 'de>; + #[inline(always)] fn variant_seed(mut self, seed: V) -> Result<(V::Value, Self::Variant)> where V: DeserializeSeed<'de>, { - specify!(self, Enum); - let (decoder, variant_index) = match &mut *self.decoder { - SerdeDecoder::Enum(b) => { - let variant_index = b.0.decode(); - (&mut b.1[variant_index as usize], variant_index as u32) - } - _ => return type_changed(), - }; + let (variant_decoder, decoders) = &mut **specify!(self, Enum); + let variant_index = variant_decoder.decode(); + // Safety: populate guarantees `variant_decoder.max_variant_index() < decoders.len()`. + let decoder = unsafe { decoders.get_unchecked_mut(variant_index as usize) }; + let variant_index = variant_index as u32; + let val: Result<_> = seed.deserialize(variant_index.into_deserializer()); Ok(( val?, @@ -511,10 +519,12 @@ impl<'a, 'de> EnumAccess<'de> for DecoderWrapper<'a, 'de> { impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> { type Error = Error; + #[inline(always)] fn unit_variant(self) -> Result<()> { Ok(()) } + #[inline(always)] fn newtype_variant_seed(self, seed: T) -> Result where T: DeserializeSeed<'de>, @@ -522,6 +532,7 @@ impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> { seed.deserialize(self) } + #[inline(always)] fn tuple_variant(self, len: usize, v: V) -> Result where V: Visitor<'de>, @@ -529,6 +540,7 @@ impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> { self.deserialize_tuple(len, v) } + #[inline(always)] fn struct_variant(self, fields: &'static [&'static str], v: V) -> Result where V: Visitor<'de>, diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 577cec5..50e6e40 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -1,4 +1,4 @@ -use crate::error::{err, error_from_display, Error}; +use crate::error::{error_from_display, Error}; use std::fmt::Display; mod de; @@ -9,9 +9,13 @@ mod variant; pub use de::*; pub use ser::*; -fn type_changed() -> Result { - err("type changed") +// Use macro instead of function because ! type isn't stable. +macro_rules! type_changed { + () => { + panic!("type changed") + }; } +use type_changed; fn default_box_slice(len: usize) -> Box<[T]> { let mut vec = vec![]; diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 3e2a4df..1575466 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -5,7 +5,7 @@ use crate::f32::F32Encoder; use crate::int::IntEncoder; use crate::length::LengthEncoder; use crate::serde::variant::VariantEncoder; -use crate::serde::{default_box_slice, get_mut_or_resize}; +use crate::serde::{default_box_slice, get_mut_or_resize, type_changed}; use crate::str::StrEncoder; use serde::ser::{ SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, @@ -191,7 +191,7 @@ macro_rules! specify { ($wrapper:ident, $variant:ident) => {{ let lazy = &mut *$wrapper.lazy; match lazy { - // Check if we're already the correct encoder. This results in 1 branch in the hot path. + // Check if it's already the correct encoder. This results in 1 branch in the hot path. LazyEncoder::Specified { specified: SpecifiedEncoder::$variant(_), .. } => (), _ => { // Either create the correct encoder if unspecified or panic if we already have an @@ -202,7 +202,7 @@ macro_rules! specify { index_alloc: &mut usize, ) { let &mut LazyEncoder::Unspecified { reserved } = me else { - panic!("type changed"); + type_changed!(); }; *me = LazyEncoder::Specified { specified: SpecifiedEncoder::$variant(Default::default()), @@ -376,7 +376,7 @@ impl<'a> Serializer for EncoderWrapper<'a> { #[cold] fn cold(me: &mut LazyEncoder, len: usize) { let &mut LazyEncoder::Unspecified { reserved } = me else { - panic!("type changed"); + type_changed!(); }; *me = LazyEncoder::Specified { specified: SpecifiedEncoder::Tuple(default_box_slice(len)), @@ -399,7 +399,9 @@ impl<'a> Serializer for EncoderWrapper<'a> { // Safety: see specify! macro which this is based on. unsafe { std::hint::unreachable_unchecked() }; }; - assert!(encoders.len() == len, "type changed"); // Removes multiple bounds checks. + if encoders.len() != len { + type_changed!(); // Removes multiple bounds checks. + } Ok(TupleSerializer { encoders, index_alloc: self.index_alloc, From a69aedb6f2222e32dce81d26d37a243675f8216d Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sat, 9 Mar 2024 20:11:06 -0800 Subject: [PATCH 24/45] Rename len variable because it's too similar to length. --- src/serde/de.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index c25429f..3b9d007 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -302,11 +302,11 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { } #[inline(always)] - fn deserialize_tuple(mut self, len: usize, v: V) -> Result + fn deserialize_tuple(mut self, tuple_len: usize, v: V) -> Result where V: Visitor<'de>, { - // Copy of specify! macro that takes an additional len parameter to cold. + // Copy of specify! macro that takes an additional tuple_len parameter to cold. match &mut self.decoder { SerdeDecoder::Tuple(_) => (), _ => { @@ -314,22 +314,22 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { fn cold<'de>( decoder: &mut SerdeDecoder<'de>, input: &mut &'de [u8], - len: usize, + tuple_len: usize, ) -> Result<()> { let &mut SerdeDecoder::Unspecified { length } = decoder else { type_changed!(); }; - *decoder = SerdeDecoder::Tuple(default_box_slice(len)); + *decoder = SerdeDecoder::Tuple(default_box_slice(tuple_len)); decoder.populate(input, length) } - cold(&mut *self.decoder, &mut *self.input, len)?; + cold(&mut *self.decoder, &mut *self.input, tuple_len)?; } } let SerdeDecoder::Tuple(decoders) = &mut *self.decoder else { // Safety: see specify! macro which this is based on. unsafe { std::hint::unreachable_unchecked() }; }; - if decoders.len() != len { + if decoders.len() != tuple_len { type_changed!(); // Removes multiple bounds checks. } From 05ca48267c0b268980b408ec340ab4f13aeb2981 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sun, 10 Mar 2024 14:06:31 -0700 Subject: [PATCH 25/45] Remove pointless Box. --- src/serde/de.rs | 6 +++--- src/serde/ser.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index 3b9d007..7dca37d 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -38,7 +38,7 @@ pub use inner::deserialize; #[derive(Debug)] enum SerdeDecoder<'a> { Bool(BoolDecoder<'a>), - Enum(Box<(VariantDecoder<'a>, Vec>)>), // (variants, values) TODO only 1 allocation? + Enum((VariantDecoder<'a>, Vec>)), // (variants, values) F32(F32Decoder<'a>), // We don't need signed integer decoders here because unsigned ones work the same. Map(Box<(LengthDecoder<'a>, (SerdeDecoder<'a>, SerdeDecoder<'a>))>), // (lengths, (keys, values)) @@ -214,7 +214,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where V: Visitor<'de>, { - let (variant_decoder, decoders) = &mut **specify!(self, Enum); + let (variant_decoder, decoders) = specify!(self, Enum); let variant_index = variant_decoder.decode(); // Safety: populate guarantees `variant_decoder.max_variant_index() < decoders.len()`. let decoder = unsafe { decoders.get_unchecked_mut(variant_index as usize) }; @@ -499,7 +499,7 @@ impl<'a, 'de> EnumAccess<'de> for DecoderWrapper<'a, 'de> { where V: DeserializeSeed<'de>, { - let (variant_decoder, decoders) = &mut **specify!(self, Enum); + let (variant_decoder, decoders) = specify!(self, Enum); let variant_index = variant_decoder.decode(); // Safety: populate guarantees `variant_decoder.max_variant_index() < decoders.len()`. let decoder = unsafe { decoders.get_unchecked_mut(variant_index as usize) }; diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 1575466..172de26 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -56,7 +56,7 @@ pub use inner::serialize; #[derive(Debug)] enum SpecifiedEncoder { Bool(BoolEncoder), - Enum(Box<(VariantEncoder, Vec)>), // (variants, values) TODO only 1 allocation? + Enum((VariantEncoder, Vec)), // (variants, values) F32(F32Encoder), // Serialize needs separate signed integer encoders to be able to pack [0, -1, 0, -1, 0, -1]. I8(IntEncoder), From 81d637e7e7bf5406dd13efc3fc425160dc1e9a0c Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sun, 10 Mar 2024 18:27:53 -0700 Subject: [PATCH 26/45] Fix double free on panic in FastVec::{clear, reserve}. --- src/fast.rs | 48 +++++++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/fast.rs b/src/fast.rs index ff0833b..9080bc5 100644 --- a/src/fast.rs +++ b/src/fast.rs @@ -1,6 +1,6 @@ use std::fmt::{Debug, Formatter}; use std::marker::PhantomData; -use std::mem::MaybeUninit; +use std::mem::{ManuallyDrop, MaybeUninit}; pub type VecImpl = FastVec; pub type SliceImpl<'a, T> = FastSlice<'a, T>; @@ -80,7 +80,14 @@ impl FastVec { } pub fn clear(&mut self) { - self.mut_vec(Vec::clear); + // Safety: same as `Vec::clear` except `self.end = self.start` instead of `self.len = 0` but + // these are equivalent operations. Can't use `self.mut_vec(Vec::clear)` because T::drop + // panicking would double free elements. + unsafe { + let elems: *mut [T] = self.as_mut_slice(); + self.end = self.start; + std::ptr::drop_in_place(elems); + } } pub fn reserve(&mut self, additional: usize) { @@ -88,28 +95,31 @@ impl FastVec { #[cold] #[inline(never)] fn reserve_slow(me: &mut FastVec, additional: usize) { - me.mut_vec(|v| v.reserve(additional)); + // Safety: `Vec::reserve` panics on OOM without freeing Vec, so Vec is unmodified. + unsafe { + me.mut_vec(|v| { + // Optimizes out a redundant check in `Vec::reserve`. + // Safety: we've already ensured this condition before calling reserve_slow. + if additional <= v.capacity().wrapping_sub(v.len()) { + std::hint::unreachable_unchecked(); + } + v.reserve(additional); + }); + } } reserve_slow(self, additional); } } - pub fn resize(&mut self, new_len: usize, value: T) - where - T: Clone, - { - self.mut_vec(|v| v.resize(new_len, value)); - } - - /// Accesses the [`FastVec`] mutably as a [`Vec`]. TODO(unsound) panic in `f` causes double free. - fn mut_vec(&mut self, f: impl FnOnce(&mut Vec)) { - unsafe { - let copied = std::ptr::read(self as *mut FastVec); - let mut vec = Vec::from(copied); - f(&mut vec); - let copied = FastVec::from(vec); - std::ptr::write(self as *mut FastVec, copied); - } + /// Accesses the [`FastVec`] mutably as a [`Vec`]. + /// # Safety + /// If `f` panics the [`Vec`] must be unmodified. + unsafe fn mut_vec(&mut self, f: impl FnOnce(&mut Vec)) { + let copied = std::ptr::read(self as *mut FastVec); + let mut vec = ManuallyDrop::new(Vec::from(copied)); + f(&mut vec); + let copied = FastVec::from(ManuallyDrop::into_inner(vec)); + std::ptr::write(self as *mut FastVec, copied); } /// Get a pointer to write to without incrementing length. From a291a8be0e67a2b4112012c4ede5375fc9ff42d9 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sun, 10 Mar 2024 18:49:01 -0700 Subject: [PATCH 27/45] Add bincode benchmarks to cargo bench. --- src/benches.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/benches.rs b/src/benches.rs index 7e9c1c5..c3f2a5c 100644 --- a/src/benches.rs +++ b/src/benches.rs @@ -1,5 +1,6 @@ use rand::prelude::*; use rand_chacha::ChaCha20Rng; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use test::black_box; @@ -102,11 +103,18 @@ fn random_data(n: usize) -> Vec { (0..n).map(|_| rng.gen()).collect() } +// Use bincode fixint for benchmarks because it's faster than varint. +fn bincode_serialize(v: &(impl Serialize + ?Sized)) -> Vec { + bincode::serialize(v).unwrap() +} +fn bincode_deserialize(v: &[u8]) -> T { + bincode::deserialize(v).unwrap() +} + #[cfg(feature = "derive")] fn bitcode_encode(v: &(impl crate::Encode + ?Sized)) -> Vec { crate::encode(v) } - #[cfg(feature = "derive")] fn bitcode_decode(v: &[u8]) -> T { crate::decode(v).unwrap() @@ -116,9 +124,8 @@ fn bitcode_decode(v: &[u8]) -> T { fn bitcode_serialize(v: &(impl Serialize + ?Sized)) -> Vec { crate::serialize(v).unwrap() } - #[cfg(feature = "serde")] -fn bitcode_deserialize(v: &[u8]) -> T { +fn bitcode_deserialize(v: &[u8]) -> T { crate::deserialize(v).unwrap() } @@ -153,6 +160,7 @@ macro_rules! bench { } } +bench!(serialize, deserialize, bincode); #[cfg(feature = "serde")] bench!(serialize, deserialize, bitcode); #[cfg(feature = "derive")] @@ -225,11 +233,7 @@ mod tests { println!("| Format | Compression | Size (bytes) | Serialize (ns) | Deserialize (ns) |"); println!("|------------------|--------------|--------------|----------------|------------------|"); - print_results( - "bincode", - |v| bincode::serialize(v).unwrap(), - |v| bincode::deserialize(v).unwrap(), - ); + print_results("bincode", bincode_serialize, bincode_deserialize); print_results( "bincode-varint", |v| bincode::DefaultOptions::new().serialize(v).unwrap(), From fb8508d16e189f31d6b8c967b29cbab17f5a3645 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Sun, 10 Mar 2024 19:12:18 -0700 Subject: [PATCH 28/45] Move Length{Decoder, Encoder} out of Box for Map/Seq encoders to improve locality. --- src/serde/de.rs | 8 ++++---- src/serde/ser.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index 7dca37d..37ff8a8 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -41,8 +41,8 @@ enum SerdeDecoder<'a> { Enum((VariantDecoder<'a>, Vec>)), // (variants, values) F32(F32Decoder<'a>), // We don't need signed integer decoders here because unsigned ones work the same. - Map(Box<(LengthDecoder<'a>, (SerdeDecoder<'a>, SerdeDecoder<'a>))>), // (lengths, (keys, values)) - Seq(Box<(LengthDecoder<'a>, SerdeDecoder<'a>)>), // (lengths, values) + Map((LengthDecoder<'a>, Box<(SerdeDecoder<'a>, SerdeDecoder<'a>)>)), // (lengths, (keys, values)) + Seq((LengthDecoder<'a>, Box>)), // (lengths, values) Str(StrDecoder<'a>), Tuple(Box<[SerdeDecoder<'a>]>), // [field0, field1, ..] U8(IntDecoder<'a, u8>), @@ -257,7 +257,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where V: Visitor<'de>, { - let (length_decoder, decoder) = &mut **specify!(self, Seq); + let (length_decoder, decoder) = specify!(self, Seq); let len = length_decoder.decode(); struct Access<'a, 'de> { @@ -384,7 +384,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where V: Visitor<'de>, { - let (length_decoder, decoders) = &mut **specify!(self, Map); + let (length_decoder, decoders) = specify!(self, Map); let len = length_decoder.decode(); struct Access<'a, 'de> { diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 172de26..5183bda 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -64,8 +64,8 @@ enum SpecifiedEncoder { I32(IntEncoder), I64(IntEncoder), I128(IntEncoder), - Map(Box<(LengthEncoder, (LazyEncoder, LazyEncoder))>), // (lengths, (keys, values)) - Seq(Box<(LengthEncoder, LazyEncoder)>), // (lengths, values) + Map((LengthEncoder, Box<(LazyEncoder, LazyEncoder)>)), // (lengths, (keys, values)) + Seq((LengthEncoder, Box)), // (lengths, values) Str(StrEncoder), Tuple(Box<[LazyEncoder]>), // [field0, field1, ..] U8(IntEncoder), From 6ab27a68a19e3c4da8afb92723600b62890cdc4d Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Mon, 11 Mar 2024 20:32:57 -0700 Subject: [PATCH 29/45] Add non generic Buffer type which fixes DecodeBuffer lifetime issue. --- fuzz/fuzz_targets/fuzz.rs | 9 +- src/benches_borrowed.rs | 14 ++- src/buffer.rs | 247 ++++++++++++++++++++++++++++++++++++++ src/coder.rs | 1 + src/derive/mod.rs | 100 +++++---------- src/lib.rs | 24 +++- 6 files changed, 311 insertions(+), 84 deletions(-) create mode 100644 src/buffer.rs diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index f4285fd..b811cba 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -12,6 +12,7 @@ fuzz_target!(|data: &[u8]| { return; } let (start, data) = data.split_at(3); + let mut buffer = bitcode::Buffer::default(); macro_rules! test { ($typ1: expr, $typ2: expr, $data: expr, $($typ: ty),*) => { @@ -20,13 +21,11 @@ fuzz_target!(|data: &[u8]| { $( if j == $typ1 { if $typ2 == 0 { - let mut encode_buffer = bitcode::EncodeBuffer::<$typ>::default(); - let mut decode_buffer = bitcode::DecodeBuffer::<$typ>::default(); - let mut previous = None; for _ in 0..2 { - let current = if let Ok(de) = decode_buffer.decode(data) { - let data2 = encode_buffer.encode(&de); + let data = data.to_vec(); // Detect dangling pointers to data in buffer. + let current = if let Ok(de) = buffer.decode::<$typ>(&data) { + let data2 = buffer.encode::<$typ>(&de); let de2 = bitcode::decode::<$typ>(&data2).unwrap(); assert_eq!(de, de2); true diff --git a/src/benches_borrowed.rs b/src/benches_borrowed.rs index e0f47f9..d8bed42 100644 --- a/src/benches_borrowed.rs +++ b/src/benches_borrowed.rs @@ -83,7 +83,7 @@ fn bench_bincode_deserialize(b: &mut Bencher) { fn bench_bitcode_encode(b: &mut Bencher) { let data = bench_data(); let data = bench_data2(&data); - let mut buffer = crate::EncodeBuffer::default(); + let mut buffer = crate::Buffer::default(); b.iter(|| { black_box(buffer.encode(black_box(&data))); @@ -95,13 +95,17 @@ fn bench_bitcode_encode(b: &mut Bencher) { fn bench_bitcode_decode(b: &mut Bencher) { let data = bench_data(); let data = bench_data2(&data); - let mut encode_buffer = crate::EncodeBuffer::default(); + let mut encode_buffer = crate::Buffer::default(); let bytes = encode_buffer.encode(&data); - let mut decode_buffer = crate::DecodeBuffer::>::default(); - assert_eq!(decode_buffer.decode(bytes).unwrap(), data); + let mut decode_buffer = crate::Buffer::default(); + assert_eq!(decode_buffer.decode::>(bytes).unwrap(), data); b.iter(|| { - black_box(decode_buffer.decode(black_box(bytes)).unwrap()); + black_box( + decode_buffer + .decode::>(black_box(bytes)) + .unwrap(), + ); }) } diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..10e3d19 --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,247 @@ +use std::any::TypeId; + +/// A buffer for reusing allocations between calls to [`Buffer::encode`] and/or [`Buffer::decode`]. +/// TODO Send + Sync +/// +/// ```rust +/// use bitcode::{Buffer, Encode, Decode}; +/// +/// fn main() { +/// let original = "Hello world!"; +/// +/// let mut buffer = Buffer::new(); +/// buffer.encode(&original); +/// let encoded: &[u8] = buffer.encode(&original); // Won't allocate +/// +/// let mut buffer = Buffer::new(); +/// buffer.decode::<&str>(&encoded).unwrap(); +/// let decoded: &str = buffer.decode(&encoded).unwrap(); // Won't allocate +/// assert_eq!(original, decoded); +/// } +/// ``` +#[derive(Default)] +pub struct Buffer { + pub(crate) registry: Registry, + pub(crate) out: Vec, // Isn't stored in registry because all encoders can share this. +} + +impl Buffer { + /// Constructs a new buffer. + pub fn new() -> Self { + Self::default() + } +} + +// Set of arbitrary types. +#[derive(Default)] +pub(crate) struct Registry(Vec<(TypeId, ErasedBox)>); + +impl Registry { + /// Gets a `&mut T` if it already exists or initializes one with [`Default`]. + #[cfg(test)] + pub(crate) fn get(&mut self) -> &mut T { + // Safety: T is static. + unsafe { self.get_non_static::() } + } + + /// Like [`Registry::get`] but can get non-static types. + /// # Safety + /// Lifetimes are the responsibility of the caller. `&'static [u8]` and `&'a [u8]` are the same + /// type from the perspective of this function. + pub(crate) unsafe fn get_non_static(&mut self) -> &mut T { + // Use sorted Vec + binary search because we expect fewer insertions than lookups. + // We could use a HashMap, but that seems like overkill. + let type_id = non_static_type_id::(); + let i = match self.0.binary_search_by_key(&type_id, |(k, _)| *k) { + Ok(i) => i, + Err(i) => { + #[cold] + #[inline(never)] + unsafe fn cold(me: &mut Registry, i: usize) { + let type_id = non_static_type_id::(); + // Safety: caller of `Registry::get` upholds any lifetime requirements. + let erased = ErasedBox::new(T::default()); + me.0.insert(i, (type_id, erased)); + } + cold::(self, i); + i + } + }; + // Safety: binary_search_by_key either found item at `i` or cold initialized item at `i`. + let item = &mut self.0.get_unchecked_mut(i).1; + // Safety: type_id uniquely identifies the type, so the entry with equal type_id is a T. + item.cast_unchecked_mut() + } +} + +/// Ignores lifetimes in `T` when determining its [`TypeId`]. +/// https://github.com/rust-lang/rust/issues/41875#issuecomment-317292888 +fn non_static_type_id() -> TypeId { + use std::marker::PhantomData; + trait NonStaticAny { + fn get_type_id(&self) -> TypeId + where + Self: 'static; + } + impl NonStaticAny for PhantomData { + fn get_type_id(&self) -> TypeId + where + Self: 'static, + { + TypeId::of::() + } + } + let phantom_data = PhantomData::; + NonStaticAny::get_type_id(unsafe { + std::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data) + }) +} + +/// `Box` but of an unknown runtime `T`, requires unsafe to get the `T` back out. +struct ErasedBox { + ptr: *mut (), // Box + drop: *const (), // unsafe fn(*mut Box) +} + +impl ErasedBox { + /// Allocates a [`Box`] which doesn't know its own type. Only works on `T: Sized`. + /// # Safety + /// Ignores lifetimes so drop may be called after `T`'s lifetime has expired. + unsafe fn new(t: T) -> Self { + let ptr = Box::into_raw(Box::new(t)) as *mut (); + let drop = std::ptr::drop_in_place::> as *const (); + Self { ptr, drop } + } + + /// Casts to a `&mut T`. + /// # Safety + /// `T` must be the same `T` passed to [`ErasedBox::new`]. + unsafe fn cast_unchecked_mut(&mut self) -> &mut T { + &mut *(self.ptr as *mut T) + } +} + +impl Drop for ErasedBox { + fn drop(&mut self) { + // Safety: `ErasedBox::new` put a `Box` in self.ptr and an `unsafe fn(*mut Box)` in self.drop. + unsafe { + let drop: unsafe fn(*mut *mut ()) = std::mem::transmute(self.drop); + drop((&mut self.ptr) as *mut *mut ()); // Pass *mut Box. + } + } +} + +#[cfg(test)] +mod tests { + use super::{non_static_type_id, Buffer, ErasedBox, Registry}; + use test::{black_box, Bencher}; + + #[test] + fn buffer() { + let mut b = Buffer::new(); + assert_eq!(b.encode(&false), &[0]); + assert_eq!(b.encode(&true), &[1]); + assert_eq!(b.decode::(&[0]).unwrap(), false); + assert_eq!(b.decode::(&[1]).unwrap(), true); + } + + #[test] + fn registry() { + let mut r = Registry::default(); + assert_eq!(*r.get::(), 0); + *r.get::() = 1; + assert_eq!(*r.get::(), 1); + + assert_eq!(*r.get::(), 0); + *r.get::() = 5; + assert_eq!(*r.get::(), 5); + + assert_eq!(*r.get::(), 1); + } + + #[test] + fn type_id() { + assert_ne!(non_static_type_id::(), non_static_type_id::()); + assert_ne!(non_static_type_id::<()>(), non_static_type_id::<[(); 1]>()); + assert_ne!( + non_static_type_id::<&'static mut [u8]>(), + non_static_type_id::<&'static [u8]>() + ); + assert_ne!( + non_static_type_id::<*mut u8>(), + non_static_type_id::<*const u8>() + ); + fn f<'a>(_: &'a ()) { + assert_eq!( + non_static_type_id::<&'static [u8]>(), + non_static_type_id::<&'a [u8]>() + ); + assert_eq!( + non_static_type_id::<&'static ()>(), + non_static_type_id::<&'a ()>() + ); + } + f(&()); + } + + #[test] + fn erased_box() { + use std::rc::Rc; + let rc = Rc::new(()); + struct TestDrop(Rc<()>); + let b = unsafe { ErasedBox::new(TestDrop(Rc::clone(&rc))) }; + assert_eq!(Rc::strong_count(&rc), 2); + drop(b); + assert_eq!(Rc::strong_count(&rc), 1); + } + + macro_rules! register10 { + ($registry:ident $(, $t:literal)*) => { + $( + $registry.get::<[u8; $t]>(); + $registry.get::<[i8; $t]>(); + $registry.get::<[u16; $t]>(); + $registry.get::<[i16; $t]>(); + $registry.get::<[u32; $t]>(); + $registry.get::<[i32; $t]>(); + $registry.get::<[u64; $t]>(); + $registry.get::<[i64; $t]>(); + $registry.get::<[u128; $t]>(); + $registry.get::<[i128; $t]>(); + )* + } + } + type T = [u8; 1]; + + #[bench] + fn bench_registry1_get(b: &mut Bencher) { + let mut r = Registry::default(); + r.get::(); + assert_eq!(r.0.len(), 1); + b.iter(|| { + black_box(*black_box(&mut r).get::()); + }) + } + + #[bench] + fn bench_registry10_get(b: &mut Bencher) { + let mut r = Registry::default(); + r.get::(); + register10!(r, 1); + assert_eq!(r.0.len(), 10); + b.iter(|| { + black_box(*black_box(&mut r).get::()); + }) + } + + #[bench] + fn bench_registry100_get(b: &mut Bencher) { + let mut r = Registry::default(); + r.get::(); + register10!(r, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + assert_eq!(r.0.len(), 100); + b.iter(|| { + black_box(*black_box(&mut r).get::()); + }) + } +} diff --git a/src/coder.rs b/src/coder.rs index 79f7f18..0a5045e 100644 --- a/src/coder.rs +++ b/src/coder.rs @@ -4,6 +4,7 @@ use std::num::NonZeroUsize; pub type Result = std::result::Result; +/// TODO pick different name because it aliases with [`crate::buffer::Buffer`]. pub trait Buffer { /// Convenience function for `collect_into`. fn collect(&mut self) -> Vec { diff --git a/src/derive/mod.rs b/src/derive/mod.rs index 4f1439a..75fc6e8 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -47,18 +47,24 @@ pub trait Decode<'a>: Sized { pub trait DecodeOwned: for<'de> Decode<'de> {} impl DecodeOwned for T where T: for<'de> Decode<'de> {} +// Stop #[inline(always)] of Encoder::encode/Decoder::decode since 90% of the time is spent in these +// functions, and we don't want extra code interfering with optimizations. +#[inline(never)] +fn encode_inline_never(encoder: &mut T::Encoder, t: &T) { + encoder.encode(t); +} +#[inline(never)] +fn decode_inline_never<'a, T: Decode<'a>>(decoder: &mut T::Decoder) -> T { + decoder.decode() +} + /// Encodes a `T:` [`Encode`] into a [`Vec`]. /// /// **Warning:** The format is subject to change between major versions. pub fn encode(t: &T) -> Vec { let mut encoder = T::Encoder::default(); encoder.reserve(NonZeroUsize::new(1).unwrap()); - - #[inline(never)] - fn encode_inner(encoder: &mut T::Encoder, t: &T) { - encoder.encode(t); - } - encode_inner(&mut encoder, t); + encode_inline_never(&mut encoder, t); encoder.collect() } @@ -69,76 +75,32 @@ pub fn decode<'a, T: Decode<'a>>(mut bytes: &'a [u8]) -> Result { let mut decoder = T::Decoder::default(); decoder.populate(&mut bytes, 1)?; expect_eof(bytes)?; - #[inline(never)] - fn decode_inner<'a, T: Decode<'a>>(decoder: &mut T::Decoder) -> T { - decoder.decode() - } - Ok(decode_inner(&mut decoder)) -} - -/// A buffer for reusing allocations between multiple calls to [`EncodeBuffer::encode`]. -pub struct EncodeBuffer { - encoder: T::Encoder, - out: Vec, -} - -// #[derive(Default)] bounds T: Default. -impl Default for EncodeBuffer { - fn default() -> Self { - Self { - encoder: Default::default(), - out: Default::default(), - } - } + Ok(decode_inline_never(&mut decoder)) } -impl EncodeBuffer { - /// Encodes a `T:` [`Encode`] into a [`&[u8]`][`prim@slice`]. - /// - /// Can reuse allocations when called multiple times on the same [`EncodeBuffer`]. - /// - /// **Warning:** The format is subject to change between major versions. - pub fn encode<'a>(&'a mut self, t: &T) -> &'a [u8] { - // TODO dedup with encode. - self.encoder.reserve(NonZeroUsize::new(1).unwrap()); - #[inline(never)] - fn encode_inner(encoder: &mut T::Encoder, t: &T) { - encoder.encode(t); - } - encode_inner(&mut self.encoder, t); +impl crate::buffer::Buffer { + /// Like [`encode`], but saves allocations between calls. + pub fn encode<'a, T: Encode + ?Sized>(&'a mut self, t: &T) -> &'a [u8] { + // Safety: Encoders don't have any lifetimes (they don't contain T either). + let encoder = unsafe { self.registry.get_non_static::() }; + encoder.reserve(NonZeroUsize::new(1).unwrap()); + encode_inline_never(encoder, t); self.out.clear(); - self.encoder.collect_into(&mut self.out); + encoder.collect_into(&mut self.out); self.out.as_slice() } -} - -/// A buffer for reusing allocations between multiple calls to [`DecodeBuffer::decode`]. -/// -/// TODO don't bound [`DecodeBuffer`] to decode's `&'a [u8]`. -pub struct DecodeBuffer<'a, T: Decode<'a>>(>::Decoder); - -impl<'a, T: Decode<'a>> Default for DecodeBuffer<'a, T> { - fn default() -> Self { - Self(Default::default()) - } -} -impl<'a, T: Decode<'a>> DecodeBuffer<'a, T> { - /// Decodes a [`&[u8]`][`prim@slice`] into an instance of `T:` [`Decode`]. - /// - /// Can reuse allocations when called multiple times on the same [`DecodeBuffer`]. - /// - /// **Warning:** The format is subject to change between major versions. - pub fn decode(&mut self, mut bytes: &'a [u8]) -> Result { - // TODO dedup with decode. - self.0.populate(&mut bytes, 1)?; + /// Like [`decode`], but saves allocations between calls. + pub fn decode<'a, T: Decode<'a>>(&mut self, mut bytes: &'a [u8]) -> Result { + // Safety: Decoders have dangling pointers to `bytes` from previous calls which haven't been + // cleared. This isn't an issue in practice because they remain as pointers in FastSlice and + // aren't dereferenced. If we wanted to be safer we could clear all the decoders but this + // would result in lots of extra code to maintain and a performance/binary size hit. + // To detect misuse we run miri tests/cargo fuzz where bytes goes out of scope between calls. + let decoder = unsafe { self.registry.get_non_static::() }; + decoder.populate(&mut bytes, 1)?; expect_eof(bytes)?; - #[inline(never)] - fn decode_inner<'a, T: Decode<'a>>(decoder: &mut T::Decoder) -> T { - decoder.decode() - } - let ret = decode_inner(&mut self.0); - Ok(ret) + Ok(decode_inline_never(decoder)) } } diff --git a/src/lib.rs b/src/lib.rs index b1de80d..a8bb1af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ extern crate self as bitcode; extern crate test; mod bool; +mod buffer; mod coder; mod consume; mod derive; @@ -26,6 +27,7 @@ mod pack_ints; mod str; mod u8_char; +pub use crate::buffer::Buffer; pub use crate::derive::*; pub use crate::error::Error; @@ -68,7 +70,7 @@ macro_rules! bench_encode_decode { #[bench] fn [](b: &mut test::Bencher) { let data: $t = bench_data(); - let mut buffer = crate::EncodeBuffer::<_>::default(); + let mut buffer = crate::Buffer::default(); b.iter(|| { test::black_box(buffer.encode(test::black_box(&data))); }) @@ -78,12 +80,24 @@ macro_rules! bench_encode_decode { fn [](b: &mut test::Bencher) { let data: $t = bench_data(); let encoded = crate::encode(&data); - let mut buffer = crate::DecodeBuffer::<_>::default(); - b.iter(|| { + let mut buffer = crate::Buffer::default(); + + let mut f = || { + #[cfg(miri)] // Make sure dangling pointers aren't read due to Buffer. + let encoded = encoded.clone(); + let decoded: $t = buffer.decode(test::black_box(&encoded)).unwrap(); debug_assert_eq!(data, decoded); - decoded - }) + test::black_box(decoded); + }; + + // Make sure f gets called at least twice (b.iter() calls once with miri). + if cfg!(miri) { + f(); + f(); + } else { + b.iter(f); + } } )+ } From 58ae910b9ad10f9472560de77e36bd710a433f0d Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Tue, 12 Mar 2024 15:34:26 -0700 Subject: [PATCH 30/45] Add some inlines. --- src/pack_ints.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pack_ints.rs b/src/pack_ints.rs index 422b256..a43fe4f 100644 --- a/src/pack_ints.rs +++ b/src/pack_ints.rs @@ -62,9 +62,11 @@ pub trait Int: Copy + std::fmt::Debug + Default + Ord + Pod + Sized { // Unaligned native endian. TODO could be aligned on big endian since we always have to copy. type Une: Pod + Default; type Int: SizedInt; + #[inline] fn from_unaligned(unaligned: Self::Une) -> Self { bytemuck::must_cast(unaligned) } + #[inline] fn to_unaligned(self) -> Self::Une { bytemuck::must_cast(self) } @@ -174,9 +176,11 @@ macro_rules! impl_simple { fn write(v: Self, out: &mut Vec) { out.extend_from_slice(&v.to_le_bytes()); } + #[inline] fn wrapping_add(self, rhs: Self::Une) -> Self::Une { self.wrapping_add(Self::from_ne_bytes(rhs)).to_ne_bytes() } + #[inline] fn wrapping_sub(self, rhs: Self) -> Self { self.wrapping_sub(rhs) } From d53b292a694eedbae6c67c1b8c78fe1d55221e2f Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Tue, 12 Mar 2024 15:44:22 -0700 Subject: [PATCH 31/45] Fast path for tuple/struct with 1 field. --- src/serde/de.rs | 10 ++++++++++ src/serde/ser.rs | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/src/serde/de.rs b/src/serde/de.rs index 37ff8a8..3cb597f 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -306,6 +306,15 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { where V: Visitor<'de>, { + // Fast path: avoid overhead of tuple for 1 element. + if tuple_len == 1 { + return v.visit_seq(Access { + decoders: std::slice::from_mut(self.decoder), + input: self.input, + index: 0, + }); + } + // Copy of specify! macro that takes an additional tuple_len parameter to cold. match &mut self.decoder { SerdeDecoder::Tuple(_) => (), @@ -365,6 +374,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { Some(self.decoders.len()) } } + v.visit_seq(Access { decoders, input: &mut *self.input, diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 5183bda..9af687f 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -365,6 +365,14 @@ impl<'a> Serializer for EncoderWrapper<'a> { #[inline(always)] fn serialize_tuple(self, len: usize) -> Result { + // Fast path: avoid overhead of tuple for 1 element. + if len == 1 { + return Ok(TupleSerializer { + encoders: std::slice::from_mut(self.lazy), + index_alloc: self.index_alloc, + }); + } + // Copy of specify! macro that takes an additional len parameter to cold. let lazy = &mut *self.lazy; match lazy { From 26960e857029958946ab389b65b9365e453acabe Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Tue, 12 Mar 2024 22:30:03 -0700 Subject: [PATCH 32/45] Replace 4 unsafe transmutes with a safe abstraction. --- src/fast.rs | 14 ++++++++++++++ src/int.rs | 9 +++++---- src/pack.rs | 4 +--- src/pack_ints.rs | 17 ++++++++--------- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/fast.rs b/src/fast.rs index 9080bc5..a9ef379 100644 --- a/src/fast.rs +++ b/src/fast.rs @@ -422,6 +422,20 @@ impl<'borrowed, T> CowSlice<'borrowed, T> { self.slice = slice.into(); ret } + + /// Casts `&mut CowSlice` to `&mut CowSlice`. + #[inline] + pub fn cast_mut(&mut self) -> &mut CowSlice<'borrowed, B> + where + T: bytemuck::Pod, + B: bytemuck::Pod, + { + use std::mem::*; + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + // Safety: size/align are equal and both are bytemuck::Pod. + unsafe { transmute(self) } + } } pub struct SetOwned<'a, 'borrowed, T>(&'a mut CowSlice<'borrowed, T>); diff --git a/src/int.rs b/src/int.rs index 2df9373..e6ce2a0 100644 --- a/src/int.rs +++ b/src/int.rs @@ -13,10 +13,11 @@ pub struct IntEncoder(VecImpl); impl Encoder

for IntEncoder { #[inline(always)] fn as_primitive(&mut self) -> Option<&mut VecImpl

> { - assert_eq!(std::mem::size_of::(), std::mem::size_of::

()); - // Safety: T and P are the same size, T is Pod, and we aren't reading P. - let vec: &mut VecImpl

= unsafe { std::mem::transmute(&mut self.0) }; - Some(vec) + use std::mem::*; + assert_eq!(align_of::(), align_of::

()); + assert_eq!(size_of::(), size_of::

()); + // Safety: size/align are equal, T: Int implies Pod, and caller isn't reading P which may be NonZero. + unsafe { Some(transmute(&mut self.0)) } } #[inline(always)] diff --git a/src/pack.rs b/src/pack.rs index 9e0848c..f191f47 100644 --- a/src/pack.rs +++ b/src/pack.rs @@ -146,9 +146,7 @@ pub fn unpack_bytes<'a, T: Byte>( length: usize, out: &mut CowSlice<'a, T>, ) -> Result<()> { - // Safety: T is u8 or i8 which have same size/align and are Copy. - let out: &mut CowSlice<'a, u8> = unsafe { std::mem::transmute(out) }; - unpack_bytes_unsigned(input, length, out) + unpack_bytes_unsigned(input, length, out.cast_mut()) } /// [`unpack_bytes`] but after i8s have been cast to u8s. diff --git a/src/pack_ints.rs b/src/pack_ints.rs index a43fe4f..0685e50 100644 --- a/src/pack_ints.rs +++ b/src/pack_ints.rs @@ -94,8 +94,7 @@ macro_rules! impl_usize_and_isize { } fn with_output<'a>(out: &mut CowSlice<'a, Self::Une>, length: usize, f: impl FnOnce(&mut CowSlice<'a, ::Une>) -> Result<()>) -> Result<()> { if cfg!(target_pointer_width = "64") { - // Safety: isize::Une == i64::Une on 64 bit. - f(unsafe { std::mem::transmute(out) }) + f(out.cast_mut()) } else { // i64 to 32 bit isize on requires checked conversion. TODO reuse allocations. let mut out_i64 = CowSlice::default(); @@ -311,10 +310,12 @@ impl SizedUInt for u8 { fn pack8(v: &mut [Self], out: &mut Vec) { pack_bytes(v, out); } - fn unpack8(input: &mut &[u8], length: usize, out: &mut CowSlice<[u8; 1]>) -> Result<()> { - // Safety: [u8; 1] and u8 are the same from the perspective of CowSlice. - let out: &mut CowSlice = unsafe { std::mem::transmute(out) }; - unpack_bytes(input, length, out) + fn unpack8<'a>( + input: &mut &'a [u8], + length: usize, + out: &mut CowSlice<'a, [u8; 1]>, + ) -> Result<()> { + unpack_bytes(input, length, out.cast_mut::()) } } @@ -438,9 +439,7 @@ fn unpack_ints_sized<'a, T: SizedInt>( length: usize, out: &mut CowSlice<'a, T::Une>, ) -> Result<()> { - // Safety: T::Une and T::Unsigned::Une are the same type. - let out: &mut CowSlice<'a, _> = unsafe { std::mem::transmute(out) }; - unpack_ints_sized_unsigned::(input, length, out) + unpack_ints_sized_unsigned::(input, length, out.cast_mut()) } /// [`unpack_ints_sized`] but after signed integers have been cast to unsigned. From 9aaaf39e4db57e57b03a84eedeba878b088b8a82 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 13:33:28 -0700 Subject: [PATCH 33/45] Fix miri big-endian not compiling. --- .github/workflows/build.yml | 5 ++--- Cargo.toml | 3 +++ src/benches.rs | 4 ++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d946109..1f90980 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,6 +45,5 @@ jobs: run: cargo miri test --all-features - name: Setup Miri (big-endian) run: rustup target add mips64-unknown-linux-gnuabi64 && cargo miri setup --target mips64-unknown-linux-gnuabi64 -# TODO miri big-endian (zstd doesn't compile) -# - name: Test (miri big-endian) -# run: cargo miri test --target mips64-unknown-linux-gnuabi64 + - name: Test (miri big-endian) + run: cargo miri test --target mips64-unknown-linux-gnuabi64 diff --git a/Cargo.toml b/Cargo.toml index 0efbfde..eb3abe3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,9 @@ paste = "1.0.14" rand = "0.8.5" rand_chacha = "0.3.1" serde = { version = "1.0", features = [ "derive" ] } + +# zstd doesn't compile with miri big-endian. +[target.'cfg(not(miri))'.dev-dependencies] zstd = "0.13.0" [features] diff --git a/src/benches.rs b/src/benches.rs index c3f2a5c..e546954 100644 --- a/src/benches.rs +++ b/src/benches.rs @@ -258,7 +258,9 @@ mod compression { ("lz4", lz4_encode, lz4_decode), ("deflate-fast", deflate_fast_encode, deflate_decode), ("deflate-best", deflate_best_encode, deflate_decode), + #[cfg(not(miri))] // zstd doesn't compile with miri big-endian. ("zstd-0", zstd_encode::<0>, zstd_decode), + #[cfg(not(miri))] ("zstd-22", zstd_encode::<22>, zstd_decode), ]; @@ -288,10 +290,12 @@ mod compression { bytes } + #[cfg(not(miri))] fn zstd_encode(v: &[u8]) -> Vec { zstd::stream::encode_all(v, LEVEL).unwrap() } + #[cfg(not(miri))] fn zstd_decode(v: &[u8]) -> Vec { zstd::stream::decode_all(v).unwrap() } From 99b7b678af0ac909a50399c639e2f97770f29383 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 18:09:05 -0700 Subject: [PATCH 34/45] Handle errors instead of panicking in derive macro. --- bitcode_derive/src/attribute.rs | 5 +- bitcode_derive/src/bound.rs | 7 +- bitcode_derive/src/decode.rs | 275 ++++++++++++-------------------- bitcode_derive/src/encode.rs | 254 ++++++++++------------------- bitcode_derive/src/lib.rs | 37 ++--- bitcode_derive/src/shared.rs | 157 +++++++++++++++++- src/derive/mod.rs | 1 + 7 files changed, 362 insertions(+), 374 deletions(-) diff --git a/bitcode_derive/src/attribute.rs b/bitcode_derive/src/attribute.rs index a9e944d..31980f1 100644 --- a/bitcode_derive/src/attribute.rs +++ b/bitcode_derive/src/attribute.rs @@ -42,12 +42,12 @@ impl BitcodeAttr { return err(nested, "duplicate"); } *b = Some(bound_type); + Ok(()) } else { - return err(nested, "can only apply bound to fields"); + err(nested, "can only apply bound to fields") } } } - Ok(()) } } @@ -81,7 +81,6 @@ impl BitcodeAttrs { Ok(ret) } - #[allow(unused)] // TODO pub fn parse_variant(attrs: &[Attribute], _derive_attrs: &Self) -> Result { let mut ret = Self::new(AttrType::Variant); ret.parse_inner(attrs)?; diff --git a/bitcode_derive/src/bound.rs b/bitcode_derive/src/bound.rs index 232a742..99f6d83 100644 --- a/bitcode_derive/src/bound.rs +++ b/bitcode_derive/src/bound.rs @@ -23,15 +23,16 @@ impl FieldBounds { } } - pub fn apply_to_generics(self, generics: &mut syn::Generics) { + pub fn added_to(self, mut generics: syn::Generics) -> syn::Generics { for (bound, (fields, extra_bound_types)) in self.bounds { - *generics = with_bound(&fields, extra_bound_types, generics, &bound); + generics = with_bound(&fields, extra_bound_types, &generics, &bound); } + generics } } // Based on https://github.com/serde-rs/serde/blob/0c6a2bbf794abe966a4763f5b7ff23acb535eb7f/serde_derive/src/bound.rs#L94-L314 -pub fn with_bound( +fn with_bound( fields: &[syn::Field], extra_bound_types: Vec, generics: &syn::Generics, diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index 622e2fd..ab4b765 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -1,14 +1,10 @@ -use crate::attribute::BitcodeAttrs; -use crate::bound::FieldBounds; -use crate::shared::{ - destructure_fields, field_name, remove_lifetimes, replace_lifetimes, ReplaceSelves, -}; -use crate::{err, private}; +use crate::private; +use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use syn::{ - parse_quote, Data, DeriveInput, Fields, GenericParam, Lifetime, LifetimeParam, Path, - PredicateLifetime, Result, Type, WherePredicate, + parse_quote, GenericParam, Generics, Lifetime, LifetimeParam, Path, PredicateLifetime, Type, + WherePredicate, }; const DE_LIFETIME: &str = "__de"; @@ -17,8 +13,7 @@ fn de_lifetime() -> Lifetime { } #[derive(Copy, Clone)] -#[repr(u8)] -enum Item { +pub enum Item { Type, Default, Populate, @@ -35,7 +30,9 @@ impl Item { Self::DecodeInPlace, ]; const COUNT: usize = Self::ALL.len(); +} +impl crate::shared::Item for Item { fn field_impl( self, field_name: TokenStream, @@ -89,11 +86,11 @@ impl Item { } } - pub fn variant_impls( + fn enum_impl( self, variant_count: usize, - mut pattern: impl FnMut(usize) -> TokenStream, - mut inner: impl FnMut(Self, usize) -> TokenStream, + pattern: impl Fn(usize) -> TokenStream, + inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { // if variant_count is 0 or 1 variants don't have to be decoded. let decode_variants = variant_count > 1; @@ -146,9 +143,7 @@ impl Item { if inner.is_empty() { quote! {} } else { - let i: u8 = i - .try_into() - .expect("enums with more than 256 variants are not supported"); // TODO don't panic. + let i = variant_index(i); let length = decode_variants .then(|| { quote! { @@ -176,7 +171,7 @@ impl Item { unsafe { std::hint::unreachable_unchecked() } }; } - let mut pattern = |i: usize| { + let pattern = |i: usize| { let pattern = pattern(i); matches!(self, Self::DecodeInPlace) .then(|| { @@ -194,7 +189,7 @@ impl Item { .map(|i| { let inner = inner(item, i); let pattern = pattern(i); - let i: u8 = i.try_into().unwrap(); // Already checked in reserve impl. + let i = variant_index(i); quote! { #i => { #inner @@ -225,184 +220,110 @@ impl Item { } } } +} - // TODO dedup with encode.rs - fn field_impls( - self, - global_prefix: Option<&str>, - fields: &Fields, - parent_attrs: &BitcodeAttrs, - bounds: &mut FieldBounds, - ) -> Result { - fields - .iter() - .enumerate() - .map(move |(i, field)| { - let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?; - - let name = field_name(i, field, false); - let real_name = field_name(i, field, true); +pub struct Decode; +impl crate::shared::Derive<{ Item::COUNT }> for Decode { + type Item = Item; + const ALL: [Self::Item; Item::COUNT] = Item::ALL; - let global_name = global_prefix - .map(|global_prefix| { - let ident = - Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); - quote! { #ident } - }) - .unwrap_or_else(|| name.clone()); + fn bound(&self) -> Path { + let private = private(); + let de = de_lifetime(); + parse_quote!(#private::Decode<#de>) + } - let field_impl = self.field_impl(name, global_name, real_name, &field.ty); + fn derive_impl( + &self, + output: [TokenStream; Item::COUNT], + ident: Ident, + mut generics: Generics, + ) -> TokenStream { + let input_generics = generics.clone(); + let (_, input_generics, _) = input_generics.split_for_impl(); + let input_ty = quote! { #ident #input_generics }; - let private = private(); - let de = de_lifetime(); - let bound: Path = parse_quote!(#private::Decode<#de>); - bounds.add_bound_type(field.clone(), &field_attrs, bound); - Ok(field_impl) - }) - .collect() - } -} + // Add 'de lifetime after isolating input_generics. + let de = de_lifetime(); + let de_where_predicate = WherePredicate::Lifetime(PredicateLifetime { + lifetime: de.clone(), + colon_token: parse_quote!(:), + bounds: generics + .params + .iter() + .filter_map(|p| { + if let GenericParam::Lifetime(p) = p { + Some(p.lifetime.clone()) + } else { + None + } + }) + .collect(), + }); -struct Output([TokenStream; Item::COUNT]); + // Push de_param after bounding 'de: 'a. + let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone())); + generics.params.push(de_param.clone()); + generics + .make_where_clause() + .predicates + .push(de_where_predicate); -impl Output { - fn haunt(mut self) -> Self { - let type_ = &mut self.0[Item::Type as usize]; - if type_.is_empty() { - let de = de_lifetime(); - *type_ = quote! { __spooky: std::marker::PhantomData<&#de ()>, }; - } - let default = &mut self.0[Item::Default as usize]; - if default.is_empty() { - *default = quote! { __spooky: Default::default(), }; - } - self - } -} + let combined_generics = generics.clone(); + let (impl_generics, _, where_clause) = combined_generics.split_for_impl(); -pub fn derive_impl(mut input: DeriveInput) -> Result { - let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; - let mut generics = input.generics; - let mut bounds = FieldBounds::default(); + // Decoder can't contain any lifetimes from input (which would limit reuse of decoder). + remove_lifetimes(&mut generics); + generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it. + let (decoder_impl_generics, decoder_generics, decoder_where_clause) = + generics.split_for_impl(); - let ident = input.ident; - syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); - let output = (match input.data { - Data::Struct(data_struct) => { - let destructure_fields = &destructure_fields(&data_struct.fields); - Output(Item::ALL.map(|item| { - let field_impls = item - .field_impls(None, &data_struct.fields, &attrs, &mut bounds) - .unwrap(); // TODO don't unwrap - item.struct_impl(&ident, destructure_fields, &field_impls) - })) + let [mut type_body, mut default_body, populate_body, decode_in_place_body] = output; + if type_body.is_empty() { + type_body = quote! { __spooky: std::marker::PhantomData<&#de ()>, }; } - Data::Enum(data_enum) => { - let variant_count = data_enum.variants.len(); - Output(Item::ALL.map(|item| { - item.variant_impls( - variant_count, - |i| { - let variant = &data_enum.variants[i]; - let variant_name = &variant.ident; - let destructure_fields = destructure_fields(&variant.fields); - quote! { - #ident::#variant_name #destructure_fields - } - }, - |item, i| { - let variant = &data_enum.variants[i]; - let global_prefix = format!("{}_", &variant.ident); - let attrs = BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); // TODO don't unwrap. - item.field_impls(Some(&global_prefix), &variant.fields, &attrs, &mut bounds) - .unwrap() // TODO don't unwrap. - }, - ) - })) + if default_body.is_empty() { + default_body = quote! { __spooky: Default::default(), }; } - Data::Union(u) => err(&u.union_token, "unions are not supported")?, - }) - .haunt(); - bounds.apply_to_generics(&mut generics); - let input_generics = generics.clone(); - let (_, input_generics, _) = input_generics.split_for_impl(); - let input_ty = quote! { #ident #input_generics }; + let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site()); + let decoder_ty = quote! { #decoder_ident #decoder_generics }; + let private = private(); - // Add 'de lifetime after isolating input_generics. - let de = de_lifetime(); - let de_where_predicate = WherePredicate::Lifetime(PredicateLifetime { - lifetime: de.clone(), - colon_token: parse_quote!(:), - bounds: generics - .params - .iter() - .filter_map(|p| { - if let GenericParam::Lifetime(p) = p { - Some(p.lifetime.clone()) - } else { - None + quote! { + const _: () = { + impl #impl_generics #private::Decode<#de> for #input_ty #where_clause { + type Decoder = #decoder_ty; } - }) - .collect(), - }); - - // Push de_param after bounding 'de: 'a. - let de_param = GenericParam::Lifetime(LifetimeParam::new(de.clone())); - generics.params.push(de_param.clone()); - generics - .make_where_clause() - .predicates - .push(de_where_predicate); - - let combined_generics = generics.clone(); - let (impl_generics, _, where_clause) = combined_generics.split_for_impl(); - - // Decoder can't contain any lifetimes from input (which would limit reuse of decoder). - remove_lifetimes(&mut generics); - generics.params.push(de_param); // Re-add de_param since remove_lifetimes removed it. - let (decoder_impl_generics, decoder_generics, decoder_where_clause) = generics.split_for_impl(); - - let Output([type_body, default_body, populate_body, decode_in_place_body]) = output; - let decoder_ident = Ident::new(&format!("{ident}Decoder"), Span::call_site()); - let decoder_ty = quote! { #decoder_ident #decoder_generics }; - let private = private(); - let ret = quote! { - const _: () = { - impl #impl_generics #private::Decode<#de> for #input_ty #where_clause { - type Decoder = #decoder_ty; - } - - #[allow(non_snake_case)] - pub struct #decoder_ident #decoder_impl_generics #decoder_where_clause { - #type_body - } + #[allow(non_snake_case)] + pub struct #decoder_ident #decoder_impl_generics #decoder_where_clause { + #type_body + } - // Avoids bounding #impl_generics: Default. - impl #decoder_impl_generics std::default::Default for #decoder_ty #decoder_where_clause { - fn default() -> Self { - Self { - #default_body + // Avoids bounding #impl_generics: Default. + impl #decoder_impl_generics std::default::Default for #decoder_ty #decoder_where_clause { + fn default() -> Self { + Self { + #default_body + } } } - } - impl #decoder_impl_generics #private::View<#de> for #decoder_ty #decoder_where_clause { - fn populate(&mut self, input: &mut &#de [u8], __length: usize) -> #private::Result<()> { - #populate_body - Ok(()) + impl #decoder_impl_generics #private::View<#de> for #decoder_ty #decoder_where_clause { + fn populate(&mut self, input: &mut &#de [u8], __length: usize) -> #private::Result<()> { + #populate_body + Ok(()) + } } - } - impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause { - #[cfg_attr(not(debug_assertions), inline(always))] - fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) { - #decode_in_place_body + impl #impl_generics #private::Decoder<#de, #input_ty> for #decoder_ty #where_clause { + #[cfg_attr(not(debug_assertions), inline(always))] + fn decode_in_place(&mut self, out: &mut std::mem::MaybeUninit<#input_ty>) { + #decode_in_place_body + } } - } - }; - }; - // panic!("{ret}"); - Ok(ret) + }; + } + } } diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index 9338bdc..73b312f 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -1,15 +1,11 @@ -use crate::attribute::BitcodeAttrs; -use crate::bound::FieldBounds; -use crate::shared::{ - destructure_fields, field_name, remove_lifetimes, replace_lifetimes, ReplaceSelves, -}; -use crate::{err, private}; +use crate::private; +use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use syn::{parse_quote, Data, DeriveInput, Fields, Path, Result, Type}; +use syn::{parse_quote, Generics, Path, Type}; #[derive(Copy, Clone)] -enum Item { +pub enum Item { Type, Default, Encode, @@ -17,7 +13,6 @@ enum Item { CollectInto, Reserve, } - impl Item { const ALL: [Self; 6] = [ Self::Type, @@ -28,7 +23,8 @@ impl Item { Self::Reserve, ]; const COUNT: usize = Self::ALL.len(); - +} +impl crate::shared::Item for Item { fn field_impl( self, field_name: TokenStream, @@ -101,11 +97,11 @@ impl Item { } } - pub fn variant_impls( + fn enum_impl( self, variant_count: usize, - mut pattern: impl FnMut(usize) -> TokenStream, - mut inner: impl FnMut(Self, usize) -> TokenStream, + pattern: impl Fn(usize) -> TokenStream, + inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { // if variant_count is 0 or 1 variants don't have to be encoded. let encode_variants = variant_count > 1; @@ -139,9 +135,7 @@ impl Item { let variants: TokenStream = (0..variant_count) .map(|i| { let pattern = pattern(i); - let i: u8 = i - .try_into() - .expect("enums with more than 256 variants are not supported"); // TODO don't panic. + let i = variant_index(i); quote! { #pattern => #i, } @@ -157,8 +151,8 @@ impl Item { .unwrap_or_default(); let inners: TokenStream = (0..variant_count) .map(|i| { - // We don't know the exact number of this variant since there are more than one so we have to - // reserve one at a time. + // We don't know the exact number of this variant since there is more than + // one, so we have to reserve one at a time. let reserve = encode_variants .then(|| { let reserve = inner(Self::Reserve, i); @@ -189,7 +183,13 @@ impl Item { }) .unwrap_or_default() } - Self::EncodeVectored => unimplemented!(), // TODO encode enum vectored. + // This is a copy of Encode::encode_vectored's default impl (which provides no speedup). + // TODO optimize enum encode_vectored. + Self::EncodeVectored => quote! { + for t in i { + self.encode(t); + } + }, Self::CollectInto => { let variants = encode_variants .then(|| { @@ -217,170 +217,86 @@ impl Item { } } } - - fn field_impls( - self, - global_prefix: Option<&str>, - fields: &Fields, - parent_attrs: &BitcodeAttrs, - bounds: &mut FieldBounds, - ) -> Result { - fields - .iter() - .enumerate() - .map(move |(i, field)| { - let field_attrs = BitcodeAttrs::parse_field(&field.attrs, parent_attrs)?; - - let name = field_name(i, field, false); - let real_name = field_name(i, field, true); - - let global_name = global_prefix - .map(|global_prefix| { - let ident = - Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); - quote! { #ident } - }) - .unwrap_or_else(|| name.clone()); - - let field_impl = self.field_impl(name, global_name, real_name, &field.ty); - let private = private(); - let bound: Path = parse_quote!(#private::Encode); - bounds.add_bound_type(field.clone(), &field_attrs, bound); - Ok(field_impl) - }) - .collect() - } } -struct Output([TokenStream; Item::COUNT]); - -pub fn derive_impl(mut input: DeriveInput) -> Result { - let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; - let mut generics = input.generics; - let mut bounds = FieldBounds::default(); - - let ident = input.ident; - syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); - - let (output, is_encode_vectored) = match input.data { - Data::Struct(data_struct) => { - let destructure_fields = &destructure_fields(&data_struct.fields); - ( - Output(Item::ALL.map(|item| { - let field_impls = item - .field_impls(None, &data_struct.fields, &attrs, &mut bounds) - .unwrap(); // TODO don't unwrap - item.struct_impl(&ident, destructure_fields, &field_impls) - })), - true, - ) - } - Data::Enum(data_enum) => { - let variant_count = data_enum.variants.len(); - ( - Output(Item::ALL.map(|item| { - if matches!(item, Item::EncodeVectored) { - return Default::default(); // Unimplemented for now. - } - - item.variant_impls( - variant_count, - |i| { - let variant = &data_enum.variants[i]; - let variant_name = &variant.ident; - let destructure_fields = destructure_fields(&variant.fields); - quote! { - #ident::#variant_name #destructure_fields - } - }, - |item, i| { - let variant = &data_enum.variants[i]; - let global_prefix = format!("{}_", &variant.ident); - let attrs = - BitcodeAttrs::parse_variant(&variant.attrs, &attrs).unwrap(); // TODO don't unwrap. - item.field_impls( - Some(&global_prefix), - &variant.fields, - &attrs, - &mut bounds, - ) - .unwrap() // TODO don't unwrap. - }, - ) - })), - false, - ) - } - Data::Union(u) => err(&u.union_token, "unions are not supported")?, - }; +pub struct Encode; +impl crate::shared::Derive<{ Item::COUNT }> for Encode { + type Item = Item; + const ALL: [Self::Item; Item::COUNT] = Item::ALL; - bounds.apply_to_generics(&mut generics); - let input_generics = generics.clone(); - let (impl_generics, input_generics, where_clause) = input_generics.split_for_impl(); - let input_ty = quote! { #ident #input_generics }; + fn bound(&self) -> Path { + let private = private(); + parse_quote!(#private::Encode) + } - // Encoder can't contain any lifetimes from input (which would limit reuse of encoder). - remove_lifetimes(&mut generics); - let (encoder_impl_generics, encoder_generics, encoder_where_clause) = generics.split_for_impl(); + fn derive_impl( + &self, + output: [TokenStream; Item::COUNT], + ident: Ident, + mut generics: Generics, + ) -> TokenStream { + let input_generics = generics.clone(); + let (impl_generics, input_generics, where_clause) = input_generics.split_for_impl(); + let input_ty = quote! { #ident #input_generics }; - let Output( - [type_body, default_body, encode_body, encode_vectored_body, collect_into_body, reserve_body], - ) = output; - let encoder_ident = Ident::new(&format!("{ident}Encoder"), Span::call_site()); - let encoder_ty = quote! { #encoder_ident #encoder_generics }; - let private = private(); + // Encoder can't contain any lifetimes from input (which would limit reuse of encoder). + remove_lifetimes(&mut generics); + let (encoder_impl_generics, encoder_generics, encoder_where_clause) = + generics.split_for_impl(); - let encode_vectored = is_encode_vectored.then(|| quote! { - // #[cfg_attr(not(debug_assertions), inline(always))] - // #[inline(never)] - fn encode_vectored<'__v>(&mut self, i: impl Iterator + Clone) where #input_ty: '__v { - #[allow(unused_imports)] - use #private::Buffer as _; - #encode_vectored_body - } - }); + let [type_body, default_body, encode_body, encode_vectored_body, collect_into_body, reserve_body] = + output; + let encoder_ident = Ident::new(&format!("{ident}Encoder"), Span::call_site()); + let encoder_ty = quote! { #encoder_ident #encoder_generics }; + let private = private(); - let ret = quote! { - const _: () = { - impl #impl_generics #private::Encode for #input_ty #where_clause { - type Encoder = #encoder_ty; - } + quote! { + const _: () = { + impl #impl_generics #private::Encode for #input_ty #where_clause { + type Encoder = #encoder_ty; + } - #[allow(non_snake_case)] - pub struct #encoder_ident #encoder_impl_generics #encoder_where_clause { - #type_body - } + #[allow(non_snake_case)] + pub struct #encoder_ident #encoder_impl_generics #encoder_where_clause { + #type_body + } - // Avoids bounding #impl_generics: Default. - impl #encoder_impl_generics std::default::Default for #encoder_ty #encoder_where_clause { - fn default() -> Self { - Self { - #default_body + // Avoids bounding #impl_generics: Default. + impl #encoder_impl_generics std::default::Default for #encoder_ty #encoder_where_clause { + fn default() -> Self { + Self { + #default_body + } } } - } - impl #impl_generics #private::Encoder<#input_ty> for #encoder_ty #where_clause { - #[cfg_attr(not(debug_assertions), inline(always))] - fn encode(&mut self, v: &#input_ty) { - #[allow(unused_imports)] - use #private::Buffer as _; - #encode_body - } - #encode_vectored - } + impl #impl_generics #private::Encoder<#input_ty> for #encoder_ty #where_clause { + #[cfg_attr(not(debug_assertions), inline(always))] + fn encode(&mut self, v: &#input_ty) { + #[allow(unused_imports)] + use #private::Buffer as _; + #encode_body + } - impl #encoder_impl_generics #private::Buffer for #encoder_ty #encoder_where_clause { - fn collect_into(&mut self, out: &mut Vec) { - #collect_into_body + // #[cfg_attr(not(debug_assertions), inline(always))] + // #[inline(never)] + fn encode_vectored<'__v>(&mut self, i: impl Iterator + Clone) where #input_ty: '__v { + #[allow(unused_imports)] + use #private::Buffer as _; + #encode_vectored_body + } } - fn reserve(&mut self, __additional: std::num::NonZeroUsize) { - #reserve_body + impl #encoder_impl_generics #private::Buffer for #encoder_ty #encoder_where_clause { + fn collect_into(&mut self, out: &mut Vec) { + #collect_into_body + } + + fn reserve(&mut self, __additional: std::num::NonZeroUsize) { + #reserve_body + } } - } - }; - }; - // panic!("{ret}"); - Ok(ret) + }; + } + } } diff --git a/bitcode_derive/src/lib.rs b/bitcode_derive/src/lib.rs index a566032..2874eca 100644 --- a/bitcode_derive/src/lib.rs +++ b/bitcode_derive/src/lib.rs @@ -1,7 +1,10 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::shared::Derive; use proc_macro::TokenStream; use quote::quote; use syn::spanned::Spanned; -use syn::{parse_macro_input, DeriveInput}; +use syn::{parse_macro_input, DeriveInput, Error}; mod attribute; mod bound; @@ -9,27 +12,25 @@ mod decode; mod encode; mod shared; -#[proc_macro_derive(Encode, attributes(bitcode))] -pub fn derive_encode(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - encode::derive_impl(input) - .unwrap_or_else(syn::Error::into_compile_error) - .into() +macro_rules! derive { + ($fn_name:ident, $trait_:ident) => { + #[proc_macro_derive($trait_, attributes(bitcode))] + pub fn $fn_name(input: TokenStream) -> TokenStream { + $trait_ + .derive(parse_macro_input!(input as DeriveInput)) + .unwrap_or_else(Error::into_compile_error) + .into() + } + }; } +derive!(derive_encode, Encode); +derive!(derive_decode, Decode); -#[proc_macro_derive(Decode, attributes(bitcode))] -pub fn derive_decode(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - decode::derive_impl(input) - .unwrap_or_else(syn::Error::into_compile_error) - .into() +pub(crate) fn error(spanned: &impl Spanned, s: &str) -> Error { + Error::new(spanned.span(), s.to_owned()) } -pub(crate) fn error(spanned: &impl Spanned, s: &str) -> syn::Error { - syn::Error::new(spanned.span(), s.to_owned()) -} - -pub(crate) fn err(spanned: &impl Spanned, s: &str) -> Result { +pub(crate) fn err(spanned: &impl Spanned, s: &str) -> Result { Err(error(spanned, s)) } diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index 5841642..0bd28a6 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -1,9 +1,158 @@ +use crate::attribute::BitcodeAttrs; +use crate::bound::FieldBounds; +use crate::err; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::visit_mut::VisitMut; -use syn::{Field, Fields, GenericParam, Generics, Index, Lifetime, Type, WherePredicate}; +use syn::{ + Data, DataStruct, DeriveInput, Field, Fields, GenericParam, Generics, Index, Lifetime, Path, + Result, Type, WherePredicate, +}; -pub fn destructure_fields(fields: &Fields) -> TokenStream { +type VariantIndex = u8; +pub fn variant_index(i: usize) -> VariantIndex { + i.try_into().unwrap() +} + +pub trait Item: Copy + Sized { + fn field_impl( + self, + field_name: TokenStream, + global_field_name: TokenStream, + real_field_name: TokenStream, + field_type: &Type, + ) -> TokenStream; + + fn struct_impl( + self, + ident: &Ident, + destructure_fields: &TokenStream, + do_fields: &TokenStream, + ) -> TokenStream; + + fn enum_impl( + self, + variant_count: usize, + pattern: impl Fn(usize) -> TokenStream, + inner: impl Fn(Self, usize) -> TokenStream, + ) -> TokenStream; + + fn field_impls(self, global_prefix: Option<&str>, fields: &Fields) -> TokenStream { + fields + .iter() + .enumerate() + .map(move |(i, field)| { + let name = field_name(i, field, false); + let real_name = field_name(i, field, true); + let global_name = global_prefix + .map(|global_prefix| { + let ident = + Ident::new(&format!("{global_prefix}{name}"), Span::call_site()); + quote! { #ident } + }) + .unwrap_or_else(|| name.clone()); + + self.field_impl(name, global_name, real_name, &field.ty) + }) + .collect() + } +} + +pub trait Derive { + type Item: Item; + const ALL: [Self::Item; ITEM_COUNT]; + + /// `Encode` in `T: Encode`. + fn bound(&self) -> Path; + + /// Generates the derive implementation. + fn derive_impl( + &self, + output: [TokenStream; ITEM_COUNT], + ident: Ident, + generics: Generics, + ) -> TokenStream; + + fn field_attrs( + &self, + fields: &Fields, + attrs: &BitcodeAttrs, + bounds: &mut FieldBounds, + ) -> Result> { + fields + .iter() + .map(|field| { + let field_attrs = BitcodeAttrs::parse_field(&field.attrs, attrs)?; + bounds.add_bound_type(field.clone(), &field_attrs, self.bound()); + Ok(field_attrs) + }) + .collect() + } + + fn derive(&self, mut input: DeriveInput) -> Result { + let attrs = BitcodeAttrs::parse_derive(&input.attrs)?; + let ident = input.ident; + syn::visit_mut::visit_data_mut(&mut ReplaceSelves(&ident), &mut input.data); + let mut bounds = FieldBounds::default(); + + let output = match input.data { + Data::Struct(DataStruct { ref fields, .. }) => { + // Only used for adding `bounds`. Would be used by `#[bitcode(with_serde)]`. + let field_attrs = self.field_attrs(fields, &attrs, &mut bounds)?; + let _ = field_attrs; + + let destructure_fields = &destructure_fields(fields); + Self::ALL.map(|item| { + let field_impls = item.field_impls(None, fields); + item.struct_impl(&ident, destructure_fields, &field_impls) + }) + } + Data::Enum(data_enum) => { + let max_variants = VariantIndex::MAX as usize + 1; + if data_enum.variants.len() > max_variants { + return err( + &ident, + &format!("enums with more than {max_variants} variants are not supported"), + ); + } + + // Only used for adding `bounds`. Would be used by `#[bitcode(with_serde)]`. + let variant_attrs = data_enum + .variants + .iter() + .map(|variant| { + let attrs = BitcodeAttrs::parse_variant(&variant.attrs, &attrs)?; + self.field_attrs(&variant.fields, &attrs, &mut bounds) + }) + .collect::>>()?; + let _ = variant_attrs; + + Self::ALL.map(|item| { + item.enum_impl( + data_enum.variants.len(), + |i| { + let variant = &data_enum.variants[i]; + let variant_name = &variant.ident; + let destructure_fields = destructure_fields(&variant.fields); + quote! { + #ident::#variant_name #destructure_fields + } + }, + |item, i| { + let variant = &data_enum.variants[i]; + let global_prefix = format!("{}_", &variant.ident); + item.field_impls(Some(&global_prefix), &variant.fields) + }, + ) + }) + } + Data::Union(_) => err(&ident, "unions are not supported")?, + }; + Ok(self.derive_impl(output, ident, bounds.added_to(input.generics))) + } +} + +fn destructure_fields(fields: &Fields) -> TokenStream { let field_names = fields .iter() .enumerate() @@ -19,7 +168,7 @@ pub fn destructure_fields(fields: &Fields) -> TokenStream { } } -pub fn field_name(i: usize, field: &Field, real: bool) -> TokenStream { +fn field_name(i: usize, field: &Field, real: bool) -> TokenStream { field .ident .as_ref() @@ -60,7 +209,7 @@ impl VisitMut for ReplaceLifetimes<'_> { } } -pub struct ReplaceSelves<'a>(pub &'a Ident); +struct ReplaceSelves<'a>(pub &'a Ident); impl VisitMut for ReplaceSelves<'_> { fn visit_ident_mut(&mut self, ident: &mut Ident) { if ident == "Self" { diff --git a/src/derive/mod.rs b/src/derive/mod.rs index 75fc6e8..67a05b8 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -136,6 +136,7 @@ mod tests { A(u8), } + // cargo expand --lib --tests | grep -A15 Two #[derive(Encode, Decode)] enum Two { A(u8), From b8acaa19951731551d9eefac1440e4f636a6ef30 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 18:15:09 -0700 Subject: [PATCH 35/45] Remove fn main from doc tests. --- README.md | 18 ++++++++---------- src/buffer.rs | 18 ++++++++---------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 9197745..b725942 100644 --- a/README.md +++ b/README.md @@ -25,16 +25,14 @@ struct Foo<'a> { y: &'a str, } -fn main() { - let original = Foo { - x: 10, - y: "abc", - }; - - let encoded: Vec = bitcode::encode(&original); // No error - let decoded: Foo<'_> = bitcode::decode(&encoded).unwrap(); - assert_eq!(original, decoded); -} +let original = Foo { + x: 10, + y: "abc", +}; + +let encoded: Vec = bitcode::encode(&original); // No error +let decoded: Foo<'_> = bitcode::decode(&encoded).unwrap(); +assert_eq!(original, decoded); ``` ## Tuple vs Array diff --git a/src/buffer.rs b/src/buffer.rs index 10e3d19..cc8f1a2 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -6,18 +6,16 @@ use std::any::TypeId; /// ```rust /// use bitcode::{Buffer, Encode, Decode}; /// -/// fn main() { -/// let original = "Hello world!"; +/// let original = "Hello world!"; /// -/// let mut buffer = Buffer::new(); -/// buffer.encode(&original); -/// let encoded: &[u8] = buffer.encode(&original); // Won't allocate +/// let mut buffer = Buffer::new(); +/// buffer.encode(&original); +/// let encoded: &[u8] = buffer.encode(&original); // Won't allocate /// -/// let mut buffer = Buffer::new(); -/// buffer.decode::<&str>(&encoded).unwrap(); -/// let decoded: &str = buffer.decode(&encoded).unwrap(); // Won't allocate -/// assert_eq!(original, decoded); -/// } +/// let mut buffer = Buffer::new(); +/// buffer.decode::<&str>(&encoded).unwrap(); +/// let decoded: &str = buffer.decode(&encoded).unwrap(); // Won't allocate +/// assert_eq!(original, decoded); /// ``` #[derive(Default)] pub struct Buffer { From 7f2104c26e8da45e28290dfd0dfd6e49cf6c03fe Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 18:23:57 -0700 Subject: [PATCH 36/45] Improve impl Debug for Error. --- src/error.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index d87c47a..19f6e35 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,9 +34,13 @@ type ErrorImpl = (); /// In debug mode, the error contains a reason. /// # Release mode /// In release mode, the error is a zero-sized type for efficiency. -#[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub struct Error(ErrorImpl); +impl Debug for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Error({:?})", self.to_string()) + } +} impl Display for Error { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { #[cfg(debug_assertions)] From 9279196177bd0716e436ea23ba4bc59533590b3e Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 22:59:04 -0700 Subject: [PATCH 37/45] Fix documented unsound code in serde impl. --- src/serde/de.rs | 46 ++++++++++++++++++++++- src/serde/ser.rs | 98 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 136 insertions(+), 8 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index 3cb597f..94ba782 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -401,6 +401,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { decoders: &'a mut (SerdeDecoder<'de>, SerdeDecoder<'de>), input: &'a mut &'de [u8], len: usize, + key_deserialized: bool, } impl<'de> MapAccess<'de> for Access<'_, 'de> { @@ -414,6 +415,9 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { guard_zst::(self.len)?; if self.len != 0 { self.len -= 1; + // Safety: Make sure next_value_seed is called at most once after each len decrement. + // We don't care if DeserializeSeed fails after this (not critical to safety). + self.key_deserialized = true; Ok(Some(DeserializeSeed::deserialize( seed, DecoderWrapper { @@ -426,12 +430,17 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { } } - // TODO(unsound): could be called more than len times by buggy safe code and go out of bounds. #[inline(always)] fn next_value_seed(&mut self, seed: V) -> Result where V: DeserializeSeed<'de>, { + // Safety: Make sure next_value_seed is called at most once after each len decrement + // since only len values exist. + assert!( + std::mem::take(&mut self.key_deserialized), + "next_value_seed before next_key_seed" + ); DeserializeSeed::deserialize( seed, DecoderWrapper { @@ -440,6 +449,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { }, ) } + // TODO implement next_entry_seed to avoid checking key_deserialized. #[inline(always)] fn size_hint(&self) -> Option { @@ -451,6 +461,7 @@ impl<'de> Deserializer<'de> for DecoderWrapper<'_, 'de> { decoders, input: self.input, len, + key_deserialized: false, // No keys have been deserialized yet, so next_value_seed can't be called. }) } @@ -561,6 +572,8 @@ impl<'de> VariantAccess<'de> for DecoderWrapper<'_, 'de> { #[cfg(test)] mod tests { + use serde::de::MapAccess; + use serde::Deserializer; use std::collections::BTreeMap; #[test] @@ -621,4 +634,35 @@ mod tests { // Complex. test!(vec![(None, 3), (Some(4), 5)], Vec<(Option, u8)>); } + + #[test] + #[should_panic = "next_value_seed before next_key_seed"] + fn map_incorrect_len_values() { + let mut map = BTreeMap::new(); + map.insert(1u8, 2u8); + let input = crate::serialize(&map).unwrap(); + + let w = super::DecoderWrapper { + decoder: &mut super::SerdeDecoder::Unspecified { length: 1 }, + input: &mut input.as_slice(), + }; + + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = (); + fn expecting(&self, _: &mut std::fmt::Formatter) -> std::fmt::Result { + unreachable!() + } + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + assert_eq!(map.next_key::().unwrap().unwrap(), 1u8); + assert_eq!(map.next_value::().unwrap(), 2u8); + map.next_value::().unwrap(); + Ok(()) + } + } + w.deserialize_map(Visitor).unwrap(); + } } diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 9af687f..19c7c82 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -262,7 +262,7 @@ macro_rules! impl_ser { impl<'a> Serializer for EncoderWrapper<'a> { type Ok = (); type Error = Error; - type SerializeSeq = EncoderWrapper<'a>; + type SerializeSeq = SeqSerializer<'a>; type SerializeTuple = TupleSerializer<'a>; type SerializeTupleStruct = TupleSerializer<'a>; type SerializeTupleVariant = TupleSerializer<'a>; @@ -357,9 +357,10 @@ impl<'a> Serializer for EncoderWrapper<'a> { let b = specify!(self, Seq); b.0.encode(&len); b.1.reserve_fast(len); - Ok(Self { + Ok(SeqSerializer { lazy: &mut b.1, index_alloc: self.index_alloc, + len, }) } @@ -446,6 +447,8 @@ impl<'a> Serializer for EncoderWrapper<'a> { Ok(MapSerializer { encoders: &mut b.1, index_alloc: self.index_alloc, + len, + key_serialized: false, // No keys have been serialized yet, so serialize_value can't be called. }) } @@ -481,11 +484,18 @@ macro_rules! ok_error_end { }; } -impl SerializeSeq for EncoderWrapper<'_> { +struct SeqSerializer<'a> { + lazy: &'a mut LazyEncoder, + index_alloc: &'a mut usize, + len: usize, +} + +impl SerializeSeq for SeqSerializer<'_> { ok_error_end!(); - // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. #[inline(always)] fn serialize_element(&mut self, value: &T) -> Result<()> { + // Safety: Make sure safe code doesn't lie about len and cause UB since we've only reserved len elements. + self.len = self.len.checked_sub(1).expect("length mismatch"); value.serialize(EncoderWrapper { lazy: &mut *self.lazy, index_alloc: &mut *self.index_alloc, @@ -531,39 +541,50 @@ impl_tuple!(SerializeStructVariant, serialize_field, _key); struct MapSerializer<'a> { encoders: &'a mut (LazyEncoder, LazyEncoder), // (keys, values) index_alloc: &'a mut usize, + len: usize, + key_serialized: bool, } impl SerializeMap for MapSerializer<'_> { ok_error_end!(); - // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. #[inline(always)] fn serialize_key(&mut self, key: &T) -> Result<()> where T: Serialize, { + // Safety: Make sure safe code doesn't lie about len and cause UB since we've only reserved len keys/values. + self.len = self.len.checked_sub(1).expect("length mismatch"); + // Safety: Make sure serialize_value is called at most once after each serialize_key. + self.key_serialized = true; key.serialize(EncoderWrapper { lazy: &mut self.encoders.0, index_alloc: &mut *self.index_alloc, }) } - // TODO(unsound): could be called more than len times by buggy safe code but we only reserved len. #[inline(always)] fn serialize_value(&mut self, value: &T) -> Result<()> where T: Serialize, { + // Safety: Make sure serialize_value is called at most once after each serialize_key. + assert!( + std::mem::take(&mut self.key_serialized), + "serialize_value before serialize_key" + ); value.serialize(EncoderWrapper { lazy: &mut self.encoders.1, index_alloc: &mut *self.index_alloc, }) } + // TODO implement serialize_entry to avoid checking key_serialized. } #[cfg(test)] mod tests { - use serde::ser::SerializeTuple; + use serde::ser::{SerializeMap, SerializeSeq, SerializeTuple}; use serde::{Serialize, Serializer}; + use std::num::NonZeroUsize; #[test] fn enum_256_variants() { @@ -613,4 +634,67 @@ mod tests { } let _ = crate::serialize(&vec![TupleN(1), TupleN(2)]); } + + // Has to be a macro because it borrows something on the stack and returns it. + macro_rules! new_wrapper { + () => { + super::EncoderWrapper { + lazy: &mut super::LazyEncoder::Unspecified { + reserved: NonZeroUsize::new(1), + }, + index_alloc: &mut 0, + } + }; + } + + #[test] + fn seq_valid() { + let w = new_wrapper!(); + let mut seq = w.serialize_seq(Some(1)).unwrap(); + let _ = seq.serialize_element(&0u8); // serialize_seq 1 == serialize 1. + } + + #[test] + #[should_panic = "length mismatch"] + fn seq_incorrect_len() { + let w = new_wrapper!(); + let mut seq = w.serialize_seq(Some(1)).unwrap(); + let _ = seq.serialize_element(&0u8); // serialize_seq 1 != serialize 2. + let _ = seq.serialize_element(&0u8); + } + + #[test] + fn map_valid() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_key(&0u8); // serialize_map 1 == (key, value). + let _ = map.serialize_value(&0u8); + } + + #[test] + #[should_panic = "length mismatch"] + fn map_incorrect_len_keys() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_key(&0u8); // serialize_map 1 != (key, _) (key, _) + let _ = map.serialize_key(&0u8); + } + + #[test] + #[should_panic = "serialize_value before serialize_key"] + fn map_value_before_key() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_value(&0u8); + } + + #[test] + #[should_panic = "serialize_value before serialize_key"] + fn map_incorrect_len_values() { + let w = new_wrapper!(); + let mut map = w.serialize_map(Some(1)).unwrap(); + let _ = map.serialize_key(&0u8); // serialize_map 1 != (key, value) (_, value). + let _ = map.serialize_value(&0u8); + let _ = map.serialize_value(&0u8); + } } From 290149cc8baf3a3ac02666ab6d1b52260f5d86b6 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 23:01:15 -0700 Subject: [PATCH 38/45] clippy --- src/serde/ser.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 19c7c82..6fd3184 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -197,8 +197,8 @@ macro_rules! specify { // Either create the correct encoder if unspecified or panic if we already have an // encoder since it must be a different type. #[cold] - fn cold<'a>( - me: &'a mut LazyEncoder, + fn cold( + me: &mut LazyEncoder, index_alloc: &mut usize, ) { let &mut LazyEncoder::Unspecified { reserved } = me else { From d8f869df5bfc69256848aefb917592fa1aa2be22 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 23:22:21 -0700 Subject: [PATCH 39/45] clippy --- src/derive/array.rs | 2 +- src/derive/smart_ptr.rs | 2 +- src/derive/variant.rs | 2 +- src/derive/vec.rs | 2 +- src/ext/arrayvec.rs | 2 +- src/f32.rs | 2 +- src/fast.rs | 2 +- src/int.rs | 2 +- src/lib.rs | 1 + src/pack.rs | 4 ++-- src/pack_ints.rs | 2 +- src/serde/mod.rs | 2 +- src/serde/ser.rs | 2 +- src/serde/variant.rs | 2 +- src/str.rs | 8 ++++---- src/u8_char.rs | 2 +- 16 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/derive/array.rs b/src/derive/array.rs index 40a1b50..a6de65d 100644 --- a/src/derive/array.rs +++ b/src/derive/array.rs @@ -61,7 +61,7 @@ impl<'a, T: Decode<'a>, const N: usize> Decoder<'a, [T; N]> for ArrayDecoder<'a, // Safety: Equivalent to nightly MaybeUninit::transpose. let out = unsafe { &mut *(out.as_mut_ptr() as *mut [MaybeUninit; N]) }; for out in out { - self.0.decode_in_place(out) + self.0.decode_in_place(out); } } } diff --git a/src/derive/smart_ptr.rs b/src/derive/smart_ptr.rs index 323cfd1..93198c5 100644 --- a/src/derive/smart_ptr.rs +++ b/src/derive/smart_ptr.rs @@ -15,7 +15,7 @@ impl Default for DerefEncoder { impl, T: Encode + ?Sized> Encoder for DerefEncoder { #[inline(always)] fn encode(&mut self, t: &D) { - self.0.encode(t) + self.0.encode(t); } } diff --git a/src/derive/variant.rs b/src/derive/variant.rs index f1a12c5..cbd6128 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -21,7 +21,7 @@ impl Buffer for VariantEncoder { } fn reserve(&mut self, additional: NonZeroUsize) { - self.0.reserve(additional.get()) + self.0.reserve(additional.get()); } } diff --git a/src/derive/vec.rs b/src/derive/vec.rs index 45d619d..b5cfa06 100644 --- a/src/derive/vec.rs +++ b/src/derive/vec.rs @@ -278,7 +278,7 @@ macro_rules! decode_body { impl Encoder> for VecEncoder { #[inline(always)] fn encode(&mut self, v: &Vec) { - self.encode(v.as_slice()) + self.encode(v.as_slice()); } #[inline(always)] diff --git a/src/ext/arrayvec.rs b/src/ext/arrayvec.rs index a2ad2ef..794fffc 100644 --- a/src/ext/arrayvec.rs +++ b/src/ext/arrayvec.rs @@ -119,7 +119,7 @@ fn as_slice_assert_len(t: &ArrayVec) -> &[T] { impl Encoder> for VecEncoder { #[inline(always)] fn encode(&mut self, t: &ArrayVec) { - self.encode(as_slice_assert_len(t)) + self.encode(as_slice_assert_len(t)); } #[inline(always)] fn encode_vectored<'a>(&mut self, i: impl Iterator> + Clone) diff --git a/src/f32.rs b/src/f32.rs index 782edb9..1d4dfc3 100644 --- a/src/f32.rs +++ b/src/f32.rs @@ -71,7 +71,7 @@ impl Buffer for F32Encoder { } fn reserve(&mut self, additional: NonZeroUsize) { - self.0.reserve(additional.get()) + self.0.reserve(additional.get()); } } diff --git a/src/fast.rs b/src/fast.rs index a9ef379..22f27d0 100644 --- a/src/fast.rs +++ b/src/fast.rs @@ -170,7 +170,7 @@ impl PushUnchecked for Vec { debug_assert!(n < self.capacity()); let end = self.as_mut_ptr().add(n); std::ptr::write(end, t); - self.set_len(n + 1) + self.set_len(n + 1); } } diff --git a/src/int.rs b/src/int.rs index e6ce2a0..7582bb3 100644 --- a/src/int.rs +++ b/src/int.rs @@ -34,7 +34,7 @@ impl Buffer for IntEncoder { } fn reserve(&mut self, additional: NonZeroUsize) { - self.0.reserve(additional.get()) + self.0.reserve(additional.get()); } } diff --git a/src/lib.rs b/src/lib.rs index a8bb1af..5ffbb4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::items_after_test_module, clippy::blocks_in_if_conditions)] +#![warn(clippy::semicolon_if_nothing_returned)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(test, feature(test))] #![doc = include_str!("../README.md")] diff --git a/src/pack.rs b/src/pack.rs index f191f47..11c7dc1 100644 --- a/src/pack.rs +++ b/src/pack.rs @@ -177,7 +177,7 @@ fn unpack_bytes_unsigned<'a>( Packing::_4 => unpack_arithmetic::<4>(input, length, out)?, Packing::_3 => unpack_arithmetic::<3>(input, length, out)?, Packing::_2 => unpack_arithmetic::<2>(input, length, out)?, - _ => unreachable!(), + Packing::_256 => unreachable!(), } if let Some(min) = min { for v in out { @@ -343,7 +343,7 @@ pub fn unpack_bytes_less_than<'a, const N: usize, const HISTOGRAM: usize>( Packing::_4 => unpack_arithmetic_less_than::(input, length, out), Packing::_3 => unpack_arithmetic_less_than::(input, length, out), Packing::_2 => unpack_arithmetic_less_than::(input, length, out), - _ => unreachable!(), + Packing::_256 => unreachable!(), } } diff --git a/src/pack_ints.rs b/src/pack_ints.rs index 0685e50..561c200 100644 --- a/src/pack_ints.rs +++ b/src/pack_ints.rs @@ -468,7 +468,7 @@ fn unpack_ints_sized_unsigned<'a, T: SizedUInt>( for v in out.iter_mut() { *v = min.wrapping_add(*v); } - }) + }); } Ok(()) } diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 50e6e40..16cdefc 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -29,7 +29,7 @@ fn get_mut_or_resize(vec: &mut Vec, index: usize) -> &mut T { #[cold] #[inline(never)] fn cold(vec: &mut Vec, index: usize) { - vec.resize_with(index + 1, Default::default) + vec.resize_with(index + 1, Default::default); } cold(vec, index); } diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 6fd3184..d543344 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -162,7 +162,7 @@ impl LazyEncoder { SpecifiedEncoder::U32(v) => v, SpecifiedEncoder::U64(v) => v, SpecifiedEncoder::U128(v) => v, - }) + }); } Self::Unspecified { .. } => (), } diff --git a/src/serde/variant.rs b/src/serde/variant.rs index 3b4336e..35651c1 100644 --- a/src/serde/variant.rs +++ b/src/serde/variant.rs @@ -23,7 +23,7 @@ impl Buffer for VariantEncoder { } fn reserve(&mut self, additional: NonZeroUsize) { - self.data.reserve(additional.get()) + self.data.reserve(additional.get()); } } diff --git a/src/str.rs b/src/str.rs index f786d7a..71cc039 100644 --- a/src/str.rs +++ b/src/str.rs @@ -34,7 +34,7 @@ impl Encoder for StrEncoder { #[inline(always)] fn encode_vectored<'a>(&mut self, i: impl Iterator + Clone) { - self.0.encode_vectored(i.map(str_as_u8_chars)) + self.0.encode_vectored(i.map(str_as_u8_chars)); } } @@ -50,7 +50,7 @@ impl<'b> Encoder<&'b str> for StrEncoder { where &'b str: 'a, { - self.encode_vectored(i.copied()) + self.encode_vectored(i.copied()); } } @@ -65,7 +65,7 @@ impl Encoder for StrEncoder { where String: 'a, { - self.encode_vectored(i.map(String::as_str)) + self.encode_vectored(i.map(String::as_str)); } } @@ -143,7 +143,7 @@ fn is_ascii_simd(v: &[u8]) -> bool { for chunk in chunks_exact { let mut any = false; for &v in chunk { - any |= v & 0x80 != 0 + any |= v & 0x80 != 0; } if any { debug_assert!(!chunk.is_ascii()); diff --git a/src/u8_char.rs b/src/u8_char.rs index d768b5d..7faaf4b 100644 --- a/src/u8_char.rs +++ b/src/u8_char.rs @@ -38,6 +38,6 @@ impl Buffer for U8CharEncoder { } fn reserve(&mut self, additional: NonZeroUsize) { - self.0.reserve(additional.get()) + self.0.reserve(additional.get()); } } From e832667b4395687491c075556dd03255a5518171 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Wed, 13 Mar 2024 23:40:07 -0700 Subject: [PATCH 40/45] Release 0.6.0-beta.1 --- Cargo.toml | 4 ++-- bitcode_derive/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index eb3abe3..6591dd6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ [package] name = "bitcode" authors = [ "Cai Bear", "Finn Bear" ] -version = "0.6.0-alpha.2" +version = "0.6.0-beta.1" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode" @@ -15,7 +15,7 @@ exclude = ["fuzz/"] [dependencies] arrayvec = { version = "0.7", default-features = false, optional = true } -bitcode_derive = { version = "0.6.0-alpha.1", path = "./bitcode_derive", optional = true } +bitcode_derive = { version = "0.6.0-beta.1", path = "./bitcode_derive", optional = true } bytemuck = { version = "1.14", features = [ "min_const_generics", "must_cast" ] } glam = { version = "0.22", default-features = false, features = [ "std" ], optional = true } serde = { version = "1.0", optional = true } diff --git a/bitcode_derive/Cargo.toml b/bitcode_derive/Cargo.toml index 0c0a414..3996246 100644 --- a/bitcode_derive/Cargo.toml +++ b/bitcode_derive/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bitcode_derive" authors = [ "Cai Bear", "Finn Bear" ] -version = "0.6.0-alpha.1" +version = "0.6.0-beta.1" edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/SoftbearStudios/bitcode/" From c22dbbf4c1da566ee92f9c3f8ff4345b2514d4ed Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Thu, 14 Mar 2024 13:59:07 -0700 Subject: [PATCH 41/45] Fix panic on empty array encode. --- src/derive/array.rs | 3 +++ src/derive/mod.rs | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/derive/array.rs b/src/derive/array.rs index a6de65d..49dc6b2 100644 --- a/src/derive/array.rs +++ b/src/derive/array.rs @@ -30,6 +30,9 @@ impl Buffer for ArrayEncoder { } fn reserve(&mut self, additional: NonZeroUsize) { + if N == 0 { + return; // self.0.reserve takes NonZeroUsize and `additional * N == 0`. + } self.0.reserve( additional .checked_mul(NonZeroUsize::new(N).unwrap()) diff --git a/src/derive/mod.rs b/src/derive/mod.rs index 67a05b8..a61bd40 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -112,9 +112,10 @@ mod tests { fn decode() { macro_rules! test { ($v:expr, $t:ty) => { - let encoded = super::encode::<$t>(&$v); + let v = $v; + let encoded = super::encode::<$t>(&v); println!("{:<24} {encoded:?}", stringify!($t)); - assert_eq!($v, super::decode::<$t>(&encoded).unwrap()); + assert_eq!(v, super::decode::<$t>(&encoded).unwrap()); }; } @@ -126,6 +127,7 @@ mod tests { test!([0, 1], [u8; 2]); test!([0, 1, 2], [u8; 3]); test!([0, -1, 0, -1, 0, -1, 0], [i8; 7]); + test!([], [u8; 0]); } #[derive(Encode, Decode)] From 1328a36ca9cccb37689821207db25b8f5cf37bfa Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Thu, 14 Mar 2024 14:05:00 -0700 Subject: [PATCH 42/45] Add empty array to serde tests. --- src/serde/de.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/serde/de.rs b/src/serde/de.rs index 94ba782..60419f7 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -580,9 +580,10 @@ mod tests { fn deserialize() { macro_rules! test { ($v:expr, $t:ty) => { - let ser = crate::serialize::<$t>(&$v).unwrap(); + let v = $v; + let ser = crate::serialize::<$t>(&v).unwrap(); println!("{:<24} {ser:?}", stringify!($t)); - assert_eq!($v, crate::deserialize::<$t>(&ser).unwrap()); + assert_eq!(v, crate::deserialize::<$t>(&ser).unwrap()); }; } // Primitives @@ -630,6 +631,7 @@ mod tests { // Tuples test!((1u8, 2u8, 3u8), (u8, u8, u8)); test!([1u8, 2u8, 3u8], [u8; 3]); + test!([], [u8; 0]); // Complex. test!(vec![(None, 3), (Some(4), 5)], Vec<(Option, u8)>); From e9b92afb0125285b4ef0d482f45b066ec45956cf Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Thu, 14 Mar 2024 14:27:33 -0700 Subject: [PATCH 43/45] Add library example. --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index b725942..00fd753 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,21 @@ let decoded: Foo<'_> = bitcode::decode(&encoded).unwrap(); assert_eq!(original, decoded); ``` +## Library Example + +Add bitcode to libraries without specifying the major version so binary crates can use any version. +This is a minimal stable subset of the bitcode API so avoid using any other functionality. +```toml +bitcode = { version = "0", features = ["derive"], default-features = false, optional = true } +``` +```rust +#[cfg_attr(feature = "bitcode", derive(bitcode::Encode, bitcode::Decode))] +pub struct Vec2 { + x: f32, + y: f32, +} +``` + ## Tuple vs Array If you have multiple values of the same type: - Use a tuple or struct when the values are semantically different: `x: u32, y: u32` From 46febfc11d0cb8a145808ce15bf9b033982be66d Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Thu, 14 Mar 2024 14:29:49 -0700 Subject: [PATCH 44/45] Edit previous. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 00fd753..0d12042 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ assert_eq!(original, decoded); ## Library Example -Add bitcode to libraries without specifying the major version so binary crates can use any version. +Add bitcode to libraries without specifying the major version so binary crates can pick the version. This is a minimal stable subset of the bitcode API so avoid using any other functionality. ```toml bitcode = { version = "0", features = ["derive"], default-features = false, optional = true } From 63f5ed57e03fe72c611a43505749fd4964ceac11 Mon Sep 17 00:00:00 2001 From: Cai Bear Date: Fri, 15 Mar 2024 21:04:44 -0700 Subject: [PATCH 45/45] Improve safety comments. --- src/coder.rs | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/coder.rs b/src/coder.rs index 0a5045e..aa5ef3e 100644 --- a/src/coder.rs +++ b/src/coder.rs @@ -16,7 +16,7 @@ pub trait Buffer { /// Collects the buffer into a single `Vec`. This clears the buffer. fn collect_into(&mut self, out: &mut Vec); - /// Reserves space for `additional` calls to [`Encoder::encode`]. May be a no-op. Takes a NonZeroUsize to avoid + /// Reserves space for `additional` calls to `self.encode()`. Takes a [`NonZeroUsize`] to avoid /// useless calls. fn reserve(&mut self, additional: NonZeroUsize); } @@ -35,10 +35,15 @@ pub trait Encoder: Buffer + Default { } /// Encodes a single value. Can't error since anything can be encoded. + /// # Safety + /// Can only encode `self.reserve(additional)` items. fn encode(&mut self, t: &T); - /// Calls [`Self::encode`] once for every item in `i`. Only use this with **FAST** iterators. + /// Calls [`Self::encode`] once for every item in `i`. Only use this with **FAST** iterators + /// since the iterator may be iterated multiple times. /// # Safety + /// Can only encode `self.reserve(additional)` items. + /// /// `i` must have an accurate `i.size_hint().1.unwrap()` that != 0 and is <= `MAX_VECTORED_CHUNK`. /// Currently, the non-map iterators that uphold these requirements are: /// - vec.rs @@ -54,14 +59,15 @@ pub trait Encoder: Buffer + Default { } pub trait View<'a> { - /// Reads `length` items out of `input` provisioning `length` calls to [`Decoder::decode`]. This overwrites the view. + /// Reads `length` items out of `input`, overwriting the view. If it returns `Ok`, + /// `self.decode()` can be called called `length` times. fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()>; } /// One of [`Decoder::decode`] and [`Decoder::decode_in_place`] must be implemented or calling /// either one will result in infinite recursion and a stack overflow. pub trait Decoder<'a, T>: View<'a> + Default { - /// Returns a `Some(ptr)` to the current element if it can be decoded by copying. + /// Returns a pointer to the current element if it can be decoded by copying. #[inline(always)] fn as_primitive_ptr(&self) -> Option<*const u8> { None @@ -69,7 +75,7 @@ pub trait Decoder<'a, T>: View<'a> + Default { /// Assuming [`Self::as_primitive_ptr`] returns `Some(ptr)`, this advances `ptr` by `n`. /// # Safety - /// All advances and decodes must not pass `self.populate(_, length)`. + /// Can only decode `self.populate(_, length)` items. unsafe fn as_primitive_advance(&mut self, n: usize) { let _ = n; unreachable!(); @@ -77,6 +83,8 @@ pub trait Decoder<'a, T>: View<'a> + Default { /// Decodes a single value. Can't error since `View::populate` has already validated the input. /// Prefer decode for primitives (since it's simpler) and decode_in_place for array/struct/tuple. + /// # Safety + /// Can only decode `self.populate(_, length)` items. #[inline(always)] fn decode(&mut self) -> T { let mut out = MaybeUninit::uninit(); @@ -87,6 +95,8 @@ pub trait Decoder<'a, T>: View<'a> + Default { /// [`Self::decode`] without redundant copies. Only downside is panics will leak the value. /// The only panics out of our control are Hash/Ord/PartialEq for BinaryHeap/BTreeMap/HashMap. /// E.g. if a user PartialEq panics we will leak some memory which is an acceptable tradeoff. + /// # Safety + /// Can only decode `self.populate(_, length)` items. #[inline(always)] fn decode_in_place(&mut self, out: &mut MaybeUninit) { out.write(self.decode());