diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 00000000..f560da70 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,72 @@ +name: Rust + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +env: + CARGO_TERM_COLOR: always +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: stable + components: clippy,rustfmt + - name: Print versions + run: | + cargo --version + rustc --version + clippy-driver --version + rustfmt --version + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose + - name: Run clippy + run: cargo clippy --verbose --all-targets -- -D clippy::all + - name: Check code formatting + run: cargo fmt --verbose --all -- --check + + miri-test: + name: Test with miri + runs-on: ubuntu-latest + env: + MIRIFLAGS: -Zmiri-disable-isolation + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: nightly + components: miri + - run: cargo miri test --verbose --no-default-features + - run: cargo miri test --verbose --all-features + + sanitizer-test: + name: Test with -Zsanitizer=${{ matrix.sanitizer }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + sanitizer: [address, thread, leak] + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: nightly + components: rust-src + - name: Test with sanitizer + env: + RUSTFLAGS: -Zsanitizer=${{ matrix.sanitizer }} + RUSTDOCFLAGS: -Zsanitizer=${{ matrix.sanitizer }} + # only needed by asan + ASAN_OPTIONS: detect_stack_use_after_return=1,detect_leaks=0 + # Asan's leak detection occasionally complains + # about some small leaks if backtraces are captured, + # so ensure they're not + RUST_BACKTRACE: 0 + run: cargo test -Zbuild-std --verbose --target=x86_64-unknown-linux-gnu diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 591417ad..00000000 --- a/.travis.yml +++ /dev/null @@ -1,24 +0,0 @@ -language: rust -cache: cargo -matrix: - include: - - name: "Tests and lints" - before_install: - - rustup component add clippy rustfmt - script: - - cargo test - - cargo clippy --all-targets -- -D warnings - - cargo fmt -- --check - - name: "C tests on Windows" - os: windows - script: - - cargo test -p wirefilter-ffi-ctests - - name: "C tests on OS X" - os: osx - script: - - cargo test -p wirefilter-ffi-ctests - - name: "WASM build" - before_install: - - curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh -s -- -f - script: - - wasm-pack build wasm --debug --target browser --scope cloudflare diff --git a/Cargo.lock b/Cargo.lock index 624a12b9..bc024c51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,992 +1,1049 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -[[package]] -name = "aho-corasick" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", -] +version = 3 [[package]] -name = "arrayvec" -version = "0.4.10" +name = "addr2line" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ - "nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", + "gimli", ] [[package]] -name = "atty" -version = "0.2.11" +name = "adler2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", - "termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)", - "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] -name = "backtrace" -version = "0.3.9" +name = "afl" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c80b57a86234ee3e9238f5f2d33d37f8fd5c7ff168c07f2d5147d410e86db33" dependencies = [ - "backtrace-sys 0.1.23 (registry+https://github.com/rust-lang/crates.io-index)", - "cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", - "rustc-demangle 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", - "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", + "home", + "libc", + "rustc_version", + "xdg", ] [[package]] -name = "backtrace-sys" -version = "0.1.23" +name = "aho-corasick" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ - "cc 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr", ] [[package]] -name = "bitflags" -version = "1.0.3" +name = "anes" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] -name = "bitstring" -version = "0.1.1" +name = "anstyle" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" [[package]] -name = "block-buffer" -version = "0.7.3" +name = "autocfg" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "block-padding 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", - "byte-tools 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", - "byteorder 1.2.4 (registry+https://github.com/rust-lang/crates.io-index)", - "generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] -name = "block-padding" -version = "0.1.5" +name = "backtrace" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ - "byte-tools 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets", ] [[package]] -name = "byte-tools" -version = "0.3.1" +name = "bumpalo" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "byteorder" -version = "1.2.4" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cast" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.18" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" +dependencies = [ + "shlex", +] [[package]] name = "cfg-if" -version = "0.1.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] [[package]] name = "cidr" -version = "0.1.0" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdf600c45bd958cf2945c445264471cca8b6c8e67bc87b71affd6d7e5682621" dependencies = [ - "bitstring 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", + "serde", ] [[package]] name = "clap" -version = "2.32.0" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ - "bitflags 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", - "textwrap 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", - "unicode-width 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", + "clap_builder", ] [[package]] -name = "cloudabi" -version = "0.0.3" +name = "clap_builder" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ - "bitflags 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", + "anstyle", + "clap_lex", ] +[[package]] +name = "clap_lex" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" + [[package]] name = "criterion" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", - "cast 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", - "clap 2.32.0 (registry+https://github.com/rust-lang/crates.io-index)", - "criterion-plot 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", - "csv 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", - "itertools 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", - "num-traits 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)", - "rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", - "rand_os 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", - "rand_xoshiro 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "rayon 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", - "rayon-core 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)", - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_derive 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_json 1.0.27 (registry+https://github.com/rust-lang/crates.io-index)", - "tinytemplate 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", - "walkdir 2.2.7 (registry+https://github.com/rust-lang/crates.io-index)", +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", ] [[package]] name = "criterion-plot" -version = "0.3.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ - "byteorder 1.2.4 (registry+https://github.com/rust-lang/crates.io-index)", - "cast 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", - "itertools 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "cast", + "itertools", ] [[package]] name = "crossbeam-deque" -version = "0.2.0" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "crossbeam-epoch 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", - "crossbeam-utils 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-epoch", + "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.3.1" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "arrayvec 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)", - "cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "crossbeam-utils 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "memoffset 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", - "nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", - "scopeguard 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-utils", ] [[package]] name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crunchy" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] -name = "csv" -version = "1.0.0" +name = "either" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "csv-core 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)", - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] -name = "csv-core" -version = "0.1.4" +name = "equivalent" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] -name = "digest" -version = "0.8.1" +name = "fnv" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "fuzz-bytes" +version = "0.1.0" dependencies = [ - "generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)", + "afl", + "wirefilter-engine", ] [[package]] -name = "either" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" +name = "fuzz-map-keys" +version = "0.1.0" +dependencies = [ + "afl", + "wirefilter-engine", +] [[package]] -name = "failure" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" +name = "fuzz-raw-string" +version = "0.1.0" dependencies = [ - "backtrace 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)", - "failure_derive 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "afl", + "wirefilter-engine", ] [[package]] -name = "failure_derive" -version = "0.1.2" +name = "getrandom" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 0.14.5 (registry+https://github.com/rust-lang/crates.io-index)", - "synstructure 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", ] [[package]] -name = "fake-simd" -version = "0.1.2" +name = "gimli" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] -name = "fnv" -version = "1.0.6" +name = "half" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] [[package]] -name = "fuchsia-cprng" -version = "0.1.1" +name = "hashbrown" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" [[package]] -name = "generic-array" -version = "0.12.3" +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + +[[package]] +name = "home" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "typenum 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)", + "windows-sys 0.52.0", ] [[package]] name = "indexmap" -version = "1.0.1" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", + "equivalent", + "hashbrown", ] [[package]] name = "indoc" -version = "0.3.5" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "indoc-impl 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", - "proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] -name = "indoc-impl" -version = "0.3.5" +name = "is-terminal" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" dependencies = [ - "proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)", - "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", - "unindent 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "hermit-abi", + "libc", + "windows-sys 0.52.0", ] [[package]] name = "itertools" -version = "0.8.0" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" dependencies = [ - "either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)", + "either", ] [[package]] name = "itoa" -version = "0.4.2" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.5" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ - "wasm-bindgen 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)", + "wasm-bindgen", ] -[[package]] -name = "lazy_static" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" - [[package]] name = "libc" -version = "0.2.42" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "log" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", -] - -[[package]] -name = "maplit" -version = "1.0.2" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "memchr" -version = "2.3.3" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memmem" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" [[package]] -name = "memoffset" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" - -[[package]] -name = "nodrop" -version = "0.1.13" +name = "miniz_oxide" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] [[package]] name = "num-traits" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" - -[[package]] -name = "num_cpus" -version = "1.8.0" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", + "autocfg", ] [[package]] -name = "opaque-debug" -version = "0.2.3" +name = "num_enum" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] [[package]] -name = "pest" -version = "2.1.3" +name = "num_enum_derive" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" dependencies = [ - "ucd-trie 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "pest_consume" -version = "1.0.4" +name = "object" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ - "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", - "pest_consume_macros 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", - "pest_derive 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)", + "memchr", ] [[package]] -name = "pest_consume_macros" -version = "1.0.4" +name = "once_cell" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)", - "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] -name = "pest_derive" -version = "2.1.0" +name = "oorandom" +version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", - "pest_generator 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] -name = "pest_generator" -version = "2.1.3" +name = "paste" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", - "pest_meta 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", - "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] -name = "pest_meta" -version = "2.1.3" +name = "plotters" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" dependencies = [ - "maplit 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", - "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", - "sha-1 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", ] [[package]] -name = "proc-macro-hack" -version = "0.5.15" +name = "plotters-backend" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" [[package]] -name = "proc-macro2" -version = "0.4.24" +name = "plotters-svg" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" dependencies = [ - "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "plotters-backend", ] [[package]] -name = "proc-macro2" -version = "1.0.10" +name = "ppv-lite86" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "zerocopy", ] [[package]] -name = "quote" -version = "0.6.10" +name = "proc-macro-crate" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" dependencies = [ - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", + "toml_edit", ] [[package]] -name = "quote" -version = "1.0.4" +name = "proc-macro2" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ - "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-ident", ] [[package]] -name = "rand_core" -version = "0.3.1" +name = "quote" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ - "rand_core 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2", ] [[package]] -name = "rand_core" -version = "0.4.0" +name = "rand" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] [[package]] -name = "rand_os" -version = "0.1.3" +name = "rand_chacha" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ - "cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)", - "fuchsia-cprng 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", - "rand_core 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", - "rdrand 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", - "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", + "ppv-lite86", + "rand_core", ] [[package]] -name = "rand_xoshiro" -version = "0.1.0" +name = "rand_core" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "byteorder 1.2.4 (registry+https://github.com/rust-lang/crates.io-index)", - "rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "getrandom", ] [[package]] name = "rayon" -version = "1.0.3" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ - "crossbeam-deque 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", - "either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "rayon-core 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)", + "either", + "rayon-core", ] [[package]] name = "rayon-core" -version = "1.4.1" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ - "crossbeam-deque 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", - "num_cpus 1.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "crossbeam-deque", + "crossbeam-utils", ] [[package]] -name = "rdrand" -version = "0.4.0" +name = "regex" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ - "rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", ] [[package]] -name = "redox_syscall" -version = "0.1.40" -source = "registry+https://github.com/rust-lang/crates.io-index" - -[[package]] -name = "redox_termios" -version = "0.1.1" +name = "regex-automata" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ - "redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", + "aho-corasick", + "memchr", + "regex-syntax", ] [[package]] -name = "regex" -version = "1.3.7" +name = "regex-syntax" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "aho-corasick 0.7.10 (registry+https://github.com/rust-lang/crates.io-index)", - "memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)", - "regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)", - "thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] -name = "regex-syntax" -version = "0.6.17" +name = "rustc-demangle" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] -name = "rustc-demangle" -version = "0.1.8" +name = "rustc_version" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] [[package]] name = "ryu" -version = "0.2.3" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "same-file" -version = "1.0.3" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" dependencies = [ - "winapi-util 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi-util", ] [[package]] -name = "scopeguard" -version = "0.3.3" +name = "semver" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.78" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ - "serde_derive 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_derive", +] + +[[package]] +name = "serde-wasm-bindgen" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3b143e2833c57ab9ad3ea280d21fd34e285a42837aeb0ee301f4f41890fa00e" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", ] [[package]] name = "serde_derive" -version = "1.0.78" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 0.15.22 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2", + "quote", + "syn", ] [[package]] name = "serde_json" -version = "1.0.27" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ - "itoa 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", - "ryu 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", + "itoa", + "memchr", + "ryu", + "serde", ] [[package]] -name = "sha-1" -version = "0.8.2" +name = "shlex" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "sliceslice" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "361b80c452f3f8cc2426bf996059740b4c78ef5a5aeb1c1852d8ac4f561b8b4c" dependencies = [ - "block-buffer 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", - "digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", - "fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", - "opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if", + "memchr", + "paste", + "seq-macro", ] [[package]] name = "syn" -version = "0.14.5" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2", + "quote", + "unicode-ident", ] [[package]] -name = "syn" -version = "0.15.22" +name = "thiserror" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "thiserror-impl", ] [[package]] -name = "syn" -version = "1.0.18" +name = "thiserror-impl" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ - "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", - "unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "synstructure" -version = "0.9.0" +name = "tinytemplate" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 0.14.5 (registry+https://github.com/rust-lang/crates.io-index)", - "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", + "serde", + "serde_json", ] [[package]] -name = "termion" -version = "1.5.1" +name = "toml_datetime" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" + +[[package]] +name = "toml_edit" +version = "0.22.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", - "redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", - "redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "indexmap", + "toml_datetime", + "winnow", ] [[package]] -name = "textwrap" -version = "0.10.0" +name = "unicode-ident" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" dependencies = [ - "unicode-width 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", + "same-file", + "winapi-util", ] [[package]] -name = "thread_local" -version = "1.0.1" +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if", + "once_cell", + "serde", + "serde_json", + "wasm-bindgen-macro", ] [[package]] -name = "tinytemplate" -version = "1.0.1" +name = "wasm-bindgen-backend" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_json 1.0.27 (registry+https://github.com/rust-lang/crates.io-index)", + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", ] [[package]] -name = "typenum" -version = "1.12.0" +name = "wasm-bindgen-macro" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] [[package]] -name = "ucd-trie" -version = "0.1.3" +name = "wasm-bindgen-macro-support" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] [[package]] -name = "unicode-width" -version = "0.1.5" +name = "wasm-bindgen-shared" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] -name = "unicode-xid" -version = "0.1.0" +name = "web-sys" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +dependencies = [ + "js-sys", + "wasm-bindgen", +] [[package]] -name = "unicode-xid" +name = "wildcard" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36241ad0795516b55e3b60e55c7f979d4f324e4aaea4c70d56b548b9164ee4d2" +dependencies = [ + "thiserror", +] [[package]] -name = "unindent" -version = "0.1.3" +name = "winapi-util" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] [[package]] -name = "walkdir" -version = "2.2.7" +name = "windows-sys" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "same-file 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", - "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", - "winapi-util 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "windows-targets", ] [[package]] -name = "wasm-bindgen" -version = "0.2.28" +name = "windows-sys" +version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_json 1.0.27 (registry+https://github.com/rust-lang/crates.io-index)", - "wasm-bindgen-macro 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)", + "windows-targets", ] [[package]] -name = "wasm-bindgen-backend" -version = "0.2.28" +name = "windows-targets" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "log 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)", - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 0.15.22 (registry+https://github.com/rust-lang/crates.io-index)", - "wasm-bindgen-shared 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] [[package]] -name = "wasm-bindgen-macro" -version = "0.2.28" +name = "windows_aarch64_gnullvm" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "wasm-bindgen-macro-support 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.28" +name = "windows_aarch64_msvc" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)", - "quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)", - "syn 0.15.22 (registry+https://github.com/rust-lang/crates.io-index)", - "wasm-bindgen-backend 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)", - "wasm-bindgen-shared 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] -name = "wasm-bindgen-shared" -version = "0.2.28" +name = "windows_i686_gnu" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] -name = "winapi" -version = "0.3.5" +name = "windows_i686_gnullvm" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", - "winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" +name = "windows_i686_msvc" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] -name = "winapi-util" -version = "0.1.1" +name = "windows_x86_64_gnu" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", -] +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" +name = "windows_x86_64_gnullvm" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] [[package]] name = "wirefilter-engine" version = "0.7.0" dependencies = [ - "cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "cidr 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "criterion 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", - "failure 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", - "fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", - "indexmap 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", - "indoc 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "memmem 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", - "regex 1.3.7 (registry+https://github.com/rust-lang/crates.io-index)", - "serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_json 1.0.27 (registry+https://github.com/rust-lang/crates.io-index)", + "backtrace", + "cfg-if", + "cidr", + "criterion", + "fnv", + "indoc", + "memmem", + "rand", + "regex", + "serde", + "serde_json", + "sliceslice", + "thiserror", + "wildcard", ] [[package]] name = "wirefilter-ffi" version = "0.7.0" dependencies = [ - "fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", - "indoc 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", - "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", - "regex 1.3.7 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_json 1.0.27 (registry+https://github.com/rust-lang/crates.io-index)", - "wirefilter-engine 0.7.0", - "wirefilter-ffi-ctests 0.1.0", + "fnv", + "indoc", + "libc", + "num_enum", + "regex", + "serde", + "serde_json", + "wirefilter-engine", + "wirefilter-ffi-ctests", ] [[package]] name = "wirefilter-ffi-ctests" version = "0.1.0" dependencies = [ - "cc 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", - "wirefilter-ffi 0.7.0", + "cc", + "wirefilter-ffi", ] [[package]] -name = "wirefilter-parser" -version = "0.1.0" +name = "wirefilter-wasm" +version = "0.7.0" dependencies = [ - "cidr 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", - "indoc 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", - "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", - "pest_consume 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", - "regex 1.3.7 (registry+https://github.com/rust-lang/crates.io-index)", + "getrandom", + "js-sys", + "serde-wasm-bindgen", + "wasm-bindgen", + "wirefilter-engine", ] [[package]] -name = "wirefilter-wasm" -version = "0.7.0" +name = "xdg" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213b7324336b53d2414b2db8537e56544d981803139155afa84f76eeebb7a546" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "js-sys 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", - "wasm-bindgen 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)", - "wirefilter-engine 0.7.0", -] - -[metadata] -"checksum aho-corasick 0.7.10 (registry+https://github.com/rust-lang/crates.io-index)" = "8716408b8bc624ed7f65d223ddb9ac2d044c0547b6fa4b0d554f3a9540496ada" -"checksum arrayvec 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)" = "92c7fb76bc8826a8b33b4ee5bb07a247a81e76764ab4d55e8f73e3a4d8808c71" -"checksum atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "9a7d5b8723950951411ee34d271d99dddcc2035a16ab25310ea2c8cfd4369652" -"checksum backtrace 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "89a47830402e9981c5c41223151efcced65a0510c13097c769cede7efb34782a" -"checksum backtrace-sys 0.1.23 (registry+https://github.com/rust-lang/crates.io-index)" = "bff67d0c06556c0b8e6b5f090f0eac52d950d9dfd1d35ba04e4ca3543eaf6a7e" -"checksum bitflags 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d0c54bb8f454c567f21197eefcdbf5679d0bd99f2ddbe52e84c77061952e6789" -"checksum bitstring 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "3e54f7b7a46d7b183eb41e2d82965261fa8a1597c68b50aced268ee1fc70272d" -"checksum block-buffer 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "c0940dc441f31689269e10ac70eb1002a3a1d3ad1390e030043662eb7fe4688b" -"checksum block-padding 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "fa79dedbb091f449f1f39e53edf88d5dbe95f895dae6135a8d7b881fb5af73f5" -"checksum byte-tools 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "e3b5ca7a04898ad4bcd41c90c5285445ff5b791899bb1b0abdd2a2aa791211d7" -"checksum byteorder 1.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "8389c509ec62b9fe8eca58c502a0acaf017737355615243496cde4994f8fa4f9" -"checksum cast 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "926013f2860c46252efceabb19f4a6b308197505082c609025aa6706c011d427" -"checksum cc 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)" = "2119ea4867bd2b8ed3aecab467709720b2d55b1bcfe09f772fd68066eaf15275" -"checksum cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "082bb9b28e00d3c9d39cc03e64ce4cea0f1bb9b3fde493f0cbc008472d22bdf4" -"checksum cidr 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2da1cf0f275bb8dc1867a7f40cdb3b746951db73a183048e6e37fa89ed81bd01" -"checksum clap 2.32.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b957d88f4b6a63b9d70d5f454ac8011819c6efa7727858f458ab71c756ce2d3e" -"checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" -"checksum criterion 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "0363053954f3e679645fc443321ca128b7b950a6fe288cf5f9335cc22ee58394" -"checksum criterion-plot 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76f9212ddf2f4a9eb2d401635190600656a1f88a932ef53d06e7fa4c7e02fb8e" -"checksum crossbeam-deque 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f739f8c5363aca78cfb059edf753d8f0d36908c348f3d8d1503f03d8b75d9cf3" -"checksum crossbeam-epoch 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "927121f5407de9956180ff5e936fe3cf4324279280001cd56b669d28ee7e9150" -"checksum crossbeam-utils 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "2760899e32a1d58d5abb31129f8fae5de75220bc2176e77ff7c627ae45c918d9" -"checksum csv 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "71903184af9960c555e7f3b32ff17390d20ecaaf17d4f18c4a0993f2df8a49e3" -"checksum csv-core 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4dd8e6d86f7ba48b4276ef1317edc8cc36167546d8972feb4a2b5fec0b374105" -"checksum digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5" -"checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0" -"checksum failure 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7efb22686e4a466b1ec1a15c2898f91fa9cb340452496dca654032de20ff95b9" -"checksum failure_derive 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "946d0e98a50d9831f5d589038d2ca7f8f455b1c21028c0db0e84116a12696426" -"checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" -"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" -"checksum fuchsia-cprng 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" -"checksum generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)" = "c68f0274ae0e023facc3c97b2e00f076be70e254bc851d972503b328db79b2ec" -"checksum indexmap 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "08173ba1e906efb6538785a8844dd496f5d34f0a2d88038e95195172fc667220" -"checksum indoc 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "79255cf29f5711995ddf9ec261b4057b1deb34e66c90656c201e41376872c544" -"checksum indoc-impl 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "54554010aa3d17754e484005ea0022f1c93839aabc627c2c55f3d7b47206134c" -"checksum itertools 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5b8467d9c1cebe26feb08c640139247fac215782d35371ade9a2136ed6085358" -"checksum itoa 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "5adb58558dcd1d786b5f0bd15f3226ee23486e24b7b58304b60f64dc68e62606" -"checksum js-sys 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "9f476e674d55cc43a57dfd1d3986c7c305e41827ead21ff6373652804f728afe" -"checksum lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bc5729f27f159ddd61f4df6228e827e86643d4d3e7c32183cb30a1c08f604a14" -"checksum libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)" = "b685088df2b950fccadf07a7187c8ef846a959c142338a48f9dc0b94517eb5f1" -"checksum log 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)" = "61bd98ae7f7b754bc53dca7d44b604f733c6bba044ea6f41bc8d89272d8161d2" -"checksum maplit 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" -"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400" -"checksum memmem 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" -"checksum memoffset 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "0f9dc261e2b62d7a622bf416ea3c5245cdd5d9a7fcc428c0d06804dfce1775b3" -"checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945" -"checksum num-traits 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "630de1ef5cc79d0cdd78b7e33b81f083cbfe90de0f4b2b2f07f905867c70e9fe" -"checksum num_cpus 1.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c51a3322e4bca9d212ad9a158a02abc6934d005490c054a2778df73a70aa0a30" -"checksum opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" -"checksum pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53" -"checksum pest_consume 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "c89271c86e5c547f8adf3b04babcd8aea6e4398db2252188c7d8e174718958c4" -"checksum pest_consume_macros 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "645ac2f2b122181fb8b20f9e103f57bc0e63880b95e0a80a5f7a54ba712a2cf7" -"checksum pest_derive 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "833d1ae558dc601e9a60366421196a8d94bc0ac980476d0b67e1d0988d72b2d0" -"checksum pest_generator 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "99b8db626e31e5b81787b9783425769681b347011cc59471e33ea46d2ea0cf55" -"checksum pest_meta 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "54be6e404f5317079812fc8f9f5279de376d8856929e21c184ecf6bbd692a11d" -"checksum proc-macro-hack 0.5.15 (registry+https://github.com/rust-lang/crates.io-index)" = "0d659fe7c6d27f25e9d80a1a094c223f5246f6a6596453e09d7229bf42750b63" -"checksum proc-macro2 0.4.24 (registry+https://github.com/rust-lang/crates.io-index)" = "77619697826f31a02ae974457af0b29b723e5619e113e9397b8b82c6bd253f09" -"checksum proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)" = "df246d292ff63439fea9bc8c0a270bed0e390d5ebd4db4ba15aba81111b5abe3" -"checksum quote 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)" = "53fa22a1994bd0f9372d7a816207d8a2677ad0325b073f5c5332760f0fb62b5c" -"checksum quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4c1f4b0efa5fc5e8ceb705136bfee52cfdb6a4e3509f770b478cd6ed434232a7" -"checksum rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b" -"checksum rand_core 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d0e7a549d590831370895ab7ba4ea0c1b6b011d106b5ff2da6eee112615e6dc0" -"checksum rand_os 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "7b75f676a1e053fc562eafbb47838d67c84801e38fc1ba459e8f180deabd5071" -"checksum rand_xoshiro 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "03b418169fb9c46533f326efd6eed2576699c44ca92d3052a066214a8d828929" -"checksum rayon 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "373814f27745b2686b350dd261bfd24576a6fb0e2c5919b3a2b6005f820b0473" -"checksum rayon-core 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b055d1e92aba6877574d8fe604a63c8b5df60f60e5982bf7ccbb1338ea527356" -"checksum rdrand 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2" -"checksum redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "c214e91d3ecf43e9a4e41e578973adeb14b474f2bee858742d127af75a0112b1" -"checksum redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7e891cfe48e9100a70a3b6eb652fef28920c117d366339687bd5576160db0f76" -"checksum regex 1.3.7 (registry+https://github.com/rust-lang/crates.io-index)" = "a6020f034922e3194c711b82a627453881bc4682166cabb07134a10c26ba7692" -"checksum regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)" = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" -"checksum rustc-demangle 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "76d7ba1feafada44f2d38eed812bd2489a03c0f5abb975799251518b68848649" -"checksum ryu 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "16aa12da69951804cddf5f74d96abcc414a31b064e610dc81e37c1536082f491" -"checksum same-file 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "10f7794e2fda7f594866840e95f5c5962e886e228e68b6505885811a94dd728c" -"checksum scopeguard 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "94258f53601af11e6a49f722422f6e3425c52b06245a5cf9bc09908b174f5e27" -"checksum serde 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)" = "92ec94e2754699adddbbc4f555791bd3acc2a2f5574cba16c93a4a9cf4a04415" -"checksum serde_derive 1.0.78 (registry+https://github.com/rust-lang/crates.io-index)" = "0fb622d85245add5327d4f08b2d24fd51fa5d35fe1bba19ee79a1f211e9ac0ff" -"checksum serde_json 1.0.27 (registry+https://github.com/rust-lang/crates.io-index)" = "59790990c5115d16027f00913e2e66de23a51f70422e549d2ad68c8c5f268f1c" -"checksum sha-1 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f7d94d0bede923b3cea61f3f1ff57ff8cdfd77b400fb8f9998949e0cf04163df" -"checksum syn 0.14.5 (registry+https://github.com/rust-lang/crates.io-index)" = "4bad7abdf6633f07c7046b90484f1d9dc055eca39f8c991177b1046ce61dba9a" -"checksum syn 0.15.22 (registry+https://github.com/rust-lang/crates.io-index)" = "ae8b29eb5210bc5cf63ed6149cbf9adfc82ac0be023d8735c176ee74a2db4da7" -"checksum syn 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)" = "410a7488c0a728c7ceb4ad59b9567eb4053d02e8cc7f5c0e0eeeb39518369213" -"checksum synstructure 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "85bb9b7550d063ea184027c9b8c20ac167cd36d3e06b3a40bceb9d746dc1a7b7" -"checksum termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "689a3bdfaab439fd92bc87df5c4c78417d3cbe537487274e9b0b2dce76e92096" -"checksum textwrap 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "307686869c93e71f94da64286f9a9524c0f308a9e1c87a583de8e9c9039ad3f6" -"checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" -"checksum tinytemplate 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7655088894274afb52b807bd3c87072daa1fedd155068b8705cabfd628956115" -"checksum typenum 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "373c8a200f9e67a0c95e62a4f52fbf80c23b4381c05a17845531982fa99e6b33" -"checksum ucd-trie 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" -"checksum unicode-width 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "882386231c45df4700b275c7ff55b6f3698780a650026380e72dabe76fa46526" -"checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" -"checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" -"checksum unindent 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "834b4441326c660336850c5c0926cc20548e848967a5f57bc20c2b741c8d41f4" -"checksum walkdir 2.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "9d9d7ed3431229a144296213105a390676cc49c9b6a72bd19f3176c98e129fa1" -"checksum wasm-bindgen 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)" = "dcefd48aa89f1319c61780595be0d378c5de5ccefd9d4e55c282aad4245a256c" -"checksum wasm-bindgen-backend 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)" = "9b12e4c544f22adf78c7bc06fa605ac084566adc88bcbbe8857f0ceb0a12613a" -"checksum wasm-bindgen-macro 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)" = "d0c71b31dc194560c8696348975abc32ed134742c19a865601db3d7eb01a1c56" -"checksum wasm-bindgen-macro-support 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)" = "890018bb4a49c4c977c66e0c6a5f633d82766f6e78a57e2bc8f611a22f72dc18" -"checksum wasm-bindgen-shared 0.2.28 (registry+https://github.com/rust-lang/crates.io-index)" = "0e38e79a1881e09015aaf8db197b04d97439e06da0f50a213608e93247f5ecfc" -"checksum winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "773ef9dcc5f24b7d850d0ff101e542ff24c3b090a9768e03ff889fdef41f00fd" -"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" -"checksum winapi-util 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "afc5508759c5bf4285e61feb862b6083c8480aec864fa17a81fdec6f69b461ab" -"checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 087139a1..5e665298 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,10 +2,35 @@ members = [ "engine", "ffi", + "fuzz/bytes", + "fuzz/raw-string", + "fuzz/map-keys", "wasm", - "wirefilter-parser" ] +resolver = "2" + +[workspace.dependencies] +backtrace = "0.3" +cfg-if = "1" +cidr = { version = "0.2", features = ["serde"] } +criterion = "0.5" +fnv = "1.0.6" +indoc = "2" +libc = "0.2.42" +memmem = "0.1.1" +num_enum = "0.7" +rand = "0.8" +regex = { version = "1.3.6" } +serde = { version = "1.0.113", features = [ "derive" ] } +serde_json = "1.0.56" +sliceslice = "0.4.3" +thiserror = "1.0" +wildcard = "0.2.0" +wirefilter = { path = "engine", package = "wirefilter-engine" } [profile.release] -panic = "abort" +panic = "unwind" lto = true + +[profile.dev] +panic = "unwind" diff --git a/cfsetup-cargo.sh b/cfsetup-cargo.sh deleted file mode 100755 index f37020b1..00000000 --- a/cfsetup-cargo.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -e - -export CARGO_HOME=/var/lib/cargo -export CARGO_TARGET_DIR=$CARGO_HOME/target -export RUSTFLAGS="-D warnings" - -CMD=$1 -shift - -set -x - -case $CMD in - prebuild) - # Series of hacks to prebuild dependencies in a cached layer - # (workaround for https://github.com/rust-lang/cargo/issues/2644) - - # Create dummy sources for our library - mkdir -p {engine,ffi,ffi/tests/ctests,wasm}/src - touch {engine,ffi,ffi/tests/ctests,wasm}/src/lib.rs - mkdir engine/benches - echo 'fn main() {}' > engine/benches/bench.rs - - # Build library with Cargo.lock (including all the dependencies) - cargo build --locked --all $@ - - # Clean artifacts of the library itself but keep prebuilt deps - cargo clean --locked -p wirefilter-engine -p wirefilter-ffi -p wirefilter-wasm $@ - ;; - wasm-pack) - # Latest release of wasm-pack can't find target via CARGO_TARGET_DIR nor - # in a workspace root. - # - # This is fixed on git master, but we'd rather not build tools during - # CI build, so let's hack around this limitation for now by creating - # a temporary symlink and cleaning it up afterwards. - # - # TODO: remove following two commands on next wasm-pack upgrade. - ln -s $CARGO_TARGET_DIR $1/target - trap "rm $1/target" EXIT - wasm-pack build $@ - wasm-pack pack $1 - ;; - *) - # Execute any other command without special params but in same env - cargo $CMD $@ - ;; -esac diff --git a/cfsetup.yaml b/cfsetup.yaml deleted file mode 100644 index f1d2f58e..00000000 --- a/cfsetup.yaml +++ /dev/null @@ -1,71 +0,0 @@ -everything: &everything - build: &build - builddeps: - - rust - - cargo-deb - pre-cache-copy-paths: &paths - - engine/Cargo.toml - - ffi/Cargo.toml - - ffi/tests/ctests/Cargo.toml - - wasm/Cargo.toml - - Cargo.lock - - Cargo.toml - - cfsetup-cargo.sh - pre-cache: - - ./cfsetup-cargo.sh prebuild --release - post-cache: - - cd ffi - - sudo ../cfsetup-cargo.sh deb -- --release --frozen - artifacts: - - /var/lib/cargo/target/debian/*.deb - build-arm64: - <<: *build - target-arch: arm64 - build-wasm: - builddeps: &wasm-deps - - rust - - wasm-pack - - nodejs - pre-cache-copy-paths: *paths - pre-cache: - - ./cfsetup-cargo.sh prebuild --target wasm32-unknown-unknown - post-cache: - - sudo ./cfsetup-cargo.sh wasm-pack wasm --debug --mode no-install --target browser --scope cloudflare - artifacts: &wasm-artifacts - - wasm/pkg/*.tgz - publish-wasm: - builddeps: *wasm-deps - pre-cache-copy-paths: *paths - pre-cache: - # Cargo doesn't currently allow overriding profile config per target, so use RUSTFLAGS instead: - - export RUSTFLAGS="-C opt-level=z -C codegen-units=1" - - ./cfsetup-cargo.sh prebuild --release --target wasm32-unknown-unknown - post-cache: - - export RUSTFLAGS="-C opt-level=z -C codegen-units=1" - - sudo ./cfsetup-cargo.sh wasm-pack wasm --mode no-install --target browser --scope cloudflare - - echo "//registry.npmjs.org/:_authToken=$NPM_TOKEN" >> ~/.npmrc - - npm publish wasm/pkg/*.tgz - artifacts: *wasm-artifacts - test: - builddeps: - - rust - pre-cache-copy-paths: *paths - pre-cache: &test-pre-cache - - ./cfsetup-cargo.sh prebuild - post-cache: - - sudo ./cfsetup-cargo.sh test --frozen - - sudo ./cfsetup-cargo.sh clippy --all-targets --frozen - - sudo ./cfsetup-cargo.sh fmt -- --check - ci-test: - builddeps: - - rust - - cargo-to-teamcity - pre-cache-copy-paths: *paths - pre-cache: *test-pre-cache - post-cache: - - sudo ./cfsetup-cargo.sh test --frozen | cargo-to-teamcity - - sudo ./cfsetup-cargo.sh clippy --all-targets --frozen -- -D clippy - - sudo ./cfsetup-cargo.sh fmt -- --check - -stretch: *everything -jessie: *everything diff --git a/engine/Cargo.toml b/engine/Cargo.toml index 15f49de7..a5582fdd 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -1,14 +1,14 @@ [package] -authors = ["Ingvar Stepanyan "] +authors = [ "Ingvar Stepanyan " ] name = "wirefilter-engine" version = "0.7.0" description = "An execution engine for Wireshark-like filters" readme = "README.md" license = "MIT" repository = "https://github.com/cloudflare/wirefilter" -keywords = ["wireshark", "filter", "engine", "parser", "runtime"] -categories = ["config", "parser-implementations"] -edition = "2018" +keywords = [ "wireshark", "filter", "engine", "parser", "runtime" ] +categories = [ "config", "parser-implementations" ] +edition = "2021" [lib] name = "wirefilter" @@ -19,20 +19,22 @@ name = "bench" harness = false [dependencies] -cidr = "0.1.0" -failure = "0.1.1" -fnv = "1.0.6" -indexmap = { version = "1.0.1", features = ["serde-1"] } -regex = { version = "1.1.5", optional = true } -memmem = "0.1.1" -serde = { version = "1.0.78", features = ["derive"] } -cfg-if = "0.1.6" +backtrace.workspace = true +cfg-if.workspace = true +cidr.workspace = true +fnv.workspace = true +memmem.workspace = true +rand.workspace = true +regex = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +sliceslice.workspace = true +thiserror.workspace = true +wildcard.workspace = true [dev-dependencies] -indoc = "0.3.0" -criterion = "0.2.11" -serde_json = "1.0.27" -lazy_static = "1.1.0" +criterion.workspace = true +indoc.workspace = true [features] -default = ["regex"] +default = [ "regex" ] diff --git a/engine/README.md b/engine/README.md index 9e969247..b6ab898b 100644 --- a/engine/README.md +++ b/engine/README.md @@ -14,7 +14,7 @@ an executable IR and, finally, executing filters against provided values. ```rust use wirefilter::{ExecutionContext, Scheme, Type}; -fn main() -> Result<(), failure::Error> { +fn main() -> Result<(), Box> { // Create a map of possible filter fields. let scheme = Scheme! { http.method: Bytes, @@ -37,20 +37,20 @@ fn main() -> Result<(), failure::Error> { // Set runtime field values to test the filter against. let mut ctx = ExecutionContext::new(&scheme); - ctx.set_field_value("http.method", "GET")?; + ctx.set_field_value(scheme.get_field("http.method").unwrap(), "GET")?; ctx.set_field_value( - "http.ua", + scheme.get_field("http.ua").unwrap(), "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0", )?; - ctx.set_field_value("port", 443)?; + ctx.set_field_value(scheme.get_field("port").unwrap(), 443)?; // Execute the filter with given runtime values. println!("Filter matches: {:?}", filter.execute(&ctx)?); // true // Amend one of the runtime values and execute the filter again. - ctx.set_field_value("port", 8080)?; + ctx.set_field_value(scheme.get_field("port").unwrap(), 8080)?; println!("Filter matches: {:?}", filter.execute(&ctx)?); // false @@ -58,6 +58,37 @@ fn main() -> Result<(), failure::Error> { } ``` +## Fuzzing + +There are fuzz tests in the fuzz directory. + +Install afl: + +``` +cargo install afl --force +``` + +Build `bytes` fuzz test: + +``` +cd fuzz/bytes +cargo afl build +``` + +Run fuzz test (from inside `fuzz/bytes` directory): + +``` +cargo afl fuzz -i in -o out ../../target/debug/fuzz-bytes +``` + +If you see an error like: + +``` +Looks like the target binary is not instrumented! +``` + +Try deleting the compiled binary and re-building with `cargo afl build`. + ## Licensing Licensed under the MIT license. See the [LICENSE](LICENSE) file for details. diff --git a/engine/benches/bench.rs b/engine/benches/bench.rs index 72812e9f..558cb33c 100644 --- a/engine/benches/bench.rs +++ b/engine/benches/bench.rs @@ -5,17 +5,15 @@ use std::alloc::System; #[global_allocator] static A: System = System; -use criterion::{ - criterion_group, criterion_main, Bencher, Benchmark, Criterion, ParameterizedBenchmark, -}; +use criterion::{criterion_group, criterion_main, Bencher, Criterion}; use std::{borrow::Cow, clone::Clone, fmt::Debug, net::IpAddr}; use wirefilter::{ - ExecutionContext, FilterAst, Function, FunctionArgKind, FunctionArgs, FunctionImpl, - FunctionParam, GetType, LhsValue, Scheme, Type, + ExecutionContext, FilterAst, FunctionArgKind, FunctionArgs, GetType, LhsValue, Scheme, + SimpleFunctionDefinition, SimpleFunctionImpl, SimpleFunctionParam, Type, }; -fn lowercase<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { - let input = args.next().unwrap(); +fn lowercase<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + let input = args.next()?.ok()?; match input { LhsValue::Bytes(mut bytes) => { let make_lowercase = match bytes { @@ -25,14 +23,14 @@ fn lowercase<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { if make_lowercase { bytes.to_mut().make_ascii_lowercase(); } - LhsValue::Bytes(bytes) + Some(LhsValue::Bytes(bytes)) } - _ => panic!("Invalid type: expected Bytes, got {:?}", input), + _ => panic!("Invalid type: expected Bytes, got {input:?}"), } } -fn uppercase<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { - let input = args.next().unwrap(); +fn uppercase<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + let input = args.next()?.ok()?; match input { LhsValue::Bytes(mut bytes) => { let make_uppercase = match bytes { @@ -42,15 +40,15 @@ fn uppercase<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { if make_uppercase { bytes.to_mut().make_ascii_uppercase(); } - LhsValue::Bytes(bytes) + Some(LhsValue::Bytes(bytes)) } - _ => panic!("Invalid type: expected Bytes, got {:?}", input), + _ => panic!("Invalid type: expected Bytes, got {input:?}"), } } struct FieldBench<'a, T: 'static> { field: &'static str, - functions: &'a [(&'static str, Function)], + functions: &'a [(&'static str, SimpleFunctionDefinition)], filters: &'static [&'static str], values: &'a [T], } @@ -78,66 +76,62 @@ impl<'a, T: 'static + Copy + Debug + Into>> FieldBench<'a, T> filter }; - c.bench( - "parsing", - Benchmark::new(name, { - let mut scheme = Scheme::default(); - scheme.add_field(field.to_owned(), ty).unwrap(); - for (name, function) in functions { - scheme - .add_function((*name).into(), function.clone()) - .unwrap(); - } - move |b: &mut Bencher| { - b.iter(|| scheme.parse(filter).unwrap()); - } - }), - ); - - c.bench( - "compilation", - Benchmark::new(name, { - let mut scheme = Scheme::default(); - scheme.add_field(field.to_owned(), ty).unwrap(); - for (name, function) in functions { - scheme - .add_function((*name).into(), function.clone()) + let mut group = c.benchmark_group("parsing"); + + group.bench_function(name, { + let mut scheme = Scheme::default(); + scheme.add_field(field, ty).unwrap(); + for (name, function) in functions { + scheme.add_function(name, function.clone()).unwrap(); + } + move |b: &mut Bencher| { + b.iter(|| scheme.parse(filter).unwrap()); + } + }); + + group.finish(); + + let mut group = c.benchmark_group("compilation"); + + group.bench_function(name, { + let mut scheme = Scheme::default(); + scheme.add_field(field, ty).unwrap(); + for (name, function) in functions { + scheme.add_function(name, function.clone()).unwrap(); + } + move |b: &mut Bencher| { + let filter = scheme.parse(filter).unwrap(); + + b.iter_with_setup(move || filter.clone(), FilterAst::compile); + } + }); + + group.finish(); + + let mut group = c.benchmark_group("execution"); + + group.bench_with_input(name, values, { + let mut scheme = Scheme::default(); + scheme.add_field(field, ty).unwrap(); + for (name, function) in functions { + scheme.add_function(name, function.clone()).unwrap(); + } + move |b: &mut Bencher, values: &[T]| { + let filter = scheme.parse(filter).unwrap(); + + let filter = filter.compile(); + + let mut exec_ctx = ExecutionContext::new(&scheme); + for value in values { + exec_ctx + .set_field_value(scheme.get_field(field).unwrap(), *value) .unwrap(); + b.iter(|| filter.execute(&exec_ctx)); } - move |b: &mut Bencher| { - let filter = scheme.parse(filter).unwrap(); - - b.iter_with_setup(move || filter.clone(), FilterAst::compile); - } - }), - ); - - c.bench( - "execution", - ParameterizedBenchmark::new( - name, - { - let mut scheme = Scheme::default(); - scheme.add_field(field.to_owned(), ty).unwrap(); - for (name, function) in functions { - scheme - .add_function((*name).into(), function.clone()) - .unwrap(); - } - move |b: &mut Bencher, value: &T| { - let filter = scheme.parse(filter).unwrap(); - - let filter = filter.compile(); - - let mut exec_ctx = ExecutionContext::new(&scheme); - exec_ctx.set_field_value(field, *value).unwrap(); + } + }); - b.iter(|| filter.execute(&exec_ctx)); - } - }, - values.iter().cloned(), - ), - ); + group.finish(); } } } @@ -210,26 +204,26 @@ fn bench_string_function_comparison(c: &mut Criterion) { functions: &[ ( "lowercase", - Function { - params: vec![FunctionParam { + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { arg_kind: FunctionArgKind::Field, val_type: Type::Bytes, }], opt_params: vec![], return_type: Type::Bytes, - implementation: FunctionImpl::new(lowercase), + implementation: SimpleFunctionImpl::new(lowercase), }, ), ( "uppercase", - Function { - params: vec![FunctionParam { + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { arg_kind: FunctionArgKind::Field, val_type: Type::Bytes, }], opt_params: vec![], return_type: Type::Bytes, - implementation: FunctionImpl::new(uppercase), + implementation: SimpleFunctionImpl::new(uppercase), }, ), ], diff --git a/engine/examples/cli.rs b/engine/examples/cli.rs index e1164327..2264bd1a 100644 --- a/engine/examples/cli.rs +++ b/engine/examples/cli.rs @@ -1,10 +1,10 @@ use std::env::args; use wirefilter::{ - Function, FunctionArgKind, FunctionArgs, FunctionImpl, FunctionOptParam, FunctionParam, - LhsValue, Scheme, Type, + FunctionArgKind, FunctionArgs, LhsValue, Scheme, SimpleFunctionDefinition, SimpleFunctionImpl, + SimpleFunctionOptParam, SimpleFunctionParam, Type, }; -fn panic_function<'a>(_: FunctionArgs<'_, 'a>) -> LhsValue<'a> { +fn panic_function<'a>(_: FunctionArgs<'_, 'a>) -> Option> { panic!(); } @@ -18,28 +18,31 @@ fn main() { str: Bytes, int: Int, bool: Bool, + str_arr: Array(Bytes), + str_map: Map(Bytes), + bool_arr: Array(Bool), }; scheme .add_function( - "panic".into(), - Function { - params: vec![FunctionParam { + "panic", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { arg_kind: FunctionArgKind::Field, val_type: Type::Bytes, }], - opt_params: vec![FunctionOptParam { + opt_params: vec![SimpleFunctionOptParam { arg_kind: FunctionArgKind::Literal, default_value: "".into(), }], return_type: Type::Bytes, - implementation: FunctionImpl::new(panic_function), + implementation: SimpleFunctionImpl::new(panic_function), }, ) .unwrap(); match scheme.parse(&filter) { - Ok(res) => println!("{:#?}", res), - Err(err) => println!("{}", err), + Ok(res) => println!("{res:#?}"), + Err(err) => println!("{err}"), } } diff --git a/engine/src/ast/combined_expr.rs b/engine/src/ast/combined_expr.rs deleted file mode 100644 index ae6b9802..00000000 --- a/engine/src/ast/combined_expr.rs +++ /dev/null @@ -1,327 +0,0 @@ -use super::{simple_expr::SimpleExpr, Expr}; -use crate::{ - filter::CompiledExpr, - lex::{skip_space, Lex, LexResult, LexWith}, - scheme::{Field, Scheme}, -}; -use serde::Serialize; - -lex_enum!(#[derive(PartialOrd, Ord)] CombiningOp { - "or" | "||" => Or, - "xor" | "^^" => Xor, - "and" | "&&" => And, -}); - -#[derive(Debug, PartialEq, Eq, Clone, Serialize)] -#[serde(untagged)] -pub enum CombinedExpr<'s> { - Simple(SimpleExpr<'s>), - Combining { - op: CombiningOp, - items: Vec>, - }, -} - -impl<'s> CombinedExpr<'s> { - fn lex_combining_op(input: &str) -> (Option, &str) { - match CombiningOp::lex(skip_space(input)) { - Ok((op, input)) => (Some(op), skip_space(input)), - Err(_) => (None, input), - } - } - - fn lex_more_with_precedence<'i>( - self, - scheme: &'s Scheme, - min_prec: Option, - mut lookahead: (Option, &'i str), - ) -> LexResult<'i, Self> { - let mut lhs = self; - - while let Some(op) = lookahead.0 { - let mut rhs = SimpleExpr::lex_with(lookahead.1, scheme) - .map(|(op, input)| (CombinedExpr::Simple(op), input))?; - - loop { - lookahead = Self::lex_combining_op(rhs.1); - if lookahead.0 <= Some(op) { - break; - } - rhs = rhs - .0 - .lex_more_with_precedence(scheme, lookahead.0, lookahead)?; - } - - match lhs { - CombinedExpr::Combining { - op: lhs_op, - ref mut items, - } if lhs_op == op => { - items.push(rhs.0); - } - _ => { - lhs = CombinedExpr::Combining { - op, - items: vec![lhs, rhs.0], - }; - } - } - - if lookahead.0 < min_prec { - // pretend we haven't seen an operator if its precedence is - // outside of our limits - lookahead = (None, rhs.1); - } - } - - Ok((lhs, lookahead.1)) - } -} - -impl<'i, 's> LexWith<'i, &'s Scheme> for CombinedExpr<'s> { - fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { - let (lhs, input) = SimpleExpr::lex_with(input, scheme)?; - let lookahead = Self::lex_combining_op(input); - CombinedExpr::Simple(lhs).lex_more_with_precedence(scheme, None, lookahead) - } -} - -impl<'s> Expr<'s> for CombinedExpr<'s> { - fn uses(&self, field: Field<'s>) -> bool { - match self { - CombinedExpr::Simple(op) => op.uses(field), - CombinedExpr::Combining { items, .. } => items.iter().any(|op| op.uses(field)), - } - } - - fn compile(self) -> CompiledExpr<'s> { - match self { - CombinedExpr::Simple(op) => op.compile(), - CombinedExpr::Combining { op, items } => { - let items = items - .into_iter() - .map(Expr::compile) - .collect::>() - .into_boxed_slice(); - - match op { - CombiningOp::And => { - CompiledExpr::new(move |ctx| items.iter().all(|item| item.execute(ctx))) - } - CombiningOp::Or => { - CompiledExpr::new(move |ctx| items.iter().any(|item| item.execute(ctx))) - } - CombiningOp::Xor => CompiledExpr::new(move |ctx| { - items - .iter() - .fold(false, |acc, item| acc ^ item.execute(ctx)) - }), - } - } - } - } -} - -#[test] -fn test() { - use super::field_expr::FieldExpr; - use crate::{execution_context::ExecutionContext, lex::complete}; - - let scheme = &Scheme! { - t: Bool, - f: Bool, - }; - - let ctx = &mut ExecutionContext::new(scheme); - - let t_expr = CombinedExpr::Simple(SimpleExpr::Field( - complete(FieldExpr::lex_with("t", scheme)).unwrap(), - )); - - let t_expr = || t_expr.clone(); - - let f_expr = CombinedExpr::Simple(SimpleExpr::Field( - complete(FieldExpr::lex_with("f", scheme)).unwrap(), - )); - - let f_expr = || f_expr.clone(); - - assert_ok!(CombinedExpr::lex_with("t", scheme), t_expr()); - - ctx.set_field_value("t", true).unwrap(); - ctx.set_field_value("f", false).unwrap(); - - { - let expr = assert_ok!( - CombinedExpr::lex_with("t and t", scheme), - CombinedExpr::Combining { - op: CombiningOp::And, - items: vec![t_expr(), t_expr()], - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), true); - } - - { - let expr = assert_ok!( - CombinedExpr::lex_with("t and f", scheme), - CombinedExpr::Combining { - op: CombiningOp::And, - items: vec![t_expr(), f_expr()], - } - ); - - assert_json!( - expr, - { - "op": "And", - "items": [ - { - "lhs": "t", - "op": "IsTrue" - }, - { - "lhs": "f", - "op": "IsTrue" - } - ] - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), false); - } - - { - let expr = assert_ok!( - CombinedExpr::lex_with("t or f", scheme), - CombinedExpr::Combining { - op: CombiningOp::Or, - items: vec![t_expr(), f_expr()], - } - ); - - assert_json!( - expr, - { - "op": "Or", - "items": [ - { - "lhs": "t", - "op": "IsTrue" - }, - { - "lhs": "f", - "op": "IsTrue" - } - ] - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), true); - } - - { - let expr = assert_ok!( - CombinedExpr::lex_with("f or f", scheme), - CombinedExpr::Combining { - op: CombiningOp::Or, - items: vec![f_expr(), f_expr()], - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), false); - } - - { - let expr = assert_ok!( - CombinedExpr::lex_with("t xor f", scheme), - CombinedExpr::Combining { - op: CombiningOp::Xor, - items: vec![t_expr(), f_expr()], - } - ); - - assert_json!( - expr, - { - "op": "Xor", - "items": [ - { - "lhs": "t", - "op": "IsTrue" - }, - { - "lhs": "f", - "op": "IsTrue" - } - ] - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), true); - } - - { - let expr = assert_ok!( - CombinedExpr::lex_with("f xor f", scheme), - CombinedExpr::Combining { - op: CombiningOp::Xor, - items: vec![f_expr(), f_expr()], - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), false); - } - - { - let expr = assert_ok!( - CombinedExpr::lex_with("f xor t", scheme), - CombinedExpr::Combining { - op: CombiningOp::Xor, - items: vec![f_expr(), t_expr()], - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), true); - } - - assert_ok!( - CombinedExpr::lex_with("t or t && t and t or t ^^ t and t || t", scheme), - CombinedExpr::Combining { - op: CombiningOp::Or, - items: vec![ - t_expr(), - CombinedExpr::Combining { - op: CombiningOp::And, - items: vec![t_expr(), t_expr(), t_expr()], - }, - CombinedExpr::Combining { - op: CombiningOp::Xor, - items: vec![ - t_expr(), - CombinedExpr::Combining { - op: CombiningOp::And, - items: vec![t_expr(), t_expr()], - }, - ], - }, - t_expr(), - ], - } - ); -} diff --git a/engine/src/ast/field_expr.rs b/engine/src/ast/field_expr.rs index ed12f9e3..f9804797 100644 --- a/engine/src/ast/field_expr.rs +++ b/engine/src/ast/field_expr.rs @@ -1,35 +1,73 @@ -// use crate::filter::CompiledExpr; -use super::{function_expr::FunctionCallExpr, Expr}; +use super::{ + function_expr::FunctionCallExpr, + parse::FilterParser, + visitor::{Visitor, VisitorMut}, + Expr, +}; use crate::{ - filter::CompiledExpr, - heap_searcher::HeapSearcher, - lex::{skip_space, span, Lex, LexErrorKind, LexResult, LexWith}, + ast::index_expr::IndexExpr, + compiler::Compiler, + filter::{CompiledExpr, CompiledValueExpr}, + lex::{expect, skip_space, span, Lex, LexErrorKind, LexResult, LexWith}, range_set::RangeSet, - rhs_types::{Bytes, ExplicitIpRange, Regex}, - scheme::{Field, Scheme}, + rhs_types::{Bytes, ExplicitIpRange, ListName, Regex, Wildcard}, + scheme::{Field, Identifier, List}, + searcher::{EmptySearcher, TwoWaySearcher}, strict_partial_ord::StrictPartialOrd, types::{GetType, LhsValue, RhsValue, RhsValues, Type}, }; -use fnv::FnvBuildHasher; -use indexmap::IndexSet; -use memmem::Searcher; use serde::{Serialize, Serializer}; +use sliceslice::MemchrSearcher; +use std::collections::BTreeSet; +#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "wasm32"))] +use std::sync::LazyLock; use std::{cmp::Ordering, net::IpAddr}; const LESS: u8 = 0b001; const GREATER: u8 = 0b010; const EQUAL: u8 = 0b100; -lex_enum!(#[repr(u8)] OrderingOp { - "eq" | "==" => Equal = EQUAL, - "ne" | "!=" => NotEqual = LESS | GREATER, - "ge" | ">=" => GreaterThanEqual = GREATER | EQUAL, - "le" | "<=" => LessThanEqual = LESS | EQUAL, - "gt" | ">" => GreaterThan = GREATER, - "lt" | "<" => LessThan = LESS, +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +static USE_AVX2: LazyLock = LazyLock::new(|| { + use std::env; + + const NO_VALUES: &[&str] = &["0", "no", "false"]; + + let use_avx2 = env::var("WIREFILTER_USE_AVX2").unwrap_or_default(); + is_x86_feature_detected!("avx2") && !NO_VALUES.contains(&use_avx2.as_str()) +}); + +#[cfg(target_arch = "wasm32")] +static USE_SIMD128: LazyLock = LazyLock::new(|| { + use std::env; + + const NO_VALUES: &[&str] = &["0", "no", "false"]; + + let use_simd128 = env::var("WIREFILTER_USE_SIMD128").unwrap_or_default(); + !NO_VALUES.contains(&use_simd128.as_str()) }); +lex_enum!( + /// OrderingOp is an operator for an ordering [`ComparisonOpExpr`]. + #[repr(u8)] OrderingOp { + /// `eq` / `==` operator + "eq" | "==" => Equal = EQUAL, + /// `ne` / `!=` operator + "ne" | "!=" => NotEqual = LESS | GREATER, + /// `ge` / `>=` operator + "ge" | ">=" => GreaterThanEqual = GREATER | EQUAL, + /// `le` / `<=` operator + "le" | "<=" => LessThanEqual = LESS | EQUAL, + /// `gt` / `>` operator + "gt" | ">" => GreaterThan = GREATER, + /// `lt` / `<` operator + "lt" | "<" => LessThan = LESS, + } +); + impl OrderingOp { + /// Determines whether the operator matches a given ordering. + #[inline] pub fn matches(self, ordering: Ordering) -> bool { let mask = self as u8; let flag = match ordering { @@ -40,6 +78,9 @@ impl OrderingOp { mask & flag != 0 } + /// Same as `matches` but accepts an optional ordering for incomparable + /// types. + #[inline] pub fn matches_opt(self, ordering: Option) -> bool { match ordering { Some(ordering) => self.matches(ordering), @@ -49,13 +90,19 @@ impl OrderingOp { } } -lex_enum!(IntOp { - "&" | "bitwise_and" => BitwiseAnd, -}); +lex_enum!( + /// An integer operator + IntOp { + /// `bitwise_and` / `&` operator + "&" | "bitwise_and" => BitwiseAnd, + } +); lex_enum!(BytesOp { "contains" => Contains, "~" | "matches" => Matches, + "wildcard" => Wildcard, + "strict wildcard" => StrictWildcard, }); lex_enum!(ComparisonOp { @@ -65,40 +112,80 @@ lex_enum!(ComparisonOp { BytesOp => Bytes, }); -#[derive(Debug, PartialEq, Eq, Clone, Serialize)] +/// Operator and right-hand side expression of a +/// comparison expression. +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] #[serde(untagged)] -enum FieldOp { +pub enum ComparisonOpExpr<'s> { + /// Boolean field verification #[serde(serialize_with = "serialize_is_true")] IsTrue, + /// Ordering comparison Ordering { + /// Ordering comparison operator: + /// * "eq" | "==" + /// * "ne" | "!=" + /// * "ge" | ">=" + /// * "le" | "<=" + /// * "gt" | ">" + /// * "lt" | "<" op: OrderingOp, + /// Right-hand side literal rhs: RhsValue, }, + /// Integer comparison Int { + /// Integer comparison operator: + /// * "&" | "bitwise_and" op: IntOp, - rhs: i32, + /// Right-hand side integer value + rhs: i64, }, + /// "contains" comparison #[serde(serialize_with = "serialize_contains")] Contains(Bytes), + /// "matches / ~" comparison #[serde(serialize_with = "serialize_matches")] Matches(Regex), + /// "wildcard" comparison + #[serde(serialize_with = "serialize_wildcard")] + Wildcard(Wildcard), + + /// "strict wildcard" comparison + #[serde(serialize_with = "serialize_strict_wildcard")] + StrictWildcard(Wildcard), + + /// "in {...}" comparison #[serde(serialize_with = "serialize_one_of")] OneOf(RhsValues), + + /// "contains {...}" comparison + #[serde(serialize_with = "serialize_contains_one_of")] + ContainsOneOf(Vec), + + /// "in $..." comparison + #[serde(serialize_with = "serialize_list")] + InList { + /// `List` from the `Scheme` + list: List<'s>, + /// List name + name: ListName, + }, } -fn serialize_op_rhs( - op: &'static str, - rhs: &T, - ser: S, -) -> Result { +fn serialize_op_rhs(op: &'static str, rhs: &T, ser: S) -> Result +where + T: Serialize + ?Sized, + S: Serializer, +{ use serde::ser::SerializeStruct; - let mut out = ser.serialize_struct("FieldOp", 2)?; + let mut out = ser.serialize_struct("ComparisonOpExpr", 2)?; out.serialize_field("op", op)?; out.serialize_field("rhs", rhs)?; out.end() @@ -107,7 +194,7 @@ fn serialize_op_rhs( fn serialize_is_true(ser: S) -> Result { use serde::ser::SerializeStruct; - let mut out = ser.serialize_struct("FieldOp", 1)?; + let mut out = ser.serialize_struct("ComparisonOpExpr", 1)?; out.serialize_field("op", "IsTrue")?; out.end() } @@ -120,108 +207,181 @@ fn serialize_matches(rhs: &Regex, ser: S) -> Result(rhs: &Wildcard, ser: S) -> Result { + serialize_op_rhs("Wildcard", rhs, ser) +} + +fn serialize_strict_wildcard( + rhs: &Wildcard, + ser: S, +) -> Result { + serialize_op_rhs("Strict Wildcard", rhs, ser) +} + fn serialize_one_of(rhs: &RhsValues, ser: S) -> Result { serialize_op_rhs("OneOf", rhs, ser) } -#[derive(Debug, PartialEq, Eq, Clone, Serialize)] +fn serialize_contains_one_of(rhs: &[Bytes], ser: S) -> Result { + serialize_op_rhs("ContainsOneOf", rhs, ser) +} + +fn serialize_list(_: &List<'_>, name: &ListName, ser: S) -> Result { + serialize_op_rhs("InList", name, ser) +} + +/// Represents either the access to a field's value or +/// a function call. +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] #[serde(untagged)] -pub(crate) enum LhsFieldExpr<'s> { +pub enum IdentifierExpr<'s> { + /// Field access Field(Field<'s>), + /// Function call FunctionCallExpr(FunctionCallExpr<'s>), } -impl<'s> LhsFieldExpr<'s> { - pub fn uses(&self, field: Field<'s>) -> bool { - match self { - LhsFieldExpr::Field(f) => *f == field, - LhsFieldExpr::FunctionCallExpr(call) => call.uses(field), - } - } - - fn compile_with(self, func: F) -> CompiledExpr<'s> - where - F: Fn(LhsValue<'_>) -> bool + Send + Sync, - { +impl<'s> IdentifierExpr<'s> { + pub(crate) fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledValueExpr<'s, C::U> { match self { - LhsFieldExpr::FunctionCallExpr(call) => { - CompiledExpr::new(move |ctx| func(call.execute(ctx))) - } - LhsFieldExpr::Field(f) => { - CompiledExpr::new(move |ctx| func(ctx.get_field_value_unchecked(f))) + IdentifierExpr::Field(f) => { + CompiledValueExpr::new(move |ctx| Ok(ctx.get_field_value_unchecked(f).as_ref())) } + IdentifierExpr::FunctionCallExpr(call) => compiler.compile_function_call_expr(call), } } } -impl<'i, 's> LexWith<'i, &'s Scheme> for LhsFieldExpr<'s> { - fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { - Ok(match FunctionCallExpr::lex_with(input, scheme) { - Ok((call, input)) => (LhsFieldExpr::FunctionCallExpr(call), input), - // Fallback to field - Err(_) => { - let (field, input) = Field::lex_with(input, scheme)?; - (LhsFieldExpr::Field(field), input) +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for IdentifierExpr<'s> { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let (item, input) = Identifier::lex_with(input, parser.scheme)?; + match item { + Identifier::Field(field) => Ok((IdentifierExpr::Field(field), input)), + Identifier::Function(function) => { + FunctionCallExpr::lex_with_function(input, parser, function) + .map(|(call, input)| (IdentifierExpr::FunctionCallExpr(call), input)) } - }) + } } } -impl<'s> GetType for LhsFieldExpr<'s> { +impl<'s> GetType for IdentifierExpr<'s> { fn get_type(&self) -> Type { match self { - LhsFieldExpr::Field(field) => field.get_type(), - LhsFieldExpr::FunctionCallExpr(call) => call.function.return_type, + IdentifierExpr::Field(field) => field.get_type(), + IdentifierExpr::FunctionCallExpr(call) => call.get_type(), } } } -#[derive(Debug, PartialEq, Eq, Clone, Serialize)] -pub struct FieldExpr<'s> { - lhs: LhsFieldExpr<'s>, +/// Comparison expression +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] +pub struct ComparisonExpr<'s> { + /// Left-hand side of the comparison expression + pub lhs: IndexExpr<'s>, + /// Operator + right-hand side of the comparison expression #[serde(flatten)] - op: FieldOp, + pub op: ComparisonOpExpr<'s>, +} + +impl<'s> GetType for ComparisonExpr<'s> { + fn get_type(&self) -> Type { + if self.lhs.map_each_count() > 0 { + Type::Array(Type::Bool.into()) + } else if self.op == ComparisonOpExpr::IsTrue { + // Bool or Array(Bool) + self.lhs.get_type() + } else { + Type::Bool + } + } } -impl<'i, 's> LexWith<'i, &'s Scheme> for FieldExpr<'s> { - fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { - let initial_input = input; +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for ComparisonExpr<'s> { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let (lhs, input) = IndexExpr::lex_with(input, parser)?; - let (lhs, input) = LhsFieldExpr::lex_with(input, scheme)?; + Self::lex_with_lhs(input, parser, lhs) + } +} +impl<'s> ComparisonExpr<'s> { + pub(crate) fn lex_with_lhs<'i>( + input: &'i str, + parser: &FilterParser<'s>, + lhs: IndexExpr<'s>, + ) -> LexResult<'i, Self> { let lhs_type = lhs.get_type(); let (op, input) = if lhs_type == Type::Bool { - (FieldOp::IsTrue, input) + (ComparisonOpExpr::IsTrue, input) + } else if lhs_type.next() == Some(Type::Bool) { + // Invalid because this would produce an Array(Array(Bool)) + // which cannot be coerced to an Array(Bool) + if lhs.map_each_count() > 0 { + return Err(( + LexErrorKind::UnsupportedOp { + lhs_type: Type::Array(Type::Array(Type::Bool.into()).into()), + }, + span(input, input), + )); + } else { + (ComparisonOpExpr::IsTrue, input) + } } else { - let (op, input) = ComparisonOp::lex(skip_space(input))?; + let initial_input = skip_space(input); + let (op, input) = ComparisonOp::lex(initial_input)?; let input_after_op = input; let input = skip_space(input); - match (lhs_type, op) { - (_, ComparisonOp::In) => { - let (rhs, input) = RhsValues::lex_with(input, lhs_type)?; - (FieldOp::OneOf(rhs), input) + match (&lhs_type, op) { + (Type::Ip, ComparisonOp::In) + | (Type::Bytes, ComparisonOp::In) + | (Type::Int, ComparisonOp::In) => { + if expect(input, "$").is_ok() { + let (name, input) = ListName::lex(input)?; + let list = parser.scheme.get_list(&lhs_type).ok_or(( + LexErrorKind::UnsupportedOp { lhs_type }, + span(initial_input, input), + ))?; + (ComparisonOpExpr::InList { name, list }, input) + } else { + let (rhs, input) = RhsValues::lex_with(input, lhs_type)?; + (ComparisonOpExpr::OneOf(rhs), input) + } } - (_, ComparisonOp::Ordering(op)) => { + (Type::Ip, ComparisonOp::Ordering(op)) + | (Type::Bytes, ComparisonOp::Ordering(op)) + | (Type::Int, ComparisonOp::Ordering(op)) => { let (rhs, input) = RhsValue::lex_with(input, lhs_type)?; - (FieldOp::Ordering { op, rhs }, input) + (ComparisonOpExpr::Ordering { op, rhs }, input) } (Type::Int, ComparisonOp::Int(op)) => { - let (rhs, input) = i32::lex(input)?; - (FieldOp::Int { op, rhs }, input) + let (rhs, input) = i64::lex(input)?; + (ComparisonOpExpr::Int { op, rhs }, input) } (Type::Bytes, ComparisonOp::Bytes(op)) => match op { BytesOp::Contains => { let (bytes, input) = Bytes::lex(input)?; - (FieldOp::Contains(bytes), input) + (ComparisonOpExpr::Contains(bytes), input) } BytesOp::Matches => { - let (regex, input) = Regex::lex(input)?; - (FieldOp::Matches(regex), input) + let (regex, input) = Regex::lex_with(input, parser)?; + (ComparisonOpExpr::Matches(regex), input) + } + BytesOp::Wildcard => { + let (wildcard, input) = Wildcard::lex_with(input, parser)?; + (ComparisonOpExpr::Wildcard(wildcard), input) + } + BytesOp::StrictWildcard => { + let (wildcard, input) = Wildcard::lex_with(input, parser)?; + (ComparisonOpExpr::StrictWildcard(wildcard), input) } }, _ => { @@ -233,16 +393,35 @@ impl<'i, 's> LexWith<'i, &'s Scheme> for FieldExpr<'s> { } }; - Ok((FieldExpr { lhs, op }, input)) + Ok((ComparisonExpr { lhs, op }, input)) + } + + /// Retrieves the associated left hand side expression. + pub fn lhs_expr(&self) -> &IndexExpr<'s> { + &self.lhs + } + + /// Retrieves the operator applied to the left hand side expression. + pub fn operator(&self) -> &ComparisonOpExpr<'s> { + &self.op } } -impl<'s> Expr<'s> for FieldExpr<'s> { - fn uses(&self, field: Field<'s>) -> bool { - self.lhs.uses(field) +impl<'s> Expr<'s> for ComparisonExpr<'s> { + #[inline] + fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V) { + visitor.visit_index_expr(&self.lhs) + } + + #[inline] + fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V) { + visitor.visit_index_expr(&mut self.lhs) } - fn compile(self) -> CompiledExpr<'s> { + fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledExpr<'s, C::U> { let lhs = self.lhs; macro_rules! cast_value { @@ -255,23 +434,178 @@ impl<'s> Expr<'s> for FieldExpr<'s> { } match self.op { - FieldOp::IsTrue => lhs.compile_with(move |x| cast_value!(x, Bool)), - FieldOp::Ordering { op, rhs } => { - lhs.compile_with(move |x| op.matches_opt(x.strict_partial_cmp(&rhs))) + ComparisonOpExpr::IsTrue => { + if lhs.get_type() == Type::Bool { + lhs.compile_with(compiler, false, move |x, _ctx| *cast_value!(x, Bool)) + } else if lhs.get_type().next() == Some(Type::Bool) { + // MapEach is impossible in this case, thus call `compile_vec_with` directly + // to coerce LhsValue to Vec + CompiledExpr::Vec( + lhs.compile_vec_with(compiler, move |x, _ctx| *cast_value!(x, Bool)), + ) + } else { + unreachable!() + } + } + ComparisonOpExpr::Ordering { op, rhs } => { + macro_rules! gen_ordering { + ($op:tt, $def:literal) => { + match rhs { + RhsValue::Bytes(bytes) => lhs.compile_with(compiler, $def, move |x, _ctx| { + cast_value!(x, Bytes).as_ref() $op bytes.as_ref() + }), + RhsValue::Int(int) => { + lhs.compile_with(compiler, $def, move |x, _ctx| *cast_value!(x, Int) $op int) + } + RhsValue::Ip(ip) => lhs.compile_with(compiler, $def, move |x, _ctx| { + op.matches_opt(cast_value!(x, Ip).strict_partial_cmp(&ip)) + }), + RhsValue::Bool(_) | RhsValue::Array(_) | RhsValue::Map(_) => unreachable!(), + } + }; + } + + match op { + OrderingOp::NotEqual => gen_ordering!(!=, true), + OrderingOp::Equal => gen_ordering!(==, false), + OrderingOp::GreaterThanEqual => gen_ordering!(>=, false), + OrderingOp::LessThanEqual => gen_ordering!(<=, false), + OrderingOp::GreaterThan => gen_ordering!(>, false), + OrderingOp::LessThan => gen_ordering!(<, false), + } } - FieldOp::Int { + ComparisonOpExpr::Int { op: IntOp::BitwiseAnd, rhs, - } => lhs.compile_with(move |x| cast_value!(x, Int) & rhs != 0), - FieldOp::Contains(bytes) => { - let searcher = HeapSearcher::new(bytes); + } => lhs.compile_with(compiler, false, move |x, _ctx| { + cast_value!(x, Int) & rhs != 0 + }), + ComparisonOpExpr::Contains(bytes) => { + macro_rules! search { + ($searcher:expr) => {{ + let searcher = $searcher; + lhs.compile_with(compiler, false, move |x, _ctx| { + searcher.search_in(cast_value!(x, Bytes).as_ref()) + }) + }}; + } + + let bytes: Box<[u8]> = bytes.into(); + + if bytes.is_empty() { + return search!(EmptySearcher); + } + + if let [byte] = *bytes { + return search!(MemchrSearcher::new(byte)); + } - lhs.compile_with(move |x| searcher.search_in(&cast_value!(x, Bytes)).is_some()) + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if *USE_AVX2 { + use rand::{thread_rng, Rng}; + use sliceslice::x86::*; + + fn slice_to_array(slice: &[u8]) -> [u8; N] { + let mut array = [0u8; N]; + array.copy_from_slice(slice); + array + } + + let position = thread_rng().gen_range(1..bytes.len()); + return unsafe { + match bytes.len() { + 2 => search!(Avx2Searcher::with_position( + slice_to_array::<2>(&bytes), + position + )), + 3 => search!(Avx2Searcher::with_position( + slice_to_array::<3>(&bytes), + position + )), + 4 => search!(Avx2Searcher::with_position( + slice_to_array::<4>(&bytes), + position + )), + 5 => search!(Avx2Searcher::with_position( + slice_to_array::<5>(&bytes), + position + )), + 6 => search!(Avx2Searcher::with_position( + slice_to_array::<6>(&bytes), + position + )), + 7 => search!(Avx2Searcher::with_position( + slice_to_array::<7>(&bytes), + position + )), + 8 => search!(Avx2Searcher::with_position( + slice_to_array::<8>(&bytes), + position + )), + 9 => search!(Avx2Searcher::with_position( + slice_to_array::<9>(&bytes), + position + )), + 10 => search!(Avx2Searcher::with_position( + slice_to_array::<10>(&bytes), + position + )), + 11 => search!(Avx2Searcher::with_position( + slice_to_array::<11>(&bytes), + position + )), + 12 => search!(Avx2Searcher::with_position( + slice_to_array::<12>(&bytes), + position + )), + 13 => search!(Avx2Searcher::with_position( + slice_to_array::<13>(&bytes), + position + )), + 14 => search!(Avx2Searcher::with_position( + slice_to_array::<14>(&bytes), + position + )), + 15 => search!(Avx2Searcher::with_position( + slice_to_array::<15>(&bytes), + position + )), + 16 => search!(Avx2Searcher::with_position( + slice_to_array::<16>(&bytes), + position + )), + _ => search!(Avx2Searcher::with_position(bytes, position)), + } + }; + } + #[cfg(target_arch = "wasm32")] + if *USE_SIMD128 { + use rand::{thread_rng, Rng}; + use sliceslice::wasm32::*; + + let position = thread_rng().gen_range(1..bytes.len()); + + return unsafe { search!(Wasm32Searcher::with_position(bytes, position)) }; + } + + search!(TwoWaySearcher::new(bytes)) + } + ComparisonOpExpr::Matches(regex) => { + lhs.compile_with(compiler, false, move |x, _ctx| { + regex.is_match(cast_value!(x, Bytes)) + }) } - FieldOp::Matches(regex) => { - lhs.compile_with(move |x| regex.is_match(&cast_value!(x, Bytes))) + ComparisonOpExpr::Wildcard(wildcard) => { + lhs.compile_with(compiler, false, move |x, _ctx| { + wildcard.is_match(cast_value!(x, Bytes)) + }) } - FieldOp::OneOf(values) => match values { + ComparisonOpExpr::StrictWildcard(wildcard) => { + lhs.compile_with(compiler, false, move |x, _ctx| { + wildcard.is_match(cast_value!(x, Bytes)) + }) + } + ComparisonOpExpr::OneOf(values) => match values { RhsValues::Ip(ranges) => { let mut v4 = Vec::new(); let mut v6 = Vec::new(); @@ -284,142 +618,314 @@ impl<'s> Expr<'s> for FieldExpr<'s> { let v4 = RangeSet::from(v4); let v6 = RangeSet::from(v6); - lhs.compile_with(move |x| match cast_value!(x, Ip) { - IpAddr::V4(addr) => v4.contains(&addr), - IpAddr::V6(addr) => v6.contains(&addr), + lhs.compile_with(compiler, false, move |x, _ctx| match cast_value!(x, Ip) { + IpAddr::V4(addr) => v4.contains(addr), + IpAddr::V6(addr) => v6.contains(addr), }) } RhsValues::Int(values) => { - let values: RangeSet<_> = values.iter().cloned().collect(); + let values: RangeSet<_> = values.into_iter().map(Into::into).collect(); - lhs.compile_with(move |x| values.contains(&cast_value!(x, Int))) + lhs.compile_with(compiler, false, move |x, _ctx| { + values.contains(cast_value!(x, Int)) + }) } RhsValues::Bytes(values) => { - let values: IndexSet, FnvBuildHasher> = - values.into_iter().map(Into::into).collect(); + let values: BTreeSet> = values.into_iter().map(Into::into).collect(); - lhs.compile_with(move |x| values.contains(&cast_value!(x, Bytes) as &[u8])) + lhs.compile_with(compiler, false, move |x, _ctx| { + values.contains(cast_value!(x, Bytes) as &[u8]) + }) } RhsValues::Bool(_) => unreachable!(), + RhsValues::Map(_) => unreachable!(), + RhsValues::Array(_) => unreachable!(), }, + ComparisonOpExpr::ContainsOneOf(_values) => { + unreachable!("Node should not be constructed as there is no syntax to do so") + } + ComparisonOpExpr::InList { name, list } => { + lhs.compile_with(compiler, false, move |val, ctx| { + ctx.get_list_matcher_unchecked(list) + .match_value(name.as_str(), val) + }) + } } } } #[cfg(test)] +#[allow(clippy::bool_assert_comparison)] mod tests { use super::*; use crate::{ ast::function_expr::{FunctionCallArgExpr, FunctionCallExpr}, + ast::logical_expr::LogicalExpr, execution_context::ExecutionContext, functions::{ - Function, FunctionArgKind, FunctionArgs, FunctionImpl, FunctionOptParam, FunctionParam, + FunctionArgKind, FunctionArgs, FunctionDefinition, FunctionDefinitionContext, + FunctionParam, FunctionParamError, SimpleFunctionDefinition, SimpleFunctionImpl, + SimpleFunctionOptParam, SimpleFunctionParam, }, - rhs_types::IpRange, + lhs_types::{Array, Map}, + list_matcher::{ListDefinition, ListMatcher}, + rhs_types::{IpRange, RegexFormat}, + scheme::{FieldIndex, IndexAccessError, Scheme}, + types::ExpectedType, + BytesFormat, }; - use cidr::{Cidr, IpCidr}; - use lazy_static::lazy_static; - use std::net::IpAddr; + use cidr::IpCidr; + use std::sync::LazyLock; + use std::{convert::TryFrom, iter::once, net::IpAddr}; + + fn any_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + match args.next()? { + Ok(v) => Some(LhsValue::Bool( + Array::try_from(v) + .unwrap() + .into_iter() + .any(|lhs| bool::try_from(lhs).unwrap()), + )), + Err(Type::Array(ref arr)) if arr.get_type() == Type::Bool => None, + _ => unreachable!(), + } + } - fn echo_function<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { - args.next().unwrap() + fn echo_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + args.next()?.ok() } - fn lowercase_function<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { - let input = args.next().unwrap(); + fn lowercase_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + let input = args.next()?.ok()?; match input { - LhsValue::Bytes(bytes) => LhsValue::Bytes(bytes.to_ascii_lowercase().into()), - _ => panic!("Invalid type: expected Bytes, got {:?}", input), + LhsValue::Bytes(bytes) => Some(LhsValue::Bytes(bytes.to_ascii_lowercase().into())), + _ => panic!("Invalid type: expected Bytes, got {input:?}"), } } - fn concat_function<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { + #[allow(clippy::unnecessary_wraps)] + fn concat_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { let mut output = Vec::new(); for (index, arg) in args.enumerate() { - match arg { + match arg.unwrap() { LhsValue::Bytes(bytes) => { output.extend_from_slice(&bytes); } - arg => panic!( - "Invalid type for argument {:?}: expected Bytes, got {:?}", - index, arg - ), + arg => panic!("Invalid type for argument {index:?}: expected Bytes, got {arg:?}"), } } - LhsValue::Bytes(output.into()) + Some(LhsValue::Bytes(output.into())) } - lazy_static! { - static ref SCHEME: Scheme = { - let mut scheme: Scheme = Scheme! { - http.host: Bytes, - ip.addr: Ip, - ssl: Bool, - tcp.port: Int, - }; - scheme - .add_function( - "echo".into(), - Function { - params: vec![FunctionParam { - arg_kind: FunctionArgKind::Field, - val_type: Type::Bytes, - }], - opt_params: vec![], - return_type: Type::Bytes, - implementation: FunctionImpl::new(echo_function), - }, - ) - .unwrap(); - scheme - .add_function( - "lowercase".into(), - Function { - params: vec![FunctionParam { - arg_kind: FunctionArgKind::Field, - val_type: Type::Bytes, - }], - opt_params: vec![], - return_type: Type::Bytes, - implementation: FunctionImpl::new(lowercase_function), - }, - ) - .unwrap(); - scheme - .add_function( - "concat".into(), - Function { - params: vec![], - opt_params: vec![ - FunctionOptParam { - arg_kind: FunctionArgKind::Field, - default_value: "".into(), - }, - FunctionOptParam { - arg_kind: FunctionArgKind::Literal, - default_value: "".into(), - }, - ], - return_type: Type::Bytes, - implementation: FunctionImpl::new(concat_function), - }, + #[derive(Debug)] + struct FilterFunction {} + + impl FilterFunction { + fn new() -> Self { + Self {} + } + } + + impl FunctionDefinition for FilterFunction { + fn check_param( + &self, + params: &mut dyn ExactSizeIterator>, + next_param: &FunctionParam<'_>, + _: Option<&mut FunctionDefinitionContext>, + ) -> Result<(), FunctionParamError> { + match params.len() { + 0 => { + next_param.expect_arg_kind(FunctionArgKind::Field)?; + next_param.expect_val_type(once(ExpectedType::Array))?; + } + 1 => { + next_param.expect_arg_kind(FunctionArgKind::Field)?; + next_param.expect_val_type(once(ExpectedType::Type(Type::Array( + Type::Bool.into(), + ))))?; + } + _ => unreachable!(), + } + Ok(()) + } + + fn return_type( + &self, + params: &mut dyn ExactSizeIterator>, + _: Option<&FunctionDefinitionContext>, + ) -> Type { + params.next().unwrap().get_type() + } + + /// Number of arguments needed by the function. + fn arg_count(&self) -> (usize, Option) { + (2, Some(0)) + } + + fn compile<'s>( + &'s self, + _: &mut dyn ExactSizeIterator>, + _: Option, + ) -> Box Fn(FunctionArgs<'_, 'a>) -> Option> + Sync + Send + 's> + { + Box::new(|args| { + let value_array = Array::try_from(args.next().unwrap().unwrap()).unwrap(); + let keep_array = Array::try_from(args.next().unwrap().unwrap()).unwrap(); + let output = Array::try_from_iter( + value_array.value_type(), + value_array + .into_iter() + .zip(keep_array) + .filter_map(|(value, keep)| { + if bool::try_from(keep).unwrap() { + Some(value) + } else { + None + } + }), ) .unwrap(); - scheme - }; + Some(LhsValue::Array(output)) + }) + } + } + + fn len_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + match args.next()? { + Ok(LhsValue::Bytes(bytes)) => Some(LhsValue::Int(i64::try_from(bytes.len()).unwrap())), + Err(Type::Bytes) => None, + _ => unreachable!(), + } + } + + #[derive(Debug, PartialEq, Eq, Serialize, Clone)] + pub struct NumMListDefinition {} + + impl ListDefinition for NumMListDefinition { + fn matcher_from_json_value( + &self, + _: Type, + _: serde_json::Value, + ) -> Result, serde_json::Error> { + Ok(Box::new(NumMatcher {})) + } + + fn new_matcher(&self) -> Box { + Box::new(NumMatcher {}) + } } + static SCHEME: LazyLock = LazyLock::new(|| { + let mut scheme: Scheme = Scheme! { + http.cookies: Array(Bytes), + http.headers: Map(Bytes), + http.host: Bytes, + ip.addr: Ip, + ssl: Bool, + tcp.port: Int, + tcp.ports: Array(Int), + array.of.bool: Array(Bool), + http.parts: Array(Array(Bytes)), + }; + scheme + .add_function( + "any", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Array(Type::Bool.into()), + }], + opt_params: vec![], + return_type: Type::Bool, + implementation: SimpleFunctionImpl::new(any_function), + }, + ) + .unwrap(); + scheme + .add_function( + "echo", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Bytes, + implementation: SimpleFunctionImpl::new(echo_function), + }, + ) + .unwrap(); + scheme + .add_function( + "lowercase", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Bytes, + implementation: SimpleFunctionImpl::new(lowercase_function), + }, + ) + .unwrap(); + scheme + .add_function( + "concat", + SimpleFunctionDefinition { + params: vec![], + opt_params: vec![ + SimpleFunctionOptParam { + arg_kind: FunctionArgKind::Field, + default_value: "".into(), + }, + SimpleFunctionOptParam { + arg_kind: FunctionArgKind::Literal, + default_value: "".into(), + }, + ], + return_type: Type::Bytes, + implementation: SimpleFunctionImpl::new(concat_function), + }, + ) + .unwrap(); + scheme + .add_function("filter", FilterFunction::new()) + .unwrap(); + scheme + .add_function( + "len", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Int, + implementation: SimpleFunctionImpl::new(len_function), + }, + ) + .unwrap(); + scheme + .add_list(Type::Int, Box::new(NumMListDefinition {})) + .unwrap(); + scheme + }); + fn field(name: &'static str) -> Field<'static> { - SCHEME.get_field_index(name).unwrap() + SCHEME.get_field(name).unwrap() } #[test] fn test_is_true() { let expr = assert_ok!( - FieldExpr::lex_with("ssl", &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("ssl")), - op: FieldOp::IsTrue + FilterParser::new(&SCHEME).lex_as("ssl"), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("ssl")), + indexes: vec![], + }, + op: ComparisonOpExpr::IsTrue } ); @@ -434,20 +940,23 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("ssl", true).unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("ssl"), true).unwrap(); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("ssl", false).unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("ssl"), false).unwrap(); + assert_eq!(expr.execute_one(ctx), false); } #[test] fn test_ip_compare() { let expr = assert_ok!( - FieldExpr::lex_with("ip.addr <= 10:20:30:40:50:60:70:80", &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("ip.addr")), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as("ip.addr <= 10:20:30:40:50:60:70:80"), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("ip.addr")), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::LessThanEqual, rhs: RhsValue::Ip(IpAddr::from([ 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80 @@ -468,27 +977,27 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("ip.addr", IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1])) + ctx.set_field_value(field("ip.addr"), IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1])) .unwrap(); - assert_eq!(expr.execute(ctx), true); + assert_eq!(expr.execute_one(ctx), true); ctx.set_field_value( - "ip.addr", + field("ip.addr"), IpAddr::from([0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80]), ) .unwrap(); - assert_eq!(expr.execute(ctx), true); + assert_eq!(expr.execute_one(ctx), true); ctx.set_field_value( - "ip.addr", + field("ip.addr"), IpAddr::from([0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x81]), ) .unwrap(); - assert_eq!(expr.execute(ctx), false); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("ip.addr", IpAddr::from([127, 0, 0, 1])) + ctx.set_field_value(field("ip.addr"), IpAddr::from([127, 0, 0, 1])) .unwrap(); - assert_eq!(expr.execute(ctx), false); + assert_eq!(expr.execute_one(ctx), false); } #[test] @@ -496,10 +1005,13 @@ mod tests { // just check that parsing doesn't conflict with IPv6 { let expr = assert_ok!( - FieldExpr::lex_with("http.host >= 10:20:30:40:50:60:70:80", &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("http.host")), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as("http.host >= 10:20:30:40:50:60:70:80"), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::GreaterThanEqual, rhs: RhsValue::Bytes( vec![0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80].into() @@ -520,13 +1032,26 @@ mod tests { // just check that parsing doesn't conflict with regular numbers { + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(r#"http.host < 12"#), + LexErrorKind::CountMismatch { + name: "character", + actual: 0, + expected: 1 + }, + "" + ); + let expr = assert_ok!( - FieldExpr::lex_with(r#"http.host < 12"#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("http.host")), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as(r#"http.host < 12:13"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::LessThan, - rhs: RhsValue::Bytes(vec![0x12].into()), + rhs: RhsValue::Bytes(vec![0x12, 0x13].into()), }, } ); @@ -536,16 +1061,19 @@ mod tests { { "lhs": "http.host", "op": "LessThan", - "rhs": [0x12] + "rhs": [0x12, 0x13] } ); } let expr = assert_ok!( - FieldExpr::lex_with(r#"http.host == "example.org""#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("http.host")), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as(r#"http.host == "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::Equal, rhs: RhsValue::Bytes("example.org".to_owned().into()) } @@ -564,20 +1092,25 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example.com").unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.host"), "example.com") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("http.host", "example.org").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.host"), "example.org") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); } #[test] fn test_bitwise_and() { let expr = assert_ok!( - FieldExpr::lex_with("tcp.port & 1", &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("tcp.port")), - op: FieldOp::Int { + FilterParser::new(&SCHEME).lex_as("tcp.port & 1"), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("tcp.port")), + indexes: vec![], + }, + op: ComparisonOpExpr::Int { op: IntOp::BitwiseAnd, rhs: 1, } @@ -596,20 +1129,27 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("tcp.port", 80).unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("tcp.port"), 80).unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("tcp.port", 443).unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("tcp.port"), 443).unwrap(); + assert_eq!(expr.execute_one(ctx), true); } #[test] fn test_int_in() { let expr = assert_ok!( - FieldExpr::lex_with(r#"tcp.port in { 80 443 2082..2083 }"#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("tcp.port")), - op: FieldOp::OneOf(RhsValues::Int(vec![80..=80, 443..=443, 2082..=2083])), + FilterParser::new(&SCHEME).lex_as(r#"tcp.port in { 80 443 2082..2083 }"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("tcp.port")), + indexes: vec![], + }, + op: ComparisonOpExpr::OneOf(RhsValues::Int(vec![ + 80.into(), + 443.into(), + (2082..=2083).into() + ])), } ); @@ -629,38 +1169,41 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("tcp.port", 80).unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("tcp.port"), 80).unwrap(); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("tcp.port", 8080).unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("tcp.port"), 8080).unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("tcp.port", 443).unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("tcp.port"), 443).unwrap(); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("tcp.port", 2081).unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("tcp.port"), 2081).unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("tcp.port", 2082).unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("tcp.port"), 2082).unwrap(); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("tcp.port", 2083).unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("tcp.port"), 2083).unwrap(); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("tcp.port", 2084).unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("tcp.port"), 2084).unwrap(); + assert_eq!(expr.execute_one(ctx), false); } #[test] fn test_bytes_in() { let expr = assert_ok!( - FieldExpr::lex_with(r#"http.host in { "example.org" "example.com" }"#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("http.host")), - op: FieldOp::OneOf(RhsValues::Bytes( + FilterParser::new(&SCHEME).lex_as(r#"http.host in { "example.org" "example.com" }"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::OneOf(RhsValues::Bytes( ["example.org", "example.com",] .iter() - .map(|&s| s.to_string().into()) + .map(|s| (*s).to_string().into()) .collect() )), } @@ -681,26 +1224,30 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example.com").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.host"), "example.com") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("http.host", "example.org").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.host"), "example.org") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("http.host", "example.net").unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.host"), "example.net") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); } #[test] fn test_ip_in() { let expr = assert_ok!( - FieldExpr::lex_with( - r#"ip.addr in { 127.0.0.0/8 ::1 10.0.0.0..10.0.255.255 }"#, - &SCHEME - ), - FieldExpr { - lhs: LhsFieldExpr::Field(field("ip.addr")), - op: FieldOp::OneOf(RhsValues::Ip(vec![ + FilterParser::new(&SCHEME) + .lex_as(r#"ip.addr in { 127.0.0.0/8 ::1 10.0.0.0..10.0.255.255 }"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("ip.addr")), + indexes: vec![], + }, + op: ComparisonOpExpr::OneOf(RhsValues::Ip(vec![ IpRange::Cidr(IpCidr::new([127, 0, 0, 0].into(), 8).unwrap()), IpRange::Cidr(IpCidr::new_host([0, 0, 0, 0, 0, 0, 0, 1].into())), IpRange::Explicit(ExplicitIpRange::V4( @@ -726,34 +1273,37 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("ip.addr", IpAddr::from([127, 0, 0, 1])) + ctx.set_field_value(field("ip.addr"), IpAddr::from([127, 0, 0, 1])) .unwrap(); - assert_eq!(expr.execute(ctx), true); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("ip.addr", IpAddr::from([127, 0, 0, 3])) + ctx.set_field_value(field("ip.addr"), IpAddr::from([127, 0, 0, 3])) .unwrap(); - assert_eq!(expr.execute(ctx), true); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("ip.addr", IpAddr::from([255, 255, 255, 255])) + ctx.set_field_value(field("ip.addr"), IpAddr::from([255, 255, 255, 255])) .unwrap(); - assert_eq!(expr.execute(ctx), false); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("ip.addr", IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1])) + ctx.set_field_value(field("ip.addr"), IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1])) .unwrap(); - assert_eq!(expr.execute(ctx), true); + assert_eq!(expr.execute_one(ctx), true); - ctx.set_field_value("ip.addr", IpAddr::from([0, 0, 0, 0, 0, 0, 0, 2])) + ctx.set_field_value(field("ip.addr"), IpAddr::from([0, 0, 0, 0, 0, 0, 0, 2])) .unwrap(); - assert_eq!(expr.execute(ctx), false); + assert_eq!(expr.execute_one(ctx), false); } #[test] fn test_contains_bytes() { let expr = assert_ok!( - FieldExpr::lex_with(r#"http.host contains "abc""#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("http.host")), - op: FieldOp::Contains("abc".to_owned().into()) + FilterParser::new(&SCHEME).lex_as(r#"http.host contains "abc""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Contains("abc".to_owned().into()) } ); @@ -769,20 +1319,25 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example.org").unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.host"), "example.org") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("http.host", "abc.net.au").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.host"), "abc.net.au") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); } #[test] fn test_contains_str() { let expr = assert_ok!( - FieldExpr::lex_with(r#"http.host contains 6F:72:67"#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("http.host")), - op: FieldOp::Contains(vec![0x6F, 0x72, 0x67].into()), + FilterParser::new(&SCHEME).lex_as(r#"http.host contains 6F:72:67"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Contains(vec![0x6F, 0x72, 0x67].into()), } ); @@ -798,20 +1353,25 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example.com").unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.host"), "example.com") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("http.host", "example.org").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.host"), "example.org") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); } #[test] fn test_int_compare() { let expr = assert_ok!( - FieldExpr::lex_with(r#"tcp.port < 8000"#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::Field(field("tcp.port")), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as(r#"tcp.port < 8000"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("tcp.port")), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::LessThan, rhs: RhsValue::Int(8000) }, @@ -830,32 +1390,114 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("tcp.port", 80).unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("tcp.port"), 80).unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + ctx.set_field_value(field("tcp.port"), 8080).unwrap(); + assert_eq!(expr.execute_one(ctx), false); + } + + #[test] + fn test_array_contains_str() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.cookies[0] contains "abc""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![FieldIndex::ArrayIndex(0)], + }, + op: ComparisonOpExpr::Contains("abc".to_owned().into()), + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let cookies = Array::from_iter(["abc"]); + + ctx.set_field_value(field("http.cookies"), cookies).unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + let cookies = Array::from_iter(["def"]); + + ctx.set_field_value(field("http.cookies"), cookies).unwrap(); + assert_eq!(expr.execute_one(ctx), false); + } + + #[test] + fn test_map_of_bytes_contains_str() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.headers["host"] contains "abc""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.headers")), + indexes: vec![FieldIndex::MapKey("host".to_string())], + }, + op: ComparisonOpExpr::Contains("abc".to_owned().into()), + } + ); + + assert_json!( + expr, + { + "lhs": [ + "http.headers", + {"kind": "MapKey", "value": "host"} + ], + "op": "Contains", + "rhs": "abc", + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let headers = LhsValue::Map({ + let mut map = Map::new(Type::Bytes); + map.insert(b"host", "example.org").unwrap(); + map + }); - ctx.set_field_value("tcp.port", 8080).unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.headers"), headers).unwrap(); + assert_eq!(expr.execute_one(ctx), false); + + let headers = LhsValue::Map({ + let mut map = Map::new(Type::Bytes); + map.insert(b"host", "abc.net.au").unwrap(); + map + }); + + ctx.set_field_value(field("http.headers"), headers).unwrap(); + assert_eq!(expr.execute_one(ctx), true); } #[test] fn test_bytes_compare_with_echo_function() { let expr = assert_ok!( - FieldExpr::lex_with(r#"echo(http.host) == "example.org""#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::FunctionCallExpr(FunctionCallExpr { - name: String::from("echo"), - function: SCHEME.get_function("echo").unwrap(), - args: vec![FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field( - field("http.host") - ))], - }), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as(r#"echo(http.host) == "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + })], + context: None, + }), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::Equal, rhs: RhsValue::Bytes("example.org".to_owned().into()) } } ); + assert_eq!(expr.lhs.identifier.get_type(), Type::Bytes); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + assert_json!( expr, { @@ -863,7 +1505,7 @@ mod tests { "name": "echo", "args": [ { - "kind": "LhsFieldExpr", + "kind": "IndexExpr", "value": "http.host" } ] @@ -876,32 +1518,42 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example.com").unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.host"), "example.com") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("http.host", "example.org").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.host"), "example.org") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); } #[test] fn test_bytes_compare_with_lowercase_function() { let expr = assert_ok!( - FieldExpr::lex_with(r#"lowercase(http.host) == "example.org""#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::FunctionCallExpr(FunctionCallExpr { - name: String::from("lowercase"), - function: SCHEME.get_function("lowercase").unwrap(), - args: vec![FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field( - field("http.host") - ))], - }), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as(r#"lowercase(http.host) == "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("lowercase").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + })], + context: None, + }), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::Equal, rhs: RhsValue::Bytes("example.org".to_owned().into()) } } ); + assert_eq!(expr.lhs.identifier.get_type(), Type::Bytes); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + assert_json!( expr, { @@ -909,7 +1561,7 @@ mod tests { "name": "lowercase", "args": [ { - "kind": "LhsFieldExpr", + "kind": "IndexExpr", "value": "http.host" } ] @@ -922,26 +1574,25 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "EXAMPLE.COM").unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.host"), "EXAMPLE.COM") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); - ctx.set_field_value("http.host", "EXAMPLE.ORG").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.host"), "EXAMPLE.ORG") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); } #[test] - fn test_bytes_compare_with_concat_function() { + fn test_missing_array_value_equal() { let expr = assert_ok!( - FieldExpr::lex_with(r#"concat(http.host) == "example.org""#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::FunctionCallExpr(FunctionCallExpr { - name: String::from("concat"), - function: SCHEME.get_function("concat").unwrap(), - args: vec![FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field( - field("http.host") - ))], - }), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as(r#"http.cookies[0] == "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![FieldIndex::ArrayIndex(0)], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::Equal, rhs: RhsValue::Bytes("example.org".to_owned().into()) } @@ -951,15 +1602,10 @@ mod tests { assert_json!( expr, { - "lhs": { - "name": "concat", - "args": [ - { - "kind": "LhsFieldExpr", - "value": "http.host" - } - ] - }, + "lhs": [ + "http.cookies", + {"kind": "ArrayIndex", "value": 0} + ], "op": "Equal", "rhs": "example.org" } @@ -968,26 +1614,57 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example.org").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.cookies"), Array::new(Type::Bytes)) + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); + } + + #[test] + fn test_missing_array_value_not_equal() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.cookies[0] != "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![FieldIndex::ArrayIndex(0)], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::NotEqual, + rhs: RhsValue::Bytes("example.org".to_owned().into()) + } + } + ); + + assert_json!( + expr, + { + "lhs": [ + "http.cookies", + {"kind": "ArrayIndex", "value": 0} + ], + "op": "NotEqual", + "rhs": "example.org" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example.co.uk").unwrap(); - assert_eq!(expr.execute(ctx), false); + ctx.set_field_value(field("http.cookies"), Array::new(Type::Bytes)) + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); + } + #[test] + fn test_missing_map_value_equal() { let expr = assert_ok!( - FieldExpr::lex_with(r#"concat(http.host, ".org") == "example.org""#, &SCHEME), - FieldExpr { - lhs: LhsFieldExpr::FunctionCallExpr(FunctionCallExpr { - name: String::from("concat"), - function: SCHEME.get_function("concat").unwrap(), - args: vec![ - FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field(field("http.host"))), - FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::from( - ".org".to_owned() - ))), - ], - }), - op: FieldOp::Ordering { + FilterParser::new(&SCHEME).lex_as(r#"http.headers["missing"] == "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.headers")), + indexes: vec![FieldIndex::MapKey("missing".into())], + }, + op: ComparisonOpExpr::Ordering { op: OrderingOp::Equal, rhs: RhsValue::Bytes("example.org".to_owned().into()) } @@ -997,19 +1674,10 @@ mod tests { assert_json!( expr, { - "lhs": { - "name": "concat", - "args": [ - { - "kind": "LhsFieldExpr", - "value": "http.host" - }, - { - "kind": "Literal", - "value": ".org" - }, - ] - }, + "lhs": [ + "http.headers", + {"kind": "MapKey", "value": "missing"} + ], "op": "Equal", "rhs": "example.org" } @@ -1018,10 +1686,1209 @@ mod tests { let expr = expr.compile(); let ctx = &mut ExecutionContext::new(&SCHEME); - ctx.set_field_value("http.host", "example").unwrap(); - assert_eq!(expr.execute(ctx), true); + ctx.set_field_value(field("http.headers"), Map::new(Type::Bytes)) + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); + } - ctx.set_field_value("http.host", "cloudflare").unwrap(); - assert_eq!(expr.execute(ctx), false); + #[test] + fn test_missing_map_value_not_equal() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.headers["missing"] != "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.headers")), + indexes: vec![FieldIndex::MapKey("missing".into())], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::NotEqual, + rhs: RhsValue::Bytes("example.org".to_owned().into()) + } + } + ); + + assert_json!( + expr, + { + "lhs": [ + "http.headers", + {"kind": "MapKey", "value": "missing"} + ], + "op": "NotEqual", + "rhs": "example.org" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.headers"), Map::new(Type::Bytes)) + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); + } + + #[test] + fn test_bytes_compare_with_concat_function() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"concat(http.host) == "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("concat").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + })], + context: None, + }), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("example.org".to_owned().into()) + } + } + ); + + assert_eq!(expr.lhs.identifier.get_type(), Type::Bytes); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": { + "name": "concat", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + } + ] + }, + "op": "Equal", + "rhs": "example.org" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), "example.org") + .unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + ctx.set_field_value(field("http.host"), "example.co.uk") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"concat(http.host, ".org") == "example.org""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("concat").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::from( + ".org".to_owned() + ))), + ], + context: None, + }), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("example.org".to_owned().into()) + } + } + ); + + assert_eq!(expr.lhs.identifier.get_type(), Type::Bytes); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": { + "name": "concat", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + }, + { + "kind": "Literal", + "value": ".org" + }, + ] + }, + "op": "Equal", + "rhs": "example.org" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), "example").unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + ctx.set_field_value(field("http.host"), "cloudflare") + .unwrap(); + assert_eq!(expr.execute_one(ctx), false); + } + + #[test] + fn test_filter_function() { + let expr = assert_ok!( + FilterParser::new(&SCHEME) + .lex_as(r#"filter(http.cookies, array.of.bool)[0] == "three""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("filter").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![], + }), + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("array.of.bool")), + indexes: vec![], + }), + ], + context: None, + }), + indexes: vec![FieldIndex::ArrayIndex(0)], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("three".to_owned().into()) + } + } + ); + + assert_eq!( + expr.lhs.identifier.get_type(), + Type::Array(Type::Bytes.into()) + ); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": [ + { + "name": "filter", + "args": [ + { + "kind": "IndexExpr", + "value": "http.cookies" + }, + { + "kind": "IndexExpr", + "value": "array.of.bool" + } + ] + }, + {"kind": "ArrayIndex", "value": 0}, + ], + "op": "Equal", + "rhs": "three" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let cookies = Array::from_iter(["one", "two", "three"]); + ctx.set_field_value(field("http.cookies"), cookies).unwrap(); + + let booleans = Array::from_iter([false, false, true]); + ctx.set_field_value(field("array.of.bool"), booleans) + .unwrap(); + + assert_eq!(expr.execute_one(ctx), true); + } + + #[test] + fn test_map_each_on_array_with_function() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"concat(http.cookies[*], "-cf")[2] == "three-cf""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("concat").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![FieldIndex::MapEach], + }), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::from( + "-cf".to_owned() + ))), + ], + context: None, + }), + indexes: vec![FieldIndex::ArrayIndex(2)], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("three-cf".to_owned().into()) + } + } + ); + + assert_eq!( + expr.lhs.identifier.get_type(), + Type::Array(Type::Bytes.into()) + ); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": [ + { + "name": "concat", + "args": [ + { + "kind": "IndexExpr", + "value": ["http.cookies", {"kind": "MapEach"}], + }, + { + "kind": "Literal", + "value": "-cf" + } + ] + }, + {"kind": "ArrayIndex", "value": 2}, + ], + "op": "Equal", + "rhs": "three-cf" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let cookies = Array::from_iter(["one", "two", "three"]); + ctx.set_field_value(field("http.cookies"), cookies).unwrap(); + + assert_eq!(expr.execute_one(ctx), true); + } + + #[test] + fn test_map_each_on_map_with_function() { + let expr = assert_ok!( + FilterParser::new(&SCHEME) + .lex_as(r#"concat(http.headers[*], "-cf")[2] in {"one-cf" "two-cf" "three-cf"}"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("concat").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.headers")), + indexes: vec![FieldIndex::MapEach], + }), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::from( + "-cf".to_owned() + ))), + ], + context: None, + }), + indexes: vec![FieldIndex::ArrayIndex(2)], + }, + op: ComparisonOpExpr::OneOf(RhsValues::Bytes(vec![ + "one-cf".to_owned().into(), + "two-cf".to_owned().into(), + "three-cf".to_owned().into() + ])) + } + ); + + assert_eq!( + expr.lhs.identifier.get_type(), + Type::Array(Type::Bytes.into()) + ); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": [ + { + "name": "concat", + "args": [ + { + "kind": "IndexExpr", + "value": ["http.headers", {"kind": "MapEach"}], + }, + { + "kind": "Literal", + "value": "-cf" + } + ] + }, + {"kind": "ArrayIndex", "value": 2}, + ], + "op": "OneOf", + "rhs": ["one-cf", "two-cf", "three-cf"], + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let headers = LhsValue::Map({ + let mut map = Map::new(Type::Bytes); + map.insert(b"0", "one").unwrap(); + map.insert(b"1", "two").unwrap(); + map.insert(b"2", "three").unwrap(); + map + }); + ctx.set_field_value(field("http.headers"), headers).unwrap(); + + assert_eq!(expr.execute_one(ctx), true); + } + + #[test] + fn test_map_each_on_array_for_cmp() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.cookies[*] == "three""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("three".to_owned().into()) + } + } + ); + + assert_json!( + expr, + { + "lhs": ["http.cookies", {"kind": "MapEach"}], + "op": "Equal", + "rhs": "three", + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let cookies = Array::from_iter(["one", "two", "three"]); + ctx.set_field_value(field("http.cookies"), cookies).unwrap(); + + assert_eq!(expr.execute_vec(ctx), [false, false, true]); + } + + #[test] + fn test_map_each_on_map_for_cmp() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.headers[*] == "three""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.headers")), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("three".to_owned().into()) + } + } + ); + + assert_eq!( + expr.lhs.identifier.get_type(), + Type::Map(Type::Bytes.into()) + ); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": ["http.headers", {"kind": "MapEach"}], + "op": "Equal", + "rhs": "three", + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let headers = LhsValue::Map({ + let mut map = Map::new(Type::Bytes); + map.insert(b"0", "one").unwrap(); + map.insert(b"1", "two").unwrap(); + map.insert(b"2", "three").unwrap(); + map + }); + ctx.set_field_value(field("http.headers"), headers).unwrap(); + + let mut true_count = 0; + let mut false_count = 0; + for val in expr.execute_vec(ctx).iter() { + if *val { + true_count += 1; + } else { + false_count += 1; + } + } + assert_eq!(false_count, 2); + assert_eq!(true_count, 1); + } + + #[test] + fn test_map_each_on_array_full() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"concat(http.cookies[*], "-cf")[*] == "three-cf""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("concat").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![FieldIndex::MapEach], + }), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::from( + "-cf".to_owned() + ))), + ], + context: None, + }), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("three-cf".to_owned().into()) + } + } + ); + + assert_eq!( + expr.lhs.identifier.get_type(), + Type::Array(Type::Bytes.into()) + ); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": [ + { + "name": "concat", + "args": [ + { + "kind": "IndexExpr", + "value": ["http.cookies", {"kind": "MapEach"}], + }, + { + "kind": "Literal", + "value": "-cf" + } + ] + }, + {"kind": "MapEach"}, + ], + "op": "Equal", + "rhs": "three-cf" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let cookies = Array::from_iter(["one", "two", "three"]); + ctx.set_field_value(field("http.cookies"), cookies).unwrap(); + + assert_eq!(expr.execute_vec(ctx), [false, false, true]); + } + + #[test] + fn test_map_each_on_array_len_function() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"len(http.cookies[*])[*] > 3"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("len").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.cookies")), + indexes: vec![FieldIndex::MapEach], + }),], + context: None, + }), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::GreaterThan, + rhs: RhsValue::Int(3), + } + } + ); + + assert_eq!( + expr.lhs.identifier.get_type(), + Type::Array(Type::Int.into()) + ); + assert_eq!(expr.lhs.get_type(), Type::Int); + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": [ + { + "name": "len", + "args": [ + { + "kind": "IndexExpr", + "value": ["http.cookies", {"kind": "MapEach"}], + } + ] + }, + {"kind": "MapEach"}, + ], + "op": "GreaterThan", + "rhs": 3 + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + let cookies = Array::from_iter(["one", "two", "three"]); + ctx.set_field_value(field("http.cookies"), cookies).unwrap(); + + assert_eq!(expr.execute_vec(ctx), [false, false, true]); + } + + #[test] + fn test_map_each_error() { + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(r#"http.host[*] == "three""#), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Bytes, + }), + "[*]" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(r#"ip.addr[*] == 127.0.0.1"#), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Ip, + }), + "[*]" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(r#"ssl[*]"#), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Bool, + }), + "[*]" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(r#"tcp.port[*] == 80"#), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Int, + }), + "[*]" + ); + } + + #[derive(Debug, PartialEq, Eq, Serialize, Clone)] + pub struct NumMatcher {} + + impl ListMatcher for NumMatcher { + fn match_value(&self, list_name: &str, val: &LhsValue<'_>) -> bool { + // Ideally this would lookup list_name in metadata + let list_id = if list_name == "even" { + [0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + } else { + [1u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + }; + + match val { + LhsValue::Int(num) => self.num_matches(*num, list_id), + _ => unreachable!(), // TODO: is this unreachable? + } + } + + fn to_json_value(&self) -> serde_json::Value { + serde_json::Value::Null + } + + fn clear(&mut self) {} + } + + /// Match IPs (v4 and v6) in lists. + /// + /// ```text + /// ip.src in $whitelist and not origin.ip in $whitelist + /// ``` + impl NumMatcher { + fn num_matches(&self, num: i64, list_id: [u8; 16]) -> bool { + let remainder = + i64::from(list_id == [1u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + num % 2 == remainder + } + } + + #[test] + fn test_number_in_list() { + let list = SCHEME.get_list(&Type::Int).unwrap(); + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"tcp.port in $even"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("tcp.port")), + indexes: vec![], + }, + op: ComparisonOpExpr::InList { + list, + name: ListName::from("even".to_string()) + } + } + ); + + assert_json!( + expr, + { + "lhs": "tcp.port", + "op": "InList", + "rhs": "even" + } + ); + + // EVEN list + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("tcp.port"), 1000).unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + ctx.set_field_value(field("tcp.port"), 1001).unwrap(); + assert_eq!(expr.execute_one(ctx), false); + + // ODD list + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"tcp.port in $odd"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("tcp.port")), + indexes: vec![], + }, + op: ComparisonOpExpr::InList { + list, + name: ListName::from("odd".to_string()), + } + } + ); + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("tcp.port"), 1000).unwrap(); + assert_eq!(expr.execute_one(ctx), false); + + ctx.set_field_value(field("tcp.port"), 1001).unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + let json = serde_json::to_string(ctx).unwrap(); + assert_eq!(json, "{\"tcp.port\":1001,\"$lists\":[]}"); + } + + #[test] + fn test_any_number_in_list() { + let list = SCHEME.get_list(&Type::Int).unwrap(); + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"any(tcp.ports[*] in $even)"#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("any").unwrap(), + args: vec![FunctionCallArgExpr::Logical(LogicalExpr::Comparison( + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("tcp.ports")), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::InList { + list, + name: ListName::from("even".to_string()), + }, + } + ))], + context: None, + }), + indexes: vec![], + }, + op: ComparisonOpExpr::IsTrue + } + ); + + assert_eq!(expr.lhs.identifier.get_type(), Type::Bool); + assert_eq!(expr.lhs.get_type(), Type::Bool); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": { + "name": "any", + "args": [ + { + "kind": "SimpleExpr", + "value": { + "lhs": ["tcp.ports", {"kind": "MapEach"}], + "op": "InList", + "rhs": "even" + } + } + ] + }, + "op": "IsTrue" + } + ); + + // EVEN list + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + // 1 odd, 1 even + let arr1 = Array::from_iter([1001, 1000]); + + ctx.set_field_value(field("tcp.ports"), arr1).unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + // All odd numbers + let arr2 = Array::from_iter([1001, 1003]); + + ctx.set_field_value(field("tcp.ports"), arr2).unwrap(); + assert_eq!(expr.execute_one(ctx), false); + } + + #[test] + fn test_map_each_nested() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.parts[*][*] == "[5][5]""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.parts")), + indexes: vec![FieldIndex::MapEach, FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("[5][5]".to_owned().into()) + } + } + ); + + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": ["http.parts", {"kind": "MapEach"}, {"kind": "MapEach"}], + "op": "Equal", + "rhs": "[5][5]", + } + ); + + let expr1 = expr.compile(); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.parts[5][*] == "[5][5]""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.parts")), + indexes: vec![FieldIndex::ArrayIndex(5), FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("[5][5]".to_owned().into()) + } + } + ); + + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": ["http.parts", {"kind": "ArrayIndex", "value": 5}, {"kind": "MapEach"}], + "op": "Equal", + "rhs": "[5][5]", + } + ); + + let expr2 = expr.compile(); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"http.parts[*][5] == "[5][5]""#), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.parts")), + indexes: vec![FieldIndex::MapEach, FieldIndex::ArrayIndex(5)], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("[5][5]".to_owned().into()) + } + } + ); + + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": ["http.parts", {"kind": "MapEach"}, {"kind": "ArrayIndex", "value": 5}], + "op": "Equal", + "rhs": "[5][5]", + } + ); + + let expr3 = expr.compile(); + + let ctx = &mut ExecutionContext::new(&SCHEME); + + let parts = Array::try_from_iter( + Type::Array(Type::Bytes.into()), + (0..10).map(|i| Array::from_iter((0..10).map(|j| format!("[{i}][{j}]")))), + ) + .unwrap(); + + ctx.set_field_value(field("http.parts"), parts).unwrap(); + + let mut true_count = 0; + let mut false_count = 0; + for val in expr1.execute_vec(ctx).iter() { + if *val { + true_count += 1; + } else { + false_count += 1; + } + } + assert_eq!(false_count, 99); + assert_eq!(true_count, 1); + + let mut true_count = 0; + let mut false_count = 0; + for val in expr2.execute_vec(ctx).iter() { + if *val { + true_count += 1; + } else { + false_count += 1; + } + } + assert_eq!(false_count, 9); + assert_eq!(true_count, 1); + + let mut true_count = 0; + let mut false_count = 0; + for val in expr3.execute_vec(ctx).iter() { + if *val { + true_count += 1; + } else { + false_count += 1; + } + } + assert_eq!(false_count, 9); + assert_eq!(true_count, 1); + } + + #[test] + fn test_raw_string() { + // Equal operator + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("http.host == r###\"ab\"###"), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes(Bytes::new("ab".as_bytes(), BytesFormat::Raw(3))), + }, + } + ); + + assert_json!( + expr, + { + "lhs": "http.host", + "op": "Equal", + "rhs": "ab", + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), "ab").unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + ctx.set_field_value(field("http.host"), "cd").unwrap(); + assert_eq!(expr.execute_one(ctx), false); + + // Matches operator + let parser = FilterParser::new(&SCHEME); + let r = Regex::new("a.b", RegexFormat::Literal, &parser).unwrap(); + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("http.host matches r###\"a.b\"###"), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Matches(r), + } + ); + + assert_json!( + expr, + { + "lhs": "http.host", + "op": "Matches", + "rhs": "a.b", + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), "axb").unwrap(); + assert_eq!(expr.execute_one(ctx), true); + + ctx.set_field_value(field("http.host"), "axc").unwrap(); + assert_eq!(expr.execute_one(ctx), false); + + // Wildcard operator + let wildcard = Wildcard::new( + Bytes::new(r"foo*\*\\".as_bytes(), BytesFormat::Raw(2)), + usize::MAX, + ) + .unwrap(); + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#####"http.host wildcard r##"foo*\*\\"##"#####), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::Wildcard(wildcard), + } + ); + + assert_json!( + expr, + { + "lhs": "http.host", + "op": "Wildcard", + "rhs": r"foo*\*\\", + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), r"foo*\").unwrap(); + assert!(expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foo*").unwrap(); + assert!(!expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foo\").unwrap(); + assert!(!expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"Foo*\").unwrap(); + assert!(expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foobarmumble*\") + .unwrap(); + assert!(expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foobarmumble\") + .unwrap(); + assert!(!expr.execute_one(ctx)); + + // Strict wildcard operator + let wildcard = Wildcard::new( + Bytes::new(r"foo*\*\\".as_bytes(), BytesFormat::Raw(2)), + usize::MAX, + ) + .unwrap(); + let expr = assert_ok!( + FilterParser::new(&SCHEME) + .lex_as(r#####"http.host strict wildcard r##"foo*\*\\"##"#####), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }, + op: ComparisonOpExpr::StrictWildcard(wildcard), + } + ); + + assert_json!( + expr, + { + "lhs": "http.host", + "op": "Strict Wildcard", + "rhs": r"foo*\*\\", + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), r"foo*\").unwrap(); + assert!(expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foo*").unwrap(); + assert!(!expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foo\").unwrap(); + assert!(!expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"Foo*\").unwrap(); + assert!(!expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foobarmumble*\") + .unwrap(); + assert!(expr.execute_one(ctx)); + ctx.set_field_value(field("http.host"), r"foobarmumble\") + .unwrap(); + assert!(!expr.execute_one(ctx)); + + // Function call + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("concat(http.host, r#\"cd\"#) == r##\"abcd\"##"), + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("concat").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(field("http.host")), + indexes: vec![], + }), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::new( + "cd".as_bytes(), + BytesFormat::Raw(1) + ))) + ], + context: None, + }), + indexes: vec![], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes(Bytes::new("abcd".as_bytes(), BytesFormat::Raw(2))) + } + } + ); + + assert_eq!(expr.lhs.identifier.get_type(), Type::Bytes); + assert_eq!(expr.lhs.get_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": { + "name": "concat", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + }, + { + "kind": "Literal", + "value": "cd" + } + ] + }, + "op": "Equal", + "rhs": "abcd" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), "xx").unwrap(); + assert_eq!(expr.execute_one(ctx), false); + + ctx.set_field_value(field("http.host"), "ab").unwrap(); + assert_eq!(expr.execute_one(ctx), true); + } + + #[test] + fn expression_evaluation_wildcard() { + let op_normal = |rhs, value, expected| -> Vec<(&str, &str, &[u8], bool)> { + vec![("wildcard", rhs, value, expected)] + }; + let op_strict = |rhs, value, expected| -> Vec<(&str, &str, &[u8], bool)> { + vec![("strict wildcard", rhs, value, expected)] + }; + let op_both = |rhs, value, expected| -> Vec<(&str, &str, &[u8], bool)> { + op_normal(rhs, value, expected) + .into_iter() + .chain(op_strict(rhs, value, expected)) + .collect() + }; + + let testcases = [ + // Escaping at the wildcard-level with raw strings. + op_both(r##"r"foo?*\*\\""##, r"foo?*\".as_bytes(), true), + op_both(r##"r#"foo?*\*\\"#"##, r"foo?*\".as_bytes(), true), + op_both(r##"r"foo?*\*\\""##, r"foo?bar*\".as_bytes(), true), + op_both(r##"r"foo?*\*\\""##, r"foo?x\".as_bytes(), false), + op_both(r##"r"¯\\_(ツ*)_/¯""##, r"¯\_(ツ)_/¯".as_bytes(), true), + // Escaping at the wildcard-level with quoted strings. + op_both(r#""foo?*\\*\\\\""#, r"foo?*\".as_bytes(), true), + op_both(r#""foo?*\\*\\\\""#, r"foo?*\".as_bytes(), true), + op_both(r#""foo?*\\*\\\\""#, r"foo?bar*\".as_bytes(), true), + op_both(r#""foo?*\\*\\\\""#, r"foo?x\".as_bytes(), false), + op_both(r#""fo\x6f""#, r"foo".as_bytes(), true), + op_both(r#""¯\\\\_(ツ*)_/¯""#, r"¯\_(ツ)_/¯".as_bytes(), true), + op_both(r#""\xaa\x22""#, &[0xaa, 0x22], true), + // ? is not special. + op_both(r##""?""##, r"?".as_bytes(), true), + op_both(r##""?""##, r"x".as_bytes(), false), + // Case sensitivity. + op_normal(r##""a""##, r"A".as_bytes(), true), + op_strict(r##""a""##, r"A".as_bytes(), false), + ] + .concat(); + + for t @ (op, rhs, value, expected) in testcases { + let expr: ComparisonExpr<'_> = FilterParser::new(&SCHEME) + .lex_as(&format!("http.host {op} {rhs}")) + .map(|(e, _)| e) + .expect("failed to parse expression"); + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value(field("http.host"), value).unwrap(); + + assert_eq!(expr.execute_one(ctx), expected, "failed test case {t:?}"); + } } } diff --git a/engine/src/ast/function_expr.rs b/engine/src/ast/function_expr.rs index 581c40ff..26b32bab 100644 --- a/engine/src/ast/function_expr.rs +++ b/engine/src/ast/function_expr.rs @@ -1,365 +1,1434 @@ -use super::field_expr::LhsFieldExpr; +use super::{ + parse::FilterParser, + visitor::{Visitor, VisitorMut}, + ValueExpr, +}; use crate::{ - execution_context::ExecutionContext, - functions::{Function, FunctionArgKind, FunctionParam}, - lex::{expect, skip_space, span, take, take_while, LexError, LexErrorKind, LexResult, LexWith}, - scheme::{Field, Scheme}, - types::{GetType, LhsValue, RhsValue, TypeMismatchError}, + ast::{ + field_expr::{ComparisonExpr, ComparisonOp, ComparisonOpExpr}, + index_expr::IndexExpr, + logical_expr::{LogicalExpr, UnaryOp}, + }, + compiler::Compiler, + filter::{CompiledExpr, CompiledValueExpr, CompiledValueResult}, + functions::{ + ExactSizeChain, FunctionArgs, FunctionDefinition, FunctionDefinitionContext, FunctionParam, + FunctionParamError, + }, + lex::{expect, skip_space, span, Lex, LexError, LexErrorKind, LexResult, LexWith}, + lhs_types::Array, + scheme::Function, + types::{GetType, LhsValue, RhsValue, Type}, }; use serde::Serialize; +use std::hash::{Hash, Hasher}; +use std::iter::once; -#[derive(Debug, PartialEq, Eq, Clone, Serialize)] +/// FunctionCallArgExpr is a function argument. It can be a sub-expression with +/// [`LogicalExpr`], a field with [`IndexExpr`] or a literal with [`Literal`]. +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] #[serde(tag = "kind", content = "value")] -pub(crate) enum FunctionCallArgExpr<'s> { - LhsFieldExpr(LhsFieldExpr<'s>), +pub enum FunctionCallArgExpr<'s> { + /// IndexExpr is a field that supports the indexing operator. + IndexExpr(IndexExpr<'s>), + /// A Literal. Literal(RhsValue), + /// LogicalExpr is a sub-expression which can evaluate to either true/false + /// or a list of true/false. It compiles to a CompiledExpr and is coerced + /// into a CompiledValueExpr. + // Renaming is necessary for backward compability. + #[serde(rename = "SimpleExpr")] + Logical(LogicalExpr<'s>), } -impl<'s> FunctionCallArgExpr<'s> { - pub fn uses(&self, field: Field<'s>) -> bool { +impl<'s> ValueExpr<'s> for FunctionCallArgExpr<'s> { + #[inline] + fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V) { match self { - FunctionCallArgExpr::LhsFieldExpr(lhs) => lhs.uses(field), - FunctionCallArgExpr::Literal(_) => false, + FunctionCallArgExpr::IndexExpr(index_expr) => visitor.visit_index_expr(index_expr), + FunctionCallArgExpr::Literal(_) => {} + FunctionCallArgExpr::Logical(logical_expr) => visitor.visit_logical_expr(logical_expr), } } - pub fn execute(&'s self, ctx: &'s ExecutionContext<'s>) -> LhsValue<'s> { + #[inline] + fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V) { match self { - FunctionCallArgExpr::LhsFieldExpr(lhs) => match lhs { - LhsFieldExpr::Field(field) => ctx.get_field_value_unchecked(*field), - LhsFieldExpr::FunctionCallExpr(call) => call.execute(ctx), - }, - FunctionCallArgExpr::Literal(literal) => literal.into(), + FunctionCallArgExpr::IndexExpr(index_expr) => visitor.visit_index_expr(index_expr), + FunctionCallArgExpr::Literal(_) => {} + FunctionCallArgExpr::Logical(logical_expr) => visitor.visit_logical_expr(logical_expr), + } + } + + fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledValueExpr<'s, C::U> { + match self { + FunctionCallArgExpr::IndexExpr(index_expr) => compiler.compile_index_expr(index_expr), + FunctionCallArgExpr::Literal(literal) => { + CompiledValueExpr::new(move |_| LhsValue::from(literal.clone()).into()) + } + // The function argument is an expression compiled as either an + // CompiledExpr::One or CompiledExpr::Vec. + // Here we execute the expression to get the actual argument + // for the function and forward the result in a CompiledValueExpr. + FunctionCallArgExpr::Logical(logical_expr) => { + let compiled_expr = compiler.compile_logical_expr(logical_expr); + match compiled_expr { + CompiledExpr::One(expr) => { + CompiledValueExpr::new(move |ctx| LhsValue::from(expr.execute(ctx)).into()) + } + CompiledExpr::Vec(expr) => CompiledValueExpr::new(move |ctx| { + let result = expr.execute(ctx); + LhsValue::Array(result.into()).into() + }), + } + } } } } -struct SchemeFunctionParam<'s, 'a> { - scheme: &'s Scheme, - param: &'a FunctionParam, - index: usize, +impl<'s> FunctionCallArgExpr<'s> { + pub(crate) fn map_each_count(&self) -> usize { + match self { + FunctionCallArgExpr::IndexExpr(index_expr) => index_expr.map_each_count(), + FunctionCallArgExpr::Literal(_) => 0, + FunctionCallArgExpr::Logical(_) => 0, + } + } + + #[allow(dead_code)] + pub(crate) fn simplify(self) -> Self { + match self { + FunctionCallArgExpr::Logical(LogicalExpr::Comparison(ComparisonExpr { + lhs, + op: ComparisonOpExpr::IsTrue, + })) => FunctionCallArgExpr::IndexExpr(lhs), + _ => self, + } + } } -impl<'i, 's, 'a> LexWith<'i, SchemeFunctionParam<'s, 'a>> for FunctionCallArgExpr<'s> { - fn lex_with(input: &'i str, ctx: SchemeFunctionParam<'s, 'a>) -> LexResult<'i, Self> { - let initial_input = input; +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for FunctionCallArgExpr<'s> { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let _initial_input = input; - match ctx.param.arg_kind { - FunctionArgKind::Field => { - let (lhs, input) = LhsFieldExpr::lex_with(input, ctx.scheme)?; - if lhs.get_type() != ctx.param.val_type { - Err(( - LexErrorKind::InvalidArgumentType { - index: ctx.index, - mismatch: TypeMismatchError { - actual: lhs.get_type(), - expected: ctx.param.val_type, - }, - }, - span(initial_input, input), - )) + macro_rules! c_is_field { + // characters above F/f in the alphabet mean it can't be a decimal or hex int + ($c:expr) => { + (($c.is_ascii_alphanumeric() && !$c.is_ascii_hexdigit()) || $c == '_') + }; + } + + macro_rules! c_is_field_or_int { + ($c:expr) => { + ($c.is_ascii_alphanumeric() || $c == '_') + }; + } + + // Grammar is ambiguous but lets try to parse the tokens we can be sure of + // This will provide better error reporting in most cases + let mut chars = input.chars(); + if let Some(c) = chars.next() { + // check up to 3 next chars because third char of a hex-string is either ':' + // or '-' + let c2 = chars.next(); + let c3 = chars.next(); + if c == '"' || (c == 'r' && (c2 == Some('#') || c2 == Some('"'))) { + return RhsValue::lex_with(input, Type::Bytes) + .map(|(literal, input)| (FunctionCallArgExpr::Literal(literal), input)); + } else if c == '(' || UnaryOp::lex(input).is_ok() { + return LogicalExpr::lex_with(input, parser) + .map(|(lhs, input)| (FunctionCallArgExpr::Logical(lhs), input)); + } else if c_is_field!(c) + || (c_is_field_or_int!(c) && c2.is_some() && c_is_field!(c2.unwrap())) + || (c_is_field_or_int!(c) + && c2.is_some() + && c_is_field_or_int!(c2.unwrap()) + && c3.is_some() + && c_is_field!(c3.unwrap())) + { + let (lhs, input) = IndexExpr::lex_with(input, parser)?; + let lookahead = skip_space(input); + if ComparisonOp::lex(lookahead).is_ok() { + return ComparisonExpr::lex_with_lhs(input, parser, lhs).map(|(op, input)| { + ( + FunctionCallArgExpr::Logical(LogicalExpr::Comparison(op)), + input, + ) + }); } else { - Ok((FunctionCallArgExpr::LhsFieldExpr(lhs), input)) + return Ok((FunctionCallArgExpr::IndexExpr(lhs), input)); } } - FunctionArgKind::Literal => { - let (rhs_value, input) = RhsValue::lex_with(input, ctx.param.val_type)?; - Ok((FunctionCallArgExpr::Literal(rhs_value), input)) + } + + // Fallback to blind parsing next argument + if let Ok((lhs, input)) = IndexExpr::lex_with(input, parser) { + let lookahead = skip_space(input); + if ComparisonOp::lex(lookahead).is_ok() { + return ComparisonExpr::lex_with_lhs(input, parser, lhs).map(|(op, input)| { + ( + FunctionCallArgExpr::Logical(LogicalExpr::Comparison(op)), + input, + ) + }); + } else { + return Ok((FunctionCallArgExpr::IndexExpr(lhs), input)); } } - } -} -#[derive(Debug, PartialEq, Eq, Clone, Serialize)] -pub(crate) struct FunctionCallExpr<'s> { - pub name: String, - #[serde(skip)] - pub function: &'s Function, - pub args: Vec>, + RhsValue::lex_with(input, Type::Ip) + .map(|(literal, input)| (FunctionCallArgExpr::Literal(literal), input)) + .or_else(|_| { + RhsValue::lex_with(input, Type::Int) + .map(|(literal, input)| (FunctionCallArgExpr::Literal(literal), input)) + }) + // try to parse Bytes after Int because digit literals < 255 are wrongly + // interpreted as Bytes + .or_else(|_| { + RhsValue::lex_with(input, Type::Bytes) + .map(|(literal, input)| (FunctionCallArgExpr::Literal(literal), input)) + }) + .map_err(|_| (LexErrorKind::EOF, _initial_input)) + } } -impl<'s> FunctionCallExpr<'s> { - pub fn new(name: &str, function: &'s Function) -> Self { - Self { - name: name.into(), - function, - args: Vec::default(), +impl<'s> GetType for FunctionCallArgExpr<'s> { + fn get_type(&self) -> Type { + match self { + FunctionCallArgExpr::IndexExpr(index_expr) => index_expr.get_type(), + FunctionCallArgExpr::Literal(literal) => literal.get_type(), + FunctionCallArgExpr::Logical(logical_expr) => logical_expr.get_type(), } } +} - pub fn uses(&self, field: Field<'s>) -> bool { - self.args.iter().any(|arg| arg.uses(field)) +impl<'a, 's> From<&'a FunctionCallArgExpr<'s>> for FunctionParam<'a> { + fn from(arg_expr: &'a FunctionCallArgExpr<'s>) -> Self { + match arg_expr { + FunctionCallArgExpr::IndexExpr(expr) => FunctionParam::Variable(expr.get_type()), + FunctionCallArgExpr::Logical(expr) => FunctionParam::Variable(expr.get_type()), + FunctionCallArgExpr::Literal(value) => FunctionParam::Constant(value), + } } +} - pub fn execute(&self, ctx: &'s ExecutionContext<'s>) -> LhsValue<'_> { - self.function.implementation.execute( - self.args.iter().map(|arg| arg.execute(ctx)).chain( - self.function.opt_params[self.args.len() - self.function.params.len()..] - .iter() - .map(|opt_arg| opt_arg.default_value.as_ref()), - ), - ) - } +/// FunctionCallExpr represents a function call expression. +#[derive(Clone, Debug, Serialize)] +pub struct FunctionCallExpr<'s> { + #[serde(rename = "name")] + pub(crate) function: Function<'s>, + pub(crate) args: Vec>, + #[serde(skip)] + pub(crate) context: Option, } -fn invalid_args_count<'i>(function: &Function, input: &'i str) -> LexError<'i> { - ( - LexErrorKind::InvalidArgumentsCount { - expected_min: function.params.len(), - expected_max: function.params.len() + function.opt_params.len(), - }, - input, - ) +impl<'s> PartialEq for FunctionCallExpr<'s> { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.function == other.function && self.args == other.args + } } -impl<'i, 's> LexWith<'i, &'s Scheme> for FunctionCallExpr<'s> { - fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { - let initial_input = input; +impl<'s> Eq for FunctionCallExpr<'s> {} - let (name, mut input) = take_while(input, "function character", |c| { - c.is_ascii_alphanumeric() || c == '_' - })?; +impl<'s> Hash for FunctionCallExpr<'s> { + #[inline] + fn hash(&self, state: &mut H) { + self.function.hash(state); + self.args.hash(state); + } +} - input = skip_space(input); +impl<'s> ValueExpr<'s> for FunctionCallExpr<'s> { + #[inline] + fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V) { + self.args + .iter() + .for_each(|arg| visitor.visit_function_call_arg_expr(arg)); + visitor.visit_function(&self.function) + } - input = expect(input, "(")?; + #[inline] + fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V) { + self.args + .iter_mut() + .for_each(|arg| visitor.visit_function_call_arg_expr(arg)); + visitor.visit_function(&self.function) + } - input = skip_space(input); + fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledValueExpr<'s, C::U> { + let return_type = self.return_type(); - let function = scheme - .get_function(name) - .map_err(|err| (LexErrorKind::UnknownFunction(err), initial_input))?; + let Self { + function, + args, + context, + .. + } = self; + let map_each_count = args.first().map_or(0, |arg| arg.map_each_count()); + let call = function + .as_definition() + .compile(&mut args.iter().map(|arg| arg.into()), context); + let mut args = args + .into_iter() + .map(|arg| compiler.compile_function_call_arg_expr(arg)) + .collect::>(); - let mut function_call = FunctionCallExpr::new(name, function); + if map_each_count > 0 { + let first = args.remove(0); - for i in 0..function.params.len() { - if i == 0 { - if take(input, 1)?.0 == ")" { - break; + #[inline(always)] + fn compute<'s, 'a, I: ExactSizeIterator>>( + first: CompiledValueResult<'a>, + call: &(dyn for<'b> Fn(FunctionArgs<'_, 'b>) -> Option> + + Sync + + Send + + 's), + return_type: Type, + f: impl Fn(LhsValue<'a>) -> I, + ) -> CompiledValueResult<'a> { + let mut first = match first { + Ok(first) => first, + Err(_) => { + return Err(Type::Array(return_type.into())); + } + }; + // Extract the values of the map + if let LhsValue::Map(map) = first { + first = LhsValue::Array( + Array::try_from_iter(map.value_type(), map.values_into_iter()).unwrap(), + ); + } + // Retrieve the underlying `Array` + let mut first = match first { + LhsValue::Array(arr) => arr, + _ => unreachable!(), + }; + if !first.is_empty() { + first = first.filter_map_to(return_type, |elem| call(&mut f(elem))); } + Ok(LhsValue::Array(first)) + } + + if args.is_empty() { + CompiledValueExpr::new(move |ctx| { + compute( + first.execute(ctx), + &call, + return_type, + #[inline] + |elem| once(Ok(elem)), + ) + }) } else { - input = expect(input, ",") - .map_err(|(_, input)| invalid_args_count(&function, input))?; + CompiledValueExpr::new(move |ctx| { + compute( + first.execute(ctx), + &call, + return_type, + #[inline] + |elem| { + ExactSizeChain::new( + once(Ok(elem)), + args.iter().map(|arg| arg.execute(ctx)), + ) + }, + ) + }) } + } else { + CompiledValueExpr::new(move |ctx| { + if let Some(value) = call(&mut args.iter().map(|arg| arg.execute(ctx))) { + debug_assert!(value.get_type() == return_type); + Ok(value) + } else { + Err(return_type) + } + }) + } + } +} - input = skip_space(input); +impl<'s> FunctionCallExpr<'s> { + pub(crate) fn new( + function: Function<'s>, + args: Vec>, + context: Option, + ) -> Self { + Self { + function, + args, + context, + } + } - let arg = FunctionCallArgExpr::lex_with( - input, - SchemeFunctionParam { - scheme, - param: &function.params[i], - index: i, - }, - )?; + pub(crate) fn lex_with_function<'i>( + input: &'i str, + parser: &FilterParser<'s>, + function: Function<'s>, + ) -> LexResult<'i, Self> { + let definition = function.as_definition(); - function_call.args.push(arg.0); + let mut input = skip_space(input); - input = skip_space(arg.1); - } + input = expect(input, "(")?; - if function_call.args.len() != function.params.len() { - return Err(invalid_args_count(&function, input)); - } + input = skip_space(input); + + let (mandatory_arg_count, optional_arg_count) = definition.arg_count(); + + let mut args: Vec> = Vec::new(); let mut index = 0; + let mut ctx = definition.context(); + while let Some(c) = input.chars().next() { if c == ')' { break; } - // ',' is expected only if the current optional argument + // ',' is expected only if the current argument // is not the first one in the list of specified arguments. - if !function_call.args.is_empty() { + if index != 0 { input = expect(input, ",")?; } input = skip_space(input); - let opt_param = function - .opt_params - .get(index) - .ok_or_else(|| invalid_args_count(&function, input))?; + let (arg, rest) = FunctionCallArgExpr::lex_with(input, parser)?; - let param = FunctionParam { - arg_kind: opt_param.arg_kind.clone(), - val_type: opt_param.default_value.get_type(), - }; + // Mapping is only accepted for the first argument + // of a function call for code simplicity + if arg.map_each_count() > 0 && index != 0 { + return Err((LexErrorKind::InvalidMapEachAccess, span(input, rest))); + } - let (arg, rest) = FunctionCallArgExpr::lex_with( - input, - SchemeFunctionParam { - scheme, - param: ¶m, - index: function.params.len() + index, - }, - )?; + let next_param = (&arg).into(); + + if optional_arg_count.is_some() + && index >= (mandatory_arg_count + optional_arg_count.unwrap()) + { + return Err(invalid_args_count(definition, input)); + } - function_call.args.push(arg); + definition + .check_param( + &mut args.iter().map(|arg| arg.into()), + &next_param, + ctx.as_mut(), + ) + .map_err(|err| match err { + FunctionParamError::KindMismatch(err) => ( + LexErrorKind::InvalidArgumentKind { + index, + mismatch: err, + }, + span(input, rest), + ), + FunctionParamError::TypeMismatch(err) => ( + LexErrorKind::InvalidArgumentType { + index, + mismatch: err, + }, + span(input, rest), + ), + FunctionParamError::InvalidConstant(err) => ( + LexErrorKind::InvalidArgumentValue { + index, + invalid: err, + }, + span(input, rest), + ), + })?; + + args.push(arg); input = skip_space(rest); index += 1; } + if args.len() < mandatory_arg_count { + return Err(invalid_args_count(definition, input)); + } + input = expect(input, ")")?; + let function_call = FunctionCallExpr::new(function, args, ctx); + Ok((function_call, input)) } + + /// Returns the function being called. + #[inline] + pub fn function(&self) -> Function<'s> { + self.function + } + + /// Returns the arguments being passed to the function. + #[inline] + pub fn args(&self) -> &[FunctionCallArgExpr<'s>] { + &self.args[..] + } + + /// Returns the return type of the function call expression. + #[inline] + pub fn return_type(&self) -> Type { + self.function.as_definition().return_type( + &mut self.args.iter().map(|arg| arg.into()), + self.context.as_ref(), + ) + } +} + +fn invalid_args_count<'i>(function: &dyn FunctionDefinition, input: &'i str) -> LexError<'i> { + let (mandatory, optional) = function.arg_count(); + ( + LexErrorKind::InvalidArgumentsCount { + expected_min: mandatory, + expected_max: optional.map(|v| mandatory + v), + }, + input, + ) +} + +impl<'s> GetType for FunctionCallExpr<'s> { + fn get_type(&self) -> Type { + if !self.args.is_empty() && self.args[0].map_each_count() > 0 { + Type::Array(self.return_type().into()) + } else { + self.return_type() + } + } } -#[test] -fn test_function() { +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for FunctionCallExpr<'s> { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let (function, rest) = Function::lex_with(input, parser.scheme)?; + + Self::lex_with_function(rest, parser, function) + } +} + +#[cfg(test)] +mod tests { + use super::*; use crate::{ - functions::{FunctionArgs, FunctionImpl, FunctionOptParam}, - scheme::UnknownFieldError, - types::Type, + ast::{ + field_expr::{ComparisonExpr, ComparisonOpExpr, IdentifierExpr, OrderingOp}, + logical_expr::{LogicalExpr, LogicalOp, ParenthesizedExpr}, + parse::FilterParser, + }, + functions::{ + FunctionArgKind, FunctionArgKindMismatchError, FunctionArgs, SimpleFunctionDefinition, + SimpleFunctionImpl, SimpleFunctionOptParam, SimpleFunctionParam, + }, + rhs_types::{Bytes, BytesFormat}, + scheme::{FieldIndex, IndexAccessError, Scheme}, + types::{RhsValues, Type, TypeMismatchError}, }; - use lazy_static::lazy_static; + use std::convert::TryFrom; + use std::sync::LazyLock; - fn echo_function<'a>(args: FunctionArgs<'_, 'a>) -> LhsValue<'a> { - args.next().unwrap() + fn any_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + match args.next()? { + Ok(v) => Some(LhsValue::Bool( + Array::try_from(v) + .unwrap() + .into_iter() + .any(|lhs| bool::try_from(lhs).unwrap()), + )), + Err(Type::Array(ref arr)) if arr.get_type() == Type::Bool => None, + _ => unreachable!(), + } } - lazy_static! { - static ref SCHEME: Scheme = { - let mut scheme = Scheme! { - http.host: Bytes, - ip.addr: Ip, - ssl: Bool, - tcp.port: Int, - }; - scheme - .add_function( - "echo".into(), - Function { - params: vec![FunctionParam { + fn regex_replace<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + args.next()?.ok() + } + + fn lower_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + use std::borrow::Cow; + + match args.next()? { + Ok(LhsValue::Bytes(mut b)) => { + let mut text: Vec = b.to_mut().to_vec(); + text.make_ascii_lowercase(); + Some(LhsValue::Bytes(Cow::Owned(text))) + } + Err(Type::Bytes) => None, + _ => unreachable!(), + } + } + + fn echo_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + args.next()?.ok() + } + + fn len_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + match args.next()? { + Ok(LhsValue::Bytes(bytes)) => Some(LhsValue::Int(i64::try_from(bytes.len()).unwrap())), + Err(Type::Bytes) => None, + _ => unreachable!(), + } + } + + static SCHEME: LazyLock = LazyLock::new(|| { + let mut scheme = Scheme! { + http.headers: Map(Bytes), + http.host: Bytes, + http.request.headers.names: Array(Bytes), + http.request.headers.values: Array(Bytes), + http.request.headers.is_empty: Array(Bool), + ip.addr: Ip, + ssl: Bool, + tcp.port: Int, + }; + scheme + .add_function( + "any", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Array(Type::Bool.into()), + }], + opt_params: vec![], + return_type: Type::Bool, + implementation: SimpleFunctionImpl::new(any_function), + }, + ) + .unwrap(); + scheme + .add_function( + "echo", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![ + SimpleFunctionOptParam { + arg_kind: FunctionArgKind::Literal, + default_value: LhsValue::Int(10), + }, + SimpleFunctionOptParam { + arg_kind: FunctionArgKind::Literal, + default_value: LhsValue::Int(1), + }, + ], + return_type: Type::Bytes, + implementation: SimpleFunctionImpl::new(echo_function), + }, + ) + .unwrap(); + scheme + .add_function( + "lower", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Bytes, + implementation: SimpleFunctionImpl::new(lower_function), + }, + ) + .unwrap(); + scheme + .add_function( + "regex_replace", + SimpleFunctionDefinition { + params: vec![ + SimpleFunctionParam { arg_kind: FunctionArgKind::Field, val_type: Type::Bytes, - }], - opt_params: vec![FunctionOptParam { + }, + SimpleFunctionParam { arg_kind: FunctionArgKind::Literal, - default_value: LhsValue::Int(10), - }], - return_type: Type::Bytes, - implementation: FunctionImpl::new(echo_function), + val_type: Type::Bytes, + }, + SimpleFunctionParam { + arg_kind: FunctionArgKind::Literal, + val_type: Type::Bytes, + }, + ], + opt_params: vec![], + return_type: Type::Bool, + implementation: SimpleFunctionImpl::new(regex_replace), + }, + ) + .unwrap(); + scheme + .add_function( + "len", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Int, + implementation: SimpleFunctionImpl::new(len_function), + }, + ) + .unwrap(); + scheme + }); + + #[test] + fn test_lex_function_call_expr() { + // test that adjacent single digit int literals are parsed properly + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"echo ( http.host, 1, 2 );"#), + FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("http.host").unwrap()), + indexes: vec![], + }), + FunctionCallArgExpr::Literal(RhsValue::Int(1)), + FunctionCallArgExpr::Literal(RhsValue::Int(2)), + ], + context: None, + }, + ";" + ); + + assert_eq!(expr.return_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bytes); + + assert_json!( + expr, + { + "name": "echo", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" }, - ) - .unwrap(); - scheme - }; - } + { + "kind": "Literal", + "value": 1 + }, + { + "kind": "Literal", + "value": 2 + } + ] + } + ); - let expr = assert_ok!( - FunctionCallExpr::lex_with("echo ( http.host );", &SCHEME), - FunctionCallExpr { - name: String::from("echo"), - function: SCHEME.get_function("echo").unwrap(), - args: vec![FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field( - SCHEME.get_field_index("http.host").unwrap() - ))], - }, - ";" - ); - - assert_json!( - expr, - { - "name": "echo", - "args": [ - { - "kind": "LhsFieldExpr", - "value": "http.host" + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("echo ( http.host );"), + FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("http.host").unwrap()), + indexes: vec![], + })], + context: None, + }, + ";" + ); + + assert_eq!(expr.return_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bytes); + + assert_json!( + expr, + { + "name": "echo", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + } + ] + } + ); + + // test that adjacent single digit int literals are parsed properly (without spaces) + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"echo (http.host,1,2);"#), + FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("http.host").unwrap()), + indexes: vec![], + }), + FunctionCallArgExpr::Literal(RhsValue::Int(1)), + FunctionCallArgExpr::Literal(RhsValue::Int(2)), + ], + context: None, + }, + ";" + ); + + assert_eq!(expr.return_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bytes); + + assert_json!( + expr, + { + "name": "echo", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + }, + { + "kind": "Literal", + "value": 1 + }, + { + "kind": "Literal", + "value": 2 + } + ] + } + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>("echo ( );"), + LexErrorKind::InvalidArgumentsCount { + expected_min: 1, + expected_max: Some(3), + }, + ");" + ); + + assert_err!( + FilterParser::new(&SCHEME) + .lex_as::>("echo ( http.host , http.host );"), + LexErrorKind::InvalidArgumentKind { + index: 1, + mismatch: FunctionArgKindMismatchError { + actual: FunctionArgKind::Field, + expected: FunctionArgKind::Literal, } - ] - } - ); + }, + "http.host" + ); - assert_err!( - FunctionCallExpr::lex_with("echo ( );", &SCHEME), - LexErrorKind::InvalidArgumentsCount { - expected_min: 1, - expected_max: 2 - }, - ");" - ); - - assert_err!( - FunctionCallExpr::lex_with("echo ( http.host , http.host );", &SCHEME), - LexErrorKind::ExpectedName("digit"), - "http.host );" - ); - - let expr = assert_ok!( - FunctionCallExpr::lex_with("echo ( echo ( http.host ) );", &SCHEME), - FunctionCallExpr { - name: String::from("echo"), - function: SCHEME.get_function("echo").unwrap(), - args: [FunctionCallArgExpr::LhsFieldExpr( - LhsFieldExpr::FunctionCallExpr(FunctionCallExpr { - name: String::from("echo"), - function: SCHEME.get_function("echo").unwrap(), - args: vec![FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field( - SCHEME.get_field_index("http.host").unwrap() - ))], - }) - )] - .to_vec(), - }, - ";" - ); - - assert_json!( - expr, - { - "name": "echo", - "args": [ - { - "kind": "LhsFieldExpr", - "value": { - "name": "echo", - "args": [ - { - "kind": "LhsFieldExpr", - "value": "http.host" - } - ] + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("echo ( echo ( http.host ) );"), + FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: [FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME.get_field("http.host").unwrap() + ), + indexes: vec![], + })], + context: None, + }), + indexes: vec![], + })] + .to_vec(), + context: None, + }, + ";" + ); + + assert_eq!(expr.return_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Bytes); + + assert_json!( + expr, + { + "name": "echo", + "args": [ + { + "kind": "IndexExpr", + "value": { + "name": "echo", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + } + ] + } + } + ] + } + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as( + r#"any ( ( http.request.headers.is_empty or http.request.headers.is_empty ) )"# + ), + FunctionCallExpr { + function: SCHEME.get_function("any").unwrap(), + args: vec![FunctionCallArgExpr::Logical(LogicalExpr::Parenthesized( + Box::new(ParenthesizedExpr { + expr: LogicalExpr::Combining { + op: LogicalOp::Or, + items: vec![ + LogicalExpr::Comparison(ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME + .get_field("http.request.headers.is_empty") + .unwrap() + ), + indexes: vec![], + }, + op: ComparisonOpExpr::IsTrue, + }), + LogicalExpr::Comparison(ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME + .get_field("http.request.headers.is_empty") + .unwrap() + ), + indexes: vec![], + }, + op: ComparisonOpExpr::IsTrue, + }) + ] + } + }) + ))], + context: None, + }, + "" + ); + + assert_eq!(expr.return_type(), Type::Bool); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "name": "any", + "args": [ + { + "kind": "SimpleExpr", + "value": { + "items": [ + { + "lhs": "http.request.headers.is_empty", + "op": "IsTrue", + }, + { + "lhs": "http.request.headers.is_empty", + "op": "IsTrue", + } + ], + "op": "Or", + } + } + ] + } + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("echo ( http.request.headers.names[*] );"), + FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME.get_field("http.request.headers.names").unwrap() + ), + indexes: vec![FieldIndex::MapEach], + })], + context: None, + }, + ";" + ); + + assert_eq!(expr.return_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Array(Type::Bytes.into())); + + assert_json!( + expr, + { + "name": "echo", + "args": [ + { + "kind": "IndexExpr", + "value": ["http.request.headers.names", {"kind": "MapEach"}], + } + ] + } + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("echo ( http.headers[*] );"), + FunctionCallExpr { + function: SCHEME.get_function("echo").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("http.headers").unwrap()), + indexes: vec![FieldIndex::MapEach], + })], + context: None, + }, + ";" + ); + + assert_eq!(expr.return_type(), Type::Bytes); + assert_eq!(expr.get_type(), Type::Array(Type::Bytes.into())); + + assert_json!( + expr, + { + "name": "echo", + "args": [ + { + "kind": "IndexExpr", + "value": ["http.headers", {"kind": "MapEach"}], } + ] + } + ); + + assert_ok!( + FilterParser::new(&SCHEME).lex_as("http.request.headers.names[*] == \"test\""), + FunctionCallArgExpr::Logical(LogicalExpr::Comparison(ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME.get_field("http.request.headers.names").unwrap() + ), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("test".to_owned().into()) } - ] - } - ); - - assert_err!( - FunctionCallExpr::lex_with("echo ( \"test\" );", &SCHEME), - LexErrorKind::ExpectedName("identifier character"), - "\"test\" );" - ); - - assert_err!( - FunctionCallExpr::lex_with("echo ( 10 );", &SCHEME), - LexErrorKind::UnknownField(UnknownFieldError), - "10" - ); - - assert_err!( - FunctionCallExpr::lex_with("echo ( ip.addr );", &SCHEME), - LexErrorKind::InvalidArgumentType { - index: 0, - mismatch: TypeMismatchError { - actual: Type::Ip, - expected: Type::Bytes, + })), + "" + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME) + .lex_as("any(lower(http.request.headers.names[*])[*] contains \"c\")"), + FunctionCallExpr { + function: SCHEME.get_function("any").unwrap(), + args: vec![FunctionCallArgExpr::Logical(LogicalExpr::Comparison( + ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("lower").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME.get_field("http.request.headers.names").unwrap() + ), + indexes: vec![FieldIndex::MapEach], + })], + context: None, + }), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::Contains("c".to_string().into(),) + } + ))], + context: None, + }, + "" + ); + + assert_eq!(expr.return_type(), Type::Bool); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "args": [ + { + "kind": "SimpleExpr", + "value": { + "lhs": [ + { + "args": [ + { + "kind": "IndexExpr", + "value": ["http.request.headers.names", {"kind": "MapEach"}] + } + ], + "name": "lower" + },{ + "kind": "MapEach" + } + ], + "op": "Contains", + "rhs": "c" + } + } + ], + "name": "any" } - }, - "ip.addr" - ); + ); - assert_err!( - FunctionCallExpr::lex_with("echo ( http.host, 10, \"test\" );", &SCHEME), - LexErrorKind::InvalidArgumentsCount { - expected_min: 1, - expected_max: 2, - }, - "\"test\" );" - ); + let expr = FunctionCallArgExpr::lex_with("lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(lower(http.host)))))))))))))))))))))))))))))))) contains \"c\"", &FilterParser::new(&SCHEME)); + assert!(expr.is_ok()); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("len(http.request.headers.names[*])"), + FunctionCallExpr { + function: SCHEME.get_function("len").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME.get_field("http.request.headers.names").unwrap() + ), + indexes: vec![FieldIndex::MapEach], + })], + context: None, + }, + "" + ); + + assert_eq!(expr.args[0].map_each_count(), 1); + assert_eq!(expr.return_type(), Type::Int); + assert_eq!(expr.get_type(), Type::Array(Type::Int.into())); + } + + #[test] + fn test_lex_function_with_unary_expression_as_argument() { + let expr = assert_ok!( + FilterParser::new(&SCHEME) + .lex_as("any(not(http.request.headers.names[*] in {\"Cookie\" \"Cookies\"}))"), + FunctionCallExpr { + function: SCHEME.get_function("any").unwrap(), + args: vec![FunctionCallArgExpr::Logical(LogicalExpr::Unary { + op: UnaryOp::Not, + arg: Box::new(LogicalExpr::Parenthesized(Box::new(ParenthesizedExpr { + expr: LogicalExpr::Comparison(ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME.get_field("http.request.headers.names").unwrap() + ), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::OneOf(RhsValues::Bytes(vec![ + "Cookie".to_owned().into(), + "Cookies".to_owned().into(), + ])), + }) + },))) + })], + context: None, + }, + "" + ); + + assert_eq!(expr.return_type(), Type::Bool); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "name": "any", + "args": [ + { + "kind": "SimpleExpr", + "value": { + "op": "Not", + "arg": { + "lhs": [ + "http.request.headers.names", + { + "kind": "MapEach" + } + ], + "op": "OneOf", + "rhs": [ + "Cookie", + "Cookies" + ] + } + } + } + ] + } + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME) + .lex_as("any(!(http.request.headers.names[*] in {\"Cookie\" \"Cookies\"}))"), + FunctionCallExpr { + function: SCHEME.get_function("any").unwrap(), + args: vec![FunctionCallArgExpr::Logical(LogicalExpr::Unary { + op: UnaryOp::Not, + arg: Box::new(LogicalExpr::Parenthesized(Box::new(ParenthesizedExpr { + expr: LogicalExpr::Comparison(ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field( + SCHEME.get_field("http.request.headers.names").unwrap() + ), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::OneOf(RhsValues::Bytes(vec![ + "Cookie".to_owned().into(), + "Cookies".to_owned().into(), + ])), + }) + },))) + })], + context: None, + }, + "" + ); + + assert_eq!(expr.return_type(), Type::Bool); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "name": "any", + "args": [ + { + "kind": "SimpleExpr", + "value": { + "op": "Not", + "arg": { + "lhs": [ + "http.request.headers.names", + { + "kind": "MapEach" + } + ], + "op": "OneOf", + "rhs": [ + "Cookie", + "Cookies" + ] + } + } + } + ] + } + ); + } + + #[test] + fn test_lex_function_call_raw_string() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("regex_replace(http.host, r\"this is a r##raw## string\", r\"this is a new r##raw## string\") eq \"test\""), + FunctionCallExpr { + function: SCHEME.get_function("regex_replace").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("http.host").unwrap()), + indexes: vec![], + }), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::new("this is a r##raw## string".as_bytes(), BytesFormat::Raw(0)))), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::new("this is a new r##raw## string".as_bytes(), BytesFormat::Raw(0)))) + ], + context: None, + }, + " eq \"test\"" + ); + + assert_eq!(expr.return_type(), Type::Bool); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "name": "regex_replace", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + }, + { + "kind": "Literal", + "value": "this is a r##raw## string" + }, + { + "kind": "Literal", + "value": "this is a new r##raw## string" + } + ] + } + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as("regex_replace(http.host, r###\"this is a r##\"raw\"## string\"###, r###\"this is a new r##\"raw\"## string\"###) eq \"test\""), + FunctionCallExpr { + function: SCHEME.get_function("regex_replace").unwrap(), + args: vec![ + FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("http.host").unwrap()), + indexes: vec![], + }), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::new("this is a r##\"raw\"## string".as_bytes(), BytesFormat::Raw(3)))), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::new("this is a new r##\"raw\"## string".as_bytes(), BytesFormat::Raw(3)))) + ], + context: None, + }, + " eq \"test\"" + ); + + assert_eq!(expr.return_type(), Type::Bool); + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "name": "regex_replace", + "args": [ + { + "kind": "IndexExpr", + "value": "http.host" + }, + { + "kind": "Literal", + "value": "this is a r##\"raw\"## string" + }, + { + "kind": "Literal", + "value": "this is a new r##\"raw\"## string" + } + ] + } + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>( + "regex_replace(http.host, r#\"a\", \"b\") eq \"c\"" + ), + LexErrorKind::MissingEndingQuote {}, + "#\"a\", \"b\") eq \"c\"" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>( + "regex_replace(http.host, r\"a\"#, \"b\") eq \"c\"" + ), + LexErrorKind::ExpectedLiteral(","), + "#, \"b\") eq \"c\"" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>( + "regex_replace(http.host, r##\"a\"#, \"b\") eq \"c\"" + ), + LexErrorKind::MissingEndingQuote {}, + "##\"a\"#, \"b\") eq \"c\"" + ); + } + + #[test] + fn test_lex_function_call_expr_failure() { + assert_err!( + FilterParser::new(&SCHEME).lex_as::>("echo ( \"test\" );"), + LexErrorKind::InvalidArgumentKind { + index: 0, + mismatch: FunctionArgKindMismatchError { + actual: FunctionArgKind::Literal, + expected: FunctionArgKind::Field, + } + }, + "\"test\"" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>("echo ( 10 );"), + LexErrorKind::InvalidArgumentKind { + index: 0, + mismatch: FunctionArgKindMismatchError { + actual: FunctionArgKind::Literal, + expected: FunctionArgKind::Field, + } + }, + "10" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>("echo ( ip.addr );"), + LexErrorKind::InvalidArgumentType { + index: 0, + mismatch: TypeMismatchError { + actual: Type::Ip, + expected: Type::Bytes.into(), + } + }, + "ip.addr" + ); + + assert_err!( + FilterParser::new(&SCHEME) + .lex_as::>("echo ( http.host, 10, 2, \"test\" );"), + LexErrorKind::InvalidArgumentsCount { + expected_min: 1, + expected_max: Some(3), + }, + "\"test\" );" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>("echo ( http.test );"), + LexErrorKind::UnknownIdentifier, + "http.test" + ); + + assert_err!( + FilterParser::new(&SCHEME) + .lex_as::>("echo ( echo ( http.test ) );"), + LexErrorKind::UnknownIdentifier, + "http.test" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>("echo ( http.host[*] );"), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Bytes, + }), + "[*]" + ); + + assert_err!( + FilterParser::new(&SCHEME) + .lex_as::>("echo ( http.request.headers.names[0][*] );"), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Bytes, + }), + "[*]" + ); + + assert_err!( + FilterParser::new(&SCHEME) + .lex_as::>("echo ( http.request.headers.names[*][0] );"), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::ArrayIndex(0), + actual: Type::Bytes, + }), + "[0]" + ); + + assert_err!( + FilterParser::new(&SCHEME) + .lex_as::>("echo ( http.headers[*][\"host\"] );"), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapKey("host".to_string()), + actual: Type::Bytes, + }), + "[\"host\"]" + ); + + assert_err!( + FilterParser::new(&SCHEME) + .lex_as::>("echo ( http.host, http.headers[*] );"), + LexErrorKind::InvalidMapEachAccess, + "http.headers[*]" + ); + } } diff --git a/engine/src/ast/index_expr.rs b/engine/src/ast/index_expr.rs new file mode 100644 index 00000000..0521c0ec --- /dev/null +++ b/engine/src/ast/index_expr.rs @@ -0,0 +1,949 @@ +use super::{ + field_expr::IdentifierExpr, + parse::FilterParser, + visitor::{Visitor, VisitorMut}, + ValueExpr, +}; +use crate::{ + compiler::Compiler, + execution_context::ExecutionContext, + filter::{CompiledExpr, CompiledOneExpr, CompiledValueExpr, CompiledVecExpr}, + lex::{expect, skip_space, span, Lex, LexErrorKind, LexResult, LexWith}, + lhs_types::TypedArray, + lhs_types::{Array, Map}, + scheme::{FieldIndex, IndexAccessError}, + types::{GetType, IntoIter, LhsValue, Type}, +}; +use serde::{ser::SerializeSeq, Serialize, Serializer}; + +/// IndexExpr is an expr that destructures an index into an IdentifierExpr. +/// +/// For example, given a scheme which declares a field, http.request.headers, +/// as a map of string to list of strings, then the expression +/// http.request.headers["Cookie"][0] would have an IdentifierExpr +/// http.request.headers and indexes ["Cookie", 0]. +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct IndexExpr<'s> { + /// The accessed identifier. + pub identifier: IdentifierExpr<'s>, + /// The list of indexes access. + pub indexes: Vec, +} + +macro_rules! index_access_one { + ($indexes:ident, $first:expr, $default:expr, $ctx:ident, $func:expr) => { + $indexes + .iter() + .fold($first, |value, idx| { + value.and_then(|val| val.get(idx).unwrap()) + }) + .map_or_else( + || $default, + #[allow(clippy::redundant_closure_call)] + |val| $func(val, $ctx), + ) + }; +} + +macro_rules! index_access_vec { + ($indexes:ident, $first:expr, $ctx:ident, $func:ident) => { + index_access_one!( + $indexes, + $first, + TypedArray::default(), + $ctx, + |val: &LhsValue<'_>, ctx| { + TypedArray::from_iter(val.iter().unwrap().map(|item| $func(item, ctx))) + } + ) + }; +} + +impl<'s> ValueExpr<'s> for IndexExpr<'s> { + #[inline] + fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V) { + match self.identifier { + IdentifierExpr::Field(ref field) => visitor.visit_field(field), + IdentifierExpr::FunctionCallExpr(ref call) => visitor.visit_function_call_expr(call), + } + } + + #[inline] + fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V) { + match self.identifier { + IdentifierExpr::Field(ref field) => visitor.visit_field(field), + IdentifierExpr::FunctionCallExpr(ref mut call) => { + visitor.visit_function_call_expr(call) + } + } + } + + fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledValueExpr<'s, C::U> { + let mut ty = self.get_type(); + let map_each_count = self.map_each_count(); + let Self { + identifier, + indexes, + } = self; + + let last = match map_each_count { + 0 => Some(indexes.len()), + 1 if indexes.last() == Some(&FieldIndex::MapEach) => { + ty = Type::Array(ty.into()); + Some(indexes.len() - 1) + } + _ => None, + }; + if last == Some(0) { + // Fast path + identifier.compile_with_compiler(compiler) + } else if let Some(last) = last { + // Average path + match identifier { + IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| { + indexes[..last] + .iter() + .try_fold(ctx.get_field_value_unchecked(f), |value, index| { + value.get(index).unwrap() + }) + .map(LhsValue::as_ref) + .ok_or(ty) + }), + IdentifierExpr::FunctionCallExpr(call) => { + let call = compiler.compile_function_call_expr(call); + CompiledValueExpr::new(move |ctx| { + let result = call.execute(ctx).ok(); + indexes[..last] + .iter() + .fold(result, |value, index| { + value.and_then(|val| val.extract(index).unwrap()) + }) + .ok_or(ty) + }) + } + } + } else { + // Slow path + match identifier { + IdentifierExpr::Field(f) => CompiledValueExpr::new(move |ctx| { + let mut iter = MapEachIterator::from_indexes(&indexes[..]); + iter.reset(ctx.get_field_value_unchecked(f).as_ref()); + Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap())) + }), + IdentifierExpr::FunctionCallExpr(call) => { + let call = compiler.compile_function_call_expr(call); + CompiledValueExpr::new(move |ctx| { + let mut iter = MapEachIterator::from_indexes(&indexes[..]); + iter.reset(call.execute(ctx).map_err(|_| Type::Array(ty.into()))?); + Ok(LhsValue::Array(Array::try_from_iter(ty, iter).unwrap())) + }) + } + } + } + } +} + +fn simplify_indexes(mut indexes: Vec) -> Box<[FieldIndex]> { + if Some(&FieldIndex::MapEach) == indexes.last() { + indexes.pop(); + } + indexes.into_boxed_slice() +} + +impl<'s> IndexExpr<'s> { + fn compile_one_with + 's>( + self, + compiler: &mut C, + default: bool, + func: F, + ) -> CompiledOneExpr<'s, C::U> + where + F: Fn(&LhsValue<'_>, &ExecutionContext<'_, C::U>) -> bool + Sync + Send + 's, + { + let Self { + identifier, + indexes, + } = self; + let indexes = simplify_indexes(indexes); + match identifier { + IdentifierExpr::FunctionCallExpr(call) => { + let call = compiler.compile_function_call_expr(call); + if indexes.is_empty() { + CompiledOneExpr::new(move |ctx| { + call.execute(ctx).map_or(default, |val| func(&val, ctx)) + }) + } else { + CompiledOneExpr::new(move |ctx| { + index_access_one!( + indexes, + call.execute(ctx).as_ref().ok(), + default, + ctx, + func + ) + }) + } + } + IdentifierExpr::Field(f) => { + if indexes.is_empty() { + CompiledOneExpr::new(move |ctx| func(ctx.get_field_value_unchecked(f), ctx)) + } else { + CompiledOneExpr::new(move |ctx| { + index_access_one!( + indexes, + Some(ctx.get_field_value_unchecked(f)), + default, + ctx, + func + ) + }) + } + } + } + } + + pub(crate) fn compile_vec_with + 's>( + self, + compiler: &mut C, + func: F, + ) -> CompiledVecExpr<'s, C::U> + where + F: Fn(&LhsValue<'_>, &ExecutionContext<'_, C::U>) -> bool + Sync + Send + 's, + { + let Self { + identifier, + indexes, + } = self; + let indexes = simplify_indexes(indexes); + match identifier { + IdentifierExpr::FunctionCallExpr(call) => { + let call = compiler.compile_function_call_expr(call); + CompiledVecExpr::new(move |ctx| { + index_access_vec!(indexes, call.execute(ctx).as_ref().ok(), ctx, func) + }) + } + IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| { + index_access_vec!(indexes, Some(ctx.get_field_value_unchecked(f)), ctx, func) + }), + } + } + + pub(crate) fn compile_iter_with + 's>( + self, + compiler: &mut C, + func: F, + ) -> CompiledVecExpr<'s, C::U> + where + F: Fn(&LhsValue<'_>, &ExecutionContext<'_, C::U>) -> bool + Sync + Send + 's, + { + let Self { + identifier, + indexes, + } = self; + match identifier { + IdentifierExpr::Field(f) => CompiledVecExpr::new(move |ctx| { + let mut iter = MapEachIterator::from_indexes(&indexes[..]); + iter.reset(ctx.get_field_value_unchecked(f).as_ref()); + TypedArray::from_iter(iter.map(|item| func(&item, ctx))) + }), + IdentifierExpr::FunctionCallExpr(call) => { + let call = compiler.compile_function_call_expr(call); + CompiledVecExpr::new(move |ctx| { + let mut iter = MapEachIterator::from_indexes(&indexes[..]); + if let Ok(val) = call.execute(ctx) { + iter.reset(val); + } else { + return TypedArray::default(); + } + + TypedArray::from_iter(iter.map(|item| func(&item, ctx))) + }) + } + } + } + + /// Compiles an [`IndexExpr`] node into a [`CompiledExpr`] (boxed closure) using the + /// provided comparison function that returns a boolean. + pub fn compile_with + 's>( + self, + compiler: &mut C, + default: bool, + func: F, + ) -> CompiledExpr<'s, C::U> + where + F: Fn(&LhsValue<'_>, &ExecutionContext<'_, C::U>) -> bool + Sync + Send + 's, + { + match self.map_each_count() { + 0 => CompiledExpr::One(self.compile_one_with(compiler, default, func)), + 1 if self.indexes.last() == Some(&FieldIndex::MapEach) => { + CompiledExpr::Vec(self.compile_vec_with(compiler, func)) + } + _ => CompiledExpr::Vec(self.compile_iter_with(compiler, func)), + } + } + + pub(crate) fn map_each_count(&self) -> usize { + self.indexes + .iter() + .filter(|&index| index == &FieldIndex::MapEach) + .count() + } + + /// Returns the associated identifier (field or function call). + pub fn identifier(&self) -> &IdentifierExpr<'s> { + &self.identifier + } + + /// Returns the index accesses as a list of [`FieldIndex`]. + pub fn indexes(&self) -> &[FieldIndex] { + &self.indexes + } +} + +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for IndexExpr<'s> { + fn lex_with(mut input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let (identifier, rest) = IdentifierExpr::lex_with(input, parser)?; + + let mut current_type = identifier.get_type(); + + let mut indexes = Vec::new(); + + input = rest; + + while let Ok(rest) = expect(input, "[") { + let rest = skip_space(rest); + + let (idx, rest) = FieldIndex::lex(rest)?; + + let mut rest = skip_space(rest); + + rest = expect(rest, "]")?; + + match &idx { + FieldIndex::ArrayIndex(_) => match current_type { + Type::Array(array_type) => { + current_type = array_type.into(); + } + _ => { + return Err(( + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: idx, + actual: current_type, + }), + span(input, rest), + )) + } + }, + FieldIndex::MapKey(_) => match current_type { + Type::Map(map_type) => { + current_type = map_type.into(); + } + _ => { + return Err(( + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: idx, + actual: current_type, + }), + span(input, rest), + )) + } + }, + FieldIndex::MapEach => match current_type { + Type::Array(array_type) => { + current_type = array_type.into(); + } + Type::Map(map_type) => { + current_type = map_type.into(); + } + _ => { + return Err(( + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: idx, + actual: current_type, + }), + span(input, rest), + )) + } + }, + }; + + input = rest; + + indexes.push(idx); + } + + Ok(( + IndexExpr { + identifier, + indexes, + }, + input, + )) + } +} + +impl<'s> GetType for IndexExpr<'s> { + fn get_type(&self) -> Type { + let mut ty = self.identifier.get_type(); + for index in &self.indexes { + ty = match (ty, index) { + (Type::Array(sub_ty), FieldIndex::ArrayIndex(_)) => sub_ty.into(), + (Type::Array(sub_ty), FieldIndex::MapEach) => sub_ty.into(), + (Type::Map(sub_ty), FieldIndex::MapKey(_)) => sub_ty.into(), + (Type::Map(sub_ty), FieldIndex::MapEach) => sub_ty.into(), + (_, _) => unreachable!(), + } + } + ty + } +} + +impl<'s> Serialize for IndexExpr<'s> { + fn serialize(&self, ser: S) -> Result { + if self.indexes.is_empty() { + self.identifier.serialize(ser) + } else { + let mut seq = ser.serialize_seq(Some(self.indexes.len() + 1))?; + match &self.identifier { + IdentifierExpr::Field(field) => seq.serialize_element(field)?, + IdentifierExpr::FunctionCallExpr(call) => seq.serialize_element(call)?, + }; + for index in &self.indexes { + seq.serialize_element(index)?; + } + seq.end() + } + } +} + +enum FieldIndexIterator<'a, 'b> { + ArrayIndex(Option<(Array<'a>, u32)>), + MapKey(Option<(Map<'a>, &'b [u8])>), + MapEach(IntoIter<'a>), +} + +impl<'a, 'b> FieldIndexIterator<'a, 'b> { + fn new(val: LhsValue<'a>, idx: &'b FieldIndex) -> Result { + match idx { + FieldIndex::ArrayIndex(idx) => match val { + LhsValue::Array(arr) => Ok(Self::ArrayIndex(Some((arr, *idx)))), + _ => Err(IndexAccessError { + index: FieldIndex::ArrayIndex(*idx), + actual: val.get_type(), + }), + }, + FieldIndex::MapKey(key) => match val { + LhsValue::Map(map) => Ok(Self::MapKey(Some((map, key.as_bytes())))), + _ => Err(IndexAccessError { + index: FieldIndex::MapKey(key.clone()), + actual: val.get_type(), + }), + }, + FieldIndex::MapEach => match val { + LhsValue::Array(_) | LhsValue::Map(_) => Ok(Self::MapEach(val.into_iter())), + _ => Err(IndexAccessError { + index: FieldIndex::MapEach, + actual: val.get_type(), + }), + }, + } + } +} + +impl<'a, 'b> Iterator for FieldIndexIterator<'a, 'b> { + type Item = LhsValue<'a>; + + fn next(&mut self) -> Option { + match self { + Self::ArrayIndex(opt) => opt.take().and_then(|(arr, idx)| arr.extract(idx as usize)), + Self::MapKey(opt) => opt.take().and_then(|(map, key)| map.extract(key)), + Self::MapEach(iter) => iter.next(), + } + } +} + +struct MapEachIterator<'a, 'b> { + indexes: &'b [FieldIndex], + stack: Vec>, +} + +impl<'a, 'b> MapEachIterator<'a, 'b> { + fn from_indexes(indexes: &'b [FieldIndex]) -> Self { + Self { + indexes, + stack: Vec::with_capacity(indexes.len()), + } + } + + fn reset(&mut self, val: LhsValue<'a>) { + self.stack.clear(); + let first = self.indexes.first().unwrap(); + self.stack + .push(FieldIndexIterator::new(val, first).unwrap()); + } +} + +impl<'a, 'b> Iterator for MapEachIterator<'a, 'b> { + type Item = LhsValue<'a>; + + fn next(&mut self) -> Option> { + while !self.stack.is_empty() { + assert!(self.stack.len() <= self.indexes.len()); + if let Some(nxt) = self.stack.last_mut().unwrap().next() { + // Check that current iterator is a leaf iterator + if self.stack.len() == self.indexes.len() { + // Return a value if a leaf iterator returned a value + return Some(nxt); + } else { + self.stack.push( + FieldIndexIterator::new(nxt, &self.indexes[self.stack.len()]).unwrap(), + ); + } + } else { + // Last iterator is finished, remove it + self.stack.pop(); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ast::field_expr::IdentifierExpr, Array, FieldIndex, FilterParser, FunctionArgKind, + FunctionArgs, FunctionCallArgExpr, FunctionCallExpr, Scheme, SimpleFunctionDefinition, + SimpleFunctionImpl, SimpleFunctionParam, + }; + use std::sync::LazyLock; + + fn array_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + match args.next()? { + Ok(LhsValue::Bytes(bytes)) => Some(Array::from_iter([bytes]).into()), + Err(Type::Bytes) => None, + _ => unreachable!(), + } + } + + fn array2_function<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + match args.next()? { + Ok(LhsValue::Bytes(bytes)) => Some({ + let inner = Array::from_iter([bytes]); + let outer = Array::try_from_iter(Type::Array(Type::Bytes.into()), [inner]).unwrap(); + outer.into() + }), + Err(Type::Bytes) => None, + _ => unreachable!(), + } + } + + static SCHEME: LazyLock = LazyLock::new(|| { + let mut scheme = Scheme::new(); + scheme + .add_field("test", Type::Array(Type::Bytes.into())) + .unwrap(); + scheme + .add_field("test2", Type::Array(Type::Array(Type::Bytes.into()).into())) + .unwrap(); + scheme + .add_field("map", Type::Map(Type::Bytes.into())) + .unwrap(); + scheme + .add_function( + "array", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Array(Type::Bytes.into()), + implementation: SimpleFunctionImpl::new(array_function), + }, + ) + .unwrap(); + scheme + .add_function( + "array2", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Array(Type::Array(Type::Bytes.into()).into()), + implementation: SimpleFunctionImpl::new(array2_function), + }, + ) + .unwrap(); + scheme + }); + + #[test] + fn test_array_indices() { + fn run(i: u32) { + let filter = format!("test[{i}]"); + assert_ok!( + FilterParser::new(&SCHEME).lex_as(&filter), + IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test").unwrap()), + indexes: vec![FieldIndex::ArrayIndex(i)], + } + ); + } + + run(0); + run(1); + run(99); + run(999); + run(9999); + run(99999); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>("test[-1]"), + LexErrorKind::ExpectedLiteral("expected positive integer as index"), + "-1]" + ); + } + + #[test] + fn test_map_access() { + assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"map["a"]"#), + IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("map").unwrap()), + indexes: vec![FieldIndex::MapKey("a".to_string())], + } + ); + + assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"map["😍"]"#), + IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("map").unwrap()), + indexes: vec![FieldIndex::MapKey("😍".to_string())], + } + ); + } + + #[test] + fn test_access_with_non_string() { + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(r#"test[a]"#), + LexErrorKind::ExpectedLiteral("expected quoted utf8 string or positive integer"), + "a]" + ); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(r#"map[a]"#), + LexErrorKind::ExpectedLiteral("expected quoted utf8 string or positive integer"), + "a]" + ); + } + + #[test] + fn test_function_call_with_missing_argument_then_index_access() { + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"array(test[0])[0]"#), + IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("array").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test").unwrap()), + indexes: vec![FieldIndex::ArrayIndex(0)], + })], + context: None + }), + indexes: vec![FieldIndex::ArrayIndex(0)], + } + ); + + assert_eq!(expr.identifier.get_type(), Type::Array(Type::Bytes.into())); + assert_eq!(expr.get_type(), Type::Bytes); + + let value = expr.compile(); + + let mut exec_ctx = ExecutionContext::new(&SCHEME); + + exec_ctx + .set_field_value(SCHEME.get_field("test").unwrap(), Array::new(Type::Bytes)) + .unwrap(); + + assert_eq!(value.execute(&exec_ctx), Err(Type::Bytes)); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"array(test[0])[*]"#), + IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("array").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test").unwrap()), + indexes: vec![FieldIndex::ArrayIndex(0)], + })], + context: None + }), + indexes: vec![FieldIndex::MapEach], + } + ); + + assert_eq!(expr.identifier.get_type(), Type::Array(Type::Bytes.into())); + assert_eq!(expr.get_type(), Type::Bytes); + + let value = expr.compile(); + + let mut exec_ctx = ExecutionContext::new(&SCHEME); + + exec_ctx + .set_field_value(SCHEME.get_field("test").unwrap(), Array::new(Type::Bytes)) + .unwrap(); + + assert_eq!( + value.execute(&exec_ctx), + Err(Type::Array(Type::Bytes.into())) + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"array2(test[0])[*][*]"#), + IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("array2").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test").unwrap()), + indexes: vec![FieldIndex::ArrayIndex(0)], + })], + context: None + }), + indexes: vec![FieldIndex::MapEach, FieldIndex::MapEach], + } + ); + + assert_eq!( + expr.identifier.get_type(), + Type::Array(Type::Array(Type::Bytes.into()).into()) + ); + assert_eq!(expr.get_type(), Type::Bytes); + + let value = expr.compile(); + + let mut exec_ctx = ExecutionContext::new(&SCHEME); + + exec_ctx + .set_field_value(SCHEME.get_field("test").unwrap(), Array::new(Type::Bytes)) + .unwrap(); + + assert_eq!( + value.execute(&exec_ctx), + Err(Type::Array(Type::Bytes.into())) + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"array2(test[0])[*][0]"#), + IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("array2").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test").unwrap()), + indexes: vec![FieldIndex::ArrayIndex(0)], + })], + context: None + }), + indexes: vec![FieldIndex::MapEach, FieldIndex::ArrayIndex(0)], + } + ); + + assert_eq!( + expr.identifier.get_type(), + Type::Array(Type::Array(Type::Bytes.into()).into()) + ); + assert_eq!(expr.get_type(), Type::Bytes); + + let value = expr.compile(); + + let mut exec_ctx = ExecutionContext::new(&SCHEME); + + exec_ctx + .set_field_value(SCHEME.get_field("test").unwrap(), Array::new(Type::Bytes)) + .unwrap(); + + assert_eq!( + value.execute(&exec_ctx), + Err(Type::Array(Type::Bytes.into())) + ); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(r#"array2(test[0])[0][*]"#), + IndexExpr { + identifier: IdentifierExpr::FunctionCallExpr(FunctionCallExpr { + function: SCHEME.get_function("array2").unwrap(), + args: vec![FunctionCallArgExpr::IndexExpr(IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test").unwrap()), + indexes: vec![FieldIndex::ArrayIndex(0)], + })], + context: None + }), + indexes: vec![FieldIndex::ArrayIndex(0), FieldIndex::MapEach], + } + ); + + assert_eq!( + expr.identifier.get_type(), + Type::Array(Type::Array(Type::Bytes.into()).into()) + ); + assert_eq!(expr.get_type(), Type::Bytes); + + let value = expr.compile(); + + let mut exec_ctx = ExecutionContext::new(&SCHEME); + + exec_ctx + .set_field_value(SCHEME.get_field("test").unwrap(), Array::new(Type::Bytes)) + .unwrap(); + + assert_eq!( + value.execute(&exec_ctx), + Err(Type::Array(Type::Bytes.into())) + ); + } + + #[test] + fn test_mapeach() { + let filter = "test2[0][*]".to_string(); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(&filter), + IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test2").unwrap()), + indexes: vec![FieldIndex::ArrayIndex(0), FieldIndex::MapEach], + } + ); + + assert_eq!(expr.map_each_count(), 1); + assert_eq!(expr.get_type(), Type::Bytes); + + let filter = "test2[*][0]".to_string(); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(&filter), + IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test2").unwrap()), + indexes: vec![FieldIndex::MapEach, FieldIndex::ArrayIndex(0)], + } + ); + + assert_eq!(expr.map_each_count(), 1); + assert_eq!(expr.get_type(), Type::Bytes); + + let filter = "test2[*][*]".to_string(); + + let expr = assert_ok!( + FilterParser::new(&SCHEME).lex_as(&filter), + IndexExpr { + identifier: IdentifierExpr::Field(SCHEME.get_field("test2").unwrap()), + indexes: vec![FieldIndex::MapEach, FieldIndex::MapEach], + } + ); + + assert_eq!(expr.map_each_count(), 2); + assert_eq!(expr.get_type(), Type::Bytes); + + let filter = "test2[0][*][*]".to_string(); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(&filter), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Bytes + }), + "[*]" + ); + + let filter = "test2[*][0][*]".to_string(); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(&filter), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Bytes + }), + "[*]" + ); + + let filter = "test2[*][*][0]".to_string(); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(&filter), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::ArrayIndex(0), + actual: Type::Bytes + }), + "[0]" + ); + + let filter = "test2[*][*][*]".to_string(); + + assert_err!( + FilterParser::new(&SCHEME).lex_as::>(&filter), + LexErrorKind::InvalidIndexAccess(IndexAccessError { + index: FieldIndex::MapEach, + actual: Type::Bytes + }), + "[*]" + ); + } + + #[test] + fn test_flatten() { + let arr = LhsValue::Array( + Array::try_from_iter( + Type::Array(Type::Bytes.into()), + (0..10).map(|i| Array::from_iter((0..10).map(|j| format!("[{i}][{j}]")))), + ) + .unwrap(), + ); + + for i in 0..10 { + let indexes = [FieldIndex::ArrayIndex(i), FieldIndex::MapEach]; + let mut iter = MapEachIterator::from_indexes(&indexes[..]); + + iter.reset(arr.clone()); + for (j, elem) in iter.enumerate() { + let bytes = match elem { + LhsValue::Bytes(bytes) => bytes, + _ => unreachable!(), + }; + assert_eq!(std::str::from_utf8(&bytes).unwrap(), format!("[{i}][{j}]")); + } + + let indexes = [FieldIndex::MapEach, FieldIndex::ArrayIndex(i)]; + let mut iter = MapEachIterator::from_indexes(&indexes[..]); + + iter.reset(arr.clone()); + for (j, elem) in iter.enumerate() { + let bytes = match elem { + LhsValue::Bytes(bytes) => bytes, + _ => unreachable!(), + }; + assert_eq!(std::str::from_utf8(&bytes).unwrap(), format!("[{j}][{i}]")); + } + } + + let indexes = [FieldIndex::MapEach, FieldIndex::MapEach]; + let mut iter = MapEachIterator::from_indexes(&indexes[..]); + let mut i = 0; + let mut j = 0; + + iter.reset(arr.clone()); + for elem in iter { + let bytes = match elem { + LhsValue::Bytes(bytes) => bytes, + _ => unreachable!(), + }; + assert_eq!(std::str::from_utf8(&bytes).unwrap(), format!("[{i}][{j}]")); + j = (j + 1) % 10; + i += (j == 0) as u32; + } + } +} diff --git a/engine/src/ast/logical_expr.rs b/engine/src/ast/logical_expr.rs new file mode 100644 index 00000000..4f8b993a --- /dev/null +++ b/engine/src/ast/logical_expr.rs @@ -0,0 +1,923 @@ +use super::{ + field_expr::ComparisonExpr, + parse::FilterParser, + visitor::{Visitor, VisitorMut}, + Expr, +}; +use crate::{ + compiler::Compiler, + filter::{CompiledExpr, CompiledOneExpr, CompiledVecExpr}, + lex::{expect, skip_space, Lex, LexErrorKind, LexResult, LexWith}, + types::{GetType, Type, TypeMismatchError}, +}; +use serde::Serialize; + +lex_enum!( + /// LogicalOp is an operator for a [`LogicalExpr`]. Its ordering is defined + /// by the operators' precedences in ascending order. + #[derive(PartialOrd, Ord)] LogicalOp { + /// `or` / `||` operator + "or" | "||" => Or, + /// `xor` / `^^` operator + "xor" | "^^" => Xor, + /// `and` / `&&` operator + "and" | "&&" => And, + } +); + +lex_enum!( + /// An operator that takes a single argument + UnaryOp { + /// `not` / `!` operator + "not" | "!" => Not, + } +); + +/// A parenthesized expression. +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] +#[serde(transparent)] +pub struct ParenthesizedExpr<'s> { + /// The inner expression. + pub expr: LogicalExpr<'s>, +} + +/// LogicalExpr is a either a generic sub-expression +/// or a logical conjunction expression. +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] +#[serde(untagged)] +pub enum LogicalExpr<'s> { + /// Logical conjunction expression + Combining { + /// Logical operator + op: LogicalOp, + /// List of sub-expressions + items: Vec>, + }, + /// A comparison expression. + Comparison(ComparisonExpr<'s>), + /// A parenthesized expression. + Parenthesized(Box>), + /// A unary expression. + Unary { + /// Unary operator. + op: UnaryOp, + /// Sub-expression. + arg: Box>, + }, +} + +impl<'s> GetType for LogicalExpr<'s> { + fn get_type(&self) -> Type { + match &self { + LogicalExpr::Combining { ref items, .. } => items[0].get_type(), + LogicalExpr::Comparison(comparison) => comparison.get_type(), + LogicalExpr::Parenthesized(parenthesized) => parenthesized.expr.get_type(), + LogicalExpr::Unary { arg, .. } => arg.get_type(), + } + } +} + +impl<'s> LogicalExpr<'s> { + fn lex_combining_op(input: &str) -> (Option, &str) { + match LogicalOp::lex(skip_space(input)) { + Ok((op, input)) => (Some(op), skip_space(input)), + Err(_) => (None, input), + } + } + + fn lex_simple_expr<'i>(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + Ok(if let Ok(input) = expect(input, "(") { + let input = skip_space(input); + let (expr, input) = LogicalExpr::lex_with(input, parser)?; + let input = skip_space(input); + let input = expect(input, ")")?; + ( + LogicalExpr::Parenthesized(Box::new(ParenthesizedExpr { expr })), + input, + ) + } else if let Ok((op, input)) = UnaryOp::lex(input) { + let input = skip_space(input); + let (arg, input) = Self::lex_simple_expr(input, parser)?; + ( + LogicalExpr::Unary { + op, + arg: Box::new(arg), + }, + input, + ) + } else { + let (op, input) = ComparisonExpr::lex_with(input, parser)?; + (LogicalExpr::Comparison(op), input) + }) + } + + fn lex_more_with_precedence<'i>( + self, + parser: &FilterParser<'s>, + min_prec: Option, + mut lookahead: (Option, &'i str), + ) -> LexResult<'i, Self> { + let mut lhs = self; + + while let Some(op) = lookahead.0 { + let mut rhs = Self::lex_simple_expr(lookahead.1, parser)?; + + loop { + lookahead = Self::lex_combining_op(rhs.1); + if lookahead.0 <= Some(op) { + break; + } + rhs = rhs + .0 + .lex_more_with_precedence(parser, lookahead.0, lookahead)?; + } + + // check that the LogicalExpr is valid by ensuring both the left + // hand side and right hand side of the operator are comparable. + // For example, it doesn't make sense to do a logical operator on + // a Bool and Bytes, or an Array(Bool) with Bool. + let (lhsty, rhsty) = (lhs.get_type(), rhs.0.get_type()); + match (&lhsty, &rhsty) { + (Type::Bool, Type::Bool) => {} + (Type::Array(_), Type::Array(_)) => {} + _ => { + return Err(( + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: lhsty.into(), + actual: rhsty, + }), + lookahead.1, + )) + } + } + + match lhs { + LogicalExpr::Combining { + op: lhs_op, + ref mut items, + } if lhs_op == op => { + items.push(rhs.0); + } + _ => { + lhs = LogicalExpr::Combining { + op, + items: vec![lhs, rhs.0], + }; + } + } + + if lookahead.0 < min_prec { + // pretend we haven't seen an operator if its precedence is + // outside of our limits + lookahead = (None, rhs.1); + } + } + + Ok((lhs, lookahead.1)) + } +} + +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for LogicalExpr<'s> { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let (lhs, input) = Self::lex_simple_expr(input, parser)?; + let lookahead = Self::lex_combining_op(input); + lhs.lex_more_with_precedence(parser, None, lookahead) + } +} + +impl<'s> Expr<'s> for LogicalExpr<'s> { + #[inline] + fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V) { + match self { + LogicalExpr::Comparison(node) => visitor.visit_comparison_expr(node), + LogicalExpr::Parenthesized(node) => visitor.visit_logical_expr(&node.expr), + LogicalExpr::Unary { arg, .. } => visitor.visit_logical_expr(arg), + LogicalExpr::Combining { items, .. } => { + items + .iter() + .for_each(|node| visitor.visit_logical_expr(node)); + } + } + } + + #[inline] + fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V) { + match self { + LogicalExpr::Comparison(node) => visitor.visit_comparison_expr(node), + LogicalExpr::Parenthesized(node) => visitor.visit_logical_expr(&mut node.expr), + LogicalExpr::Unary { arg, .. } => visitor.visit_logical_expr(arg), + LogicalExpr::Combining { items, .. } => { + items + .iter_mut() + .for_each(|node| visitor.visit_logical_expr(node)); + } + } + } + + fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledExpr<'s, C::U> { + match self { + LogicalExpr::Comparison(op) => compiler.compile_comparison_expr(op), + LogicalExpr::Parenthesized(node) => compiler.compile_logical_expr(node.expr), + LogicalExpr::Unary { + op: UnaryOp::Not, + arg, + } => { + let arg = compiler.compile_logical_expr(*arg); + match arg { + CompiledExpr::One(one) => { + CompiledExpr::One(CompiledOneExpr::new(move |ctx| !one.execute(ctx))) + } + CompiledExpr::Vec(vec) => CompiledExpr::Vec(CompiledVecExpr::new(move |ctx| { + vec.execute(ctx).iter().map(|item| !item).collect() + })), + } + } + LogicalExpr::Combining { op, items } => { + let items = items.into_iter(); + let mut items = items.map(|item| compiler.compile_logical_expr(item)); + let first = items.next().unwrap(); + match first { + CompiledExpr::One(first) => { + let items = items + .map(|item| match item { + CompiledExpr::One(one) => one, + CompiledExpr::Vec(_) => unreachable!(), + }) + .collect::>() + .into_boxed_slice(); + match op { + LogicalOp::And => CompiledExpr::One(CompiledOneExpr::new(move |ctx| { + first.execute(ctx) && items.iter().all(|item| item.execute(ctx)) + })), + LogicalOp::Or => CompiledExpr::One(CompiledOneExpr::new(move |ctx| { + first.execute(ctx) || items.iter().any(|item| item.execute(ctx)) + })), + LogicalOp::Xor => CompiledExpr::One(CompiledOneExpr::new(move |ctx| { + items + .iter() + .fold(first.execute(ctx), |acc, item| acc ^ item.execute(ctx)) + })), + } + } + CompiledExpr::Vec(first) => { + let items = items + .map(|item| match item { + CompiledExpr::One(_) => unreachable!(), + CompiledExpr::Vec(vec) => vec, + }) + .collect::>() + .into_boxed_slice(); + match op { + LogicalOp::And => CompiledExpr::Vec(CompiledVecExpr::new(move |ctx| { + let items = items.iter().map(|item| item.execute(ctx)); + let mut output = first.execute(ctx); + for values in items { + output.iter_mut().zip(values.iter()).for_each( + |(left, right)| { + *left = *left && *right; + }, + ); + if values.len() < output.len() { + output.truncate(values.len()); + } + } + output + })), + LogicalOp::Or => CompiledExpr::Vec(CompiledVecExpr::new(move |ctx| { + let items = items.iter().map(|item| item.execute(ctx)); + let mut output = first.execute(ctx); + for values in items { + output.iter_mut().zip(values.iter()).for_each( + |(left, right)| { + *left = *left || *right; + }, + ); + if values.len() < output.len() { + output.truncate(values.len()); + } + } + output + })), + LogicalOp::Xor => CompiledExpr::Vec(CompiledVecExpr::new(move |ctx| { + let items = items.iter().map(|item| item.execute(ctx)); + let mut output = first.execute(ctx); + for values in items { + output.iter_mut().zip(values.iter()).for_each( + |(left, right)| { + *left ^= *right; + }, + ); + if values.len() < output.len() { + output.truncate(values.len()); + } + } + output + })), + } + } + } + } + } + } +} + +#[test] +#[allow(clippy::bool_assert_comparison)] +#[allow(clippy::cognitive_complexity)] +fn test() { + use super::field_expr::ComparisonExpr; + use crate::{ + ast::field_expr::{ComparisonOpExpr, IdentifierExpr}, + ast::index_expr::IndexExpr, + execution_context::ExecutionContext, + lex::complete, + lhs_types::Array, + scheme::FieldIndex, + types::Type, + }; + + let scheme = &Scheme! { + t: Bool, + f: Bool, + at: Array(Bool), + af: Array(Bool), + aat: Array(Array(Bool)), + }; + + let ctx = &mut ExecutionContext::new(scheme); + + let t_expr = LogicalExpr::Comparison(complete(FilterParser::new(scheme).lex_as("t")).unwrap()); + + let t_expr = || t_expr.clone(); + + let f_expr = LogicalExpr::Comparison(complete(FilterParser::new(scheme).lex_as("f")).unwrap()); + + let f_expr = || f_expr.clone(); + + assert_ok!(FilterParser::new(scheme).lex_as("t"), t_expr()); + + let at_expr = + LogicalExpr::Comparison(complete(FilterParser::new(scheme).lex_as("at")).unwrap()); + + let at_expr = || at_expr.clone(); + + let af_expr = + LogicalExpr::Comparison(complete(FilterParser::new(scheme).lex_as("af")).unwrap()); + + let af_expr = || af_expr.clone(); + + assert_ok!(FilterParser::new(scheme).lex_as("at"), at_expr()); + + ctx.set_field_value(scheme.get_field("t").unwrap(), true) + .unwrap(); + ctx.set_field_value(scheme.get_field("f").unwrap(), false) + .unwrap(); + ctx.set_field_value(scheme.get_field("at").unwrap(), { + Array::from_iter([true, false, true]) + }) + .unwrap(); + ctx.set_field_value(scheme.get_field("af").unwrap(), { + Array::from_iter([false, false, true]) + }) + .unwrap(); + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("t and t"), + LogicalExpr::Combining { + op: LogicalOp::And, + items: vec![t_expr(), t_expr()], + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), true); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("t and f"), + LogicalExpr::Combining { + op: LogicalOp::And, + items: vec![t_expr(), f_expr()], + } + ); + + assert_json!( + expr, + { + "op": "And", + "items": [ + { + "lhs": "t", + "op": "IsTrue" + }, + { + "lhs": "f", + "op": "IsTrue" + } + ] + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), false); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("t or f"), + LogicalExpr::Combining { + op: LogicalOp::Or, + items: vec![t_expr(), f_expr()], + } + ); + + assert_json!( + expr, + { + "op": "Or", + "items": [ + { + "lhs": "t", + "op": "IsTrue" + }, + { + "lhs": "f", + "op": "IsTrue" + } + ] + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), true); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("f or f"), + LogicalExpr::Combining { + op: LogicalOp::Or, + items: vec![f_expr(), f_expr()], + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), false); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("t xor f"), + LogicalExpr::Combining { + op: LogicalOp::Xor, + items: vec![t_expr(), f_expr()], + } + ); + + assert_json!( + expr, + { + "op": "Xor", + "items": [ + { + "lhs": "t", + "op": "IsTrue" + }, + { + "lhs": "f", + "op": "IsTrue" + } + ] + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), true); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("f xor f"), + LogicalExpr::Combining { + op: LogicalOp::Xor, + items: vec![f_expr(), f_expr()], + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), false); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("f xor t"), + LogicalExpr::Combining { + op: LogicalOp::Xor, + items: vec![f_expr(), t_expr()], + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), true); + } + + assert_ok!( + FilterParser::new(scheme).lex_as("t or t && t and t or t ^^ t and t || t"), + LogicalExpr::Combining { + op: LogicalOp::Or, + items: vec![ + t_expr(), + LogicalExpr::Combining { + op: LogicalOp::And, + items: vec![t_expr(), t_expr(), t_expr()], + }, + LogicalExpr::Combining { + op: LogicalOp::Xor, + items: vec![ + t_expr(), + LogicalExpr::Combining { + op: LogicalOp::And, + items: vec![t_expr(), t_expr()], + }, + ], + }, + t_expr(), + ], + } + ); + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("at and af"), + LogicalExpr::Combining { + op: LogicalOp::And, + items: vec![at_expr(), af_expr()], + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [false, false, true]); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("at or af"), + LogicalExpr::Combining { + op: LogicalOp::Or, + items: vec![at_expr(), af_expr()], + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [true, false, true]); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("at xor af"), + LogicalExpr::Combining { + op: LogicalOp::Xor, + items: vec![at_expr(), af_expr()], + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [true, false, false]); + } + + { + assert_err!( + FilterParser::new(scheme).lex_as::>("t and af"), + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Bool.into(), + actual: Type::Array(Type::Bool.into()), + }), + "" + ); + + assert_err!( + FilterParser::new(scheme).lex_as::>("at and f"), + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Array(Type::Bool.into()).into(), + actual: Type::Bool, + }), + "" + ); + } + + { + assert_err!( + FilterParser::new(scheme).lex_as::>("t or af"), + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Bool.into(), + actual: Type::Array(Type::Bool.into()), + }), + "" + ); + + assert_err!( + FilterParser::new(scheme).lex_as::>("at or f"), + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Array(Type::Bool.into()).into(), + actual: Type::Bool, + }), + "" + ); + } + + { + assert_err!( + FilterParser::new(scheme).lex_as::>("t xor af"), + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Bool.into(), + actual: Type::Array(Type::Bool.into()), + }), + "" + ); + + assert_err!( + FilterParser::new(scheme).lex_as::>("at xor f"), + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Array(Type::Bool.into()).into(), + actual: Type::Bool, + }), + "" + ); + } + + { + let expr = assert_ok!(FilterParser::new(scheme).lex_as("t"), t_expr()); + + assert_eq!(expr.get_type(), Type::Bool); + + assert_json!( + expr, + { + "lhs": "t", + "op": "IsTrue" + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), true); + } + + { + let expr = assert_ok!(FilterParser::new(scheme).lex_as("at"), at_expr()); + + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": "at", + "op": "IsTrue" + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [true, false, true]); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("at[*]"), + LogicalExpr::Comparison(ComparisonExpr { + lhs: IndexExpr { + identifier: IdentifierExpr::Field(scheme.get_field("at").unwrap()), + indexes: vec![FieldIndex::MapEach], + }, + op: ComparisonOpExpr::IsTrue + }) + ); + + assert_eq!(expr.get_type(), Type::Array(Type::Bool.into())); + + assert_json!( + expr, + { + "lhs": ["at", {"kind": "MapEach"}], + "op": "IsTrue" + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [true, false, true]); + } + + { + assert_err!( + FilterParser::new(scheme).lex_as::>("aat[*]"), + LexErrorKind::UnsupportedOp { + lhs_type: Type::Array(Type::Array(Type::Bool.into()).into()) + }, + "" + ); + } + + let parenthesized_expr = + |expr| LogicalExpr::Parenthesized(Box::new(ParenthesizedExpr { expr })); + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("((t))"), + parenthesized_expr(parenthesized_expr(t_expr())) + ); + + assert_json!( + expr, + { + "lhs": "t", + "op": "IsTrue" + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), true); + } + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("((at))"), + parenthesized_expr(parenthesized_expr(at_expr())) + ); + + assert_json!( + expr, + { + "lhs": "at", + "op": "IsTrue" + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [true, false, true]); + } + + let not_expr = |expr| LogicalExpr::Unary { + op: UnaryOp::Not, + arg: Box::new(expr), + }; + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("not t"), + not_expr(t_expr()) + ); + + assert_json!( + expr, + { + "op": "Not", + "arg": { + "lhs": "t", + "op": "IsTrue" + } + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), false); + } + + assert_ok!(FilterParser::new(scheme).lex_as("!t"), not_expr(t_expr())); + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("not at"), + not_expr(at_expr()) + ); + + assert_json!( + expr, + { + "op": "Not", + "arg": { + "lhs": "at", + "op": "IsTrue" + } + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [false, true, false]); + } + + assert_ok!(FilterParser::new(scheme).lex_as("!at"), not_expr(at_expr())); + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("!!t"), + not_expr(not_expr(t_expr())) + ); + + assert_json!( + expr, + { + "op": "Not", + "arg": { + "op": "Not", + "arg": { + "lhs": "t", + "op": "IsTrue" + } + } + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), true); + } + + assert_ok!( + FilterParser::new(scheme).lex_as("! (not !t)"), + not_expr(parenthesized_expr(not_expr(not_expr(t_expr())))) + ); + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("!!at"), + not_expr(not_expr(at_expr())) + ); + + assert_json!( + expr, + { + "op": "Not", + "arg": { + "op": "Not", + "arg": { + "lhs": "at", + "op": "IsTrue" + } + } + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_vec(ctx), [true, false, true]); + } + + assert_ok!( + FilterParser::new(scheme).lex_as("! (not !at)"), + not_expr(parenthesized_expr(not_expr(not_expr(at_expr())))) + ); + + { + let expr = assert_ok!( + FilterParser::new(scheme).lex_as("not t && f"), + LogicalExpr::Combining { + op: LogicalOp::And, + items: vec![not_expr(t_expr()), f_expr()], + } + ); + + assert_json!( + expr, + { + "op": "And", + "items": [ + { + "op": "Not", + "arg": { + "lhs": "t", + "op": "IsTrue" + } + }, + { + "lhs": "f", + "op": "IsTrue" + } + ] + } + ); + + let expr = expr.compile(); + + assert_eq!(expr.execute_one(ctx), false); + } +} diff --git a/engine/src/ast/mod.rs b/engine/src/ast/mod.rs index 6faa4f0f..276297dc 100644 --- a/engine/src/ast/mod.rs +++ b/engine/src/ast/mod.rs @@ -1,20 +1,62 @@ -mod combined_expr; -mod field_expr; -mod function_expr; -mod simple_expr; +pub mod field_expr; +pub mod function_expr; +pub mod index_expr; +pub mod logical_expr; +pub mod parse; +pub mod visitor; -use self::combined_expr::CombinedExpr; +use self::index_expr::IndexExpr; +use self::logical_expr::LogicalExpr; +use self::parse::FilterParser; use crate::{ - filter::{CompiledExpr, Filter}, - lex::{LexResult, LexWith}, - scheme::{Field, Scheme, UnknownFieldError}, + compiler::{Compiler, DefaultCompiler}, + filter::{CompiledExpr, CompiledValueExpr, Filter, FilterValue}, + lex::{LexErrorKind, LexResult, LexWith}, + scheme::{Scheme, UnknownFieldError}, + types::{GetType, Type, TypeMismatchError}, }; use serde::Serialize; use std::fmt::{self, Debug}; +use visitor::{UsesListVisitor, UsesVisitor, Visitor, VisitorMut}; -trait Expr<'s>: Sized + Eq + Debug + for<'i> LexWith<'i, &'s Scheme> + Serialize { - fn uses(&self, field: Field<'s>) -> bool; - fn compile(self) -> CompiledExpr<'s>; +/// Trait used to represent node that evaluates to a [`bool`] (or a [`Vec`]). +pub trait Expr<'s>: + Sized + Eq + Debug + for<'i, 'p> LexWith<'i, &'p FilterParser<'s>> + Serialize +{ + /// Recursively visit all nodes in the AST using a [`Visitor`]. + fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V); + /// Recursively visit all nodes in the AST using a [`VisitorMut`]. + fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V); + /// Compiles current node into a [`CompiledExpr`] using [`Compiler`]. + fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledExpr<'s, C::U>; + /// Compiles current node into a [`CompiledExpr`] using [`DefaultCompiler`]. + fn compile(self) -> CompiledExpr<'s> { + let mut compiler = DefaultCompiler::new(); + self.compile_with_compiler(&mut compiler) + } +} + +/// Trait used to represent node that evaluates to an [`LhsValue`]. +pub trait ValueExpr<'s>: + Sized + Eq + Debug + for<'i, 'p> LexWith<'i, &'p FilterParser<'s>> + Serialize +{ + /// Recursively visit all nodes in the AST using a [`Visitor`]. + fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V); + /// Recursively visit all nodes in the AST using a [`VisitorMut`]. + fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V); + /// Compiles current node into a [`CompiledValueExpr`] using [`Compiler`]. + fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> CompiledValueExpr<'s, C::U>; + /// Compiles current node into a [`CompiledValueExpr`] using [`DefaultCompiler`]. + fn compile(self) -> CompiledValueExpr<'s> { + let mut compiler = DefaultCompiler::new(); + self.compile_with_compiler(&mut compiler) + } } /// A parsed filter AST. @@ -22,13 +64,13 @@ trait Expr<'s>: Sized + Eq + Debug + for<'i> LexWith<'i, &'s Scheme> + Serialize /// It's attached to its corresponding [`Scheme`](struct@Scheme) because all /// parsed fields are represented as indices and are valid only when /// [`ExecutionContext`](::ExecutionContext) is created from the same scheme. -#[derive(PartialEq, Eq, Serialize, Clone)] +#[derive(PartialEq, Eq, Serialize, Clone, Hash)] #[serde(transparent)] pub struct FilterAst<'s> { #[serde(skip)] scheme: &'s Scheme, - op: CombinedExpr<'s>, + op: LogicalExpr<'s>, } impl<'s> Debug for FilterAst<'s> { @@ -37,25 +79,205 @@ impl<'s> Debug for FilterAst<'s> { } } -impl<'i, 's> LexWith<'i, &'s Scheme> for FilterAst<'s> { - fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { - let (op, input) = CombinedExpr::lex_with(input, scheme)?; - Ok((FilterAst { scheme, op }, input)) +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for FilterAst<'s> { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let (op, input) = LogicalExpr::lex_with(input, parser)?; + // LogicalExpr::lex_with can return an AST where the root is an + // LogicalExpr::Combining of type [`Array(Bool)`]. + // + // It must do this because we need to be able to use + // LogicalExpr::Combining of type [`Array(Bool)`] + // as arguments to functions, however it should not be valid as a + // filter expression itself. + // + // Here we enforce the constraint that the root of the AST, a + // LogicalExpr, must evaluate to type [`Bool`]. + let ty = op.get_type(); + match ty { + Type::Bool => Ok(( + FilterAst { + scheme: parser.scheme, + op, + }, + input, + )), + _ => Err(( + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Bool.into(), + actual: ty, + }), + input, + )), + } } } impl<'s> FilterAst<'s> { + /// Returns the associated scheme. + #[inline] + pub fn scheme(&self) -> &'s Scheme { + self.scheme + } + + /// Returns the associated expression. + #[inline] + pub fn expression(&self) -> &LogicalExpr<'s> { + &self.op + } + + /// Recursively visit all nodes in the AST using a [`Visitor`]. + #[inline] + pub fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V) { + visitor.visit_logical_expr(&self.op) + } + + /// Recursively visit all nodes in the AST using a [`VisitorMut`]. + #[inline] + pub fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V) { + visitor.visit_logical_expr(&mut self.op) + } + /// Recursively checks whether a [`FilterAst`] uses a given field name. /// /// This is useful to lazily initialise expensive fields only if necessary. pub fn uses(&self, field_name: &str) -> Result { - self.scheme - .get_field_index(field_name) - .map(|field| self.op.uses(field)) + self.scheme.get_field(field_name).map(|field| { + let mut visitor = UsesVisitor::new(field); + self.walk(&mut visitor); + visitor.uses() + }) + } + + /// Recursively checks whether a [`FilterAst`] uses a list. + pub fn uses_list(&self, field_name: &str) -> Result { + self.scheme.get_field(field_name).map(|field| { + let mut visitor = UsesListVisitor::new(field); + self.walk(&mut visitor); + visitor.uses() + }) } - /// Compiles a [`FilterAst`] into a [`Filter`]. + /// Compiles a [`FilterAst`] into a [`Filter`] using a specific [`Compiler`]. + pub fn compile_with_compiler + 's>(self, compiler: &mut C) -> Filter<'s, C::U> { + match compiler.compile_logical_expr(self.op) { + CompiledExpr::One(one) => Filter::new(one, self.scheme), + CompiledExpr::Vec(_) => unreachable!(), + } + } + + /// Compiles a [`FilterAst`] into a [`Filter`] using the [`DefaultCompiler`]. pub fn compile(self) -> Filter<'s> { - Filter::new(self.op.compile(), self.scheme) + let mut compiler = DefaultCompiler::new(); + self.compile_with_compiler(&mut compiler) + } +} + +/// A parsed value AST. +/// +/// It's attached to its corresponding [`Scheme`](struct@Scheme) because all +/// parsed fields are represented as indices and are valid only when +/// [`ExecutionContext`](::ExecutionContext) is created from the same scheme. +#[derive(PartialEq, Eq, Serialize, Clone, Hash)] +#[serde(transparent)] +pub struct FilterValueAst<'s> { + #[serde(skip)] + scheme: &'s Scheme, + + op: IndexExpr<'s>, +} + +impl<'s> Debug for FilterValueAst<'s> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.op.fmt(f) + } +} + +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for FilterValueAst<'s> { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + let (op, rest) = IndexExpr::lex_with(input.trim(), parser)?; + if op.map_each_count() > 0 { + Err(( + LexErrorKind::TypeMismatch(TypeMismatchError { + expected: op.get_type().into(), + actual: Type::Array(op.get_type().into()), + }), + input, + )) + } else { + Ok(( + FilterValueAst { + scheme: parser.scheme(), + op, + }, + rest, + )) + } + } +} + +impl<'s> FilterValueAst<'s> { + /// Returns the associated scheme. + #[inline] + pub fn scheme(&self) -> &'s Scheme { + self.scheme + } + + /// Returns the associated expression. + #[inline] + pub fn expression(&self) -> &IndexExpr<'s> { + &self.op + } + + /// Recursively visit all nodes in the AST using a [`Visitor`]. + #[inline] + pub fn walk<'a, V: Visitor<'s, 'a>>(&'a self, visitor: &mut V) { + visitor.visit_index_expr(&self.op) + } + + /// Recursively visit all nodes in the AST using a [`VisitorMut`]. + #[inline] + pub fn walk_mut<'a, V: VisitorMut<'s, 'a>>(&'a mut self, visitor: &mut V) { + visitor.visit_index_expr(&mut self.op) + } + + /// Recursively checks whether a [`FilterAst`] uses a given field name. + /// + /// This is useful to lazily initialise expensive fields only if necessary. + pub fn uses(&self, field_name: &str) -> Result { + self.scheme.get_field(field_name).map(|field| { + let mut visitor = UsesVisitor::new(field); + self.walk(&mut visitor); + visitor.uses() + }) + } + + /// Recursively checks whether a [`FilterAst`] uses a list. + pub fn uses_list(&self, field_name: &str) -> Result { + self.scheme.get_field(field_name).map(|field| { + let mut visitor = UsesListVisitor::new(field); + self.walk(&mut visitor); + visitor.uses() + }) + } + + /// Compiles a [`FilterValueAst`] into a [`FilterValue`] using a specific [`Compiler`]. + pub fn compile_with_compiler + 's>( + self, + compiler: &mut C, + ) -> FilterValue<'s, C::U> { + FilterValue::new(compiler.compile_index_expr(self.op), self.scheme) + } + + /// Compiles a [`FilterValueAst`] into a [`FilterValue`] using the [`DefaultCompiler`]. + pub fn compile(self) -> FilterValue<'s> { + let mut compiler = DefaultCompiler::new(); + self.compile_with_compiler(&mut compiler) + } +} + +impl<'s> GetType for FilterValueAst<'s> { + #[inline] + fn get_type(&self) -> Type { + self.op.get_type() } } diff --git a/engine/src/ast/parse.rs b/engine/src/ast/parse.rs new file mode 100644 index 00000000..94c0e3cf --- /dev/null +++ b/engine/src/ast/parse.rs @@ -0,0 +1,181 @@ +use super::{FilterAst, FilterValueAst}; +use crate::{ + lex::{complete, LexErrorKind, LexResult, LexWith}, + scheme::Scheme, +}; +use std::cmp::{max, min}; +use std::error::Error; +use std::fmt::{self, Debug, Display, Formatter}; + +/// An opaque filter parsing error associated with the original input. +/// +/// For now, you can just print it in a debug or a human-readable fashion. +#[derive(Debug, PartialEq)] +pub struct ParseError<'i> { + /// The error that occurred when parsing the input + pub(crate) kind: LexErrorKind, + + /// The input that caused the parse error + pub(crate) input: &'i str, + + /// The line number on the input where the error occurred + pub(crate) line_number: usize, + + /// The start of the bad input + pub(crate) span_start: usize, + + /// The number of characters that span the bad input + pub(crate) span_len: usize, +} + +impl<'i> Error for ParseError<'i> {} + +impl<'i> ParseError<'i> { + /// Create a new ParseError for the input, LexErrorKind and span in the + /// input. + pub fn new(mut input: &'i str, (kind, span): (LexErrorKind, &'i str)) -> Self { + let input_range = input.as_ptr() as usize..=input.as_ptr() as usize + input.len(); + assert!( + input_range.contains(&(span.as_ptr() as usize)) + && input_range.contains(&(span.as_ptr() as usize + span.len())) + ); + let mut span_start = span.as_ptr() as usize - input.as_ptr() as usize; + + let (line_number, line_start) = input[..span_start] + .match_indices('\n') + .map(|(pos, _)| pos + 1) + .scan(0, |line_number, line_start| { + *line_number += 1; + Some((*line_number, line_start)) + }) + .last() + .unwrap_or_default(); + + input = &input[line_start..]; + + span_start -= line_start; + let mut span_len = span.len(); + + if let Some(line_end) = input.find('\n') { + input = &input[..line_end]; + span_len = min(span_len, line_end - span_start); + } + + ParseError { + kind, + input, + line_number, + span_start, + span_len, + } + } +} + +impl<'i> Display for ParseError<'i> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + writeln!( + f, + "Filter parsing error ({}:{}):", + self.line_number + 1, + self.span_start + 1 + )?; + + writeln!(f, "{}", self.input)?; + + for _ in 0..self.span_start { + write!(f, " ")?; + } + + for _ in 0..max(1, self.span_len) { + write!(f, "^")?; + } + + writeln!(f, " {}", self.kind)?; + + Ok(()) + } +} + +/// A structure used to drive parsing of an expression into a [`FilterAst`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FilterParser<'s> { + pub(crate) scheme: &'s Scheme, + pub(crate) regex_dfa_size_limit: usize, + pub(crate) regex_compiled_size_limit: usize, + pub(crate) wildcard_star_limit: usize, +} + +impl<'s> FilterParser<'s> { + /// Creates a new parser with default configuration. + #[inline] + pub fn new(scheme: &'s Scheme) -> Self { + Self { + scheme, + // Default value extracted from the regex crate. + regex_compiled_size_limit: 10 * (1 << 20), + // Default value extracted from the regex crate. + regex_dfa_size_limit: 2 * (1 << 20), + wildcard_star_limit: usize::MAX, + } + } + + /// Returns the [`Scheme`](struct@Scheme) for which this parser has been constructor for. + #[inline] + pub fn scheme(&self) -> &'s Scheme { + self.scheme + } + + #[inline] + pub(crate) fn lex_as<'i, L: for<'p> LexWith<'i, &'p FilterParser<'s>>>( + &self, + input: &'i str, + ) -> LexResult<'i, L> { + L::lex_with(input, self) + } + + /// Parses a filter expression into an AST form. + pub fn parse<'i>(&self, input: &'i str) -> Result, ParseError<'i>> { + complete(self.lex_as(input.trim())).map_err(|err| ParseError::new(input, err)) + } + + /// Parses a value expression into an AST form. + pub fn parse_value<'i>(&self, input: &'i str) -> Result, ParseError<'i>> { + complete(self.lex_as(input.trim())).map_err(|err| ParseError::new(input, err)) + } + + /// Set the approximate size limit of the compiled regular expression. + #[inline] + pub fn regex_set_compiled_size_limit(&mut self, regex_compiled_size_limit: usize) { + self.regex_compiled_size_limit = regex_compiled_size_limit; + } + + /// Get the approximate size limit of the compiled regular expression. + #[inline] + pub fn regex_get_compiled_size_limit(&self) -> usize { + self.regex_compiled_size_limit + } + + /// Set the approximate size of the cache used by the DFA of a regex. + #[inline] + pub fn regex_set_dfa_size_limit(&mut self, regex_dfa_size_limit: usize) { + self.regex_dfa_size_limit = regex_dfa_size_limit; + } + + /// Get the approximate size of the cache used by the DFA of a regex. + #[inline] + pub fn regex_get_dfa_size_limit(&self) -> usize { + self.regex_dfa_size_limit + } + + /// Set the maximum number of star metacharacters allowed in a wildcard. + #[inline] + pub fn wildcard_set_star_limit(&mut self, wildcard_star_limit: usize) { + self.wildcard_star_limit = wildcard_star_limit; + } + + /// Get the maximum number of star metacharacters allowed in a wildcard. + #[inline] + pub fn wildcard_get_star_limit(&self) -> usize { + self.wildcard_star_limit + } +} diff --git a/engine/src/ast/simple_expr.rs b/engine/src/ast/simple_expr.rs deleted file mode 100644 index d36fb6c4..00000000 --- a/engine/src/ast/simple_expr.rs +++ /dev/null @@ -1,176 +0,0 @@ -use super::{combined_expr::CombinedExpr, field_expr::FieldExpr, CompiledExpr, Expr}; -use crate::{ - lex::{expect, skip_space, Lex, LexResult, LexWith}, - scheme::{Field, Scheme}, -}; -use serde::Serialize; - -lex_enum!(UnaryOp { - "not" | "!" => Not, -}); - -#[derive(Debug, PartialEq, Eq, Clone, Serialize)] -#[serde(untagged)] -pub enum SimpleExpr<'s> { - Field(FieldExpr<'s>), - Parenthesized(Box>), - Unary { - op: UnaryOp, - arg: Box>, - }, -} - -impl<'i, 's> LexWith<'i, &'s Scheme> for SimpleExpr<'s> { - fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { - Ok(if let Ok(input) = expect(input, "(") { - let input = skip_space(input); - let (op, input) = CombinedExpr::lex_with(input, scheme)?; - let input = skip_space(input); - let input = expect(input, ")")?; - (SimpleExpr::Parenthesized(Box::new(op)), input) - } else if let Ok((op, input)) = UnaryOp::lex(input) { - let input = skip_space(input); - let (arg, input) = SimpleExpr::lex_with(input, scheme)?; - ( - SimpleExpr::Unary { - op, - arg: Box::new(arg), - }, - input, - ) - } else { - let (op, input) = FieldExpr::lex_with(input, scheme)?; - (SimpleExpr::Field(op), input) - }) - } -} - -impl<'s> Expr<'s> for SimpleExpr<'s> { - fn uses(&self, field: Field<'s>) -> bool { - match self { - SimpleExpr::Field(op) => op.uses(field), - SimpleExpr::Parenthesized(op) => op.uses(field), - SimpleExpr::Unary { arg, .. } => arg.uses(field), - } - } - - fn compile(self) -> CompiledExpr<'s> { - match self { - SimpleExpr::Field(op) => op.compile(), - SimpleExpr::Parenthesized(op) => op.compile(), - SimpleExpr::Unary { - op: UnaryOp::Not, - arg, - } => { - let arg = arg.compile(); - CompiledExpr::new(move |ctx| !arg.execute(ctx)) - } - } - } -} - -#[test] -fn test() { - use crate::{execution_context::ExecutionContext, lex::complete}; - - let scheme = &Scheme! { t: Bool }; - - let ctx = &mut ExecutionContext::new(scheme); - ctx.set_field_value("t", true).unwrap(); - - let t_expr = SimpleExpr::Field(complete(FieldExpr::lex_with("t", scheme)).unwrap()); - let t_expr = || t_expr.clone(); - - { - let expr = assert_ok!(SimpleExpr::lex_with("t", scheme), t_expr()); - - assert_json!( - expr, - { - "lhs": "t", - "op": "IsTrue" - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), true); - } - - let parenthesized_expr = |expr| SimpleExpr::Parenthesized(Box::new(CombinedExpr::Simple(expr))); - - { - let expr = assert_ok!( - SimpleExpr::lex_with("((t))", scheme), - parenthesized_expr(parenthesized_expr(t_expr())) - ); - - assert_json!( - expr, - { - "lhs": "t", - "op": "IsTrue" - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), true); - } - - let not_expr = |expr| SimpleExpr::Unary { - op: UnaryOp::Not, - arg: Box::new(expr), - }; - - { - let expr = assert_ok!(SimpleExpr::lex_with("not t", scheme), not_expr(t_expr())); - - assert_json!( - expr, - { - "op": "Not", - "arg": { - "lhs": "t", - "op": "IsTrue" - } - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), false); - } - - assert_ok!(SimpleExpr::lex_with("!t", scheme), not_expr(t_expr())); - - { - let expr = assert_ok!( - SimpleExpr::lex_with("!!t", scheme), - not_expr(not_expr(t_expr())) - ); - - assert_json!( - expr, - { - "op": "Not", - "arg": { - "op": "Not", - "arg": { - "lhs": "t", - "op": "IsTrue" - } - } - } - ); - - let expr = expr.compile(); - - assert_eq!(expr.execute(ctx), true); - } - - assert_ok!( - SimpleExpr::lex_with("! (not !t)", scheme), - not_expr(parenthesized_expr(not_expr(not_expr(t_expr())))) - ); -} diff --git a/engine/src/ast/visitor.rs b/engine/src/ast/visitor.rs new file mode 100644 index 00000000..c9b9d6a9 --- /dev/null +++ b/engine/src/ast/visitor.rs @@ -0,0 +1,308 @@ +use super::{ + field_expr::{ComparisonExpr, ComparisonOpExpr}, + function_expr::{FunctionCallArgExpr, FunctionCallExpr}, + index_expr::IndexExpr, + logical_expr::LogicalExpr, + Expr, ValueExpr, +}; +use crate::scheme::{Field, Function}; + +/// Trait used to immutably visit all nodes in the AST. +pub trait Visitor<'s, 'a>: Sized { + // `Expr` node visitor methods + + /// Visit [`Expr`] node. + #[inline] + fn visit_expr(&mut self, node: &'a impl Expr<'s>) { + node.walk(self) + } + + /// Visit [`LogicalExpr`] node. + #[inline] + fn visit_logical_expr(&mut self, node: &'a LogicalExpr<'s>) { + self.visit_expr(node) + } + + /// Visit [`ComparisonExpr`] node. + #[inline] + fn visit_comparison_expr(&mut self, node: &'a ComparisonExpr<'s>) { + self.visit_expr(node) + } + + // `ValueExpr` node visitor methods + + /// Visit [`ValueExpr`] node. + #[inline] + fn visit_value_expr(&mut self, node: &'a impl ValueExpr<'s>) { + node.walk(self) + } + + /// Visit [`IndexExpr`] node. + #[inline] + fn visit_index_expr(&mut self, node: &'a IndexExpr<'s>) { + self.visit_value_expr(node) + } + + /// Visit [`FunctionCallExpr`] node. + #[inline] + fn visit_function_call_expr(&mut self, node: &'a FunctionCallExpr<'s>) { + self.visit_value_expr(node) + } + + /// Visit [`FunctionCallArgExpr`] node. + #[inline] + fn visit_function_call_arg_expr(&mut self, node: &'a FunctionCallArgExpr<'s>) { + self.visit_value_expr(node) + } + + // Leaf node visitor methods + + /// Visit [`Field`] node. + #[inline] + fn visit_field(&mut self, _: &'a Field<'s>) {} + + /// Visit [`Function`] node. + #[inline] + fn visit_function(&mut self, _: &'a Function<'s>) {} + + // TODO: add visitor methods for literals? +} + +/// Trait used to mutably visit all nodes in the AST. +/// +/// Note that this trait is dangerous and any modification +/// to the AST should be done with cautions and respect +/// some invariants such as keeping type coherency. +pub trait VisitorMut<'s, 'a>: Sized { + // `Expr` node visitor methods + + /// Visit [`Expr`] node. + #[inline] + fn visit_expr(&mut self, node: &'a mut impl Expr<'s>) { + node.walk_mut(self) + } + + /// Visit [`LogicalExpr`] node. + #[inline] + fn visit_logical_expr(&mut self, node: &'a mut LogicalExpr<'s>) { + self.visit_expr(node) + } + + /// Visit [`ComparisonExpr`] node. + #[inline] + fn visit_comparison_expr(&mut self, node: &'a mut ComparisonExpr<'s>) { + self.visit_expr(node) + } + + // `ValueExpr` node visitor methods + + /// Visit [`ValueExpr`] node. + #[inline] + fn visit_value_expr(&mut self, node: &'a mut impl ValueExpr<'s>) { + node.walk_mut(self) + } + + /// Visit [`IndexExpr`] node. + #[inline] + fn visit_index_expr(&mut self, node: &'a mut IndexExpr<'s>) { + self.visit_value_expr(node) + } + + /// Visit [`FunctionCallExpr`] node. + #[inline] + fn visit_function_call_expr(&mut self, node: &'a mut FunctionCallExpr<'s>) { + self.visit_value_expr(node) + } + + /// Visit [`FunctionCallArgExpr`] node. + #[inline] + fn visit_function_call_arg_expr(&mut self, node: &'a mut FunctionCallArgExpr<'s>) { + self.visit_value_expr(node) + } + + // Leaf node visitor methods + + /// Visit [`Field`] node. + #[inline] + fn visit_field(&mut self, _: &'a Field<'s>) {} + + /// Visit [`Function`] node. + #[inline] + fn visit_function(&mut self, _: &'a Function<'s>) {} + + // TODO: add visitor methods for literals? +} + +/// Recursively check if a [`Field`] is being used. +pub(crate) struct UsesVisitor<'s> { + field: Field<'s>, + uses: bool, +} + +impl<'s> UsesVisitor<'s> { + pub fn new(field: Field<'s>) -> Self { + Self { field, uses: false } + } + + pub fn uses(&self) -> bool { + self.uses + } +} + +impl<'s> Visitor<'s, '_> for UsesVisitor<'s> { + fn visit_expr(&mut self, node: &impl Expr<'s>) { + // Stop visiting the AST once we have found one occurence of the field + if !self.uses { + node.walk(self) + } + } + + fn visit_value_expr(&mut self, node: &impl ValueExpr<'s>) { + // Stop visiting the AST once we have found one occurence of the field + if !self.uses { + node.walk(self) + } + } + + fn visit_field(&mut self, f: &Field<'s>) { + if self.field == *f { + self.uses = true; + } + } +} + +/// Recursively check if a [`Field`] is being used in a list comparison. +pub(crate) struct UsesListVisitor<'s> { + field: Field<'s>, + uses: bool, +} + +impl<'s> UsesListVisitor<'s> { + pub fn new(field: Field<'s>) -> Self { + Self { field, uses: false } + } + + pub fn uses(&self) -> bool { + self.uses + } +} + +impl<'s> Visitor<'s, '_> for UsesListVisitor<'s> { + fn visit_expr(&mut self, node: &impl Expr<'s>) { + // Stop visiting the AST once we have found one occurence of the field + if !self.uses { + node.walk(self) + } + } + + fn visit_value_expr(&mut self, node: &impl ValueExpr<'s>) { + // Stop visiting the AST once we have found one occurence of the field + if !self.uses { + node.walk(self) + } + } + + fn visit_comparison_expr(&mut self, comparison_expr: &ComparisonExpr<'s>) { + if let ComparisonOpExpr::InList { .. } = comparison_expr.op { + let mut visitor = UsesVisitor::new(self.field); + visitor.visit_comparison_expr(comparison_expr); + if visitor.uses { + self.uses = true; + } + } + if !self.uses { + comparison_expr.walk(self) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + AlwaysList, FunctionArgKind, Scheme, SimpleFunctionDefinition, SimpleFunctionImpl, + SimpleFunctionParam, Type, + }; + use std::sync::LazyLock; + + static SCHEME: LazyLock = LazyLock::new(|| { + let mut scheme = Scheme! { + http.headers: Map(Bytes), + http.request.headers.names: Array(Bytes), + http.request.headers.values: Array(Bytes), + http.host: Bytes, + ip.addr: Ip, + ssl: Bool, + tcp.port: Int, + }; + scheme + .add_function( + "echo", + SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_params: vec![], + return_type: Type::Bytes, + implementation: SimpleFunctionImpl::new(|args| args.next()?.ok()), + }, + ) + .unwrap(); + scheme + .add_list(Type::Bytes, Box::new(AlwaysList {})) + .unwrap(); + scheme + }); + + #[test] + fn test_uses_visitor_simple() { + let ast = SCHEME.parse(r#"http.host == "test""#).unwrap(); + for field in SCHEME.fields() { + assert_eq!(ast.uses(field.name()), Ok(field.name() == "http.host")); + } + } + + #[test] + fn test_uses_list_visitor_simple() { + let ast = SCHEME.parse(r#"http.host in $test"#).unwrap(); + for field in SCHEME.fields() { + assert_eq!(ast.uses(field.name()), Ok(field.name() == "http.host")); + } + } + + #[test] + fn test_uses_visitor_function() { + let ast = SCHEME.parse(r#"echo(http.host) == "test""#).unwrap(); + for field in SCHEME.fields() { + assert_eq!(ast.uses(field.name()), Ok(field.name() == "http.host")); + } + } + + #[test] + fn test_uses_list_visitor_function() { + let ast = SCHEME.parse(r#"echo(http.host) in $test"#).unwrap(); + for field in SCHEME.fields() { + assert_eq!(ast.uses(field.name()), Ok(field.name() == "http.host")); + } + } + + #[test] + fn test_uses_visitor_mapeach() { + let ast = SCHEME + .parse(r#"echo(echo(http.headers[*])[*])[0] == "test""#) + .unwrap(); + for field in SCHEME.fields() { + assert_eq!(ast.uses(field.name()), Ok(field.name() == "http.headers")); + } + } + + #[test] + fn test_uses_list_visitor_mapeach() { + let ast = SCHEME + .parse(r#"echo(echo(http.headers[*])[*])[0] in $test"#) + .unwrap(); + for field in SCHEME.fields() { + assert_eq!(ast.uses(field.name()), Ok(field.name() == "http.headers")); + } + } +} diff --git a/engine/src/compiler.rs b/engine/src/compiler.rs new file mode 100644 index 00000000..45f418ae --- /dev/null +++ b/engine/src/compiler.rs @@ -0,0 +1,85 @@ +use crate::{ + ComparisonExpr, CompiledExpr, CompiledValueExpr, Expr, FunctionCallArgExpr, FunctionCallExpr, + IndexExpr, LogicalExpr, ValueExpr, +}; + +/// Trait used to drive the compilation of a [`FilterAst`] into a [`Filter`]. +pub trait Compiler<'s>: Sized + 's { + /// The user data type passed in the [`ExecutionContext`]. + type U; + + /// Compiles a [`Expr`] node into a [`CompiledExpr`] (boxed closure). + #[inline] + fn compile_expr(&mut self, node: impl Expr<'s>) -> CompiledExpr<'s, Self::U> { + node.compile_with_compiler(self) + } + + /// Compiles a [`LogicalExpr`] node into a [`CompiledExpr`] (boxed closure). + #[inline] + fn compile_logical_expr(&mut self, node: LogicalExpr<'s>) -> CompiledExpr<'s, Self::U> { + self.compile_expr(node) + } + + /// Compiles a [`ComparisonExpr`] node into a [`CompiledExpr`] (boxed closure). + #[inline] + fn compile_comparison_expr(&mut self, node: ComparisonExpr<'s>) -> CompiledExpr<'s, Self::U> { + self.compile_expr(node) + } + + /// Compiles a [`ValueExpr`] node into a [`CompiledValueExpr`] (boxed closure). + #[inline] + fn compile_value_expr(&mut self, node: impl ValueExpr<'s>) -> CompiledValueExpr<'s, Self::U> { + node.compile_with_compiler(self) + } + + /// Compiles a [`FunctionCallExpr`] node into a [`CompiledValueExpr`] (boxed closure). + #[inline] + fn compile_function_call_expr( + &mut self, + node: FunctionCallExpr<'s>, + ) -> CompiledValueExpr<'s, Self::U> { + self.compile_value_expr(node) + } + + /// Compiles a [`FunctionCallArgExpr`] node into a [`CompiledValueExpr`] (boxed closure). + #[inline] + fn compile_function_call_arg_expr( + &mut self, + node: FunctionCallArgExpr<'s>, + ) -> CompiledValueExpr<'s, Self::U> { + self.compile_value_expr(node) + } + + /// Compiles a [`IndexExpr`] node into a [`CompiledValueExpr`] (boxed closure). + #[inline] + fn compile_index_expr(&mut self, node: IndexExpr<'s>) -> CompiledValueExpr<'s, Self::U> { + self.compile_value_expr(node) + } +} + +/// Default compiler +#[derive(Clone, Copy, Debug)] +pub struct DefaultCompiler { + _marker: std::marker::PhantomData, +} + +impl Default for DefaultCompiler { + #[inline] + fn default() -> Self { + Self { + _marker: std::marker::PhantomData, + } + } +} + +impl DefaultCompiler { + /// Creates a new [`DefaultCompiler`]. + #[inline] + pub fn new() -> Self { + Self::default() + } +} + +impl<'s, U: 's> Compiler<'s> for DefaultCompiler { + type U = U; +} diff --git a/engine/src/execution_context.rs b/engine/src/execution_context.rs index 13bd4fd0..c47b4601 100644 --- a/engine/src/execution_context.rs +++ b/engine/src/execution_context.rs @@ -1,35 +1,106 @@ use crate::{ - scheme::{Field, Scheme}, - types::{GetType, LhsValue, TypeMismatchError}, + scheme::{Field, List, Scheme, SchemeMismatchError}, + types::{GetType, LhsValue, LhsValueSeed, Type, TypeMismatchError}, + ListMatcher, }; +use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, Visitor}; +use serde::ser::{SerializeMap, SerializeSeq, Serializer}; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::fmt; +use std::fmt::Debug; +use thiserror::Error; + +/// An error that occurs when setting the field value in the [`ExecutionContext`](struct@ExecutionContext) +#[derive(Debug, PartialEq, Eq, Error)] +pub enum SetFieldValueError { + /// An error that occurs when trying to assign a value of the wrong type to a field. + #[error("{0}")] + TypeMismatchError(#[source] TypeMismatchError), + + /// An error that occurs when trying to set the value of a field from a different scheme. + #[error("{0}")] + SchemeMismatchError(#[source] SchemeMismatchError), +} + +/// An error that occurs when previously defined list gets redefined. +#[derive(Debug, PartialEq, Eq, Error)] +#[error("Invalid list matcher {matcher} for list {list}")] +pub struct InvalidListMatcherError { + matcher: String, + list: String, +} /// An execution context stores an associated [`Scheme`](struct@Scheme) and a /// set of runtime values to execute [`Filter`](::Filter) against. /// /// It acts as a map in terms of public API, but provides a constant-time /// index-based access to values for a filter during execution. -pub struct ExecutionContext<'e> { +#[derive(Debug, PartialEq)] +pub struct ExecutionContext<'e, U = ()> { scheme: &'e Scheme, values: Box<[Option>]>, + list_matchers: Box<[Box]>, + user_data: U, } -impl<'e> ExecutionContext<'e> { +impl<'e, U> ExecutionContext<'e, U> { /// Creates an execution context associated with a given scheme. /// /// This scheme will be used for resolving any field names and indices. - pub fn new<'s: 'e>(scheme: &'s Scheme) -> Self { + pub fn new<'s: 'e>(scheme: &'s Scheme) -> Self + where + U: Default, + { + Self::new_with(scheme, Default::default) + } + + /// Creates an execution context associated with a given scheme. + /// + /// This scheme will be used for resolving any field names and indices. + pub fn new_with<'s: 'e>(scheme: &'s Scheme, f: impl FnOnce() -> U) -> Self { ExecutionContext { scheme, - values: vec![None; scheme.get_field_count()].into(), + values: vec![None; scheme.field_count()].into(), + list_matchers: scheme + .lists() + .map(|list| list.definition().new_matcher()) + .collect(), + user_data: f(), } } - /// Returns an associated scheme. + /// Returns the associated scheme. pub fn scheme(&self) -> &'e Scheme { self.scheme } - pub(crate) fn get_field_value_unchecked(&'e self, field: Field<'e>) -> LhsValue<'e> { + /// Sets a runtime value for a given field name. + pub fn set_field_value<'v: 'e, V: Into>>( + &mut self, + field: Field<'e>, + value: V, + ) -> Result>, SetFieldValueError> { + if !std::ptr::eq(self.scheme, field.scheme()) { + return Err(SetFieldValueError::SchemeMismatchError(SchemeMismatchError)); + } + let value = value.into(); + + let field_type = field.get_type(); + let value_type = value.get_type(); + + if field_type == value_type { + Ok(self.values[field.index()].replace(value)) + } else { + Err(SetFieldValueError::TypeMismatchError(TypeMismatchError { + expected: field_type.into(), + actual: value_type, + })) + } + } + + #[inline] + pub(crate) fn get_field_value_unchecked(&self, field: Field<'_>) -> &LhsValue<'_> { // This is safe because this code is reachable only from Filter::execute // which already performs the scheme compatibility check, but check that // invariant holds in the future at least in the debug mode. @@ -38,37 +109,246 @@ impl<'e> ExecutionContext<'e> { // For now we panic in this, but later we are going to align behaviour // with wireshark: resolve all subexpressions that don't have RHS value // to `false`. - let lhs_value = self.values[field.index()].as_ref().unwrap_or_else(|| { + self.values[field.index()].as_ref().unwrap_or_else(|| { panic!( "Field {} was registered but not given a value", field.name() ); - }); - lhs_value.as_ref() + }) } - /// Sets a runtime value for a given field name. - pub fn set_field_value<'v: 'e, V: Into>>( - &mut self, - name: &str, - value: V, - ) -> Result<(), TypeMismatchError> { - let field = self.scheme.get_field_index(name).unwrap(); - let value = value.into(); + /// Get the value of a field. + pub fn get_field_value(&self, field: Field<'_>) -> Option<&LhsValue<'_>> { + assert!(self.scheme() == field.scheme()); - let field_type = field.get_type(); - let value_type = value.get_type(); + self.values[field.index()].as_ref() + } - if field_type == value_type { - self.values[field.index()] = Some(value); - Ok(()) - } else { - Err(TypeMismatchError { - expected: field_type, - actual: value_type, - }) + #[inline] + pub(crate) fn get_list_matcher_unchecked(&self, list: List<'_>) -> &dyn ListMatcher { + debug_assert!(self.scheme() == list.scheme()); + + &*self.list_matchers[list.index()] + } + + /// Get the list matcher object for the specified list type. + pub fn get_list_matcher(&self, list: List<'_>) -> &dyn ListMatcher { + assert!(self.scheme() == list.scheme()); + + &*self.list_matchers[list.index()] + } + + /// Get the list matcher object for the specified list type. + pub fn get_list_matcher_mut(&mut self, list: List<'_>) -> &mut dyn ListMatcher { + assert!(self.scheme() == list.scheme()); + + &mut *self.list_matchers[list.index()] + } + + /// Get immutable reference to user data stored in + /// this execution context with [`ExecutionContext::new_with`]. + #[inline] + pub fn get_user_data(&self) -> &U { + &self.user_data + } + + /// Get mutable reference to user data stored in + /// this execution context with [`ExecutionContext::new_with`]. + #[inline] + pub fn get_user_data_mut(&mut self) -> &mut U { + &mut self.user_data + } + + /// Extract all values and list data into a new [`ExecutionContext`]. + #[inline] + pub fn take_with(self, default: impl Fn(U) -> T) -> ExecutionContext<'e, T> { + ExecutionContext { + scheme: self.scheme, + values: self.values, + list_matchers: self.list_matchers, + user_data: default(self.user_data), } } + + /// Temporarily borrow all values and list data into a new [`ExecutionContext`]. + #[inline] + pub fn borrow_with(&mut self, user_data: T) -> ExecutionContextGuard<'_, 'e, U, T> { + ExecutionContextGuard::new(self, user_data) + } + + /// Clears the execution context, removing all values and lists + /// while retaining the allocated memory. + #[inline] + pub fn clear(&mut self) { + self.values.iter_mut().for_each(|value| *value = None); + self.list_matchers + .iter_mut() + .for_each(|list_matcher| list_matcher.clear()); + } +} + +/// Guard over a temporarily borrowed [`ExecutionContext`]. +/// When the guard is dropped, the original [`ExecutionContext`] +/// is restored. +pub struct ExecutionContextGuard<'a, 'e, U, T> { + old: &'a mut ExecutionContext<'e, U>, + new: ExecutionContext<'e, T>, +} + +impl<'a, 'e, U, T> ExecutionContextGuard<'a, 'e, U, T> { + fn new(old: &'a mut ExecutionContext<'e, U>, user_data: T) -> Self { + let scheme = old.scheme(); + let values = std::mem::take(&mut old.values); + let list_matchers = std::mem::take(&mut old.list_matchers); + + let new = ExecutionContext { + scheme, + values, + list_matchers, + user_data, + }; + + Self { old, new } + } +} + +impl<'a, 'e, U, T> std::ops::Deref for ExecutionContextGuard<'a, 'e, U, T> { + type Target = ExecutionContext<'e, T>; + + fn deref(&self) -> &Self::Target { + &self.new + } +} + +impl<'a, 'e, U, T> std::ops::DerefMut for ExecutionContextGuard<'a, 'e, U, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.new + } +} + +impl<'a, 'e, U, T> Drop for ExecutionContextGuard<'a, 'e, U, T> { + fn drop(&mut self) { + self.old.values = std::mem::take(&mut self.new.values); + self.old.list_matchers = std::mem::take(&mut self.new.list_matchers); + } +} + +#[derive(Serialize, Deserialize)] +struct ListData { + #[serde(rename = "type")] + ty: Type, + data: serde_json::Value, +} + +impl<'de, 'a, U> DeserializeSeed<'de> for &'a mut ExecutionContext<'de, U> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ExecutionContextVisitor<'de, 'a, U>(&'a mut ExecutionContext<'de, U>); + + impl<'de, 'a, U> Visitor<'de> for ExecutionContextVisitor<'de, 'a, U> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "a map of lhs value") + } + + fn visit_map(self, mut access: M) -> Result<(), M::Error> + where + M: MapAccess<'de>, + { + while let Some(key) = access.next_key::>()? { + if key == "$lists" { + // Deserialize lists + let vec = access.next_value::>()?; + for ListData { ty, data } in vec.into_iter() { + let list = self.0.scheme.get_list(&ty).ok_or_else(|| { + de::Error::custom(format!("unknown list for type: {ty:?}")) + })?; + self.0.list_matchers[list.index()] = list + .definition() + .matcher_from_json_value(ty, data) + .map_err(|err| { + de::Error::custom(format!( + "failed to deserialize list matcher: {err:?}" + )) + })?; + } + } else { + let field = self + .0 + .scheme + .get_field(&key) + .map_err(|_| de::Error::custom(format!("unknown field: {key}")))?; + let value = access + .next_value_seed::>(LhsValueSeed(&field.get_type()))?; + let field = self + .0 + .scheme() + .get_field(&key) + .map_err(|_| de::Error::custom(format!("unknown field: {key}")))?; + self.0.set_field_value(field, value).map_err(|e| match e { + SetFieldValueError::TypeMismatchError(e) => de::Error::custom(format!( + "invalid type: {:?}, expected {:?}", + e.actual, e.expected + )), + SetFieldValueError::SchemeMismatchError(_) => unreachable!(), + })?; + } + } + + Ok(()) + } + } + + deserializer.deserialize_map(ExecutionContextVisitor(self)) + } +} + +impl<'e> Serialize for ExecutionContext<'e> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.values.len()))?; + for field in self.scheme().fields() { + if let Some(Some(value)) = self.values.get(field.index()) { + map.serialize_entry(field.name(), value)?; + } + } + + struct ListMatcherSlice<'a>(&'a Scheme, &'a [Box]); + + impl<'a> Serialize for ListMatcherSlice<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.1.len()))?; + for list in self.0.lists() { + let data = self.1[list.index()].to_json_value(); + if data != serde_json::Value::Null { + seq.serialize_element(&ListData { + ty: list.get_type(), + data, + })?; + } + } + seq.end() + } + } + + if !self.list_matchers.is_empty() { + map.serialize_entry( + "$lists", + &ListMatcherSlice(self.scheme, &self.list_matchers), + )?; + } + map.end() + } } #[test] @@ -77,13 +357,236 @@ fn test_field_value_type_mismatch() { let scheme = Scheme! { foo: Int }; + let mut ctx = ExecutionContext::<()>::new(&scheme); + + assert_eq!( + ctx.set_field_value(scheme.get_field("foo").unwrap(), LhsValue::Bool(false)), + Err(SetFieldValueError::TypeMismatchError(TypeMismatchError { + expected: Type::Int.into(), + actual: Type::Bool, + })) + ); +} + +#[test] +fn test_scheme_mismatch() { + let scheme = Scheme! { foo: Bool }; + + let mut ctx = ExecutionContext::<()>::new(&scheme); + + let scheme2 = Scheme! { foo: Bool }; + + assert_eq!( + ctx.set_field_value(scheme2.get_field("foo").unwrap(), LhsValue::Bool(false)), + Err(SetFieldValueError::SchemeMismatchError( + SchemeMismatchError {} + )) + ); +} + +#[test] +fn test_serde() { + use crate::lhs_types::{Array, Map}; + use crate::types::Type; + use std::net::IpAddr; + use std::str::FromStr; + + let mut scheme = Scheme::new(); + scheme.add_field("bool", Type::Bool).unwrap(); + scheme.add_field("ip", Type::Ip).unwrap(); + scheme.add_field("str", Type::Bytes).unwrap(); + scheme.add_field("bytes", Type::Bytes).unwrap(); + scheme.add_field("num", Type::Int).unwrap(); + scheme.add_field("min_num", Type::Int).unwrap(); + scheme.add_field("max_num", Type::Int).unwrap(); + scheme + .add_field("arr", Type::Array(Type::Bool.into())) + .unwrap(); + scheme + .add_field("map", Type::Map(Type::Int.into())) + .unwrap(); + let mut ctx = ExecutionContext::new(&scheme); assert_eq!( - ctx.set_field_value("foo", LhsValue::Bool(false)), - Err(TypeMismatchError { - expected: Type::Int, - actual: Type::Bool - }) + ctx.set_field_value(scheme.get_field("bool").unwrap(), LhsValue::Bool(false)), + Ok(None), + ); + + assert_eq!( + ctx.set_field_value( + scheme.get_field("ip").unwrap(), + LhsValue::Ip(IpAddr::from_str("127.0.0.1").unwrap()) + ), + Ok(None), + ); + + assert_eq!( + ctx.set_field_value(scheme.get_field("str").unwrap(), "a string"), + Ok(None), + ); + assert_eq!( + ctx.set_field_value(scheme.get_field("bytes").unwrap(), &b"a\xFF\xFFb"[..]), + Ok(None), + ); + + assert_eq!( + ctx.set_field_value(scheme.get_field("num").unwrap(), 42), + Ok(None), + ); + + assert_eq!( + ctx.set_field_value(scheme.get_field("min_num").unwrap(), i64::MIN), + Ok(None), ); + + assert_eq!( + ctx.set_field_value(scheme.get_field("max_num").unwrap(), i64::MAX), + Ok(None), + ); + + assert_eq!( + ctx.set_field_value(scheme.get_field("arr").unwrap(), { + Array::from_iter([false, true]) + }), + Ok(None), + ); + + assert_eq!( + ctx.set_field_value(scheme.get_field("map").unwrap(), { + let mut map = Map::new(Type::Int); + map.insert(b"leet", 1337).unwrap(); + map.insert(b"tabs", 25).unwrap(); + map + }), + Ok(None), + ); + + let json = assert_json!( + ctx, + { + "bool": false, + "ip": "127.0.0.1", + "str": "a string", + "bytes": [97, 255, 255, 98], + "num": 42, + "min_num": i64::MIN, + "max_num": i64::MAX, + "arr": [false, true], + "map": { + "leet": 1337, + "tabs": 25, + } + } + ) + .to_string(); + + let mut ctx2 = ExecutionContext::new(&scheme); + let mut deserializer = serde_json::Deserializer::from_str(&json); + ctx2.deserialize(&mut deserializer).unwrap(); + assert_eq!(ctx, ctx2); + + let mut ctx2 = ExecutionContext::new(&scheme); + let mut deserializer = serde_json::Deserializer::from_slice(json.as_bytes()); + ctx2.deserialize(&mut deserializer).unwrap(); + assert_eq!(ctx, ctx2); + + let mut ctx3 = ExecutionContext::new(&scheme); + let mut deserializer = serde_json::Deserializer::from_reader(json.as_bytes()); + ctx3.deserialize(&mut deserializer).unwrap(); + assert_eq!(ctx, ctx3); + + assert_eq!( + ctx.set_field_value(scheme.get_field("map").unwrap(), { + let mut map = Map::new(Type::Int); + map.insert(b"leet", 1337).unwrap(); + map.insert(b"tabs", 25).unwrap(); + map.insert(b"a\xFF\xFFb", 17).unwrap(); + map + }), + Ok(Some({ + let mut map = Map::new(Type::Int); + map.insert(b"leet", 1337).unwrap(); + map.insert(b"tabs", 25).unwrap(); + map.into() + })), + ); + + let json = assert_json!( + ctx, + { + "bool": false, + "ip": "127.0.0.1", + "str": "a string", + "bytes": [97, 255, 255, 98], + "num": 42, + "min_num": i64::MIN, + "max_num": i64::MAX, + "arr": [false, true], + "map": [ + [[97, 255, 255, 98], 17], + ["leet", 1337], + ["tabs", 25] + ] + } + ) + .to_string(); + + let mut ctx2 = ExecutionContext::new(&scheme); + let mut deserializer = serde_json::Deserializer::from_str(&json); + ctx2.deserialize(&mut deserializer).unwrap(); + assert_eq!(ctx, ctx2); + + let mut ctx2 = ExecutionContext::new(&scheme); + let mut deserializer = serde_json::Deserializer::from_slice(json.as_bytes()); + ctx2.deserialize(&mut deserializer).unwrap(); + assert_eq!(ctx, ctx2); + + let mut ctx3 = ExecutionContext::new(&scheme); + let mut deserializer = serde_json::Deserializer::from_reader(json.as_bytes()); + ctx3.deserialize(&mut deserializer).unwrap(); + assert_eq!(ctx, ctx3); +} + +#[test] +fn test_clear() { + use crate::types::Type; + use std::net::IpAddr; + use std::str::FromStr; + + let mut scheme = Scheme::new(); + scheme.add_field("bool", Type::Bool).unwrap(); + scheme.add_field("ip", Type::Ip).unwrap(); + + let bool_field = scheme.get_field("bool").unwrap(); + let ip_field = scheme.get_field("ip").unwrap(); + + let mut ctx = ExecutionContext::<'_, ()>::new(&scheme); + + assert_eq!( + ctx.set_field_value(bool_field, LhsValue::Bool(false)), + Ok(None), + ); + + assert_eq!( + ctx.set_field_value( + ip_field, + LhsValue::Ip(IpAddr::from_str("127.0.0.1").unwrap()) + ), + Ok(None), + ); + + assert_eq!( + ctx.get_field_value(bool_field), + Some(&LhsValue::Bool(false)) + ); + assert_eq!( + ctx.get_field_value(ip_field), + Some(&LhsValue::Ip(IpAddr::from_str("127.0.0.1").unwrap())) + ); + + ctx.clear(); + + assert_eq!(ctx.get_field_value(bool_field), None); + assert_eq!(ctx.get_field_value(ip_field), None); } diff --git a/engine/src/filter.rs b/engine/src/filter.rs index 5532678f..4a60e992 100644 --- a/engine/src/filter.rs +++ b/engine/src/filter.rs @@ -1,30 +1,167 @@ -use crate::{execution_context::ExecutionContext, scheme::Scheme}; -use failure::Fail; - -/// An error that occurs if filter and provided [`ExecutionContext`] have -/// different [schemes](struct@Scheme). -#[derive(Debug, PartialEq, Fail)] -#[fail(display = "execution context doesn't match the scheme with which filter was parsed")] -pub struct SchemeMismatchError; - -// Each AST expression node gets compiled into CompiledExpr. Therefore, Filter -// essentialy is a public API facade for a tree of CompiledExprs. When filter -// gets executed it calls `execute` method on its root expression which then -// under the hood propagates field values to its leafs by recursively calling -// their `execute` methods and aggregating results into a single boolean value -// as recursion unwinds. -pub(crate) struct CompiledExpr<'s>(Box bool + Sync + Send>); - -impl<'s> CompiledExpr<'s> { +//! Each AST expression node gets compiled into a CompiledExpr or a CompiledValueExpr. +//! Therefore, Filter essentialy is a public API facade for a tree of Compiled(Value)Exprs. +//! When filter gets executed it calls `execute` method on its root expression which then +//! under the hood propagates field values to its leafs by recursively calling +//! their `execute` methods and aggregating results into a single boolean value +//! as recursion unwinds. + +use crate::{ + execution_context::ExecutionContext, + lhs_types::TypedArray, + scheme::{Scheme, SchemeMismatchError}, + types::{LhsValue, Type}, +}; +use std::fmt; + +type BoxedClosureToOneBool<'s, U> = + Box Fn(&'e ExecutionContext<'e, U>) -> bool + Sync + Send + 's>; + +/// Boxed closure for [`Expr`] AST node that evaluates to a simple [`bool`]. +pub struct CompiledOneExpr<'s, U = ()>(BoxedClosureToOneBool<'s, U>); + +impl<'s, U> fmt::Debug for CompiledOneExpr<'s, U> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_tuple("CompiledOneExpr") + .field(&((&*self.0) as *const _)) + .finish() + } +} + +impl<'s, U> CompiledOneExpr<'s, U> { + /// Creates a compiled expression IR from a generic closure. + pub fn new( + closure: impl for<'e> Fn(&'e ExecutionContext<'e, U>) -> bool + Sync + Send + 's, + ) -> Self { + CompiledOneExpr(Box::new(closure)) + } + + /// Executes the closure against a provided context with values. + pub fn execute<'e>(&self, ctx: &'e ExecutionContext<'e, U>) -> bool { + self.0(ctx) + } + + /// Extracts the underlying boxed closure. + pub fn into_boxed_closure(self) -> BoxedClosureToOneBool<'s, U> { + self.0 + } +} + +pub(crate) type CompiledVecExprResult = TypedArray<'static, bool>; + +type BoxedClosureToVecBool<'s, U> = + Box Fn(&'e ExecutionContext<'e, U>) -> CompiledVecExprResult + Sync + Send + 's>; + +/// Boxed closure for [`Expr`] AST node that evaluates to a list of [`bool`]. +pub struct CompiledVecExpr<'s, U = ()>(BoxedClosureToVecBool<'s, U>); + +impl<'s, U> fmt::Debug for CompiledVecExpr<'s, U> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_tuple("CompiledVecExpr") + .field(&((&*self.0) as *const _)) + .finish() + } +} + +impl<'s, U> CompiledVecExpr<'s, U> { /// Creates a compiled expression IR from a generic closure. - pub(crate) fn new(closure: impl 's + Fn(&ExecutionContext) -> bool + Sync + Send) -> Self { - CompiledExpr(Box::new(closure)) + pub fn new( + closure: impl for<'e> Fn(&'e ExecutionContext<'e, U>) -> CompiledVecExprResult + + Sync + + Send + + 's, + ) -> Self { + CompiledVecExpr(Box::new(closure)) } - /// Executes a filter against a provided context with values. - pub fn execute(&self, ctx: &ExecutionContext) -> bool { + /// Executes the closure against a provided context with values. + pub fn execute<'e>(&self, ctx: &'e ExecutionContext<'e, U>) -> CompiledVecExprResult { self.0(ctx) } + + /// Extracts the underlying boxed closure. + pub fn into_boxed_closure(self) -> BoxedClosureToVecBool<'s, U> { + self.0 + } +} + +/// Enum of boxed closure for [`Expr`] AST nodes. +#[derive(Debug)] +pub enum CompiledExpr<'s, U = ()> { + /// Variant for [`Expr`] AST node that evaluates to a simple [`bool`]. + One(CompiledOneExpr<'s, U>), + /// Variant for [`Expr`] AST node that evaluates to a list of [`bool`]. + Vec(CompiledVecExpr<'s, U>), +} + +impl<'s, U> CompiledExpr<'s, U> { + #[cfg(test)] + pub(crate) fn execute_one<'e>(&self, ctx: &'e ExecutionContext<'e, U>) -> bool { + match self { + CompiledExpr::One(one) => one.execute(ctx), + CompiledExpr::Vec(_) => unreachable!(), + } + } + + #[cfg(test)] + pub(crate) fn execute_vec<'e>( + &self, + ctx: &'e ExecutionContext<'e, U>, + ) -> CompiledVecExprResult { + match self { + CompiledExpr::One(_) => unreachable!(), + CompiledExpr::Vec(vec) => vec.execute(ctx), + } + } +} + +pub type CompiledValueResult<'a> = Result, Type>; + +impl<'a> From> for CompiledValueResult<'a> { + fn from(value: LhsValue<'a>) -> Self { + Ok(value) + } +} + +impl<'a> From for CompiledValueResult<'a> { + fn from(ty: Type) -> Self { + Err(ty) + } +} + +type BoxedClosureToValue<'s, U> = + Box Fn(&'e ExecutionContext<'e, U>) -> CompiledValueResult<'e> + Sync + Send + 's>; + +/// Boxed closure for [`ValueExpr`] AST node that evaluates to an [`LhsValue`]. +pub struct CompiledValueExpr<'s, U = ()>(BoxedClosureToValue<'s, U>); + +impl<'s, U> fmt::Debug for CompiledValueExpr<'s, U> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_tuple("CompiledValueExpr") + .field(&((&*self.0) as *const _)) + .finish() + } +} + +impl<'s, U> CompiledValueExpr<'s, U> { + /// Creates a compiled expression IR from a generic closure. + pub fn new( + closure: impl for<'e> Fn(&'e ExecutionContext<'e, U>) -> CompiledValueResult<'e> + + Sync + + Send + + 's, + ) -> Self { + CompiledValueExpr(Box::new(closure)) + } + + /// Executes the closure against a provided context with values. + pub fn execute<'e>(&self, ctx: &'e ExecutionContext<'e, U>) -> CompiledValueResult<'e> { + self.0(ctx) + } + + /// Extracts the underlying boxed closure. + pub fn into_boxed_closure(self) -> BoxedClosureToValue<'s, U> { + self.0 + } } /// An IR for a compiled filter expression. @@ -42,20 +179,48 @@ impl<'s> CompiledExpr<'s> { /// In the future the underlying representation might change, but for now it /// provides the best trade-off between safety and performance of compilation /// and execution. -pub struct Filter<'s> { - root_expr: CompiledExpr<'s>, +pub struct Filter<'s, U = ()> { + root_expr: CompiledOneExpr<'s, U>, scheme: &'s Scheme, } -impl<'s> Filter<'s> { +impl<'s, U> Filter<'s, U> { /// Creates a compiled expression IR from a generic closure. - pub(crate) fn new(root_expr: CompiledExpr<'s>, scheme: &'s Scheme) -> Self { + pub(crate) fn new(root_expr: CompiledOneExpr<'s, U>, scheme: &'s Scheme) -> Self { Filter { root_expr, scheme } } - /// Executes a filter against a provided context with values. - pub fn execute(&self, ctx: &ExecutionContext<'s>) -> Result { - if self.scheme == ctx.scheme() { + /// Executes a compiled filter expression against a provided context with values. + pub fn execute<'e>( + &self, + ctx: &'e ExecutionContext<'e, U>, + ) -> Result { + if ctx.scheme() == self.scheme { + Ok(self.root_expr.execute(ctx)) + } else { + Err(SchemeMismatchError) + } + } +} + +/// An IR for a compiled value expression. +pub struct FilterValue<'s, U = ()> { + root_expr: CompiledValueExpr<'s, U>, + scheme: &'s Scheme, +} + +impl<'s, U> FilterValue<'s, U> { + /// Creates a compiled expression IR from a generic closure. + pub(crate) fn new(root_expr: CompiledValueExpr<'s, U>, scheme: &'s Scheme) -> Self { + FilterValue { root_expr, scheme } + } + + /// Executes a compiled value expression against a provided context with values. + pub fn execute<'e>( + &self, + ctx: &'e ExecutionContext<'e, U>, + ) -> Result, Type>, SchemeMismatchError> { + if ctx.scheme() == self.scheme { Ok(self.root_expr.execute(ctx)) } else { Err(SchemeMismatchError) @@ -83,7 +248,7 @@ mod tests { fn is_send() {} fn is_sync() {} - is_send::(); - is_sync::(); + is_send::>>(); + is_sync::>>(); } } diff --git a/engine/src/functions.rs b/engine/src/functions.rs index e95fbff6..bdd0f659 100644 --- a/engine/src/functions.rs +++ b/engine/src/functions.rs @@ -1,55 +1,444 @@ -use crate::types::{LhsValue, Type}; -use std::fmt; +use crate::{ + filter::CompiledValueResult, + types::{ExpectedType, ExpectedTypeList, GetType, LhsValue, RhsValue, Type, TypeMismatchError}, +}; +use std::any::Any; +use std::convert::TryFrom; +use std::{ + fmt::{self, Debug}, + iter::once, + sync::Arc, +}; +use thiserror::Error; + +pub(crate) struct ExactSizeChain +where + A: ExactSizeIterator, + B: ExactSizeIterator::Item>, +{ + chain: std::iter::Chain, + len_a: usize, + len_b: usize, +} + +impl ExactSizeChain +where + A: ExactSizeIterator, + B: ExactSizeIterator::Item>, +{ + #[inline] + pub(crate) fn new(a: A, b: B) -> ExactSizeChain { + let len_a = a.len(); + let len_b = b.len(); + ExactSizeChain { + chain: a.chain(b), + len_a, + len_b, + } + } +} + +impl Iterator for ExactSizeChain +where + A: ExactSizeIterator, + B: ExactSizeIterator::Item>, +{ + type Item = A::Item; + + #[inline] + fn next(&mut self) -> Option { + match self.chain.next() { + None => None, + Some(elem) => { + if self.len_a > 0 { + self.len_a -= 1; + } else if self.len_b > 0 { + self.len_b -= 1; + } + Some(elem) + } + } + } +} + +impl ExactSizeIterator for ExactSizeChain +where + A: ExactSizeIterator, + B: ExactSizeIterator::Item>, +{ + #[inline] + fn len(&self) -> usize { + self.len_a + self.len_b + } +} /// An iterator over function arguments as [`LhsValue`]s. -pub type FunctionArgs<'i, 'a> = &'i mut dyn Iterator>; +pub type FunctionArgs<'i, 'a> = &'i mut dyn ExactSizeIterator>; + +/// Defines what kind of argument a function expects. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum FunctionArgKind { + /// Allow only literal as argument. + Literal, + /// Allow only field as argument. + Field, +} + +/// An error that occurs on a kind mismatch. +#[derive(Debug, PartialEq, Eq, Error)] +#[error("expected argument of kind {expected:?}, but got {actual:?}")] +pub struct FunctionArgKindMismatchError { + /// Expected value type. + pub expected: FunctionArgKind, + /// Provided value type. + pub actual: FunctionArgKind, +} + +/// An error that occurs on a kind mismatch. +#[derive(Debug, PartialEq, Eq, Error)] +#[error("invalid argument: {msg:?}")] +pub struct FunctionArgInvalidConstantError { + msg: String, +} + +impl FunctionArgInvalidConstantError { + /// Returns a new invalid constant error. + #[inline] + pub fn new(msg: String) -> Self { + Self { msg } + } +} + +impl From for FunctionArgInvalidConstantError { + #[inline] + fn from(msg: String) -> Self { + Self::new(msg) + } +} + +/// An error that occurs for a bad function parameter +#[derive(Debug, PartialEq, Eq, Error)] +pub enum FunctionParamError { + /// Function paramater value type has a different type than expected + #[error("expected {0}")] + TypeMismatch(#[source] TypeMismatchError), + /// Function parameter argument kind has a different kind than expected + #[error("expected {0}")] + KindMismatch(#[source] FunctionArgKindMismatchError), + /// Function parameter constant value is invalid + #[error("{0}")] + InvalidConstant(#[source] FunctionArgInvalidConstantError), +} + +impl From for FunctionParamError { + #[inline] + fn from(err: TypeMismatchError) -> Self { + Self::TypeMismatch(err) + } +} + +impl From for FunctionParamError { + #[inline] + fn from(err: FunctionArgKindMismatchError) -> Self { + Self::KindMismatch(err) + } +} + +impl From for FunctionParamError { + #[inline] + fn from(err: FunctionArgInvalidConstantError) -> Self { + Self::InvalidConstant(err) + } +} + +/// Function parameter +#[derive(Clone, Debug)] +pub enum FunctionParam<'a> { + /// Contant function parameter (literal value) + Constant(&'a RhsValue), + /// Variable function parameter (field, or complex expressions) + Variable(Type), +} -type FunctionPtr = for<'a> fn(FunctionArgs<'_, 'a>) -> LhsValue<'a>; +impl From<&FunctionParam<'_>> for FunctionArgKind { + fn from(arg: &FunctionParam<'_>) -> Self { + match arg { + FunctionParam::Constant(_) => FunctionArgKind::Literal, + FunctionParam::Variable(_) => FunctionArgKind::Field, + } + } +} -/// Wrapper around a function pointer providing the runtime implemetation. +impl<'a> GetType for FunctionParam<'a> { + fn get_type(&self) -> Type { + match self { + FunctionParam::Constant(value) => value.get_type(), + FunctionParam::Variable(ty) => *ty, + } + } +} + +impl<'a> FunctionParam<'a> { + /// Returns the underlying value if the current parameter is a constant, otherwise an error. + pub fn as_constant(&self) -> Result<&'a RhsValue, FunctionArgKindMismatchError> { + match self { + Self::Constant(value) => Ok(value), + Self::Variable(_) => Err(FunctionArgKindMismatchError { + expected: FunctionArgKind::Literal, + actual: FunctionArgKind::Field, + }), + } + } + + /// Returns the underlying type if the current parameter is a variable, otherwise an error. + pub fn as_variable(&self) -> Result<&Type, FunctionArgKindMismatchError> { + match self { + Self::Variable(ref ty) => Ok(ty), + Self::Constant(_) => Err(FunctionArgKindMismatchError { + expected: FunctionArgKind::Field, + actual: FunctionArgKind::Literal, + }), + } + } + + /// Check if the arg_kind of current paramater matches the expected_arg_kind + pub fn expect_arg_kind( + &self, + expected_arg_kind: FunctionArgKind, + ) -> Result<(), FunctionParamError> { + let kind = self.into(); + if kind == expected_arg_kind { + Ok(()) + } else { + Err(FunctionParamError::KindMismatch( + FunctionArgKindMismatchError { + expected: expected_arg_kind, + actual: kind, + }, + )) + } + } + + /// Checks if the val_type of current parameter matches the expected_type + pub fn expect_val_type( + &self, + expected_types: impl Iterator, + ) -> Result<(), FunctionParamError> { + let ty = self.get_type(); + let mut types = ExpectedTypeList::default(); + for expected_type in expected_types { + match (&expected_type, &ty) { + (ExpectedType::Array, Type::Array(_)) => return Ok(()), + (ExpectedType::Array, _) => {} + (ExpectedType::Map, Type::Map(_)) => return Ok(()), + (ExpectedType::Map, _) => {} + (ExpectedType::Type(val_type), _) => { + if ty == *val_type { + return Ok(()); + } + } + } + types.insert(expected_type); + } + Err(FunctionParamError::TypeMismatch(TypeMismatchError { + expected: types, + actual: ty, + })) + } + + /// Checks that the parameter is a constant of a certain type + /// and call the closure `op` to verify its value + pub fn expect_const_value< + U: TryFrom<&'a RhsValue, Error = TypeMismatchError>, + F: FnOnce(U) -> Result<(), String>, + >( + &self, + op: F, + ) -> Result<(), FunctionParamError> { + match self { + Self::Constant(value) => { + op(U::try_from(value).map_err(FunctionParamError::TypeMismatch)?).map_err(|msg| { + FunctionParamError::InvalidConstant(FunctionArgInvalidConstantError { msg }) + }) + } + Self::Variable(_) => Err(FunctionParamError::KindMismatch( + FunctionArgKindMismatchError { + expected: FunctionArgKind::Literal, + actual: FunctionArgKind::Field, + }, + )), + } + } +} + +/// Context that can be created and used +/// when parsing a function call +pub struct FunctionDefinitionContext { + inner: Arc, + clone_cb: fn(&(dyn Any + Send + Sync)) -> Arc, + fmt_cb: fn(&(dyn Any + Send + Sync), &mut std::fmt::Formatter<'_>) -> std::fmt::Result, +} + +impl FunctionDefinitionContext { + fn clone_any( + t: &(dyn Any + Send + Sync), + ) -> Arc { + Arc::new(t.downcast_ref::().unwrap().clone()) + } + + fn fmt_any( + t: &(dyn Any + Send + Sync), + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + t.downcast_ref::().unwrap().fmt(f) + } + + /// Creates a new FunctionDefinitionContext object containing user-defined + /// object of type `T` + pub fn new(t: T) -> Self { + Self { + inner: Arc::new(t), + clone_cb: Self::clone_any::, + fmt_cb: Self::fmt_any::, + } + } + /// Returns a reference to the underlying Any object + pub fn as_any_ref(&self) -> &(dyn Any + Send + Sync) { + &*self.inner + } + /// Returns a mutable reference to the underlying Any object + pub fn as_any_mut(&mut self) -> &mut (dyn Any + Send + Sync) { + Arc::get_mut(&mut self.inner).unwrap() + } + /// Converts current `FunctionDefinitionContext` to `Box` + pub fn into_any(self) -> Arc { + let Self { inner, .. } = self; + inner + } + /// Attempt to downcast the context to a concrete type. + pub fn downcast(self) -> Result, Self> { + let Self { + inner, + clone_cb, + fmt_cb, + } = self; + inner.downcast::().map_err(|inner| Self { + inner, + clone_cb, + fmt_cb, + }) + } + + /// Attempt to extract the concrete value stored in the context. + pub fn try_unwrap(self) -> Result { + self.downcast::().map(|val| match Arc::try_unwrap(val) { + Ok(val) => val, + Err(_) => unreachable!(), + }) + } +} + +impl std::convert::AsRef for FunctionDefinitionContext { + fn as_ref(&self) -> &T { + self.inner.downcast_ref::().unwrap() + } +} + +impl std::convert::AsMut for FunctionDefinitionContext { + fn as_mut(&mut self) -> &mut T { + Arc::get_mut(&mut self.inner) + .unwrap() + .downcast_mut::() + .unwrap() + } +} + +impl Clone for FunctionDefinitionContext { + fn clone(&self) -> Self { + Self { + inner: (self.clone_cb)(&*self.inner), + clone_cb: self.clone_cb, + fmt_cb: self.fmt_cb, + } + } +} + +impl std::fmt::Debug for FunctionDefinitionContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "FunctionDefinitionContext(")?; + (self.fmt_cb)(&*self.inner, f)?; + write!(f, ")")?; + Ok(()) + } +} + +/// Trait to implement function +pub trait FunctionDefinition: Debug + Send + Sync { + /// Custom context to store information during parsing + fn context(&self) -> Option { + None + } + /// Given a slice of already checked parameters, checks that next_param is + /// correct. Return the expected the parameter definition. + fn check_param( + &self, + params: &mut dyn ExactSizeIterator>, + next_param: &FunctionParam<'_>, + ctx: Option<&mut FunctionDefinitionContext>, + ) -> Result<(), FunctionParamError>; + /// Function return type. + fn return_type( + &self, + params: &mut dyn ExactSizeIterator>, + ctx: Option<&FunctionDefinitionContext>, + ) -> Type; + /// Number of mandatory arguments and number of optional arguments + /// (N, Some(0)) means N mandatory arguments and no optional arguments + /// (N, None) means N mandatory arguments and unlimited optional arguments + fn arg_count(&self) -> (usize, Option); + /// Compile the function definition down to a closure that is going to be called + /// during filter execution. + fn compile<'s>( + &'s self, + params: &mut dyn ExactSizeIterator>, + ctx: Option, + ) -> Box Fn(FunctionArgs<'_, 'a>) -> Option> + Sync + Send + 's>; +} + +/// Simple function API + +type FunctionPtr = for<'a> fn(FunctionArgs<'_, 'a>) -> Option>; + +/// Wrapper around a function pointer providing the runtime implementation. #[derive(Clone)] -pub struct FunctionImpl(FunctionPtr); +pub struct SimpleFunctionImpl(FunctionPtr); -impl FunctionImpl { +impl SimpleFunctionImpl { /// Creates a new wrapper around a function pointer. pub fn new(func: FunctionPtr) -> Self { Self(func) } - - /// Calls the wrapped function pointer. - pub fn execute<'a>(&self, args: impl IntoIterator>) -> LhsValue<'a> { - (self.0)(&mut args.into_iter()) - } } -impl fmt::Debug for FunctionImpl { +impl fmt::Debug for SimpleFunctionImpl { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_tuple("FunctionImpl") + fmt.debug_tuple("SimpleFunctionImpl") .field(&(self.0 as *const ())) .finish() } } -impl PartialEq for FunctionImpl { - fn eq(&self, other: &FunctionImpl) -> bool { - self.0 as *const () == other.0 as *const () +impl PartialEq for SimpleFunctionImpl { + fn eq(&self, other: &SimpleFunctionImpl) -> bool { + std::ptr::eq(self.0 as *const (), other.0 as *const ()) } } -impl Eq for FunctionImpl {} - -/// Defines what kind of argument a function expects. -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum FunctionArgKind { - /// Allow only literal as argument. - Literal, - /// Allow only field as argument. - Field, -} +impl Eq for SimpleFunctionImpl {} /// Defines a mandatory function argument. #[derive(Debug, PartialEq, Eq, Clone)] -pub struct FunctionParam { +pub struct SimpleFunctionParam { /// How the argument can be specified when calling a function. pub arg_kind: FunctionArgKind, /// The type of its associated value. @@ -58,22 +447,115 @@ pub struct FunctionParam { /// Defines an optional function argument. #[derive(Debug, PartialEq, Eq, Clone)] -pub struct FunctionOptParam { +pub struct SimpleFunctionOptParam { /// How the argument can be specified when calling a function. pub arg_kind: FunctionArgKind, /// The default value if the argument is missing. pub default_value: LhsValue<'static>, } -/// Defines a function. +/// Simple interface to define a function. #[derive(Debug, PartialEq, Eq, Clone)] -pub struct Function { +pub struct SimpleFunctionDefinition { /// List of mandatory arguments. - pub params: Vec, + pub params: Vec, /// List of optional arguments that can be specified after manatory ones. - pub opt_params: Vec, + pub opt_params: Vec, /// Function return type. pub return_type: Type, /// Actual implementation that will be called at runtime. - pub implementation: FunctionImpl, + pub implementation: SimpleFunctionImpl, +} + +impl FunctionDefinition for SimpleFunctionDefinition { + fn check_param( + &self, + params: &mut dyn ExactSizeIterator>, + next_param: &FunctionParam<'_>, + _: Option<&mut FunctionDefinitionContext>, + ) -> Result<(), FunctionParamError> { + let index = params.len(); + if index < self.params.len() { + let param = &self.params[index]; + next_param.expect_arg_kind(param.arg_kind)?; + next_param.expect_val_type(once(ExpectedType::Type(param.val_type)))?; + } else if index < self.params.len() + self.opt_params.len() { + let opt_param = &self.opt_params[index - self.params.len()]; + next_param.expect_arg_kind(opt_param.arg_kind)?; + next_param + .expect_val_type(once(ExpectedType::Type(opt_param.default_value.get_type())))?; + } else { + unreachable!(); + } + Ok(()) + } + + fn return_type( + &self, + _: &mut dyn ExactSizeIterator>, + _: Option<&FunctionDefinitionContext>, + ) -> Type { + self.return_type + } + + fn arg_count(&self) -> (usize, Option) { + (self.params.len(), Some(self.opt_params.len())) + } + + fn compile<'s>( + &'s self, + params: &mut dyn ExactSizeIterator>, + _: Option, + ) -> Box Fn(FunctionArgs<'_, 'a>) -> Option> + Sync + Send + 's> { + let params_count = params.len(); + let opt_params = &self.opt_params[(params_count - self.params.len())..]; + if opt_params.is_empty() { + Box::new(move |args| { + assert_eq!(params_count, args.len()); + (self.implementation.0)(args) + }) + } else { + let opt_args: Vec, Type>> = opt_params + .iter() + .map(|opt_param| Ok(opt_param.default_value.clone())) + .collect(); + Box::new(move |args| { + assert_eq!(params_count, args.len()); + (self.implementation.0)(&mut ExactSizeChain::new(args, opt_args.iter().cloned())) + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_function_definition_context() { + let ctx1 = FunctionDefinitionContext::new(Some(42u8)); + + assert_eq!( + "FunctionDefinitionContext(Some(42))".to_owned(), + format!("{ctx1:?}") + ); + + assert_eq!( + ctx1.as_any_ref().downcast_ref::>().unwrap(), + &Some(42u8) + ); + + let ctx2 = ctx1.clone(); + + let value = ctx1.downcast::>().unwrap(); + + assert_eq!(value, Arc::new(Some(42u8))); + + assert_eq!( + ctx2.as_any_ref().downcast_ref::>().unwrap(), + &*value + ); + + assert_eq!(ctx2.try_unwrap::>().unwrap(), Some(42u8)); + } } diff --git a/engine/src/heap_searcher.rs b/engine/src/heap_searcher.rs deleted file mode 100644 index 763bea34..00000000 --- a/engine/src/heap_searcher.rs +++ /dev/null @@ -1,43 +0,0 @@ -use memmem::{Searcher, TwoWaySearcher}; -use std::marker::PhantomPinned; -use std::pin::Pin; -use std::ptr::NonNull; - -/// A version of [`TwoWaySearcher`] that owns the needle data. -pub struct HeapSearcher { - bytes: Box<[u8]>, - inner: Option>, - _pin: PhantomPinned, -} - -impl HeapSearcher { - pub fn new(bytes: impl Into>) -> Pin> { - // NOTE: first put bytes into the structure and pin them all together - // on the heap. - let mut heap_searcher = Box::pin(HeapSearcher { - bytes: bytes.into(), - inner: None, - _pin: PhantomPinned, - }); - - // NOTE: obtain a pointer for pinned bytes and create a pin - // for the mutable reference of the searcher. This can be later - // used in `Pin::get_unchecked_mut` which consumes the pin. - let bytes = NonNull::from(&heap_searcher.bytes); - let mut_pin = heap_searcher.as_mut(); - - unsafe { - let inner = TwoWaySearcher::new(&*bytes.as_ptr()); - - Pin::get_unchecked_mut(mut_pin).inner = Some(inner); - } - - heap_searcher - } -} - -impl Searcher for HeapSearcher { - fn search_in(&self, haystack: &[u8]) -> Option { - self.inner.as_ref().unwrap().search_in(haystack) - } -} diff --git a/engine/src/lex.rs b/engine/src/lex.rs index 2f4e7b54..ec5ab33a 100644 --- a/engine/src/lex.rs +++ b/engine/src/lex.rs @@ -1,73 +1,157 @@ use crate::{ - rhs_types::RegexError, - scheme::{UnknownFieldError, UnknownFunctionError}, + functions::{FunctionArgInvalidConstantError, FunctionArgKindMismatchError}, + rhs_types::{RegexError, WildcardError}, + scheme::{IndexAccessError, UnknownFieldError, UnknownFunctionError}, types::{Type, TypeMismatchError}, }; -use cidr::NetworkParseError; -use failure::Fail; +use cidr::errors::NetworkParseError; use std::num::ParseIntError; +use thiserror::Error; -#[derive(Debug, PartialEq, Fail)] +#[derive(Debug, PartialEq, Error)] +/// LexErrorKind occurs when there is an invalid or unexpected token. pub enum LexErrorKind { - #[fail(display = "expected {}", _0)] + /// Expected the next token to be a Field + #[error("expected {0}")] ExpectedName(&'static str), - #[fail(display = "expected literal {:?}", _0)] + /// Expected the next token to be a Literal + #[error("expected literal {0:?}")] ExpectedLiteral(&'static str), - #[fail(display = "{} while parsing with radix {}", err, radix)] + /// Expected the next token to be an int + #[error("{err} while parsing with radix {radix}")] ParseInt { - #[cause] + /// The error that occurred parsing the token as an int + #[source] err: ParseIntError, + /// The base of the number radix: u32, }, - #[fail(display = "{}", _0)] - ParseNetwork(#[cause] NetworkParseError), + /// Expected the next token to be a network address such a CIDR, IPv4 or + /// IPv6 address + #[error("{0}")] + ParseNetwork(#[source] NetworkParseError), - #[fail(display = "{}", _0)] - ParseRegex(#[cause] RegexError), + /// Expected the next token to be a regular expression + #[error("{0}")] + ParseRegex(#[source] RegexError), - #[fail(display = "expected \", xHH or OOO after \\")] + /// Expected the next token to be a wildcard expression + #[error("{0}")] + ParseWildcard(#[source] WildcardError), + + /// Expected the next token to be an escape character + #[error("expected \", xHH or OOO after \\")] InvalidCharacterEscape, - #[fail(display = "could not find an ending quote")] + /// Invalid raw string hash count + #[error("invalid raw string hash count, there can't be more than 255 #s")] + InvalidRawStringHashCount, + + /// Expected the next token to be an ending quote + #[error("could not find an ending quote")] MissingEndingQuote, - #[fail(display = "expected {} {}s, but found {}", expected, name, actual)] + /// Expected to take some number of characters from the input but the + /// input was too short + #[error("expected {expected} {name}s, but found {actual}")] CountMismatch { + /// This is set to "character" for all occurences of this error name: &'static str, + /// The actual number of characters actual: usize, + /// The expected number of characters expected: usize, }, - #[fail(display = "{}", _0)] - UnknownField(#[cause] UnknownFieldError), + /// The next token refers to a Field that is not present in the Scheme + #[error("{0}")] + UnknownField(#[source] UnknownFieldError), - #[fail(display = "{}", _0)] - UnknownFunction(#[cause] UnknownFunctionError), + /// The next token refers to a Function that is not present in the Scheme + #[error("{0}")] + UnknownFunction(#[source] UnknownFunctionError), - #[fail(display = "cannot use this operation type {:?}", lhs_type)] - UnsupportedOp { lhs_type: Type }, + /// The next token refers to an Identifier that is not present in the Scheme + /// ie: neither as a Field or as a Function + #[error("unknown identifier")] + UnknownIdentifier, - #[fail(display = "incompatible range bounds")] + /// The operation cannot be performed on this Field + #[error("cannot perform this operation on type {lhs_type:?}")] + UnsupportedOp { + /// The type of the Field + lhs_type: Type, + }, + + /// This variant is not in use + #[error("incompatible range bounds")] IncompatibleRangeBounds, - #[fail(display = "unrecognised input")] + /// End Of File + #[error("unrecognised input")] EOF, - #[fail(display = "invalid number of arguments")] + /// Invalid number of arguments for the function + #[error("invalid number of arguments")] InvalidArgumentsCount { + /// The minimum number of arguments for the function expected_min: usize, - expected_max: usize, + /// The maximum number of arguments for the function or None if the + /// function takes an unlimited number of arguments + expected_max: Option, }, - #[fail(display = "invalid type of argument #{}: {}", index, mismatch)] + /// Invalid argument kind for the function + #[error("invalid kind of argument #{index}: {mismatch}")] + InvalidArgumentKind { + /// The position of the argument in the function call + index: usize, + /// The expected and the actual kind for the argument + #[source] + mismatch: FunctionArgKindMismatchError, + }, + + /// Invalid argument type for the function + #[error("invalid type of argument #{index}: {mismatch}")] InvalidArgumentType { + /// The position of the argument in the function call index: usize, - #[cause] + /// The expected and actual type for the argument + #[source] mismatch: TypeMismatchError, }, + + /// Invalid argument value for the function + #[error("invalid value of argument #{index}: {invalid}")] + InvalidArgumentValue { + /// The position of the argument in the function call + index: usize, + /// The error message that explains why the value is invalid + #[source] + invalid: FunctionArgInvalidConstantError, + }, + + /// The index is invalid + #[error("{0}")] + InvalidIndexAccess(#[source] IndexAccessError), + + /// Invalid type + #[error("{0}")] + TypeMismatch(#[source] TypeMismatchError), + + /// Invalid usage of map each access operator + #[error("invalid use of map each access operator")] + InvalidMapEachAccess, + + /// Invalid list name + #[error("invalid list name {name:?}")] + InvalidListName { + /// Name of the list + name: String, + }, } pub type LexError<'i> = (LexErrorKind, &'i str); @@ -89,8 +173,8 @@ impl<'i, T: Lex<'i>, E> LexWith<'i, E> for T { } pub fn expect<'i>(input: &'i str, s: &'static str) -> Result<&'i str, LexError<'i>> { - if input.starts_with(s) { - Ok(&input[s.len()..]) + if let Some(index) = input.strip_prefix(s) { + Ok(index) } else { Err((LexErrorKind::ExpectedLiteral(s), input)) } @@ -120,11 +204,12 @@ macro_rules! lex_enum { // On the parser side, tries to parse `SomeType` and wraps into the variant // on success. (@decl $preamble:tt $name:ident $input:ident { $($decl:tt)* } { $($expr:tt)* } { - $ty:ty => $item:ident, + $(#[$meta:meta])* $ty:ty => $item:ident, $($rest:tt)* }) => { lex_enum!(@decl $preamble $name $input { $($decl)* + $(#[$meta])* $item($ty), } { $($expr)* @@ -142,11 +227,12 @@ macro_rules! lex_enum { // On the parser side, tries to parse either of the given string values, // and returns the variant if any of them succeeded. (@decl $preamble:tt $name:ident $input:ident { $($decl:tt)* } { $($expr:tt)* } { - $($s:tt)|+ => $item:ident $(= $value:expr)*, + $(#[$meta:meta])* $($s:literal)|+ => $item:ident $(= $value:expr)*, $($rest:tt)* }) => { lex_enum!(@decl $preamble $name $input { $($decl)* + $(#[$meta])* $item $(= $value)*, } { $($expr)* @@ -161,12 +247,12 @@ macro_rules! lex_enum { // This is invoked when no more variants are left to process. // At this point declaration and lexer body are considered complete. (@decl { $($preamble:tt)* } $name:ident $input:ident $decl:tt { $($expr:stmt)* } {}) => { - #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize)] + #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize)] $($preamble)* pub enum $name $decl impl<'i> $crate::lex::Lex<'i> for $name { - fn lex($input: &'i str) -> $crate::lex::LexResult<'_, Self> { + fn lex($input: &'i str) -> $crate::lex::LexResult<'i, Self> { $($expr)* Err(( $crate::lex::LexErrorKind::ExpectedName(stringify!($name)), @@ -177,9 +263,9 @@ macro_rules! lex_enum { }; // The public entry point to the macro. - ($(# $attrs:tt)* $name:ident $items:tt) => { + ($(#[$meta:meta])* $name:ident $items:tt) => { lex_enum!(@decl { - $(# $attrs)* + $(#[$meta])* } $name input {} {} $items); }; } @@ -212,7 +298,7 @@ pub fn take_while<'i, F: Fn(char) -> bool>( pub fn take(input: &str, expected: usize) -> LexResult<'_, &str> { let mut chars = input.chars(); for i in 0..expected { - chars.next().ok_or_else(|| { + chars.next().ok_or({ ( LexErrorKind::CountMismatch { name: "character", @@ -258,10 +344,9 @@ macro_rules! assert_err { #[cfg(test)] macro_rules! assert_json { - ($expr:expr, $json:tt) => { - assert_eq!( - ::serde_json::to_value(&$expr).unwrap(), - ::serde_json::json!($json) - ); - }; + ($expr:expr, $json:tt) => {{ + let json = ::serde_json::to_value(&$expr).unwrap(); + assert_eq!(json, ::serde_json::json!($json)); + json + }}; } diff --git a/engine/src/lhs_types/array.rs b/engine/src/lhs_types/array.rs new file mode 100644 index 00000000..c0a0cad6 --- /dev/null +++ b/engine/src/lhs_types/array.rs @@ -0,0 +1,622 @@ +use crate::{ + lhs_types::AsRefIterator, + types::{ + CompoundType, GetType, IntoValue, LhsValue, LhsValueMut, LhsValueSeed, Type, + TypeMismatchError, + }, +}; +use serde::{ + de::{self, DeserializeSeed, Deserializer, SeqAccess, Visitor}, + ser::SerializeSeq, + Serialize, Serializer, +}; +use std::{ + fmt, + hash::{Hash, Hasher}, + hint::unreachable_unchecked, + ops::Deref, +}; + +// Ideally, we would want to use Cow<'a, LhsValue<'a>> here +// but it doesnt work for unknown reasons +// See https://github.com/rust-lang/rust/issues/23707#issuecomment-557312736 +#[derive(Debug, Clone)] +enum InnerArray<'a> { + Owned(Vec>), + Borrowed(&'a [LhsValue<'a>]), +} + +impl<'a> InnerArray<'a> { + #[inline] + fn as_vec(&mut self) -> &mut Vec> { + match self { + InnerArray::Owned(vec) => vec, + InnerArray::Borrowed(slice) => { + *self = InnerArray::Owned(slice.to_vec()); + match self { + InnerArray::Owned(vec) => vec, + _ => unsafe { unreachable_unchecked() }, + } + } + } + } + + #[inline] + fn get_mut(&mut self, idx: usize) -> Option<&mut LhsValue<'a>> { + self.as_vec().get_mut(idx) + } + + #[inline] + fn insert(&mut self, idx: usize, value: LhsValue<'a>) { + self.as_vec().insert(idx, value) + } + + #[inline] + fn push(&mut self, value: LhsValue<'a>) { + self.as_vec().push(value) + } + + #[inline] + fn truncate(&mut self, len: usize) { + match self { + InnerArray::Owned(vec) => vec.truncate(len), + InnerArray::Borrowed(slice) => { + *slice = &slice[..len]; + } + } + } +} + +impl<'a> Deref for InnerArray<'a> { + type Target = [LhsValue<'a>]; + + #[inline] + fn deref(&self) -> &Self::Target { + match self { + InnerArray::Owned(vec) => &vec[..], + InnerArray::Borrowed(slice) => slice, + } + } +} + +/// An array of [`Type`]. +#[derive(Debug, Clone)] +pub struct Array<'a> { + val_type: CompoundType, + data: InnerArray<'a>, +} + +impl<'a> Array<'a> { + /// Creates a new array + pub fn new(val_type: impl Into) -> Self { + Self { + val_type: val_type.into(), + data: InnerArray::Owned(Vec::new()), + } + } + + /// Creates a new array with the specified capacity + pub fn with_capacity(val_type: impl Into, capacity: usize) -> Self { + Self { + val_type: val_type.into(), + data: InnerArray::Owned(Vec::with_capacity(capacity)), + } + } + + /// Get a reference to an element if it exists + pub fn get(&self, idx: usize) -> Option<&LhsValue<'a>> { + self.data.get(idx) + } + + /// Get a mutable reference to an element if it exists + pub fn get_mut(&mut self, idx: usize) -> Option> { + self.data.get_mut(idx).map(LhsValueMut::from) + } + + /// Inserts an element at index `idx` + pub fn insert( + &mut self, + idx: usize, + value: impl Into>, + ) -> Result<(), TypeMismatchError> { + let value = value.into(); + let value_type = value.get_type(); + if value_type != self.val_type.into() { + return Err(TypeMismatchError { + expected: Type::from(self.val_type).into(), + actual: value_type, + }); + } + self.data.insert(idx, value); + Ok(()) + } + + /// Push an element to the back of the array + pub fn push(&mut self, value: impl Into>) -> Result<(), TypeMismatchError> { + let value = value.into(); + let value_type = value.get_type(); + if value_type != self.val_type.into() { + return Err(TypeMismatchError { + expected: Type::from(self.val_type).into(), + actual: value_type, + }); + } + self.data.push(value); + Ok(()) + } + + pub(crate) fn as_ref(&'a self) -> Array<'a> { + Array { + val_type: self.val_type, + data: match self.data { + InnerArray::Owned(ref vec) => InnerArray::Borrowed(&vec[..]), + InnerArray::Borrowed(slice) => InnerArray::Borrowed(slice), + }, + } + } + + /// Converts an `Array` with borrowed data to a fully owned `Array`. + pub fn into_owned(self) -> Array<'static> { + Array { + val_type: self.val_type, + data: match self.data { + InnerArray::Owned(vec) => { + InnerArray::Owned(vec.into_iter().map(LhsValue::into_owned).collect()) + } + InnerArray::Borrowed(slice) => { + InnerArray::Owned(slice.iter().cloned().map(LhsValue::into_owned).collect()) + } + }, + } + } + + /// Returns the type of the contained values. + #[inline] + pub fn value_type(&self) -> Type { + self.val_type.into() + } + + /// Returns the number of elements in the array + #[inline] + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns true if the array contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + pub(crate) fn extract(self, idx: usize) -> Option> { + let Self { data, .. } = self; + if idx >= data.len() { + None + } else { + match data { + InnerArray::Owned(mut vec) => Some(vec.swap_remove(idx)), + InnerArray::Borrowed(slice) => Some(unsafe { slice.get_unchecked(idx) }.as_ref()), + } + } + } + + pub(crate) fn as_slice(&self) -> &[LhsValue<'a>] { + &self.data + } + + pub(crate) fn filter_map_to(self, value_type: impl Into, func: F) -> Self + where + F: Fn(LhsValue<'a>) -> Option>, + { + let Self { mut data, .. } = self; + let mut vec = std::mem::take(data.as_vec()); + let val_type = value_type.into(); + let mut write = 0; + for read in 0..vec.len() { + let elem = &mut vec[read]; + if let Some(elem) = func(std::mem::replace(elem, LhsValue::Bool(false))) { + assert!(elem.get_type() == val_type.into()); + vec[write] = elem; + write += 1; + } + } + vec.truncate(write); + Array { + val_type, + data: InnerArray::Owned(vec), + } + } + + /// Creates a new array from the specified iterator. + pub fn try_from_iter>>( + val_type: impl Into, + iter: impl IntoIterator, + ) -> Result { + let val_type = val_type.into(); + iter.into_iter() + .map(|elem| { + let elem = elem.into(); + let elem_type = elem.get_type(); + if val_type != elem_type.into() { + Err(TypeMismatchError { + expected: Type::from(val_type).into(), + actual: elem_type, + }) + } else { + Ok(elem) + } + }) + .collect::, _>>() + .map(|vec| Array { + val_type, + data: InnerArray::Owned(vec), + }) + } + + /// Creates a new array form the specified vector. + pub fn try_from_vec( + val_type: impl Into, + vec: Vec>, + ) -> Result { + let val_type = val_type.into(); + for elem in &vec { + let elem_type = elem.get_type(); + if val_type != elem_type.into() { + return Err(TypeMismatchError { + expected: Type::from(val_type).into(), + actual: elem_type, + }); + } + } + Ok(Array { + val_type, + data: InnerArray::Owned(vec), + }) + } + + /// Try extending the array with elements provided by the specified iterator. + pub fn try_extend(&mut self, iter: I) -> Result<(), TypeMismatchError> + where + V: Into>, + I: IntoIterator, + { + let value_type = self.value_type(); + let vec = self.data.as_vec(); + for elem in iter { + let elem = elem.into(); + let elem_type = elem.get_type(); + if value_type != elem_type { + return Err(TypeMismatchError { + expected: value_type.into(), + actual: elem_type, + }); + }; + vec.push(elem); + } + Ok(()) + } +} + +impl<'a> PartialEq for Array<'a> { + #[inline] + fn eq(&self, other: &Array<'a>) -> bool { + self.val_type == other.val_type && self.data.deref() == other.data.deref() + } +} + +impl<'a> Eq for Array<'a> {} + +impl<'a> GetType for Array<'a> { + fn get_type(&self) -> Type { + Type::Array(self.val_type) + } +} + +impl<'a> Hash for Array<'a> { + fn hash(&self, state: &mut H) { + self.get_type().hash(state); + self.data.deref().hash(state); + } +} + +impl<'a, V: IntoValue<'a>> FromIterator for Array<'a> { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + let vec = iter.into_iter().map(IntoValue::into_value).collect(); + Self { + val_type: V::TYPE.into(), + data: InnerArray::Owned(vec), + } + } +} + +pub enum ArrayIterator<'a> { + Owned(std::vec::IntoIter>), + Borrowed(AsRefIterator<'a, std::slice::Iter<'a, LhsValue<'a>>>), +} + +impl<'a> Iterator for ArrayIterator<'a> { + type Item = LhsValue<'a>; + + fn next(&mut self) -> Option { + match self { + ArrayIterator::Owned(vec_iter) => vec_iter.next(), + ArrayIterator::Borrowed(slice_iter) => slice_iter.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } +} + +impl<'a> ExactSizeIterator for ArrayIterator<'a> { + fn len(&self) -> usize { + match self { + ArrayIterator::Owned(vec_iter) => vec_iter.len(), + ArrayIterator::Borrowed(slice_iter) => slice_iter.len(), + } + } +} + +impl<'a> IntoIterator for Array<'a> { + type Item = LhsValue<'a>; + type IntoIter = ArrayIterator<'a>; + fn into_iter(self) -> Self::IntoIter { + match self.data { + InnerArray::Owned(vec) => ArrayIterator::Owned(vec.into_iter()), + InnerArray::Borrowed(slice) => ArrayIterator::Borrowed(AsRefIterator(slice.iter())), + } + } +} + +impl<'a, 'b> IntoIterator for &'b Array<'a> { + type Item = &'b LhsValue<'a>; + type IntoIter = std::slice::Iter<'b, LhsValue<'a>>; + fn into_iter(self) -> Self::IntoIter { + self.data.iter() + } +} + +impl<'a> Serialize for Array<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.len()))?; + for element in self.data.iter() { + seq.serialize_element(element)?; + } + seq.end() + } +} + +impl<'de, 'a> DeserializeSeed<'de> for &'a mut Array<'de> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ArrayVisitor<'de, 'a>(&'a mut Array<'de>); + + impl<'de, 'a> Visitor<'de> for ArrayVisitor<'de, 'a> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "an array of lhs value") + } + + fn visit_seq(self, mut seq: A) -> Result<(), A::Error> + where + A: SeqAccess<'de>, + { + let value_type = self.0.value_type(); + let vec = self.0.data.as_vec(); + while let Some(elem) = seq.next_element_seed(LhsValueSeed(&value_type))? { + let elem_type = elem.get_type(); + if value_type != elem_type { + return Err(de::Error::custom(format!( + "invalid type: {elem_type:?}, expected {value_type:?}" + ))); + } + vec.push(elem); + } + Ok(()) + } + } + + deserializer.deserialize_seq(ArrayVisitor(self)) + } +} + +/// Wrapper type around mutable `Array` to prevent +/// illegal operations like changing the type of +/// its values. +pub struct ArrayMut<'a, 'b>(&'a mut Array<'b>); + +impl<'a, 'b> ArrayMut<'a, 'b> { + /// Push an element to the back of the array + #[inline] + pub fn push(&mut self, value: impl Into>) -> Result<(), TypeMismatchError> { + self.0.push(value) + } + + /// Inserts an element at index `idx` + #[inline] + pub fn insert( + &mut self, + idx: usize, + value: impl Into>, + ) -> Result<(), TypeMismatchError> { + self.0.insert(idx, value) + } + + /// Get a mutable reference to an element if it exists + #[inline] + pub fn get_mut(&'a mut self, idx: usize) -> Option> { + self.0.get_mut(idx).map(LhsValueMut::from) + } +} + +impl<'a, 'b> Deref for ArrayMut<'a, 'b> { + type Target = Array<'b>; + + #[inline] + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a, 'b> From<&'a mut Array<'b>> for ArrayMut<'a, 'b> { + #[inline] + fn from(arr: &'a mut Array<'b>) -> Self { + Self(arr) + } +} + +/// Typed wrapper over an `Array` which provides +/// infaillible operations. +#[derive(Debug)] +pub struct TypedArray<'a, V> +where + V: IntoValue<'a>, +{ + array: Array<'a>, + _marker: std::marker::PhantomData<[V]>, +} + +impl<'a, V: IntoValue<'a>> TypedArray<'a, V> { + /// Push an element to the back of the array + #[inline] + pub fn push(&mut self, value: V) { + self.array.data.push(value.into_value()) + } + + /// Returns the number of elements in the array + #[inline] + pub fn len(&self) -> usize { + self.array.len() + } + + /// Returns true if the array contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.array.is_empty() + } + + /// Shortens the array, keeping the first `len` elements and dropping the rest. + #[inline] + pub fn truncate(&mut self, len: usize) { + self.array.data.truncate(len); + } +} + +impl TypedArray<'static, bool> { + #[inline] + pub(crate) fn iter(&self) -> impl ExactSizeIterator + '_ { + self.array.data.iter().map(|value| match value { + LhsValue::Bool(b) => b, + _ => unsafe { unreachable_unchecked() }, + }) + } + + #[inline] + pub(crate) fn iter_mut(&mut self) -> impl ExactSizeIterator + '_ { + self.array + .data + .as_vec() + .iter_mut() + .map(|value| match value { + LhsValue::Bool(b) => b, + _ => unsafe { unreachable_unchecked() }, + }) + } +} + +impl> PartialEq for TypedArray<'static, bool> { + fn eq(&self, other: &T) -> bool { + self.iter().eq(other.as_ref()) + } +} + +impl<'a, V: IntoValue<'a>> From> for Array<'a> { + #[inline] + fn from(value: TypedArray<'a, V>) -> Self { + value.array + } +} + +impl<'a, V: IntoValue<'a>> Default for TypedArray<'a, V> { + #[inline] + fn default() -> Self { + Self { + array: Array::new(V::TYPE), + _marker: std::marker::PhantomData, + } + } +} + +impl<'a, V: IntoValue<'a>> Extend for TypedArray<'a, V> { + #[inline] + fn extend>(&mut self, iter: T) { + self.array + .data + .as_vec() + .extend(iter.into_iter().map(IntoValue::into_value)) + } +} + +impl<'a, V: IntoValue<'a>> FromIterator for TypedArray<'a, V> { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self { + array: Array::from_iter(iter), + _marker: std::marker::PhantomData, + } + } +} + +const fn compound_from_type(ty: Type) -> CompoundType { + match CompoundType::from_type(ty) { + Some(ty) => ty, + None => panic!("Could not convert type to compound type"), + } +} + +impl<'a, V: IntoValue<'a>> IntoValue<'a> for TypedArray<'a, V> { + const TYPE: Type = Type::Array(compound_from_type(V::TYPE)); + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Array(self.array) + } +} + +#[test] +fn test_size_of_array() { + assert_eq!(std::mem::size_of::>(), 32); +} + +#[test] +fn test_borrowed_eq_owned() { + let mut owned = Array::new(Type::Bytes); + + owned + .push(LhsValue::Bytes("borrowed".as_bytes().into())) + .unwrap(); + + let borrowed = owned.as_ref(); + + assert!(matches!(owned.data, InnerArray::Owned(_))); + + assert!(matches!(borrowed.data, InnerArray::Borrowed(_))); + + assert_eq!(owned, borrowed); + + assert_eq!(borrowed, borrowed.to_owned()); +} diff --git a/engine/src/lhs_types/map.rs b/engine/src/lhs_types/map.rs new file mode 100644 index 00000000..c0e8ee5b --- /dev/null +++ b/engine/src/lhs_types/map.rs @@ -0,0 +1,496 @@ +use crate::{ + lhs_types::AsRefIterator, + types::{ + BytesOrString, CompoundType, GetType, LhsValue, LhsValueMut, LhsValueSeed, Type, + TypeMismatchError, + }, +}; +use serde::{ + de::{self, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor}, + ser::{SerializeMap, SerializeSeq}, + Serialize, Serializer, +}; +use std::{ + borrow::Cow, + collections::BTreeMap, + fmt, + hash::{Hash, Hasher}, + hint::unreachable_unchecked, + ops::Deref, +}; + +#[derive(Debug, Clone)] +enum InnerMap<'a> { + Owned(BTreeMap, LhsValue<'a>>), + Borrowed(&'a BTreeMap, LhsValue<'a>>), +} + +impl<'a> InnerMap<'a> { + #[inline] + fn as_map(&mut self) -> &mut BTreeMap, LhsValue<'a>> { + match self { + InnerMap::Owned(map) => map, + InnerMap::Borrowed(map) => { + *self = InnerMap::Owned(map.clone()); + match self { + InnerMap::Owned(map) => map, + _ => unsafe { unreachable_unchecked() }, + } + } + } + } + + #[inline] + fn get_mut(&mut self, key: &[u8]) -> Option<&mut LhsValue<'a>> { + self.as_map().get_mut(key) + } + + #[inline] + fn insert(&mut self, key: &[u8], value: LhsValue<'a>) { + self.as_map().insert(key.to_vec().into_boxed_slice(), value); + } + + #[inline] + fn get_or_insert(&mut self, key: Box<[u8]>, value: LhsValue<'a>) -> &mut LhsValue<'a> { + self.as_map().entry(key).or_insert(value) + } +} + +impl<'a> Deref for InnerMap<'a> { + type Target = BTreeMap, LhsValue<'a>>; + + #[inline] + fn deref(&self) -> &Self::Target { + match self { + InnerMap::Owned(map) => map, + InnerMap::Borrowed(ref_map) => ref_map, + } + } +} + +/// A map of string to [`Type`]. +#[derive(Debug, Clone)] +pub struct Map<'a> { + val_type: CompoundType, + data: InnerMap<'a>, +} + +impl<'a> Map<'a> { + /// Creates a new map + pub fn new(val_type: impl Into) -> Self { + Self { + val_type: val_type.into(), + data: InnerMap::Owned(BTreeMap::new()), + } + } + + /// Get a reference to an element if it exists + pub fn get(&self, key: &[u8]) -> Option<&LhsValue<'a>> { + self.data.get(key) + } + + /// Get a mutable reference to an element if it exists + pub fn get_mut(&mut self, key: &[u8]) -> Option> { + self.data.get_mut(key).map(LhsValueMut::from) + } + + /// Inserts an element, overwriting if one already exists + pub fn insert( + &mut self, + key: &[u8], + value: impl Into>, + ) -> Result<(), TypeMismatchError> { + let value = value.into(); + let value_type = value.get_type(); + if value_type != self.val_type.into() { + return Err(TypeMismatchError { + expected: Type::from(self.val_type).into(), + actual: value_type, + }); + } + self.data.insert(key, value); + Ok(()) + } + + /// Inserts `value` if `key` is missing, then returns a mutable reference to the contained value. + pub fn get_or_insert( + &mut self, + key: Box<[u8]>, + value: impl Into>, + ) -> Result, TypeMismatchError> { + let value = value.into(); + let value_type = value.get_type(); + if value_type != self.val_type.into() { + return Err(TypeMismatchError { + expected: Type::from(self.val_type).into(), + actual: value_type, + }); + } + Ok(LhsValueMut::from(self.data.get_or_insert(key, value))) + } + + pub(crate) fn as_ref(&'a self) -> Map<'a> { + Map { + val_type: self.val_type, + data: match self.data { + InnerMap::Owned(ref map) => InnerMap::Borrowed(map), + InnerMap::Borrowed(ref_map) => InnerMap::Borrowed(ref_map), + }, + } + } + + /// Converts a `Map` with borrowed data to a fully owned `Map`. + pub fn into_owned(self) -> Map<'static> { + Map { + val_type: self.val_type, + data: match self.data { + InnerMap::Owned(map) => InnerMap::Owned( + map.into_iter() + .map(|(key, val)| (key, val.into_owned())) + .collect(), + ), + InnerMap::Borrowed(map) => InnerMap::Owned( + map.iter() + .map(|(key, value)| (key.clone(), value.clone().into_owned())) + .collect(), + ), + }, + } + } + + /// Returns the type of the contained values. + pub fn value_type(&self) -> Type { + self.val_type.into() + } + + /// Returns the number of elements in the map + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns true if the map contains no elements. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Convert current map into an iterator over contained values + pub fn values_into_iter(self) -> MapValuesIntoIter<'a> { + let Map { data, .. } = self; + match data { + InnerMap::Owned(map) => MapValuesIntoIter::Owned(map.into_iter()), + InnerMap::Borrowed(map) => { + MapValuesIntoIter::Borrowed(AsRefIterator::new(map.values())) + } + } + } + + pub(crate) fn extract(self, key: &[u8]) -> Option> { + let Self { data, .. } = self; + match data { + InnerMap::Owned(mut map) => map.remove(key), + InnerMap::Borrowed(map) => map.get(key).map(LhsValue::as_ref), + } + } + + /// Creates an iterator visiting all key-value pairs in arbitrary order. + #[inline] + pub fn iter(&self) -> MapIter<'a, '_> { + MapIter(self.data.iter()) + } +} + +impl<'a> PartialEq for Map<'a> { + #[inline] + fn eq(&self, other: &Map<'a>) -> bool { + self.val_type == other.val_type && self.data.deref() == other.data.deref() + } +} + +impl<'a> Eq for Map<'a> {} + +impl<'a> GetType for Map<'a> { + #[inline] + fn get_type(&self) -> Type { + Type::Map(self.val_type) + } +} + +impl<'a> Hash for Map<'a> { + fn hash(&self, state: &mut H) { + self.get_type().hash(state); + self.data.deref().hash(state); + } +} + +/// An iterator over the entries of a Map. +pub struct MapIter<'a, 'b>(std::collections::btree_map::Iter<'b, Box<[u8]>, LhsValue<'a>>); + +impl<'a, 'b> Iterator for MapIter<'a, 'b> { + type Item = (&'b [u8], &'b LhsValue<'a>); + + #[inline] + fn next(&mut self) -> Option { + self.0.next().map(|(k, v)| (&**k, v)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } +} + +impl<'a, 'b> ExactSizeIterator for MapIter<'a, 'b> { + #[inline] + fn len(&self) -> usize { + self.0.len() + } +} + +pub enum MapValuesIntoIter<'a> { + Owned(std::collections::btree_map::IntoIter, LhsValue<'a>>), + Borrowed(AsRefIterator<'a, std::collections::btree_map::Values<'a, Box<[u8]>, LhsValue<'a>>>), +} + +impl<'a> Iterator for MapValuesIntoIter<'a> { + type Item = LhsValue<'a>; + + #[inline] + fn next(&mut self) -> Option { + match self { + MapValuesIntoIter::Owned(iter) => iter.next().map(|(_, v)| v), + MapValuesIntoIter::Borrowed(iter) => iter.next(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } +} + +impl<'a> ExactSizeIterator for MapValuesIntoIter<'a> { + fn len(&self) -> usize { + match self { + MapValuesIntoIter::Owned(iter) => iter.len(), + MapValuesIntoIter::Borrowed(iter) => iter.len(), + } + } +} + +impl<'a> IntoIterator for Map<'a> { + type Item = (Box<[u8]>, LhsValue<'a>); + type IntoIter = std::collections::btree_map::IntoIter, LhsValue<'a>>; + fn into_iter(self) -> Self::IntoIter { + match self.data { + InnerMap::Owned(map) => map.into_iter(), + InnerMap::Borrowed(ref_map) => ref_map.clone().into_iter(), + } + } +} + +impl<'a, 'b> IntoIterator for &'b Map<'a> { + type Item = (&'b [u8], &'b LhsValue<'a>); + type IntoIter = MapIter<'a, 'b>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + MapIter(self.data.deref().iter()) + } +} + +impl<'a> Serialize for Map<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let to_map = self.data.keys().all(|key| std::str::from_utf8(key).is_ok()); + + if to_map { + let mut map = serializer.serialize_map(Some(self.len()))?; + for (k, v) in self.data.iter() { + map.serialize_entry(std::str::from_utf8(k).unwrap(), v)?; + } + map.end() + } else { + // Keys have to be sorted in order to have reproducible output + let mut keys = Vec::new(); + for key in self.data.keys() { + keys.push(key) + } + keys.sort(); + let mut seq = serializer.serialize_seq(Some(self.len()))?; + for key in keys { + seq.serialize_element(&[ + &LhsValue::Bytes((&**key).into()), + self.data.get(key).unwrap(), + ])?; + } + seq.end() + } + } +} + +struct MapEntrySeed<'a>(&'a Type); + +impl<'de, 'a> DeserializeSeed<'de> for MapEntrySeed<'a> { + type Value = (Cow<'de, [u8]>, LhsValue<'de>); + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct MapEntryVisitor<'a>(&'a Type); + + impl<'de, 'a> Visitor<'de> for MapEntryVisitor<'a> { + type Value = (Cow<'de, [u8]>, LhsValue<'de>); + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "a [key, lhs value] pair") + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: SeqAccess<'de>, + { + let key = seq + .next_element::>()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let value = seq + .next_element_seed(LhsValueSeed(self.0))? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + Ok((key.into_bytes(), value)) + } + } + + deserializer.deserialize_seq(MapEntryVisitor(self.0)) + } +} + +impl<'de, 'a> DeserializeSeed<'de> for &'a mut Map<'de> { + type Value = (); + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct MapVisitor<'de, 'a>(&'a mut Map<'de>); + + impl<'de, 'a> Visitor<'de> for MapVisitor<'de, 'a> { + type Value = (); + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + formatter, + "a map of lhs value or an array of pair of lhs value" + ) + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'de>, + { + while let Some(key) = access.next_key::>()? { + let value = access.next_value_seed(LhsValueSeed(&self.0.value_type()))?; + self.0.insert(key.as_bytes(), value).map_err(|e| { + de::Error::custom(format!( + "invalid type: {:?}, expected {:?}", + e.actual, e.expected + )) + })?; + } + + Ok(()) + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: SeqAccess<'de>, + { + while let Some(entry) = seq.next_element_seed(MapEntrySeed(&self.0.value_type()))? { + self.0.insert(&entry.0, entry.1).map_err(|e| { + de::Error::custom(format!( + "invalid type: {:?}, expected {:?}", + e.actual, e.expected + )) + })?; + } + Ok(()) + } + } + + deserializer.deserialize_struct("", &[], MapVisitor(self)) + } +} + +/// Wrapper type around mutable `Map` to prevent +/// illegal operations like changing the type of +/// its values. +pub struct MapMut<'a, 'b>(&'a mut Map<'b>); + +impl<'a, 'b> MapMut<'a, 'b> { + /// Get a mutable reference to an element if it exists + #[inline] + pub fn get_mut(&'a mut self, key: &[u8]) -> Option> { + self.0.get_mut(key).map(LhsValueMut::from) + } + + /// Inserts an element, overwriting if one already exists + #[inline] + pub fn insert( + &mut self, + key: &[u8], + value: impl Into>, + ) -> Result<(), TypeMismatchError> { + self.0.insert(key, value) + } + + /// Inserts `value` if `key` is missing, then returns a mutable reference to the contained value. + #[inline] + pub fn get_or_insert( + &'a mut self, + key: Box<[u8]>, + value: impl Into>, + ) -> Result, TypeMismatchError> { + self.0.get_or_insert(key, value) + } +} + +impl<'a, 'b> Deref for MapMut<'a, 'b> { + type Target = Map<'b>; + + #[inline] + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a, 'b> From<&'a mut Map<'b>> for MapMut<'a, 'b> { + #[inline] + fn from(map: &'a mut Map<'b>) -> Self { + Self(map) + } +} + +#[test] +fn test_size_of_map() { + assert_eq!(std::mem::size_of::>(), 40); +} + +#[test] +fn test_borrowed_eq_owned() { + let mut owned = Map::new(Type::Bytes); + + owned + .insert(b"key", LhsValue::Bytes("borrowed".as_bytes().into())) + .unwrap(); + + let borrowed = owned.as_ref(); + + assert!(matches!(owned.data, InnerMap::Owned(_))); + + assert!(matches!(borrowed.data, InnerMap::Borrowed(_))); + + assert_eq!(owned, borrowed); + + assert_eq!(borrowed, borrowed.to_owned()); +} diff --git a/engine/src/lhs_types/mod.rs b/engine/src/lhs_types/mod.rs new file mode 100644 index 00000000..a021ab1a --- /dev/null +++ b/engine/src/lhs_types/mod.rs @@ -0,0 +1,31 @@ +mod array; +mod map; + +use crate::types::LhsValue; + +pub use self::{ + array::{Array, ArrayIterator, ArrayMut, TypedArray}, + map::{Map, MapIter, MapMut, MapValuesIntoIter}, +}; + +pub struct AsRefIterator<'a, T: Iterator>>(T); + +impl<'a, T: Iterator>> AsRefIterator<'a, T> { + pub fn new(iter: T) -> Self { + Self(iter) + } +} + +impl<'a, T: Iterator>> Iterator for AsRefIterator<'a, T> { + type Item = LhsValue<'a>; + + fn next(&mut self) -> Option { + self.0.next().map(LhsValue::as_ref) + } +} + +impl<'a, T: ExactSizeIterator>> ExactSizeIterator for AsRefIterator<'a, T> { + fn len(&self) -> usize { + self.0.len() + } +} diff --git a/engine/src/lib.rs b/engine/src/lib.rs index 69f1a8d0..ea98accb 100644 --- a/engine/src/lib.rs +++ b/engine/src/lib.rs @@ -8,7 +8,7 @@ //! ``` //! use wirefilter::{ExecutionContext, Scheme, Type}; //! -//! fn main() -> Result<(), failure::Error> { +//! fn main() -> Result<(), Box> { //! // Create a map of possible filter fields. //! let scheme = Scheme! { //! http.method: Bytes, @@ -33,20 +33,20 @@ //! // Set runtime field values to test the filter against. //! let mut ctx = ExecutionContext::new(&scheme); //! -//! ctx.set_field_value("http.method", "GET")?; +//! ctx.set_field_value(scheme.get_field("http.method").unwrap(), "GET")?; //! //! ctx.set_field_value( -//! "http.ua", +//! scheme.get_field("http.ua").unwrap(), //! "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0", //! )?; //! -//! ctx.set_field_value("port", 443)?; +//! ctx.set_field_value(scheme.get_field("port").unwrap(), 443)?; //! //! // Execute the filter with given runtime values. //! println!("Filter matches: {:?}", filter.execute(&ctx)?); // true //! //! // Amend one of the runtime values and execute the filter again. -//! ctx.set_field_value("port", 8080)?; +//! ctx.set_field_value(scheme.get_field("port").unwrap(), 8080)?; //! //! println!("Filter matches: {:?}", filter.execute(&ctx)?); // false //! @@ -54,6 +54,9 @@ //! } //! ``` #![warn(missing_docs)] +#![warn(rust_2018_idioms)] +#![allow(clippy::upper_case_acronyms)] +#![allow(clippy::needless_raw_string_hashes)] #[macro_use] mod lex; @@ -62,22 +65,62 @@ mod lex; mod scheme; mod ast; +mod compiler; mod execution_context; mod filter; mod functions; -mod heap_searcher; +mod lhs_types; +mod list_matcher; +mod panic; mod range_set; mod rhs_types; +mod searcher; mod strict_partial_ord; mod types; pub use self::{ - ast::FilterAst, - execution_context::ExecutionContext, - filter::{Filter, SchemeMismatchError}, + ast::{ + field_expr::{ComparisonExpr, ComparisonOpExpr, IdentifierExpr, IntOp, OrderingOp}, + function_expr::{FunctionCallArgExpr, FunctionCallExpr}, + index_expr::IndexExpr, + logical_expr::{LogicalExpr, LogicalOp, ParenthesizedExpr, UnaryOp}, + parse::{FilterParser, ParseError}, + visitor::{Visitor, VisitorMut}, + Expr, FilterAst, FilterValueAst, ValueExpr, + }, + compiler::{Compiler, DefaultCompiler}, + execution_context::{ + ExecutionContext, ExecutionContextGuard, InvalidListMatcherError, SetFieldValueError, + }, + filter::{ + CompiledExpr, CompiledOneExpr, CompiledValueExpr, CompiledVecExpr, Filter, FilterValue, + }, functions::{ - Function, FunctionArgKind, FunctionArgs, FunctionImpl, FunctionOptParam, FunctionParam, + FunctionArgInvalidConstantError, FunctionArgKind, FunctionArgKindMismatchError, + FunctionArgs, FunctionDefinition, FunctionDefinitionContext, FunctionParam, + FunctionParamError, SimpleFunctionDefinition, SimpleFunctionImpl, SimpleFunctionOptParam, + SimpleFunctionParam, + }, + lex::LexErrorKind, + lhs_types::{Array, ArrayMut, Map, MapIter, MapMut, TypedArray}, + list_matcher::{ + AlwaysList, AlwaysListMatcher, ListDefinition, ListMatcher, NeverList, NeverListMatcher, + }, + panic::{ + catch_panic, panic_catcher_disable, panic_catcher_enable, panic_catcher_get_backtrace, + panic_catcher_set_fallback_mode, panic_catcher_set_hook, PanicCatcherFallbackMode, + }, + rhs_types::{ + Bytes, BytesFormat, ExplicitIpRange, IntRange, IpCidr, IpRange, Regex, RegexError, + RegexFormat, + }, + scheme::{ + Field, FieldIndex, FieldRedefinitionError, Function, FunctionRedefinitionError, Identifier, + IdentifierRedefinitionError, IndexAccessError, List, Scheme, SchemeMismatchError, + UnknownFieldError, + }, + types::{ + ExpectedType, ExpectedTypeList, GetType, LhsValue, LhsValueMut, RhsValue, RhsValues, Type, + TypeMismatchError, }, - scheme::{FieldRedefinitionError, ParseError, Scheme, UnknownFieldError}, - types::{GetType, LhsValue, Type, TypeMismatchError}, }; diff --git a/engine/src/list_matcher.rs b/engine/src/list_matcher.rs new file mode 100644 index 00000000..cbe53d1d --- /dev/null +++ b/engine/src/list_matcher.rs @@ -0,0 +1,165 @@ +use crate::LhsValue; +use crate::Type; +use serde_json::Value; +use std::any::Any; +use std::fmt::Debug; + +/// Defines a new list to match against. +/// +/// `ListDefinition` needs to be registered in the `Scheme` for a given `Type`. +/// See `Scheme::add_list`. +pub trait ListDefinition: Debug + Sync + Send { + /// Converts a deserialized `serde_json::Value` into a `ListMatcher`. + /// + /// This method is necessary to support deserialization of lists during the + /// the deserialization of an `ExecutionContext`. + fn matcher_from_json_value( + &self, + ty: Type, + value: Value, + ) -> Result, serde_json::Error>; + + /// Creates a new matcher object for this list. + fn new_matcher(&self) -> Box; +} + +pub trait AsAny { + fn as_any(&self) -> &dyn Any; + + fn as_any_mut(&mut self) -> &mut dyn Any; +} + +impl AsAny for T { + #[inline] + fn as_any(&self) -> &dyn Any { + self + } + + #[inline] + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } +} + +/// Object safe version of PartialEq trait used for equality comparison. +pub trait DynPartialEq { + fn dyn_eq(&self, other: &dyn Any) -> bool; +} + +impl DynPartialEq for T { + #[inline] + fn dyn_eq(&self, other: &dyn Any) -> bool { + if let Some(other) = other.downcast_ref::() { + self == other + } else { + false + } + } +} + +/// Implement this Trait to match a given `LhsValue` against a list. +pub trait ListMatcher: AsAny + Debug + DynPartialEq + Send + Sync + 'static { + /// Returns true if `val` is in the given list. + fn match_value(&self, list_name: &str, val: &LhsValue<'_>) -> bool; + + /// Convert the list matcher to a serde_json::Value in order to serialize it. + fn to_json_value(&self) -> Value; + + /// Clears the list matcher, removing all its content. + fn clear(&mut self); +} + +impl PartialEq for dyn ListMatcher { + #[inline] + fn eq(&self, other: &dyn ListMatcher) -> bool { + DynPartialEq::dyn_eq(self, other.as_any()) + } +} + +/// List that always matches. +#[derive(Debug, Default)] +pub struct AlwaysList {} + +/// Matcher for `AlwaysList` +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct AlwaysListMatcher {} + +impl ListDefinition for AlwaysList { + fn matcher_from_json_value( + &self, + _: Type, + _: serde_json::Value, + ) -> Result, serde_json::Error> { + Ok(Box::new(AlwaysListMatcher {})) + } + + fn new_matcher(&self) -> Box { + Box::new(AlwaysListMatcher {}) + } +} + +impl ListMatcher for AlwaysListMatcher { + fn match_value(&self, _: &str, _: &LhsValue<'_>) -> bool { + false + } + + fn to_json_value(&self) -> serde_json::Value { + serde_json::Value::Null + } + + fn clear(&mut self) {} +} + +/// List that never matches. +#[derive(Debug, Default)] +pub struct NeverList {} + +/// Matcher for `NeverList` +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct NeverListMatcher {} + +impl ListDefinition for NeverList { + fn matcher_from_json_value( + &self, + _: Type, + _: serde_json::Value, + ) -> Result, serde_json::Error> { + Ok(Box::new(NeverListMatcher {})) + } + + fn new_matcher(&self) -> Box { + Box::new(NeverListMatcher {}) + } +} + +impl ListMatcher for NeverListMatcher { + fn match_value(&self, _: &str, _: &LhsValue<'_>) -> bool { + false + } + + fn to_json_value(&self) -> serde_json::Value { + serde_json::Value::Null + } + + fn clear(&mut self) {} +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_list_matcher_wrapper_comparison() { + let always_list_matcher: Box = Box::new(AlwaysListMatcher {}); + + let always_list_matcher_2: Box = Box::new(AlwaysListMatcher {}); + + assert_eq!(&always_list_matcher, &always_list_matcher_2); + + let never_list_matcher: Box = Box::new(NeverListMatcher {}); + + assert_ne!(&always_list_matcher, &never_list_matcher); + + assert_ne!(&always_list_matcher_2, &never_list_matcher); + } +} diff --git a/engine/src/panic.rs b/engine/src/panic.rs new file mode 100644 index 00000000..03c7ff49 --- /dev/null +++ b/engine/src/panic.rs @@ -0,0 +1,218 @@ +use backtrace::Backtrace; +use std::cell::{Cell, RefCell}; +use std::io::{self, Write}; +use std::panic::UnwindSafe; +use std::process::abort; +use std::sync::atomic::{AtomicBool, Ordering}; + +/// Describes the fallback behavior when +/// a panic occurs outside of `catch_panic`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PanicCatcherFallbackMode { + /// Continue running the subsequent panic hooks. + Continue, + /// Abort the program immediatly. + Abort, +} + +thread_local! { + // String to store last backtrace that is recorded in the panic catcher hook + static PANIC_CATCHER_BACKTRACE : RefCell = const { RefCell::new(String::new()) }; + + // Integer to track the number of nested `catch_panic` calls. + static PANIC_CATCHER_LEVEL: Cell = const { Cell::new(0) }; + + // Fallback mode for when the hook is set but a panic occurs outside of a catch block + static PANIC_CATCHER_FALLBACK_MODE: Cell = const { Cell::new(PanicCatcherFallbackMode::Continue) }; + + // Status of the panic catcher + static PANIC_CATCHER_ENABLED: Cell = const { Cell::new(false) }; +} +static PANIC_CATCHER_HOOK_SET: AtomicBool = AtomicBool::new(false); + +#[inline] +fn panic_catcher_start_catching() -> bool { + if PANIC_CATCHER_ENABLED.with(|b| b.get()) { + PANIC_CATCHER_LEVEL.with(|b| { + let Some(level) = b.get().checked_add(1) else { + abort() + }; + b.set(level) + }); + true + } else { + false + } +} + +#[inline] +fn panic_catcher_stop_catching() { + PANIC_CATCHER_LEVEL.with(|b| { + let Some(level) = b.get().checked_sub(1) else { + abort() + }; + b.set(level) + }); +} + +/// Retrieves the backtrace stored during the last panic +/// for the current thread. +pub fn panic_catcher_get_backtrace() -> Option { + PANIC_CATCHER_BACKTRACE.with(|bt| { + let bt = bt.borrow(); + if bt.is_empty() { + None + } else { + Some(bt.to_string()) + } + }) +} + +/// Configures the fallback behavior when +/// a panic occurs outside of `catch_panic`. +pub fn panic_catcher_set_fallback_mode( + fallback_mode: PanicCatcherFallbackMode, +) -> PanicCatcherFallbackMode { + PANIC_CATCHER_FALLBACK_MODE.with(|b| b.replace(fallback_mode)) +} + +/// Catch a panic. +#[inline(always)] +pub fn catch_panic(f: F) -> Result +where + F: FnOnce() -> Result + UnwindSafe, +{ + if panic_catcher_start_catching() { + let result = std::panic::catch_unwind(f); + panic_catcher_stop_catching(); + match result { + Ok(res) => res, + Err(_) => Err(panic_catcher_get_backtrace().unwrap_or_else(|| { + "thread '' panicked at '' in file '' at line 0" + .to_string() + })), + } + } else { + f() + } +} + +fn record_backtrace(info: &std::panic::PanicHookInfo<'_>, bt: &mut String) { + let (file, line) = if let Some(location) = info.location() { + (location.file(), location.line()) + } else { + ("", 0) + }; + let payload = if let Some(payload) = info.payload().downcast_ref::<&str>() { + payload + } else if let Some(payload) = info.payload().downcast_ref::() { + payload + } else { + "" + }; + bt.truncate(0); + let _ = std::fmt::write( + &mut *bt, + format_args!( + "thread '{}' panicked at '{}' in file '{}' at line {}\n{:?}\n", + std::thread::current().name().unwrap_or(""), + payload, + file, + line, + Backtrace::new() + ), + ); +} + +/// Registers panic catcher panic hook. +pub fn panic_catcher_set_hook() { + if PANIC_CATCHER_HOOK_SET.load(Ordering::SeqCst) { + return; + } + let next = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |info| { + if PANIC_CATCHER_LEVEL.with(|enabled| enabled.get() > 0) { + PANIC_CATCHER_BACKTRACE.with(|bt| { + let mut bt = bt.borrow_mut(); + record_backtrace(info, &mut bt); + }); + return; + } + match PANIC_CATCHER_FALLBACK_MODE.with(|b| b.get()) { + PanicCatcherFallbackMode::Continue => next(info), + PanicCatcherFallbackMode::Abort => { + let mut bt = String::new(); + record_backtrace(info, &mut bt); + let _ = io::stderr().write_all(bt.as_bytes()); + abort(); + } + } + })); + PANIC_CATCHER_HOOK_SET.store(true, Ordering::SeqCst); +} + +/// Enables the panic catcher. +pub fn panic_catcher_enable() { + PANIC_CATCHER_ENABLED.with(|b| b.set(true)); +} + +/// Disables the panic catcher. +pub fn panic_catcher_disable() { + PANIC_CATCHER_ENABLED.with(|b| b.set(false)); +} + +#[cfg(test)] +mod panic_test { + use super::*; + + #[test] + #[cfg_attr(miri, ignore)] + #[should_panic(expected = r#"Hello World!"#)] + fn test_panic_catcher_set_panic_hook_can_still_panic() { + panic_catcher_set_hook(); + panic!("Hello World!"); + } + + #[test] + #[cfg_attr(miri, ignore)] + #[should_panic(expected = r#"Hello World!"#)] + fn test_panic_catcher_enabled_disabled_can_still_panic() { + panic_catcher_set_hook(); + panic_catcher_enable(); + panic_catcher_disable(); + panic!("Hello World!"); + } + + #[test] + fn test_panic_catcher_can_catch_panic() { + panic_catcher_set_hook(); + assert_eq!( + panic_catcher_set_fallback_mode(PanicCatcherFallbackMode::Abort), + PanicCatcherFallbackMode::Continue + ); + panic_catcher_enable(); + match catch_panic::<_, ()>(|| panic!("Halt and Catch Panic")) { + Ok(_) => unreachable!(), + Err(msg) => assert!(msg.contains("Halt and Catch Panic")), + } + panic_catcher_disable(); + } + + #[test] + fn test_panic_catcher_can_catch_panic_nested() { + panic_catcher_set_hook(); + assert_eq!( + panic_catcher_set_fallback_mode(PanicCatcherFallbackMode::Abort), + PanicCatcherFallbackMode::Continue + ); + panic_catcher_enable(); + match catch_panic::<_, ()>(|| { + catch_panic::<_, ()>(|| panic!("Nested Panic is Caught")).unwrap_err(); + panic!("Halt and Catch Panic") + }) { + Ok(_) => unreachable!(), + Err(msg) => assert!(msg.contains("Halt and Catch Panic")), + } + panic_catcher_disable(); + } +} diff --git a/engine/src/rhs_types/array.rs b/engine/src/rhs_types/array.rs new file mode 100644 index 00000000..5129e5e8 --- /dev/null +++ b/engine/src/rhs_types/array.rs @@ -0,0 +1,51 @@ +use crate::{ + lex::{Lex, LexResult}, + lhs_types::Array, + strict_partial_ord::StrictPartialOrd, + types::{GetType, Type}, +}; +use serde::Serialize; +use std::{borrow::Borrow, cmp::Ordering}; + +/// [Uninhabited / empty type](https://doc.rust-lang.org/nomicon/exotic-sizes.html#empty-types) +/// for `array` with traits we need for RHS values. +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] +pub enum UninhabitedArray {} + +impl<'a> Borrow> for UninhabitedArray { + fn borrow(&self) -> &Array<'a> { + match *self {} + } +} + +impl<'a> PartialEq for Array<'a> { + fn eq(&self, other: &UninhabitedArray) -> bool { + match *other {} + } +} + +impl<'a> PartialOrd for Array<'a> { + fn partial_cmp(&self, other: &UninhabitedArray) -> Option { + match *other {} + } +} + +impl<'a> StrictPartialOrd for Array<'a> {} + +impl<'i> Lex<'i> for UninhabitedArray { + fn lex(_input: &str) -> LexResult<'_, Self> { + unreachable!() + } +} + +impl GetType for UninhabitedArray { + fn get_type(&self) -> Type { + unreachable!() + } +} + +impl GetType for Vec { + fn get_type(&self) -> Type { + unreachable!() + } +} diff --git a/engine/src/rhs_types/bytes.rs b/engine/src/rhs_types/bytes.rs index 4e896f97..048deb1a 100644 --- a/engine/src/rhs_types/bytes.rs +++ b/engine/src/rhs_types/bytes.rs @@ -1,68 +1,131 @@ use crate::{ - lex::{expect, take, Lex, LexErrorKind, LexResult}, + lex::{take, Lex, LexErrorKind, LexResult}, strict_partial_ord::StrictPartialOrd, }; -use serde::Serialize; +use serde::{Serialize, Serializer}; use std::{ - borrow::Borrow, fmt::{self, Debug, Formatter}, hash::{Hash, Hasher}, ops::Deref, str, }; -#[derive(PartialEq, Eq, Clone, Serialize)] -#[serde(untagged)] -pub enum Bytes { - Str(Box), - Raw(Box<[u8]>), +/// BytesFormat describes the format in which the string was expressed +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum BytesFormat { + /// Quoted string + Quoted, + /// Raw string, similar to rust raw strings + Raw(u8), // For the hash count + /// Raw byte string literal + Byte, +} + +/// Bytes literal represented either by a string, raw string or raw bytes. +#[derive(PartialEq, Eq, Clone)] +pub struct Bytes { + format: BytesFormat, + data: Box<[u8]>, +} + +impl Bytes { + /// Creates a new bytes literal. + #[inline] + pub fn new(data: impl Into>, format: BytesFormat) -> Self { + Self { + format, + data: data.into(), + } + } + + /// Returns the Bytes format based on the BytesFormat enum + #[inline] + pub fn format(&self) -> BytesFormat { + self.format + } +} + +impl Serialize for Bytes { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self.format() { + BytesFormat::Quoted | BytesFormat::Raw(_) => match std::str::from_utf8(&self.data) { + Ok(s) => s.serialize(serializer), + Err(_) => self.data.serialize(serializer), + }, + BytesFormat::Byte => self.data.serialize(serializer), + } + } } // We need custom `Hash` consistent with `Borrow` invariants. // We can get away with `Eq` invariant though because we do want // `Bytes == Bytes` to check enum tags but `Bytes == &[u8]` to ignore them, and // consistency of the latter is all that matters for `Borrow` consumers. -#[allow(clippy::derive_hash_xor_eq)] +#[allow(clippy::derived_hash_with_manual_eq)] impl Hash for Bytes { + #[inline] fn hash(&self, h: &mut H) { - (self as &[u8]).hash(h) + (self as &[u8]).hash(h); } } impl From> for Bytes { + #[inline] fn from(src: Vec) -> Self { - Bytes::Raw(src.into_boxed_slice()) + Bytes { + format: BytesFormat::Byte, + data: src.into_boxed_slice(), + } } } impl From for Bytes { + #[inline] fn from(src: String) -> Self { - Bytes::Str(src.into_boxed_str()) + Bytes { + format: BytesFormat::Quoted, + data: src.into_boxed_str().into_boxed_bytes(), + } } } impl From for Box<[u8]> { + #[inline] fn from(bytes: Bytes) -> Self { - match bytes { - Bytes::Str(s) => s.into_boxed_bytes(), - Bytes::Raw(b) => b, - } + bytes.data + } +} + +impl From for Vec { + #[inline] + fn from(bytes: Bytes) -> Self { + bytes.data.into_vec() } } impl Debug for Bytes { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Bytes::Str(s) => s.fmt(f), - Bytes::Raw(b) => { - for (i, b) in b.iter().cloned().enumerate() { - if i != 0 { - write!(f, ":")?; - } - write!(f, "{:02X}", b)?; + fn fmt_raw(data: &[u8], f: &mut Formatter<'_>) -> fmt::Result { + let mut iter = data.iter(); + if let Some(&first) = iter.next() { + write!(f, "{first:02X}")?; + for &b in iter { + write!(f, ":{b:02X}")?; } - Ok(()) } + Ok(()) + } + + match self.format { + BytesFormat::Quoted | BytesFormat::Raw(_) => match std::str::from_utf8(&self.data) { + Ok(s) => s.fmt(f), + Err(_) => fmt_raw(&self.data, f), + }, + BytesFormat::Byte => fmt_raw(&self.data, f), } } } @@ -70,20 +133,29 @@ impl Debug for Bytes { impl Deref for Bytes { type Target = [u8]; + #[inline] fn deref(&self) -> &[u8] { - match self { - Bytes::Str(s) => s.as_bytes(), - Bytes::Raw(b) => b, - } + &self.data } } -impl Borrow<[u8]> for Bytes { - fn borrow(&self) -> &[u8] { +impl AsRef<[u8]> for Bytes { + #[inline] + fn as_ref(&self) -> &[u8] { self } } +impl<'a> IntoIterator for &'a Bytes { + type Item = &'a u8; + type IntoIter = std::slice::Iter<'a, u8>; + + #[inline] + fn into_iter(self) -> std::slice::Iter<'a, u8> { + self.iter() + } +} + fn fixed_byte(input: &str, digits: usize, radix: u32) -> LexResult<'_, u8> { let (digits, rest) = take(input, digits)?; match u8::from_str_radix(digits, radix) { @@ -100,136 +172,432 @@ fn oct_byte(input: &str) -> LexResult<'_, u8> { fixed_byte(input, 3, 8) } -lex_enum!(ByteSeparator { - ":" => Colon, - "-" => Dash, - "." => Dot, -}); +fn write_char(vec: &mut Vec, c: char) { + let mut bytes = [0; 4]; + let len = c.encode_utf8(&mut bytes).len(); + vec.extend_from_slice(&bytes[..len]); +} -impl<'i> Lex<'i> for Bytes { - fn lex(mut input: &str) -> LexResult<'_, Self> { - if let Ok(input) = expect(input, "\"") { - let full_input = input; - let mut res = String::new(); - let mut iter = input.chars(); - loop { - match iter +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize)] +enum ByteSeparator { + Colon, + Dash, + Dot, +} + +impl<'i> Lex<'i> for ByteSeparator { + fn lex(input: &str) -> LexResult<'_, Self> { + let (sep, rest) = take(input, 1)?; + match sep { + ":" => Ok((ByteSeparator::Colon, rest)), + "-" => Ok((ByteSeparator::Dash, rest)), + "." => Ok((ByteSeparator::Dot, rest)), + _ => Err((LexErrorKind::ExpectedName("byte separator"), sep)), + } + } +} + +pub(crate) fn lex_quoted_string_as_vec(input: &str) -> LexResult<'_, Vec> { + let full_input = input; + let mut res = Vec::new(); + let mut iter = input.chars(); + loop { + match iter + .next() + .ok_or((LexErrorKind::MissingEndingQuote, full_input))? + { + '\\' => { + let input = iter.as_str(); + + let c = iter .next() - .ok_or_else(|| (LexErrorKind::MissingEndingQuote, full_input))? - { - '\\' => { - let input = iter.as_str(); - - let c = iter - .next() - .ok_or_else(|| (LexErrorKind::MissingEndingQuote, full_input))?; - - res.push(match c { - '"' | '\\' => c, - 'x' => { - let (b, input) = hex_byte(iter.as_str())?; - iter = input.chars(); - b as char - } - '0'..='7' => { - let (b, input) = oct_byte(input)?; - iter = input.chars(); - b as char - } - _ => { - return Err(( - LexErrorKind::InvalidCharacterEscape, - &input[..c.len_utf8()], - )); - } - }); + .ok_or((LexErrorKind::MissingEndingQuote, full_input))?; + + match c { + '"' | '\\' => write_char(&mut res, c), + 'x' => { + let (b, rest) = hex_byte(iter.as_str())?; + iter = rest.chars(); + res.push(b); + } + '0'..='7' => { + let (b, rest) = oct_byte(input)?; + iter = rest.chars(); + res.push(b); + } + _ => { + return Err((LexErrorKind::InvalidCharacterEscape, &input[..c.len_utf8()])); } - '"' => return Ok((res.into(), iter.as_str())), - c => res.push(c), - }; + } } + '"' => return Ok((res, iter.as_str())), + c => write_char(&mut res, c), + }; + } +} + +fn lex_quoted_string(input: &str) -> LexResult<'_, Bytes> { + lex_quoted_string_as_vec(input).map(|(vec, rest)| { + let bytes = Bytes { + format: BytesFormat::Quoted, + data: vec.into_boxed_slice(), + }; + + (bytes, rest) + }) +} + +fn lex_byte_string(mut input: &str) -> LexResult<'_, Bytes> { + let mut res = Vec::new(); + let (b, rest) = hex_byte(input)?; + res.push(b); + input = rest; + let (_, rest) = ByteSeparator::lex(input)?; + input = rest; + loop { + let (b, rest) = hex_byte(input)?; + res.push(b); + input = rest; + if let Ok((_, rest)) = ByteSeparator::lex(input) { + input = rest; } else { - let mut res = Vec::new(); - loop { - let (b, rest) = hex_byte(input)?; - res.push(b); - input = rest; - if let Ok((_, rest)) = ByteSeparator::lex(input) { - input = rest; - } else { - return Ok((res.into(), input)); - } + return Ok((res.into(), input)); + } + } +} + +pub(crate) fn lex_raw_string_as_str(input: &str) -> LexResult<'_, (&str, u8)> { + let full_input = input; + + let start_hash_count = input.chars().take_while(|&c| c == '#').count(); + let hash_count: u8 = start_hash_count + .try_into() + .map_err(|_| (LexErrorKind::InvalidRawStringHashCount, full_input))?; + + // consume '"'` + if input.as_bytes().get(start_hash_count) != Some(&b'"') { + return Err(( + LexErrorKind::ExpectedName("\" or #"), + &full_input[start_hash_count..], + )); + } + + let mut iter = input[start_hash_count + 1..].char_indices().peekable(); + + // look for final sequence or fail + loop { + let (i, c) = iter + .next() + .ok_or((LexErrorKind::MissingEndingQuote, full_input))?; + if c == '"' { + // count end hashes + let mut end_hash_count = 0; + while iter.by_ref().next_if(|(_, x)| x == &'#').is_some() { + end_hash_count += 1; + } + + // return if this is a final sequence + if end_hash_count >= start_hash_count { + let full_prefix = start_hash_count + 1; + return Ok(( + (&full_input[full_prefix..i + full_prefix], hash_count), + &full_input[2 * full_prefix + i..], + )); } } } } +#[inline] +fn lex_raw_string(input: &str) -> LexResult<'_, Bytes> { + let ((lexed, hash_count), rest) = lex_raw_string_as_str(input)?; + Ok(( + Bytes { + format: BytesFormat::Raw(hash_count), + data: Box::from(lexed.as_bytes()), + }, + rest, + )) +} + +pub(crate) fn lex_quoted_or_raw_string(input: &str) -> LexResult<'_, Bytes> { + match input.as_bytes().first() { + Some(b'"') => lex_quoted_string(&input[1..]), + Some(b'r') => lex_raw_string(&input[1..]), + Some(_) => Err((LexErrorKind::ExpectedName("\" or r"), input)), + None => Err((LexErrorKind::EOF, "")), + } +} + +impl<'i> Lex<'i> for Bytes { + #[inline] + fn lex(input: &str) -> LexResult<'_, Self> { + match input.as_bytes().first() { + Some(b'"' | b'r') => lex_quoted_or_raw_string(input), + Some(_) => lex_byte_string(input), + None => Err((LexErrorKind::EOF, "")), + } + } +} + impl StrictPartialOrd for [u8] {} -#[test] -fn test() { - assert_ok!( - Bytes::lex("01:2e:f3-77.12;"), - Bytes::from(vec![0x01, 0x2E, 0xF3, 0x77, 0x12]), - ";" - ); - - assert_ok!( - Bytes::lex(r#""s\\t\"r\x0A\000t""#), - Bytes::from("s\\t\"r\n\0t".to_owned()) - ); - - assert_err!( - Bytes::lex("01:4x;"), - LexErrorKind::ParseInt { - err: u8::from_str_radix("4x", 16).unwrap_err(), - radix: 16, - }, - "4x" - ); +#[cfg(test)] +mod test { + use super::*; - assert_ok!(Bytes::lex("01;"), Bytes::from(vec![0x01]), ";"); + #[test] + fn test() { + assert_ok!( + Bytes::lex("01:2e:f3-77.12;"), + Bytes::from(vec![0x01, 0x2E, 0xF3, 0x77, 0x12]), + ";" + ); - assert_ok!(Bytes::lex("01:2f-34"), Bytes::from(vec![0x01, 0x2F, 0x34])); + assert_ok!( + Bytes::lex(r#""s\\t\"r\x0A\000t""#), + Bytes::from("s\\t\"r\n\0t".to_owned()) + ); - assert_err!(Bytes::lex("\"1"), LexErrorKind::MissingEndingQuote, "1"); + assert_err!( + Bytes::lex("01:4x;"), + LexErrorKind::ParseInt { + err: u8::from_str_radix("4x", 16).unwrap_err(), + radix: 16, + }, + "4x" + ); - assert_err!( - Bytes::lex(r#""\n""#), - LexErrorKind::InvalidCharacterEscape, - "n" - ); + assert_err!( + Bytes::lex("01;"), + LexErrorKind::ExpectedName("byte separator"), + ";" + ); - assert_err!( - Bytes::lex(r#""abcd\"#), - LexErrorKind::MissingEndingQuote, - "abcd\\" - ); + assert_err!( + Bytes::lex("01:;"), + LexErrorKind::CountMismatch { + name: "character", + actual: 1, + expected: 2 + }, + ";" + ); - assert_err!( - Bytes::lex(r#""\01😢""#), - LexErrorKind::ParseInt { - err: u8::from_str_radix("01😢", 8).unwrap_err(), - radix: 8, - }, - "01😢" - ); - - assert_err!( - Bytes::lex(r#""\x3😢""#), - LexErrorKind::ParseInt { - err: u8::from_str_radix("3😢", 16).unwrap_err(), - radix: 16, - }, - "3😢" - ); - - assert_err!( - Bytes::lex("12:3😢"), - LexErrorKind::ParseInt { - err: u8::from_str_radix("3😢", 16).unwrap_err(), - radix: 16, - }, - "3😢" - ); + assert_ok!(Bytes::lex("01:2f-34"), Bytes::from(vec![0x01, 0x2F, 0x34])); + + assert_err!(Bytes::lex("\"1"), LexErrorKind::MissingEndingQuote, "1"); + + assert_err!( + Bytes::lex(r#""\n""#), + LexErrorKind::InvalidCharacterEscape, + "n" + ); + + assert_err!( + Bytes::lex(r#""abcd\"#), + LexErrorKind::MissingEndingQuote, + "abcd\\" + ); + + assert_err!( + Bytes::lex(r#""\01😢""#), + LexErrorKind::ParseInt { + err: u8::from_str_radix("01😢", 8).unwrap_err(), + radix: 8, + }, + "01😢" + ); + + assert_err!( + Bytes::lex(r#""\x3😢""#), + LexErrorKind::ParseInt { + err: u8::from_str_radix("3😢", 16).unwrap_err(), + radix: 16, + }, + "3😢" + ); + + assert_err!( + Bytes::lex("12:3😢"), + LexErrorKind::ParseInt { + err: u8::from_str_radix("3😢", 16).unwrap_err(), + radix: 16, + }, + "3😢" + ); + + assert_ok!(Bytes::lex(r#""\x7F""#), Bytes::from("\x7F".to_owned())); + + assert_ok!( + Bytes::lex(r#""\x80""#), + Bytes::new(vec![0x80], BytesFormat::Quoted) + ); + + assert_ok!( + Bytes::lex(r#""\xFF""#), + Bytes::new(vec![0xFF], BytesFormat::Quoted) + ); + + assert_ok!(Bytes::lex(r#""\177""#), Bytes::from("\x7F".to_owned())); + + assert_ok!( + Bytes::lex(r#""\200""#), + Bytes::new(vec![0x80], BytesFormat::Quoted) + ); + + assert_ok!( + Bytes::lex(r#""\377""#), + Bytes::new(vec![0xFF], BytesFormat::Quoted) + ); + + assert_ok!( + Bytes::lex("c2:b4710c6888a5d47befe865c8e6fb19"), + Bytes::from(vec![0xC2, 0xb4]), + "710c6888a5d47befe865c8e6fb19" + ); + } + + #[test] + fn test_raw_string() { + // Valid empty strings + assert_ok!( + Bytes::lex("r\"\""), + Bytes::new("".as_bytes(), BytesFormat::Raw(0)) + ); + assert_ok!( + Bytes::lex("r#\"\"#"), + Bytes::new("".as_bytes(), BytesFormat::Raw(1)) + ); + assert_ok!( + Bytes::lex("r##\"\"##"), + Bytes::new("".as_bytes(), BytesFormat::Raw(2)) + ); + assert_ok!( + Bytes::lex("r###\"\"###"), + Bytes::new("".as_bytes(), BytesFormat::Raw(3)) + ); + + // Valid raw strings + assert_ok!( + Bytes::lex("r\"a\""), + Bytes::new("a".as_bytes(), BytesFormat::Raw(0)) + ); + assert_ok!( + Bytes::lex("r#\"a\"#"), + Bytes::new("a".as_bytes(), BytesFormat::Raw(1)) + ); + assert_ok!( + Bytes::lex("r##\"a\"##"), + Bytes::new("a".as_bytes(), BytesFormat::Raw(2)) + ); + assert_ok!( + Bytes::lex("r###\"a\"###"), + Bytes::new("a".as_bytes(), BytesFormat::Raw(3)) + ); + + // Quotes and hashes can be used inside the raw string + assert_ok!( + Bytes::lex("r\"#\""), + Bytes::new("#".as_bytes(), BytesFormat::Raw(0)) + ); + assert_ok!( + Bytes::lex("r\"a#\""), + Bytes::new("a#".as_bytes(), BytesFormat::Raw(0)) + ); + assert_ok!( + Bytes::lex("r#\"\"a\"\"\"#"), + Bytes::new("\"a\"\"".as_bytes(), BytesFormat::Raw(1)) + ); + assert_ok!( + Bytes::lex("r##\"\"a\"#b\"##"), + Bytes::new("\"a\"#b".as_bytes(), BytesFormat::Raw(2)) + ); + assert_ok!( + Bytes::lex("r###\"a###\"##\"\"###"), + Bytes::new("a###\"##\"".as_bytes(), BytesFormat::Raw(3)) + ); + assert_ok!( + Bytes::lex("r#\"a\"\"\"#"), + Bytes::new("a\"\"".as_bytes(), BytesFormat::Raw(1)) + ); + assert_ok!( + Bytes::lex("r##\"a\"#\"##"), + Bytes::new("a\"#".as_bytes(), BytesFormat::Raw(2)) + ); + assert_ok!( + Bytes::lex("r###\"a###\"##\"###"), + Bytes::new("a###\"##".as_bytes(), BytesFormat::Raw(3)) + ); + + // Expect an error if the number of '#' doesn't match + assert_err!( + Bytes::lex("r#\"a\""), + LexErrorKind::MissingEndingQuote, + "#\"a\"" + ); + assert_err!( + Bytes::lex("r##\"a\"#"), + LexErrorKind::MissingEndingQuote, + "##\"a\"#" + ); + assert_err!( + Bytes::lex("r###\"a\"##"), + LexErrorKind::MissingEndingQuote, + "###\"a\"##" + ); + + // Expect an error when there are too many hashes being used + let hashes = format!("r{}\"abc\"{}", "#".repeat(255), "#".repeat(255)); + assert_ok!( + Bytes::lex(hashes.as_str()), + Bytes::new("abc".as_bytes(), BytesFormat::Raw(255)) + ); + let hashes = format!("r{}\"abc\"{}", "#".repeat(256), "#".repeat(256)); + assert_err!( + Bytes::lex(hashes.as_str()), + LexErrorKind::InvalidRawStringHashCount, + &hashes.as_str()[1..] + ); + + // Test regex escapes remain the same + assert_ok!( + Bytes::lex(r#"r".\d\D\pA\p{Greek}\PA\P{Greek}[xyz][^xyz][a-z][[:alpha:]][[:^alpha:]][x[^xyz]][a-y&&xyz][0-9&&[^4]][0-9--4][a-g~~b-h][\[\]]""#), + Bytes::new(r#".\d\D\pA\p{Greek}\PA\P{Greek}[xyz][^xyz][a-z][[:alpha:]][[:^alpha:]][x[^xyz]][a-y&&xyz][0-9&&[^4]][0-9--4][a-g~~b-h][\[\]]"#.as_bytes(), BytesFormat::Raw(0)) + ); + assert_ok!( + Bytes::lex(r##"r#"\*\a\f\t\n\r\v\123\x7F\x{10FFFF}\u007F\u{7F}\U0000007F\U{7F}"#"##), + Bytes::new( + r#"\*\a\f\t\n\r\v\123\x7F\x{10FFFF}\u007F\u{7F}\U0000007F\U{7F}"#.as_bytes(), + BytesFormat::Raw(1) + ) + ); + + // Invalid character after 'r' or '#' + assert_err!(Bytes::lex("r"), LexErrorKind::ExpectedName("\" or #"), ""); + assert_err!( + Bytes::lex("r#ab"), + LexErrorKind::ExpectedName("\" or #"), + "ab" + ); + assert_err!( + Bytes::lex("r##ab"), + LexErrorKind::ExpectedName("\" or #"), + "ab" + ); + + // Any characters after a raw string should get returned + assert_eq!( + Bytes::lex("r#\"ab\"##"), + Ok((Bytes::new("ab".as_bytes(), BytesFormat::Raw(1)), "#")) + ); + assert_eq!( + Bytes::lex("r#\"ab\"#\""), + Ok((Bytes::new("ab".as_bytes(), BytesFormat::Raw(1)), "\"")) + ); + assert_eq!( + Bytes::lex("r#\"ab\"#a"), + Ok((Bytes::new("ab".as_bytes(), BytesFormat::Raw(1)), "a")) + ); + } } diff --git a/engine/src/rhs_types/int.rs b/engine/src/rhs_types/int.rs index 80b9a51e..cf28c8cc 100644 --- a/engine/src/rhs_types/int.rs +++ b/engine/src/rhs_types/int.rs @@ -2,21 +2,22 @@ use crate::{ lex::{expect, span, take_while, Lex, LexErrorKind, LexResult}, strict_partial_ord::StrictPartialOrd, }; +use serde::Serialize; use std::ops::RangeInclusive; fn lex_digits(input: &str) -> LexResult<'_, &str> { // Lex any supported digits (up to radix 16) for better error locations. - take_while(input, "digit", |c| c.is_digit(16)) + take_while(input, "digit", |c| c.is_ascii_hexdigit()) } -fn parse_number<'i>((input, rest): (&'i str, &'i str), radix: u32) -> LexResult<'_, i32> { - match i32::from_str_radix(input, radix) { +fn parse_number<'i>((input, rest): (&'i str, &'i str), radix: u32) -> LexResult<'i, i64> { + match i64::from_str_radix(input, radix) { Ok(res) => Ok((res, rest)), Err(err) => Err((LexErrorKind::ParseInt { err, radix }, input)), } } -impl<'i> Lex<'i> for i32 { +impl<'i> Lex<'i> for i64 { fn lex(input: &str) -> LexResult<'_, Self> { if let Ok(input) = expect(input, "0x") { parse_number(lex_digits(input)?, 16) @@ -36,12 +37,29 @@ impl<'i> Lex<'i> for i32 { } } -impl<'i> Lex<'i> for RangeInclusive { +/// A range of integers defined by start and end. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize)] +#[serde(transparent)] +pub struct IntRange(RangeInclusive); + +impl From for IntRange { + fn from(i: i64) -> Self { + IntRange(i..=i) + } +} + +impl From> for IntRange { + fn from(r: RangeInclusive) -> Self { + IntRange(r) + } +} + +impl<'i> Lex<'i> for IntRange { fn lex(input: &str) -> LexResult<'_, Self> { let initial_input = input; - let (first, input) = i32::lex(input)?; + let (first, input) = i64::lex(input)?; let (last, input) = if let Ok(input) = expect(input, "..") { - i32::lex(input)? + i64::lex(input)? } else { (first, input) }; @@ -51,53 +69,65 @@ impl<'i> Lex<'i> for RangeInclusive { span(initial_input, input), )); } - Ok((first..=last, input)) + Ok(((first..=last).into(), input)) + } +} + +impl From for RangeInclusive { + fn from(range: IntRange) -> Self { + range.0 + } +} + +impl<'a> From<&'a IntRange> for RangeInclusive { + fn from(range: &'a IntRange) -> Self { + RangeInclusive::new(*range.0.start(), *range.0.end()) } } -impl StrictPartialOrd for i32 {} +impl StrictPartialOrd for i64 {} #[test] fn test() { use std::str::FromStr; - assert_ok!(i32::lex("0"), 0i32, ""); - assert_ok!(i32::lex("0-"), 0i32, "-"); - assert_ok!(i32::lex("0x1f5+"), 501i32, "+"); - assert_ok!(i32::lex("0123;"), 83i32, ";"); - assert_ok!(i32::lex("78!"), 78i32, "!"); - assert_ok!(i32::lex("0xefg"), 239i32, "g"); - assert_ok!(i32::lex("-12-"), -12i32, "-"); + assert_ok!(i64::lex("0"), 0i64, ""); + assert_ok!(i64::lex("0-"), 0i64, "-"); + assert_ok!(i64::lex("0x1f5+"), 501i64, "+"); + assert_ok!(i64::lex("0123;"), 83i64, ";"); + assert_ok!(i64::lex("78!"), 78i64, "!"); + assert_ok!(i64::lex("0xefg"), 239i64, "g"); + assert_ok!(i64::lex("-12-"), -12i64, "-"); assert_err!( - i32::lex("-2147483649!"), + i64::lex("-9223372036854775809!"), LexErrorKind::ParseInt { - err: i32::from_str("-2147483649").unwrap_err(), + err: i64::from_str("-9223372036854775809").unwrap_err(), radix: 10 }, - "-2147483649" + "-9223372036854775809" ); assert_err!( - i32::lex("2147483648!"), + i64::lex("9223372036854775808!"), LexErrorKind::ParseInt { - err: i32::from_str("2147483648").unwrap_err(), + err: i64::from_str("9223372036854775808").unwrap_err(), radix: 10 }, - "2147483648" + "9223372036854775808" ); assert_err!( - i32::lex("10fex"), + i64::lex("10fex"), LexErrorKind::ParseInt { - err: i32::from_str("10fe").unwrap_err(), + err: i64::from_str("10fe").unwrap_err(), radix: 10 }, "10fe" ); - assert_ok!(RangeInclusive::lex("78!"), 78i32..=78i32, "!"); - assert_ok!(RangeInclusive::lex("0..10"), 0i32..=10i32); - assert_ok!(RangeInclusive::lex("0123..0xefg"), 83i32..=239i32, "g"); - assert_ok!(RangeInclusive::lex("-20..-10"), -20i32..=-10i32); + assert_ok!(IntRange::lex("78!"), 78i64.into(), "!"); + assert_ok!(IntRange::lex("0..10"), (0i64..=10i64).into()); + assert_ok!(IntRange::lex("0123..0xefg"), (83i64..=239i64).into(), "g"); + assert_ok!(IntRange::lex("-20..-10"), (-20i64..=-10i64).into()); assert_err!( - >::lex("10..0"), + IntRange::lex("10..0"), LexErrorKind::IncompatibleRangeBounds, "10..0" ); diff --git a/engine/src/rhs_types/ip.rs b/engine/src/rhs_types/ip.rs index 32905ba1..7f4cf72f 100644 --- a/engine/src/rhs_types/ip.rs +++ b/engine/src/rhs_types/ip.rs @@ -1,8 +1,10 @@ +pub use cidr::IpCidr; + use crate::{ lex::{take_while, Lex, LexError, LexErrorKind, LexResult}, strict_partial_ord::StrictPartialOrd, }; -use cidr::{Cidr, IpCidr, Ipv4Cidr, Ipv6Cidr, NetworkParseError}; +use cidr::{errors::NetworkParseError, Ipv4Cidr, Ipv6Cidr}; use serde::Serialize; use std::{ cmp::Ordering, @@ -12,10 +14,11 @@ use std::{ }; fn match_addr_or_cidr(input: &str) -> LexResult<'_, &str> { - take_while(input, "IP address character", |c| match c { - '0'..='9' | 'a'..='f' | 'A'..='F' | ':' | '.' | '/' => true, - _ => false, - }) + take_while( + input, + "IP address character", + |c| matches!(c, '0'..='9' | 'a'..='f' | 'A'..='F' | ':' | '.' | '/'), + ) } fn parse_addr(input: &str) -> Result> { @@ -34,20 +37,35 @@ impl<'i> Lex<'i> for IpAddr { } } -#[derive(PartialEq, Eq, Clone, Serialize, Debug)] +/// An IP range defined explicitly by start and end +#[derive(PartialEq, Eq, Clone, Hash, Serialize, Debug)] #[serde(untagged)] pub enum ExplicitIpRange { + /// An explicit range of IPv4 addresses V4(RangeInclusive), + /// An explicit range of IPv6 addresses V6(RangeInclusive), } -#[derive(PartialEq, Eq, Clone, Serialize, Debug)] +/// A range of IP addresses +#[derive(PartialEq, Eq, Clone, Hash, Serialize, Debug)] #[serde(untagged)] pub enum IpRange { + /// An IP range defined explicitly by start and end Explicit(ExplicitIpRange), + /// A CIDR IP range Cidr(IpCidr), } +impl From for IpRange { + fn from(ip: IpAddr) -> Self { + match ip { + IpAddr::V4(ip) => IpRange::Explicit(ip.into()), + IpAddr::V6(ip) => IpRange::Explicit(ip.into()), + } + } +} + impl<'i> Lex<'i> for IpRange { fn lex(input: &str) -> LexResult<'_, Self> { let (chunk, rest) = match_addr_or_cidr(input)?; @@ -70,7 +88,7 @@ impl<'i> Lex<'i> for IpRange { }) } else { IpRange::Cidr(cidr::IpCidr::from_str(chunk).map_err(|err| { - let split_pos = chunk.find('/').unwrap_or_else(|| chunk.len()); + let split_pos = chunk.find('/').unwrap_or(chunk.len()); let err_span = match err { NetworkParseError::AddrParseError(_) | NetworkParseError::InvalidHostPart => { &chunk[..split_pos] @@ -124,6 +142,7 @@ impl From for ExplicitIpRange { } impl StrictPartialOrd for IpAddr { + #[inline] fn strict_partial_cmp(&self, other: &Self) -> Option { match (self, other) { (IpAddr::V4(lhs), IpAddr::V4(rhs)) => Some(lhs.cmp(rhs)), @@ -190,12 +209,13 @@ fn test_lex() { range([0, 0, 0, 0, 0, 0, 0, 1]..=[0, 0, 0, 0, 0, 0, 0, 2]), "||" ); + assert_ok!(IpRange::lex("1.1.1.01"), cidr([1, 1, 1, 1], 32), ""); match IpRange::lex("10.0.0.0/100") { Err(( LexErrorKind::ParseNetwork(NetworkParseError::NetworkLengthTooLongError(_)), "10.0.0.0/100", )) => {} - err => panic!("Expected NetworkLengthTooLongError, got {:?}", err), + err => panic!("Expected NetworkLengthTooLongError, got {err:?}"), } assert_err!( IpRange::lex("::/.1"), diff --git a/engine/src/rhs_types/list.rs b/engine/src/rhs_types/list.rs new file mode 100644 index 00000000..f56e42f2 --- /dev/null +++ b/engine/src/rhs_types/list.rs @@ -0,0 +1,167 @@ +use crate::lex::{expect, Lex, LexErrorKind, LexResult}; +use serde::Serialize; +use std::str; + +#[derive(PartialEq, Eq, Clone, Serialize, Hash, Debug)] +pub struct ListName(Box); + +impl From for ListName { + fn from(src: String) -> Self { + ListName(src.into_boxed_str()) + } +} + +impl<'i> Lex<'i> for ListName { + fn lex(input: &str) -> LexResult<'_, Self> { + let mut res = String::new(); + let mut rest; + let input = expect(input, "$")?; + let mut iter = input.chars(); + loop { + rest = iter.as_str(); + match iter.next() { + Some(c) => match c { + 'a'..='z' | '0'..='9' | '_' | '.' => res.push(c), + _ => { + if res.is_empty() { + return Err(( + LexErrorKind::InvalidListName { + name: c.to_string(), + }, + input, + )); + } else { + break; + } + } + }, + None => { + if res.is_empty() { + return Err((LexErrorKind::InvalidListName { name: res }, input)); + } else { + break; + } + } + } + } + + if res.as_bytes().first() == Some(&b'.') || res.as_bytes().last() == Some(&b'.') { + return Err((LexErrorKind::InvalidListName { name: res }, input)); + } + + Ok((res.into(), rest)) + } +} + +impl ListName { + pub fn as_str(&self) -> &str { + &self.0 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn valid() { + assert_ok!( + ListName::lex("$hello;"), + ListName::from("hello".to_string()), + ";" + ); + + assert_ok!( + ListName::lex("$hello_world;"), + ListName::from("hello_world".to_string()), + ";" + ); + + assert_ok!( + ListName::lex("$hello.world;"), + ListName::from("hello.world".to_string()), + ";" + ); + + assert_ok!( + ListName::lex("$hello1234567890;"), + ListName::from("hello1234567890".to_string()), + ";" + ); + + assert_ok!( + ListName::lex("$hello"), + ListName::from("hello".to_string()), + "" + ); + } + + #[test] + fn invalid_char() { + assert_err!( + ListName::lex("$;"), + LexErrorKind::InvalidListName { + name: ";".to_string(), + }, + ";" + ); + } + + #[test] + fn eof_after_dollar() { + assert_err!( + ListName::lex("$"), + LexErrorKind::InvalidListName { + name: "".to_string(), + }, + "" + ); + } + + #[test] + fn no_dollar() { + assert_err!( + ListName::lex("abc"), + LexErrorKind::ExpectedLiteral("$"), + "abc" + ); + } + + #[test] + fn special_char_at_start() { + assert_err!( + ListName::lex("$."), + LexErrorKind::InvalidListName { + name: ".".to_string(), + }, + "." + ); + + assert_err!( + ListName::lex("$.abc"), + LexErrorKind::InvalidListName { + name: ".abc".to_string(), + }, + ".abc" + ); + } + + #[test] + fn special_char_at_end() { + assert_err!( + ListName::lex("$."), + LexErrorKind::InvalidListName { + name: ".".to_string(), + }, + "." + ); + + assert_err!( + ListName::lex("$abc."), + LexErrorKind::InvalidListName { + name: "abc.".to_string(), + }, + "abc." + ); + } +} diff --git a/engine/src/rhs_types/map.rs b/engine/src/rhs_types/map.rs new file mode 100644 index 00000000..dd041c3c --- /dev/null +++ b/engine/src/rhs_types/map.rs @@ -0,0 +1,51 @@ +use crate::{ + lex::{Lex, LexResult}, + lhs_types::Map, + strict_partial_ord::StrictPartialOrd, + types::{GetType, Type}, +}; +use serde::Serialize; +use std::{borrow::Borrow, cmp::Ordering}; + +/// [Uninhabited / empty type](https://doc.rust-lang.org/nomicon/exotic-sizes.html#empty-types) +/// for `map` with traits we need for RHS values. +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] +pub enum UninhabitedMap {} + +impl<'a> Borrow> for UninhabitedMap { + fn borrow(&self) -> &Map<'a> { + match *self {} + } +} + +impl<'a> PartialEq for Map<'a> { + fn eq(&self, other: &UninhabitedMap) -> bool { + match *other {} + } +} + +impl<'a> PartialOrd for Map<'a> { + fn partial_cmp(&self, other: &UninhabitedMap) -> Option { + match *other {} + } +} + +impl<'a> StrictPartialOrd for Map<'a> {} + +impl<'i> Lex<'i> for UninhabitedMap { + fn lex(_input: &str) -> LexResult<'_, Self> { + unreachable!() + } +} + +impl GetType for UninhabitedMap { + fn get_type(&self) -> Type { + unreachable!() + } +} + +impl GetType for Vec { + fn get_type(&self) -> Type { + unreachable!() + } +} diff --git a/engine/src/rhs_types/mod.rs b/engine/src/rhs_types/mod.rs index 3165cb11..18d805af 100644 --- a/engine/src/rhs_types/mod.rs +++ b/engine/src/rhs_types/mod.rs @@ -1,12 +1,21 @@ +mod array; mod bool; mod bytes; mod int; mod ip; +mod list; +mod map; mod regex; +mod wildcard; pub use self::{ + array::UninhabitedArray, bool::UninhabitedBool, - bytes::Bytes, - ip::{ExplicitIpRange, IpRange}, - regex::{Error as RegexError, Regex}, + bytes::{Bytes, BytesFormat}, + int::IntRange, + ip::{ExplicitIpRange, IpCidr, IpRange}, + list::ListName, + map::UninhabitedMap, + regex::{Error as RegexError, Regex, RegexFormat}, + wildcard::{Wildcard, WildcardError}, }; diff --git a/engine/src/rhs_types/regex/imp_real.rs b/engine/src/rhs_types/regex/imp_real.rs index 34e0b253..8fce3d47 100644 --- a/engine/src/rhs_types/regex/imp_real.rs +++ b/engine/src/rhs_types/regex/imp_real.rs @@ -1,27 +1,65 @@ -use std::str::FromStr; +use crate::{FilterParser, RegexFormat}; pub use regex::Error; +/// Wrapper around [`regex::bytes::Regex`] #[derive(Clone)] -pub struct Regex(regex::bytes::Regex); - -impl FromStr for Regex { - type Err = Error; +pub struct Regex { + compiled_regex: regex::bytes::Regex, + format: RegexFormat, +} - fn from_str(s: &str) -> Result { - ::regex::bytes::RegexBuilder::new(s) +impl Regex { + /// Compiles a regular expression. + pub fn new( + pattern: &str, + format: RegexFormat, + parser: &FilterParser<'_>, + ) -> Result { + ::regex::bytes::RegexBuilder::new(pattern) .unicode(false) + .size_limit(parser.regex_compiled_size_limit) + .dfa_size_limit(parser.regex_dfa_size_limit) .build() - .map(Regex) + .map(|r| Regex { + compiled_regex: r, + format, + }) } -} -impl Regex { + /// Returns true if and only if the regex matches the string given. pub fn is_match(&self, text: &[u8]) -> bool { - self.0.is_match(text) + self.compiled_regex.is_match(text) } + /// Returns the original string of this regex. pub fn as_str(&self) -> &str { - self.0.as_str() + self.compiled_regex.as_str() + } + + /// Returns the format behind the regex + pub fn format(&self) -> RegexFormat { + self.format } } + +impl From for regex::bytes::Regex { + fn from(regex: Regex) -> Self { + regex.compiled_regex + } +} + +#[test] +fn test_compiled_size_limit() { + use crate::Scheme; + + let scheme = Scheme::default(); + + const COMPILED_SIZE_LIMIT: usize = 1024 * 1024; + let mut parser = FilterParser::new(&scheme); + parser.regex_set_compiled_size_limit(COMPILED_SIZE_LIMIT); + assert_eq!( + Regex::new(".{4079,65535}", RegexFormat::Literal, &parser), + Err(Error::CompiledTooBig(COMPILED_SIZE_LIMIT)) + ); +} diff --git a/engine/src/rhs_types/regex/imp_stub.rs b/engine/src/rhs_types/regex/imp_stub.rs index 5905df4d..0f5a1081 100644 --- a/engine/src/rhs_types/regex/imp_stub.rs +++ b/engine/src/rhs_types/regex/imp_stub.rs @@ -1,33 +1,40 @@ -use failure::Fail; -use std::fmt; -use std::str::FromStr; +use thiserror::Error; -#[derive(Debug, PartialEq, Fail)] -pub enum Error {} +use crate::{FilterParser, RegexFormat}; -impl fmt::Display for Error { - fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self {} - } -} +/// Dummy regex error. +#[derive(Debug, PartialEq, Error)] +pub enum Error {} +/// Dummy regex wrapper that can only store a pattern +/// but not actually be used for matching. #[derive(Clone)] -pub struct Regex(String); - -impl FromStr for Regex { - type Err = Error; - - fn from_str(s: &str) -> Result { - Ok(Regex(s.to_owned())) - } +pub struct Regex { + pattern: String, + format: RegexFormat, } impl Regex { + /// Creates a new dummy regex. + pub fn new(pattern: &str, format: RegexFormat, _: &FilterParser<'_>) -> Result { + Ok(Self { + pattern: pattern.to_string(), + format, + }) + } + + /// Not implemented and will panic if called. pub fn is_match(&self, _text: &[u8]) -> bool { unimplemented!("Engine was built without regex support") } + /// Returns the original string of this dummy regex wrapper. pub fn as_str(&self) -> &str { - self.0.as_str() + self.pattern.as_str() + } + + /// Returns the format behind the regex + pub fn format(&self) -> RegexFormat { + self.format } } diff --git a/engine/src/rhs_types/regex/mod.rs b/engine/src/rhs_types/regex/mod.rs index 5c7ba58a..425cdde3 100644 --- a/engine/src/rhs_types/regex/mod.rs +++ b/engine/src/rhs_types/regex/mod.rs @@ -1,9 +1,11 @@ -use crate::lex::{expect, span, Lex, LexErrorKind, LexResult}; +use crate::lex::{span, LexErrorKind, LexResult, LexWith}; +use crate::rhs_types::bytes::lex_raw_string_as_str; +use crate::FilterParser; use cfg_if::cfg_if; use serde::{Serialize, Serializer}; use std::{ fmt::{self, Debug, Formatter}, - str::FromStr, + hash::{Hash, Hasher}, }; cfg_if! { @@ -16,6 +18,15 @@ cfg_if! { } } +/// RegexFormat describes the format behind the regex +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum RegexFormat { + /// Literal string was used to define the expression + Literal, + /// Raw string was used to define the expression + Raw(u8), +} + impl PartialEq for Regex { fn eq(&self, other: &Regex) -> bool { self.as_str() == other.as_str() @@ -24,53 +35,81 @@ impl PartialEq for Regex { impl Eq for Regex {} +impl Hash for Regex { + fn hash(&self, state: &mut H) { + self.as_str().hash(state); + } +} + impl Debug for Regex { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str(self.as_str()) } } -impl<'i> Lex<'i> for Regex { - fn lex(input: &str) -> LexResult<'_, Self> { - let input = expect(input, "\"")?; - let mut regex_buf = String::new(); - let mut in_char_class = false; - let (regex_str, input) = { - let mut iter = input.chars(); - loop { - let before_char = iter.as_str(); - match iter - .next() - .ok_or_else(|| (LexErrorKind::MissingEndingQuote, input))? - { - '\\' => { - if let Some(c) = iter.next() { - if in_char_class || c != '"' { - regex_buf.push('\\'); - } - regex_buf.push(c); +fn lex_regex_from_raw_string<'i>( + input: &'i str, + parser: &FilterParser<'_>, +) -> LexResult<'i, Regex> { + let ((lexed, hashes), input) = lex_raw_string_as_str(input)?; + match Regex::new(lexed, RegexFormat::Raw(hashes), parser) { + Ok(regex) => Ok((regex, input)), + Err(err) => Err((LexErrorKind::ParseRegex(err), input)), + } +} + +fn lex_regex_from_literal<'i>(input: &'i str, parser: &FilterParser<'_>) -> LexResult<'i, Regex> { + let mut regex_buf = String::new(); + let mut in_char_class = false; + let (regex_str, input) = { + let mut iter = input.chars(); + loop { + let before_char = iter.as_str(); + match iter + .next() + .ok_or((LexErrorKind::MissingEndingQuote, input))? + { + '\\' => { + if let Some(c) = iter.next() { + if in_char_class || c != '"' { + regex_buf.push('\\'); } - } - '"' if !in_char_class => { - break (span(input, before_char), iter.as_str()); - } - '[' if !in_char_class => { - in_char_class = true; - regex_buf.push('['); - } - ']' if in_char_class => { - in_char_class = false; - regex_buf.push(']'); - } - c => { regex_buf.push(c); } - }; + } + '"' if !in_char_class => { + break (span(input, before_char), iter.as_str()); + } + '[' if !in_char_class => { + in_char_class = true; + regex_buf.push('['); + } + ']' if in_char_class => { + in_char_class = false; + regex_buf.push(']'); + } + c => { + regex_buf.push(c); + } + }; + } + }; + match Regex::new(®ex_buf, RegexFormat::Literal, parser) { + Ok(regex) => Ok((regex, input)), + Err(err) => Err((LexErrorKind::ParseRegex(err), regex_str)), + } +} + +impl<'i, 's> LexWith<'i, &FilterParser<'s>> for Regex { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Self> { + if let Some(c) = input.as_bytes().first() { + match c { + b'"' => lex_regex_from_literal(&input[1..], parser), + b'r' => lex_regex_from_raw_string(&input[1..], parser), + _ => Err((LexErrorKind::ExpectedName("\" or r"), input)), } - }; - match Regex::from_str(®ex_buf) { - Ok(regex) => Ok((regex, input)), - Err(err) => Err((LexErrorKind::ParseRegex(err), regex_str)), + } else { + Err((LexErrorKind::EOF, input)) } } } @@ -81,19 +120,76 @@ impl Serialize for Regex { } } -#[test] -fn test() { - let expr = assert_ok!( - Regex::lex(r#""[a-z"\]]+\d{1,10}\"";"#), - Regex::from_str(r#"[a-z"\]]+\d{1,10}""#).unwrap(), - ";" - ); - - assert_json!(expr, r#"[a-z"\]]+\d{1,10}""#); - - assert_err!( - Regex::lex(r#""abcd\"#), - LexErrorKind::MissingEndingQuote, - "abcd\\" - ); +#[cfg(test)] +mod test { + use super::*; + use crate::Scheme; + + #[test] + fn test() { + let scheme = Scheme::new(); + let expr = assert_ok!( + Regex::lex_with(r#""[a-z"\]]+\d{1,10}\"";"#, &FilterParser::new(&scheme)), + Regex::new( + r#"[a-z"\]]+\d{1,10}""#, + RegexFormat::Literal, + &FilterParser::new(&scheme) + ) + .unwrap(), + ";" + ); + + assert_json!(expr, r#"[a-z"\]]+\d{1,10}""#); + + assert_err!( + Regex::lex_with(r#""abcd\"#, &FilterParser::new(&scheme)), + LexErrorKind::MissingEndingQuote, + "abcd\\" + ); + } + + #[test] + fn test_raw_string() { + let scheme = Scheme::new(); + let expr = assert_ok!( + Regex::lex_with( + r###"r#"[a-z"\]]+\d{1,10}""#;"###, + &FilterParser::new(&scheme) + ), + Regex::new( + r#"[a-z"\]]+\d{1,10}""#, + RegexFormat::Raw(1), + &FilterParser::new(&scheme) + ) + .unwrap(), + ";" + ); + + assert_json!(expr, r#"[a-z"\]]+\d{1,10}""#); + + let expr = assert_ok!( + Regex::lex_with( + r##"r#"(?u)\*\a\f\t\n\r\v\x7F\x{10FFFF}\u007F\u{7F}\U0000007F\U{7F}"#"##, + &FilterParser::new(&scheme) + ), + Regex::new( + r#"(?u)\*\a\f\t\n\r\v\x7F\x{10FFFF}\u007F\u{7F}\U0000007F\U{7F}"#, + RegexFormat::Raw(1), + &FilterParser::new(&scheme) + ) + .unwrap(), + "" + ); + + assert_json!( + expr, + r#"(?u)\*\a\f\t\n\r\v\x7F\x{10FFFF}\u007F\u{7F}\U0000007F\U{7F}"# + ); + + assert_err!( + Regex::lex_with("x", &FilterParser::new(&scheme)), + LexErrorKind::ExpectedName("\" or r"), + "x" + ); + } } diff --git a/engine/src/rhs_types/wildcard.rs b/engine/src/rhs_types/wildcard.rs new file mode 100644 index 00000000..dd14d83c --- /dev/null +++ b/engine/src/rhs_types/wildcard.rs @@ -0,0 +1,356 @@ +use crate::lex::{LexResult, LexWith}; +use crate::rhs_types::bytes::lex_quoted_or_raw_string; +use crate::{Bytes, FilterParser, LexErrorKind}; +use serde::{Serialize, Serializer}; +use std::{ + fmt::{self, Debug, Formatter}, + hash::{Hash, Hasher}, +}; +use thiserror::Error; +use wildcard::WildcardToken; + +#[derive(Eq, PartialEq, Error, Debug)] +pub enum WildcardError { + #[error("invalid wildcard: {0}")] + InvalidWildcard( + #[source] + #[from] + wildcard::WildcardError, + ), + + #[error("wildcard has {count} star metacharacters, but the limit is {limit}")] + TooManyStarMetacharacters { count: usize, limit: usize }, + + #[error("wildcard contains a double star")] + DoubleStar, +} + +fn has_double_star(wildcard: &wildcard::Wildcard<'_>) -> bool { + let mut iter = wildcard.parsed(); + let Some(mut prev) = iter.next() else { + return false; + }; + for next in iter { + if prev == WildcardToken::MetasymbolAny && next == WildcardToken::MetasymbolAny { + return true; + } + prev = next; + } + false +} + +fn validate_wildcard( + wildcard: &wildcard::Wildcard<'_>, + wildcard_star_limit: usize, +) -> Result<(), WildcardError> { + // We can count all metasymbols because we disabled `?`: + let star_count = wildcard.metasymbol_count(); + + if star_count > wildcard_star_limit { + return Err(WildcardError::TooManyStarMetacharacters { + count: star_count, + limit: wildcard_star_limit, + }); + } + + if has_double_star(wildcard) { + return Err(WildcardError::DoubleStar); + } + + Ok(()) +} + +#[derive(Clone)] +pub struct Wildcard { + compiled_wildcard: wildcard::Wildcard<'static>, + /// The original pattern. We keep this to allow correct serialization of the wildcard pattern, + /// since bytes are encoded differently depending on whether they are a valid UTF-8 sequence. + pattern: Bytes, +} + +impl Wildcard { + pub fn new( + pattern: Bytes, + wildcard_star_limit: usize, + ) -> Result, WildcardError> { + let wildcard = wildcard::WildcardBuilder::from_owned(pattern.to_vec()) + .without_one_metasymbol() + .case_insensitive(!STRICT) + .build()?; + + validate_wildcard(&wildcard, wildcard_star_limit)?; + + Ok(Wildcard { + compiled_wildcard: wildcard, + pattern, + }) + } + + /// Returns true if and only if the wildcard matches the input given. + pub fn is_match(&self, input: &[u8]) -> bool { + self.compiled_wildcard.is_match(input) + } + + /// Returns the pattern. + pub fn pattern(&self) -> &Bytes { + &self.pattern + } +} + +impl PartialEq for Wildcard { + fn eq(&self, other: &Wildcard) -> bool { + self.pattern == other.pattern + } +} + +impl Eq for Wildcard {} + +impl Hash for Wildcard { + fn hash(&self, state: &mut H) { + self.pattern.hash(state); + } +} + +impl Debug for Wildcard { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.pattern, f) + } +} + +impl Serialize for Wildcard { + fn serialize(&self, ser: S) -> Result { + self.pattern.serialize(ser) + } +} + +impl<'i, 's, const STRICT: bool> LexWith<'i, &FilterParser<'s>> for Wildcard { + fn lex_with(input: &'i str, parser: &FilterParser<'s>) -> LexResult<'i, Wildcard> { + lex_quoted_or_raw_string(input).and_then(|(pattern, rest)| { + match Wildcard::new(pattern, parser.wildcard_star_limit) { + Ok(wildcard) => Ok((wildcard, rest)), + Err(err) => Err((LexErrorKind::ParseWildcard(err), input)), + } + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{BytesFormat, Scheme}; + + #[test] + fn test_wildcard_eq() { + fn t() { + assert_eq!( + Wildcard::::new( + Bytes::new("a quoted string".as_bytes(), BytesFormat::Quoted), + usize::MAX + ) + .unwrap(), + Wildcard::::new( + Bytes::new("a quoted string".as_bytes(), BytesFormat::Quoted), + usize::MAX + ) + .unwrap(), + ); + + // Even though they are equivalent as far as evaluation goes, they do not have the same + // visual representation: + assert_ne!( + Wildcard::::new( + Bytes::new("a quoted string".as_bytes(), BytesFormat::Quoted), + usize::MAX + ) + .unwrap(), + Wildcard::::new( + Bytes::new("a quoted string".as_bytes(), BytesFormat::Raw(0)), + usize::MAX + ) + .unwrap(), + ); + } + + t::(); + t::(); + } + + #[test] + fn test_wildcard_lex_quoted_string() { + fn t() { + let scheme = Scheme::new(); + + let expr = assert_ok!( + Wildcard::::lex_with(r#""a quoted string";"#, &FilterParser::new(&scheme)), + Wildcard::::new( + Bytes::new("a quoted string".as_bytes(), BytesFormat::Quoted), + usize::MAX + ) + .unwrap(), + ";" + ); + + assert_json!(expr, "a quoted string"); + + assert_err!( + Wildcard::::lex_with(r#""abcd\"#, &FilterParser::new(&scheme)), + LexErrorKind::MissingEndingQuote, + "abcd\\" + ); + } + + t::(); + t::(); + } + + #[test] + fn test_wildcard_lex_raw_string() { + fn t() { + let scheme = Scheme::new(); + + // Note that the `\\xaa` is escaping the `\` at the wildcard-language level, not at the + // wirefilter-language level. + + let expr = assert_ok!( + Wildcard::::lex_with( + r#####"r##"a raw\\xaa r#""# string"##;"#####, + &FilterParser::new(&scheme) + ), + Wildcard::::new( + Bytes::new( + r#####"a raw\\xaa r#""# string"#####.as_bytes(), + BytesFormat::Raw(2), + ), + usize::MAX + ) + .unwrap(), + ";" + ); + + assert_json!(expr, r#####"a raw\\xaa r#""# string"#####); + + assert_err!( + Wildcard::::lex_with(r#####"r#"abc"#####, &FilterParser::new(&scheme)), + LexErrorKind::MissingEndingQuote, + "#\"abc" + ); + } + + t::(); + t::(); + } + + #[test] + fn test_wildcard_lex_escape_quoted_string_invalid_utf8() { + fn t() { + let scheme = Scheme::new(); + + let bytes = [ + "a quoted ".as_bytes().to_vec(), + vec![0xaa, 0x22], + " string".as_bytes().to_vec(), + ] + .concat(); + + let expr = assert_ok!( + Wildcard::::lex_with( + r#####""a quoted \xaa\x22 string";"#####, + &FilterParser::new(&scheme) + ), + Wildcard::::new( + Bytes::new(bytes.into_boxed_slice(), BytesFormat::Quoted), + usize::MAX + ) + .unwrap(), + ";" + ); + + assert_json!( + expr, + [ + 97, 32, 113, 117, 111, 116, 101, 100, 32, 170, 34, 32, 115, 116, 114, 105, 110, + 103 + ] + ); + } + + t::(); + t::(); + } + + #[test] + fn test_wildcard_lex_reject_bytes_syntax() { + fn t() { + let scheme = Scheme::new(); + + assert_err!( + Wildcard::::lex_with("61:20:71:75:6F:74", &FilterParser::new(&scheme)), + LexErrorKind::ExpectedName("\" or r"), + "61:20:71:75:6F:74" + ); + } + + t::(); + t::(); + } + + #[test] + fn test_wildcard_reject_invalid_wildcard() { + fn t() { + let scheme = Scheme::new(); + + assert!(matches!( + Wildcard::::lex_with(r#"r"*foo\bar*""#, &FilterParser::new(&scheme)) + .map_err(|e| e.0), + Err(LexErrorKind::ParseWildcard(WildcardError::InvalidWildcard( + _ + ))) + )); + } + + t::(); + t::(); + } + + #[test] + fn test_wildcard_star_limit() { + fn t() { + let scheme = Scheme::new(); + let mut parser = FilterParser::new(&scheme); + + parser.wildcard_set_star_limit(3); + + assert!(Wildcard::::lex_with("\"*_*_*\"", &parser).is_ok()); + + assert_eq!( + Wildcard::::lex_with("\"*_*_*_*\"", &parser).map_err(|e| e.0), + Err(LexErrorKind::ParseWildcard( + WildcardError::TooManyStarMetacharacters { count: 4, limit: 3 } + )), + ); + } + + t::(); + t::(); + } + + #[test] + fn test_wildcard_reject_double_star() { + fn t() { + let scheme = Scheme::new(); + + assert!( + Wildcard::::lex_with("\"*foo*bar*\"", &FilterParser::new(&scheme)).is_ok() + ); + + assert_eq!( + Wildcard::::lex_with("\"*foo**bar*\"", &FilterParser::new(&scheme)) + .map_err(|e| e.0), + Err(LexErrorKind::ParseWildcard(WildcardError::DoubleStar)), + ); + } + + t::(); + t::(); + } +} diff --git a/engine/src/scheme.rs b/engine/src/scheme.rs index 753fffff..b89fd057 100644 --- a/engine/src/scheme.rs +++ b/engine/src/scheme.rs @@ -1,22 +1,107 @@ use crate::{ - ast::FilterAst, - functions::Function, - lex::{complete, expect, span, take_while, LexErrorKind, LexResult, LexWith}, - types::{GetType, Type}, + ast::parse::{FilterParser, ParseError}, + ast::{FilterAst, FilterValueAst}, + functions::FunctionDefinition, + lex::{expect, span, take_while, Lex, LexErrorKind, LexResult, LexWith}, + list_matcher::ListDefinition, + types::{GetType, RhsValue, Type}, }; -use failure::Fail; use fnv::FnvBuildHasher; -use indexmap::map::{Entry, IndexMap}; -use serde::{Deserialize, Serialize, Serializer}; +use serde::ser::SerializeMap; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::collections::hash_map::Entry; +use std::sync::Arc; use std::{ - cmp::{max, min}, - error::Error, - fmt::{self, Debug, Display, Formatter}, + collections::HashMap, + convert::TryFrom, + fmt::{self, Debug, Formatter}, + hash::{Hash, Hasher}, + iter::Iterator, ptr, }; +use thiserror::Error; -#[derive(PartialEq, Eq, Clone, Copy)] -pub(crate) struct Field<'s> { +/// An error that occurs if two underlying [schemes](struct@Scheme) +/// don't match. +#[derive(Debug, PartialEq, Eq, Error)] +#[error("underlying schemes do not match")] +pub struct SchemeMismatchError; + +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize)] +#[serde(tag = "kind", content = "value")] +/// FieldIndex is an enum with variants [`ArrayIndex(usize)`], +/// representing an index into an Array, or `[MapKey(String)`], +/// representing a key into a Map. +/// +/// ``` +/// #[allow(dead_code)] +/// enum FieldIndex { +/// ArrayIndex(u32), +/// MapKey(String), +/// } +/// ``` +pub enum FieldIndex { + /// Index into an Array + ArrayIndex(u32), + + /// Key into a Map + MapKey(String), + + /// Map each element by applying a function or a comparison + MapEach, +} + +impl<'i> Lex<'i> for FieldIndex { + fn lex(input: &'i str) -> LexResult<'i, Self> { + if let Ok(input) = expect(input, "*") { + return Ok((FieldIndex::MapEach, input)); + } + + // The token inside an [] can be either an integer index into an Array + // or a string key into a Map. The token is a key into a Map if it + // starts and ends with "\"", otherwise an integer index or an error. + let (rhs, rest) = match expect(input, "\"") { + Ok(_) => RhsValue::lex_with(input, Type::Bytes), + Err(_) => RhsValue::lex_with(input, Type::Int).map_err(|_| { + ( + LexErrorKind::ExpectedLiteral( + "expected quoted utf8 string or positive integer", + ), + input, + ) + }), + }?; + + match rhs { + RhsValue::Int(i) => match u32::try_from(i) { + Ok(u) => Ok((FieldIndex::ArrayIndex(u), rest)), + Err(_) => Err(( + LexErrorKind::ExpectedLiteral("expected positive integer as index"), + input, + )), + }, + RhsValue::Bytes(b) => match String::from_utf8(b.to_vec()) { + Ok(s) => Ok((FieldIndex::MapKey(s), rest)), + Err(_) => Err((LexErrorKind::ExpectedLiteral("expected utf8 string"), input)), + }, + _ => unreachable!(), + } + } +} + +/// An error when an index is invalid for a type. +#[derive(Debug, PartialEq, Eq, Error)] +#[error("cannot access index {index:?} for type {actual:?}")] +pub struct IndexAccessError { + /// Index that could not be accessed. + pub index: FieldIndex, + /// Provided value type. + pub actual: Type, +} + +#[derive(PartialEq, Eq, Clone, Copy, Hash)] +/// A structure to represent a field inside a [`Scheme`](struct@Scheme). +pub struct Field<'s> { scheme: &'s Scheme, index: usize, } @@ -34,166 +119,264 @@ impl<'s> Debug for Field<'s> { } impl<'i, 's> LexWith<'i, &'s Scheme> for Field<'s> { - fn lex_with(mut input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { - let initial_input = input; + fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { + match Identifier::lex_with(input, scheme) { + Ok((Identifier::Field(f), rest)) => Ok((f, rest)), + Ok((Identifier::Function(_), rest)) => Err(( + LexErrorKind::UnknownField(UnknownFieldError), + span(input, rest), + )), + Err((LexErrorKind::UnknownIdentifier, s)) => { + Err((LexErrorKind::UnknownField(UnknownFieldError), s)) + } + Err(err) => Err(err), + } + } +} - loop { - input = take_while(input, "identifier character", |c| { - c.is_ascii_alphanumeric() || c == '_' - })? - .1; +impl<'s> Field<'s> { + /// Returns the field's name as recorded in the [`Scheme`](struct@Scheme). + #[inline] + pub fn name(&self) -> &'s str { + &self.scheme.fields[self.index].0 + } - match expect(input, ".") { - Ok(rest) => input = rest, - Err(_) => break, - }; - } + /// Get the field's index in the [`Scheme`](struct@Scheme) identifier's list. + #[inline] + pub fn index(&self) -> usize { + self.index + } - let name = span(initial_input, input); + /// Returns the [`Scheme`](struct@Scheme) to which this field belongs to. + #[inline] + pub fn scheme(&self) -> &'s Scheme { + self.scheme + } +} - let field = scheme - .get_field_index(name) - .map_err(|err| (LexErrorKind::UnknownField(err), name))?; +impl<'s> GetType for Field<'s> { + #[inline] + fn get_type(&self) -> Type { + self.scheme.fields[self.index].1 + } +} - Ok((field, input)) +#[derive(PartialEq, Eq, Clone, Copy, Hash)] +/// A structure to represent a function inside a [`Scheme`](struct@Scheme). +pub struct Function<'s> { + scheme: &'s Scheme, + index: usize, +} + +impl<'s> Serialize for Function<'s> { + fn serialize(&self, ser: S) -> Result { + self.name().serialize(ser) } } -impl<'s> Field<'s> { +impl<'s> Debug for Function<'s> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +impl<'i, 's> LexWith<'i, &'s Scheme> for Function<'s> { + fn lex_with(input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { + match Identifier::lex_with(input, scheme) { + Ok((Identifier::Function(f), rest)) => Ok((f, rest)), + Ok((Identifier::Field(_), rest)) => Err(( + LexErrorKind::UnknownFunction(UnknownFunctionError), + span(input, rest), + )), + Err((LexErrorKind::UnknownIdentifier, s)) => { + Err((LexErrorKind::UnknownFunction(UnknownFunctionError), s)) + } + Err(err) => Err(err), + } + } +} + +impl<'s> Function<'s> { + /// Returns the function's name as recorded in the [`Scheme`](struct@Scheme). + #[inline] pub fn name(&self) -> &'s str { - self.scheme.fields.get_index(self.index).unwrap().0 + &self.scheme.functions[self.index].0 } + /// Get the function's index in the [`Scheme`](struct@Scheme) identifier's list. + #[inline] pub fn index(&self) -> usize { self.index } + /// Returns the [`Scheme`](struct@Scheme) to which this function belongs to. + #[inline] pub fn scheme(&self) -> &'s Scheme { self.scheme } + + #[inline] + pub(crate) fn as_definition(&self) -> &'s dyn FunctionDefinition { + &*self.scheme.functions[self.index].1 + } } -impl<'s> GetType for Field<'s> { - fn get_type(&self) -> Type { - *self.scheme.fields.get_index(self.index).unwrap().1 +/// An enum to represent an entry inside a [`Scheme`](struct@Scheme). +/// It can be either a [`Field`](struct@Field) or a [`Function`](struct@Function). +#[derive(Debug)] +pub enum Identifier<'s> { + /// Identifier is a [`Field`](struct@Field) + Field(Field<'s>), + /// Identifier is a [`Function`](struct@Function) + Function(Function<'s>), +} + +impl<'s> Identifier<'s> { + /// Converts the identifier into a [`Field`](struct@Field) if possible. + pub fn into_field(self) -> Option> { + match self { + Self::Field(f) => Some(f), + _ => None, + } + } + + /// Converts the identifier into a [`Function`](struct@Function) if possible. + pub fn into_function(self) -> Option> { + match self { + Self::Function(f) => Some(f), + _ => None, + } + } +} + +impl<'i, 's> LexWith<'i, &'s Scheme> for Identifier<'s> { + fn lex_with(mut input: &'i str, scheme: &'s Scheme) -> LexResult<'i, Self> { + let initial_input = input; + + loop { + input = take_while(input, "identifier character", |c| { + c.is_ascii_alphanumeric() || c == '_' + })? + .1; + + match expect(input, ".") { + Ok(rest) => input = rest, + Err(_) => break, + }; + } + + let name = span(initial_input, input); + + let field = scheme + .get(name) + .ok_or((LexErrorKind::UnknownIdentifier, name))?; + + Ok((field, input)) } } /// An error that occurs if an unregistered field name was queried from a /// [`Scheme`](struct@Scheme). -#[derive(Debug, PartialEq, Fail)] -#[fail(display = "unknown field")] +#[derive(Debug, PartialEq, Eq, Error)] +#[error("unknown field")] pub struct UnknownFieldError; /// An error that occurs if an unregistered function name was queried from a /// [`Scheme`](struct@Scheme). -#[derive(Debug, PartialEq, Fail)] -#[fail(display = "unknown function")] +#[derive(Debug, PartialEq, Eq, Error)] +#[error("unknown function")] pub struct UnknownFunctionError; /// An error that occurs when previously defined field gets redefined. -#[derive(Debug, PartialEq, Fail)] -#[fail(display = "attempt to redefine field {}", _0)] +#[derive(Debug, PartialEq, Eq, Error)] +#[error("attempt to redefine field {0}")] pub struct FieldRedefinitionError(String); /// An error that occurs when previously defined function gets redefined. -#[derive(Debug, PartialEq, Fail)] -#[fail(display = "attempt to redefine function {}", _0)] +#[derive(Debug, PartialEq, Eq, Error)] +#[error("attempt to redefine function {0}")] pub struct FunctionRedefinitionError(String); -#[derive(Debug, PartialEq, Fail)] -pub enum ItemRedefinitionError { - #[fail(display = "{}", _0)] - Field(#[cause] FieldRedefinitionError), +/// An error that occurs when trying to redefine a field or function. +#[derive(Debug, PartialEq, Eq, Error)] +pub enum IdentifierRedefinitionError { + /// An error that occurs when previously defined field gets redefined. + #[error("{0}")] + Field(#[source] FieldRedefinitionError), - #[fail(display = "{}", _0)] - Function(#[cause] FunctionRedefinitionError), + /// An error that occurs when previously defined function gets redefined. + #[error("{0}")] + Function(#[source] FunctionRedefinitionError), } -/// An opaque filter parsing error associated with the original input. -/// -/// For now, you can just print it in a debug or a human-readable fashion. -#[derive(Debug, PartialEq)] -pub struct ParseError<'i> { - kind: LexErrorKind, - input: &'i str, - line_number: usize, - span_start: usize, - span_len: usize, -} - -impl<'i> Error for ParseError<'i> {} - -impl<'i> ParseError<'i> { - pub(crate) fn new(mut input: &'i str, (kind, span): (LexErrorKind, &'i str)) -> Self { - let mut span_start = span.as_ptr() as usize - input.as_ptr() as usize; - - let (line_number, line_start) = input[..span_start] - .match_indices('\n') - .map(|(pos, _)| pos + 1) - .scan(0, |line_number, line_start| { - *line_number += 1; - Some((*line_number, line_start)) - }) - .last() - .unwrap_or_default(); - - input = &input[line_start..]; - - span_start -= line_start; - let mut span_len = span.len(); - - if let Some(line_end) = input.find('\n') { - input = &input[..line_end]; - span_len = min(span_len, line_end - span_start); - } +#[derive(Clone, Copy, Debug)] +enum SchemeItem { + Field(usize), + Function(usize), +} - ParseError { - kind, - input, - line_number, - span_start, - span_len, - } +impl From for Box { + fn from(func: T) -> Box { + Box::new(func) } } -impl<'i> Display for ParseError<'i> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - writeln!( - f, - "Filter parsing error ({}:{}):", - self.line_number + 1, - self.span_start + 1 - )?; +/// A structure to represent a list inside a [`scheme`](struct.Scheme.html). +/// +/// See [`Scheme::get_list`](struct.Scheme.html#method.get_list). +#[derive(PartialEq, Eq, Clone, Copy, Hash)] +pub struct List<'s> { + scheme: &'s Scheme, + index: usize, +} - writeln!(f, "{}", self.input)?; +impl<'s> List<'s> { + pub(crate) fn index(&self) -> usize { + self.index + } - for _ in 0..self.span_start { - write!(f, " ")?; - } + pub(crate) fn scheme(&self) -> &'s Scheme { + self.scheme + } - for _ in 0..max(1, self.span_len) { - write!(f, "^")?; - } + pub(crate) fn definition(&self) -> &'s dyn ListDefinition { + &*self.scheme.lists[self.index].1 + } +} - writeln!(f, " {}", self.kind)?; +impl<'s> Debug for List<'s> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.scheme.lists[self.index]) + } +} - Ok(()) +impl<'s> GetType for List<'s> { + #[inline] + fn get_type(&self) -> Type { + self.scheme.lists[self.index].0 } } +/// An error that occurs when previously defined list gets redefined. +#[derive(Debug, PartialEq, Eq, Error)] +#[error("attempt to redefine list for type {0:?}")] +pub struct ListRedefinitionError(Type); + +type IdentifierName = Arc; + /// The main registry for fields and their associated types. /// /// This is necessary to provide typechecking for runtime values provided /// to the [execution context](::ExecutionContext) and also to aid parser /// in ambiguous contexts. -#[derive(Default, Deserialize)] -#[serde(transparent)] +#[derive(Default, Debug)] pub struct Scheme { - fields: IndexMap, - #[serde(skip)] - functions: IndexMap, + fields: Vec<(IdentifierName, Type)>, + functions: Vec<(IdentifierName, Box)>, + items: HashMap, + + list_types: HashMap, + lists: Vec<(Type, Box)>, } impl PartialEq for Scheme { @@ -204,103 +387,215 @@ impl PartialEq for Scheme { impl Eq for Scheme {} +impl Hash for Scheme { + fn hash(&self, state: &mut H) { + (self as *const Scheme).hash(state); + } +} + +impl Serialize for Scheme { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.field_count()))?; + for f in self.fields() { + map.serialize_entry(f.name(), &f.get_type())?; + } + map.end() + } +} + +impl<'de> Deserialize<'de> for Scheme { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use serde::de::Error; + + let mut scheme = Scheme::new(); + let map: HashMap = HashMap::::deserialize(deserializer)?; + for (name, ty) in map { + scheme.add_field(&name, ty).map_err(D::Error::custom)?; + } + Ok(scheme) + } +} + impl<'s> Scheme { /// Creates a new scheme. pub fn new() -> Self { Default::default() } - /// Creates a new scheme with capacity for `n` fields. - pub fn with_capacity(n: usize) -> Self { - Scheme { - fields: IndexMap::with_capacity_and_hasher(n, FnvBuildHasher::default()), - functions: Default::default(), - } + /// Returns the [`identifier`](struct@Identifier) with name [`name`] + pub fn get(&'s self, name: &str) -> Option> { + self.items.get(name).map(move |item| match *item { + SchemeItem::Field(index) => Identifier::Field(Field { + scheme: self, + index, + }), + SchemeItem::Function(index) => Identifier::Function(Function { + scheme: self, + index, + }), + }) } /// Registers a field and its corresponding type. - pub fn add_field(&mut self, name: String, ty: Type) -> Result<(), ItemRedefinitionError> { - if self.functions.contains_key(&name) { - return Err(ItemRedefinitionError::Function(FunctionRedefinitionError( - name, - ))); - }; - match self.fields.entry(name) { - Entry::Occupied(entry) => Err(ItemRedefinitionError::Field(FieldRedefinitionError( - entry.key().to_string(), - ))), + pub fn add_field>( + &mut self, + name: N, + ty: Type, + ) -> Result<(), IdentifierRedefinitionError> { + match self.items.entry(name.as_ref().into()) { + Entry::Occupied(entry) => match entry.get() { + SchemeItem::Field(_) => Err(IdentifierRedefinitionError::Field( + FieldRedefinitionError(entry.key().to_string()), + )), + SchemeItem::Function(_) => Err(IdentifierRedefinitionError::Function( + FunctionRedefinitionError(entry.key().to_string()), + )), + }, Entry::Vacant(entry) => { - entry.insert(ty); + let index = self.fields.len(); + self.fields.push((entry.key().clone(), ty)); + entry.insert(SchemeItem::Field(index)); Ok(()) } } } - /// Registers a series of fields from an iterable, reporting any conflicts. - pub fn try_from_iter( - iter: impl IntoIterator, - ) -> Result { - let iter = iter.into_iter(); - let (low, _) = iter.size_hint(); - let mut scheme = Scheme::with_capacity(low); - for (name, value) in iter { - scheme.add_field(name, value)?; + /// Returns the [`field`](struct@Field) with name [`name`] + pub fn get_field(&'s self, name: &str) -> Result, UnknownFieldError> { + match self.get(name) { + Some(Identifier::Field(f)) => Ok(f), + _ => Err(UnknownFieldError), } - Ok(scheme) } - pub(crate) fn get_field_index(&'s self, name: &str) -> Result, UnknownFieldError> { - match self.fields.get_full(name) { - Some((index, ..)) => Ok(Field { - scheme: self, - index, - }), - None => Err(UnknownFieldError), - } + /// Iterates over fields registered in the [`scheme`](struct@Scheme) + #[inline] + pub fn fields(&'s self) -> impl ExactSizeIterator> + 's { + (0..self.fields.len()).map(|index| Field { + scheme: self, + index, + }) } - pub(crate) fn get_field_count(&self) -> usize { + /// Returns the number of fields in the [`scheme`](struct@Scheme) + #[inline] + pub fn field_count(&self) -> usize { self.fields.len() } + /// Returns the number of functions in the [`scheme`](struct@Scheme) + #[inline] + pub fn function_count(&self) -> usize { + self.functions.len() + } + /// Registers a function - pub fn add_function( + pub fn add_function>( &mut self, - name: String, - function: Function, - ) -> Result<(), ItemRedefinitionError> { - if self.fields.contains_key(&name) { - return Err(ItemRedefinitionError::Field(FieldRedefinitionError(name))); - }; - match self.functions.entry(name) { - Entry::Occupied(entry) => Err(ItemRedefinitionError::Function( - FunctionRedefinitionError(entry.key().to_string()), - )), + name: N, + function: impl Into>, + ) -> Result<(), IdentifierRedefinitionError> { + match self.items.entry(name.as_ref().into()) { + Entry::Occupied(entry) => match entry.get() { + SchemeItem::Field(_) => Err(IdentifierRedefinitionError::Field( + FieldRedefinitionError(entry.key().to_string()), + )), + SchemeItem::Function(_) => Err(IdentifierRedefinitionError::Function( + FunctionRedefinitionError(entry.key().to_string()), + )), + }, Entry::Vacant(entry) => { - entry.insert(function); + let index = self.functions.len(); + self.functions.push((entry.key().clone(), function.into())); + entry.insert(SchemeItem::Function(index)); Ok(()) } } } - /// Registers a list of functions - pub fn add_functions(&mut self, functions: I) -> Result<(), ItemRedefinitionError> - where - I: IntoIterator, - { - for (name, func) in functions { - self.add_function(name, func)?; + /// Returns the [`function`](struct@Function) with name [`name`] + pub fn get_function(&'s self, name: &str) -> Result, UnknownFunctionError> { + match self.get(name) { + Some(Identifier::Function(f)) => Ok(f), + _ => Err(UnknownFunctionError), } - Ok(()) } - pub(crate) fn get_function(&'s self, name: &str) -> Result<&'s Function, UnknownFunctionError> { - self.functions.get(name).ok_or(UnknownFunctionError) + /// Iterates over functions registered in the [`scheme`](struct@Scheme) + #[inline] + pub fn functions(&'s self) -> impl ExactSizeIterator> + 's { + (0..self.functions.len()).map(|index| Function { + scheme: self, + index, + }) } - /// Parses a filter into an AST form. + /// Parses a filter expression into an AST form. pub fn parse<'i>(&'s self, input: &'i str) -> Result, ParseError<'i>> { - complete(FilterAst::lex_with(input.trim(), self)).map_err(|err| ParseError::new(input, err)) + FilterParser::new(self).parse(input) + } + + /// Parses a value expression into an AST form. + pub fn parse_value<'i>(&'s self, input: &'i str) -> Result, ParseError<'i>> { + FilterParser::new(self).parse_value(input) + } + + /// Returns the number of lists in the [`scheme`](struct@Scheme) + #[inline] + pub fn list_count(&self) -> usize { + self.lists.len() + } + + /// Registers a new [`list`](trait.ListDefinition.html) for a given [`type`](enum.Type.html). + pub fn add_list( + &mut self, + ty: Type, + definition: Box, + ) -> Result<(), ListRedefinitionError> { + match self.list_types.entry(ty) { + Entry::Occupied(entry) => Err(ListRedefinitionError(*entry.key())), + Entry::Vacant(entry) => { + let index = self.lists.len(); + self.lists.push((ty, definition)); + entry.insert(index); + Ok(()) + } + } + } + + /// Returns the [`list`](struct.List.html) for a given [`type`](enum.Type.html). + pub fn get_list(&self, ty: &Type) -> Option> { + self.list_types.get(ty).map(move |index| List { + scheme: self, + index: *index, + }) + } + + /// Iterates over all registered [`lists`](trait.ListDefinition.html). + pub fn lists(&self) -> impl ExactSizeIterator> { + (0..self.lists.len()).map(|index| List { + scheme: self, + index, + }) + } +} + +impl> FromIterator<(N, Type)> for Scheme { + fn from_iter>(iter: T) -> Self { + let mut scheme = Scheme::new(); + for (name, ty) in iter { + scheme + .add_field(name.as_ref(), ty) + .map_err(|err| err.to_string()) + .unwrap(); + } + scheme } } @@ -308,34 +603,33 @@ impl<'s> Scheme { /// contents. #[macro_export] macro_rules! Scheme { - ($($ns:ident $(. $field:ident)*: $ty:ident),* $(,)*) => { - $crate::Scheme::try_from_iter( - [$( - ( - concat!(stringify!($ns) $(, ".", stringify!($field))*), - $crate::Type::$ty - ) - ),*] - .iter() - .map(|&(k, v)| (k.to_owned(), v)), - ) - // Treat duplciations in static schemes as a developer's mistake. - .unwrap_or_else(|err| panic!("{}", err)) + ($($ns:ident $(. $field:ident)*: $ty:ident $(($subty:tt $($rest:tt)*))?),* $(,)*) => { + $crate::Scheme::from_iter([$( + ( + concat!(stringify!($ns) $(, ".", stringify!($field))*), + Scheme!($ty $(($subty $($rest)*))?), + ) + ),*]) }; + ($ty:ident $(($subty:tt $($rest:tt)*))?) => {$crate::Type::$ty$(((Scheme!($subty $($rest)*)).into()))?}; } #[test] fn test_parse_error() { + use crate::types::TypeMismatchError; use indoc::indoc; - let scheme = &Scheme! { num: Int }; + let scheme = &Scheme! { + num: Int, + arr: Array(Bool), + }; { let err = scheme.parse("xyz").unwrap_err(); assert_eq!( err, ParseError { - kind: LexErrorKind::UnknownField(UnknownFieldError), + kind: LexErrorKind::UnknownIdentifier, input: "xyz", line_number: 0, span_start: 0, @@ -348,7 +642,7 @@ fn test_parse_error() { r#" Filter parsing error (1:1): xyz - ^^^ unknown field + ^^^ unknown identifier "# ) ); @@ -359,7 +653,7 @@ fn test_parse_error() { assert_eq!( err, ParseError { - kind: LexErrorKind::UnknownField(UnknownFieldError), + kind: LexErrorKind::UnknownIdentifier, input: "xyz", line_number: 0, span_start: 0, @@ -372,7 +666,7 @@ fn test_parse_error() { r#" Filter parsing error (1:1): xyz - ^^^ unknown field + ^^^ unknown identifier "# ) ); @@ -383,7 +677,7 @@ fn test_parse_error() { assert_eq!( err, ParseError { - kind: LexErrorKind::UnknownField(UnknownFieldError), + kind: LexErrorKind::UnknownIdentifier, input: " xyz", line_number: 2, span_start: 4, @@ -396,7 +690,7 @@ fn test_parse_error() { r#" Filter parsing error (3:5): xyz - ^^^ unknown field + ^^^ unknown identifier "# ) ); @@ -433,6 +727,533 @@ fn test_parse_error() { ) ); } + + { + let err = scheme + .parse(indoc!( + r#" + arr and arr + "# + )) + .unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Bool.into(), + actual: Type::Array(Type::Bool.into()), + }), + input: "arr and arr", + line_number: 0, + span_start: 11, + span_len: 0, + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:12): + arr and arr + ^ expected value of type {Type(Bool)}, but got Array(Bool) + "# + ) + ); + } + + { + let err = scheme.parse_value(indoc!(r" arr[*] ")).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::TypeMismatch(TypeMismatchError { + expected: Type::Bool.into(), + actual: Type::Array(Type::Bool.into()), + }), + input: " arr[*] ", + line_number: 0, + span_start: 1, + span_len: 6, + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:2): + arr[*] + ^^^^^^ expected value of type {Type(Bool)}, but got Array(Bool) + "# + ) + ); + } +} + +#[test] +fn test_parse_error_in_op() { + use cidr::errors::NetworkParseError; + use indoc::indoc; + use std::{net::IpAddr, str::FromStr}; + + let scheme = &Scheme! { + num: Int, + bool: Bool, + str: Bytes, + ip: Ip, + str_arr: Array(Bytes), + str_map: Map(Bytes), + }; + + { + let err = scheme.parse("bool in {0}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::EOF, + input: "bool in {0}", + line_number: 0, + span_start: 4, + span_len: 7 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:5): + bool in {0} + ^^^^^^^ unrecognised input + "# + ) + ); + } + + { + let err = scheme.parse("bool in {127.0.0.1}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::EOF, + input: "bool in {127.0.0.1}", + line_number: 0, + span_start: 4, + span_len: 15 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:5): + bool in {127.0.0.1} + ^^^^^^^^^^^^^^^ unrecognised input + "# + ) + ); + } + + { + let err = scheme.parse("bool in {\"test\"}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::EOF, + input: "bool in {\"test\"}", + line_number: 0, + span_start: 4, + span_len: 12 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:5): + bool in {"test"} + ^^^^^^^^^^^^ unrecognised input + "# + ) + ); + } + + { + let err = scheme.parse("num in {127.0.0.1}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ExpectedName("digit"), + input: "num in {127.0.0.1}", + line_number: 0, + span_start: 11, + span_len: 7 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:12): + num in {127.0.0.1} + ^^^^^^^ expected digit + "# + ) + ); + } + + { + let err = scheme.parse("num in {\"test\"}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ExpectedName("digit"), + input: "num in {\"test\"}", + line_number: 0, + span_start: 8, + span_len: 7 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:9): + num in {"test"} + ^^^^^^^ expected digit + "# + ) + ); + } + { + let err = scheme.parse("ip in {666}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ParseNetwork( + IpAddr::from_str("666") + .map_err(NetworkParseError::AddrParseError) + .unwrap_err() + ), + input: "ip in {666}", + line_number: 0, + span_start: 7, + span_len: 3 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:8): + ip in {666} + ^^^ couldn't parse address in network: invalid IP address syntax + "# + ) + ); + } + { + let err = scheme.parse("ip in {\"test\"}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ExpectedName("IP address character"), + input: "ip in {\"test\"}", + line_number: 0, + span_start: 7, + span_len: 7 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:8): + ip in {"test"} + ^^^^^^^ expected IP address character + "# + ) + ); + } + + { + let err = scheme.parse("str in {0}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ParseInt { + err: u8::from_str_radix("0}", 16).unwrap_err(), + radix: 16, + }, + input: "str in {0}", + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:9): + str in {0} + ^^ invalid digit found in string while parsing with radix 16 + "# + ) + ); + } + + { + let err = scheme.parse("str in {127.0.0.1}").unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ExpectedName("byte separator"), + input: "str in {127.0.0.1}", + line_number: 0, + span_start: 10, + span_len: 1 + } + ); + assert_eq!( + err.to_string(), + indoc!( + r#" + Filter parsing error (1:11): + str in {127.0.0.1} + ^ expected byte separator + "# + ) + ); + } + + for pattern in &["0", "127.0.0.1", "\"test\""] { + { + let filter = format!("str_arr in {{{pattern}}}"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Array(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + + { + let filter = format!("str_map in {{{pattern}}}"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Map(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + } +} + +#[test] +fn test_parse_error_ordering_op() { + let scheme = &Scheme! { + num: Int, + bool: Bool, + str: Bytes, + ip: Ip, + str_arr: Array(Bytes), + str_map: Map(Bytes), + }; + + for op in &["eq", "ne", "ge", "le", "gt", "lt"] { + { + let filter = format!("num {op} 127.0.0.1"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::EOF, + input: &filter, + line_number: 0, + span_start: 10, + span_len: 6 + } + ); + } + + { + let filter = format!("num {op} \"test\""); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ExpectedName("digit"), + input: &filter, + line_number: 0, + span_start: 7, + span_len: 6 + } + ); + } + { + let filter = format!("str {op} 0"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::CountMismatch { + name: "character", + actual: 1, + expected: 2, + }, + input: &filter, + line_number: 0, + span_start: 7, + span_len: 1 + } + ); + } + + { + let filter = format!("str {op} 256"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ExpectedName("byte separator"), + input: &filter, + line_number: 0, + span_start: 9, + span_len: 1 + } + ); + } + + { + let filter = format!("str {op} 127.0.0.1"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::ExpectedName("byte separator"), + input: &filter, + line_number: 0, + span_start: 9, + span_len: 1, + } + ); + } + + { + let filter = format!("str_arr {op} 0"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Array(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + + { + let filter = format!("str_arr {op} \"test\""); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Array(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + + { + let filter = format!("str_arr {op} 127.0.0.1"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Array(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + + { + let filter = format!("str_map {op} 0"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Map(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + + { + let filter = format!("str_map {op} \"test\""); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Map(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + + { + let filter = format!("str_map {op} 127.0.0.1"); + let err = scheme.parse(&filter).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: LexErrorKind::UnsupportedOp { + lhs_type: Type::Map(Type::Bytes.into()) + }, + input: &filter, + line_number: 0, + span_start: 8, + span_len: 2 + } + ); + } + } } #[test] @@ -441,23 +1262,24 @@ fn test_field() { x: Bytes, x.y.z0: Int, is_TCP: Bool, + map: Map(Bytes) }; assert_ok!( Field::lex_with("x;", scheme), - scheme.get_field_index("x").unwrap(), + scheme.get_field("x").unwrap(), ";" ); assert_ok!( Field::lex_with("x.y.z0-", scheme), - scheme.get_field_index("x.y.z0").unwrap(), + scheme.get_field("x.y.z0").unwrap(), "-" ); assert_ok!( Field::lex_with("is_TCP", scheme), - scheme.get_field_index("is_TCP").unwrap(), + scheme.get_field("is_TCP").unwrap(), "" ); @@ -491,7 +1313,45 @@ fn test_field_type_override() { let mut scheme = Scheme! { foo: Int }; assert_eq!( - scheme.add_field("foo".into(), Type::Bytes).unwrap_err(), - ItemRedefinitionError::Field(FieldRedefinitionError("foo".into())) + scheme.add_field("foo", Type::Bytes).unwrap_err(), + IdentifierRedefinitionError::Field(FieldRedefinitionError("foo".into())) ) } + +#[test] +fn test_field_lex_indexes() { + assert_ok!(FieldIndex::lex("0"), FieldIndex::ArrayIndex(0)); + assert_err!( + FieldIndex::lex("-1"), + LexErrorKind::ExpectedLiteral("expected positive integer as index"), + "-1" + ); + + assert_ok!( + FieldIndex::lex("\"cookies\""), + FieldIndex::MapKey("cookies".into()) + ); +} + +#[test] +fn test_scheme_iter_fields() { + let scheme = &Scheme! { + x: Bytes, + x.y.z0: Int, + is_TCP: Bool, + map: Map(Bytes) + }; + + let mut fields = scheme.fields().collect::>(); + fields.sort_by(|f1, f2| f1.name().partial_cmp(f2.name()).unwrap()); + + assert_eq!( + fields, + vec![ + scheme.get_field("is_TCP").unwrap(), + scheme.get_field("map").unwrap(), + scheme.get_field("x").unwrap(), + scheme.get_field("x.y.z0").unwrap(), + ] + ); +} diff --git a/engine/src/searcher.rs b/engine/src/searcher.rs new file mode 100644 index 00000000..6457471b --- /dev/null +++ b/engine/src/searcher.rs @@ -0,0 +1,59 @@ +use memmem::Searcher; +use std::mem::ManuallyDrop; + +pub struct EmptySearcher; + +impl EmptySearcher { + #[inline] + pub fn search_in(&self, _haystack: &[u8]) -> bool { + true + } +} + +pub struct TwoWaySearcher { + // This is an `Box` whose lifetime must exceed `searcher`. + needle: *mut [u8], + + // We need this because `memmem::TwoWaySearcher` wants a lifetime for the data it refers to, but + // we don't want to tie it to the lifetime of `TwoWaySearcher`, since our data is heap-allocated + // and is guaranteed to deref to the same address across moves of the container. Hence, we use + // `static` as a substitute lifetime and it points to the same the data as `needle`. + searcher: ManuallyDrop>, +} + +// This is safe because we are only ever accessing `needle` mutably during `Drop::drop` +// which is statically enforced by the compiler to be called once when the searcher is +// not in used anymore. +unsafe impl Send for TwoWaySearcher {} +// This is safe because we are only ever accessing `needle` mutably during `Drop::drop` +// which is statically enforced by the compiler to be called once when the searcher is +// not in used anymore. +unsafe impl Sync for TwoWaySearcher {} + +impl TwoWaySearcher { + pub fn new(needle: Box<[u8]>) -> Self { + let needle = Box::into_raw(needle); + // Convert needle's contents to the static lifetime. + let needle_static = unsafe { &*needle }; + + TwoWaySearcher { + needle, + searcher: ManuallyDrop::new(memmem::TwoWaySearcher::new(needle_static)), + } + } + + #[inline] + pub fn search_in(&self, haystack: &[u8]) -> bool { + self.searcher.search_in(haystack).is_some() + } +} + +impl Drop for TwoWaySearcher { + fn drop(&mut self) { + unsafe { + // Explicitly drop `searcher` first in case it needs `needle` to be alive. + ManuallyDrop::drop(&mut self.searcher); + drop(Box::from_raw(self.needle)); + } + } +} diff --git a/engine/src/strict_partial_ord.rs b/engine/src/strict_partial_ord.rs index ff6ffefe..c449e993 100644 --- a/engine/src/strict_partial_ord.rs +++ b/engine/src/strict_partial_ord.rs @@ -3,6 +3,7 @@ use std::cmp::Ordering; /// Strict version of PartialOrd that can define different enum items as /// incomparable. pub trait StrictPartialOrd: PartialOrd { + #[inline] fn strict_partial_cmp(&self, other: &Rhs) -> Option { self.partial_cmp(other) } diff --git a/engine/src/types.rs b/engine/src/types.rs index ee9ad085..d73ae167 100644 --- a/engine/src/types.rs +++ b/engine/src/types.rs @@ -1,18 +1,22 @@ use crate::{ lex::{expect, skip_space, Lex, LexResult, LexWith}, - rhs_types::{Bytes, IpRange, UninhabitedBool}, + lhs_types::{Array, ArrayIterator, ArrayMut, Map, MapIter, MapMut, MapValuesIntoIter}, + rhs_types::{Bytes, IntRange, IpRange, UninhabitedArray, UninhabitedBool, UninhabitedMap}, + scheme::{FieldIndex, IndexAccessError}, strict_partial_ord::StrictPartialOrd, }; -use failure::Fail; -use serde::{Deserialize, Serialize}; +use serde::de::{DeserializeSeed, Deserializer}; +use serde::{Deserialize, Serialize, Serializer}; use std::{ borrow::Cow, cmp::Ordering, + collections::BTreeSet, convert::TryFrom, fmt::{self, Debug, Formatter}, - net::IpAddr, - ops::RangeInclusive, + iter::once, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; +use thiserror::Error; fn lex_rhs_values<'i, T: Lex<'i>>(input: &'i str) -> LexResult<'i, Vec> { let mut input = expect(input, "{")?; @@ -30,20 +34,138 @@ fn lex_rhs_values<'i, T: Lex<'i>>(input: &'i str) -> LexResult<'i, Vec> { } } +/// An enum describing the expected type when a +/// TypeMismatchError occurs +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum ExpectedType { + /// Fully identified expected type + Type(Type), + /// Loosely identified array type + /// Usefull when expecting an array without + /// knowing of which specific value type + Array, + /// Loosely identified map type + /// Usefull when expecting a map without + /// knowing of which specific value type + Map, +} + +impl From for ExpectedType { + fn from(ty: Type) -> Self { + ExpectedType::Type(ty) + } +} + +/// A list of expected types. +#[derive(Default, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ExpectedTypeList(BTreeSet); + +impl ExpectedTypeList { + /// Insert an expected type in the list. + pub fn insert(&mut self, ty: impl Into) { + self.0.insert(ty.into()); + } +} + +impl fmt::Debug for ExpectedTypeList { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, f) + } +} + +impl From for ExpectedTypeList { + #[inline] + fn from(ty: Type) -> Self { + Self(once(ExpectedType::Type(ty)).collect()) + } +} + +impl From for ExpectedTypeList { + #[inline] + fn from(ty: ExpectedType) -> Self { + Self(once(ty).collect()) + } +} + +impl> From for ExpectedTypeList { + #[inline] + fn from(tys: T) -> Self { + Self(tys.collect()) + } +} + +impl std::fmt::Display for ExpectedTypeList { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + /// An error that occurs on a type mismatch. -#[derive(Debug, PartialEq, Fail)] -#[fail( - display = "expected value of type {:?}, but got {:?}", - expected, actual -)] +#[derive(Debug, PartialEq, Eq, Error)] +#[error("expected value of type {expected}, but got {actual}")] pub struct TypeMismatchError { /// Expected value type. - pub expected: Type, + pub expected: ExpectedTypeList, /// Provided value type. pub actual: Type, } +/// An error that occurs on a type mismatch. +#[derive(Debug, PartialEq, Eq, Error)] +pub enum SetValueError { + #[error("{0}")] + TypeMismatch(#[source] TypeMismatchError), + #[error("{0}")] + IndexAccess(#[source] IndexAccessError), +} + +macro_rules! replace_underscore { + ($name:ident ($val_ty:ty)) => { + Type::$name(_) + }; + ($name:ident) => { + Type::$name + }; +} + +macro_rules! specialized_get_type { + (Array, $value:ident) => { + $value.get_type() + }; + (Map, $value:ident) => { + $value.get_type() + }; + ($name:ident, $value:ident) => { + Type::$name + }; +} + +macro_rules! specialized_try_from { + (Array) => { + ExpectedType::Array + }; + (Map) => { + ExpectedType::Map + }; + ($name:ident) => { + ExpectedType::Type(Type::$name) + }; +} + +// This macro generates `Type`, `LhsValue`, `RhsValue`, `RhsValues`. +// +// Before the parenthesis is the variant for the `Type` enum (`Type::Ip`). +// First argument is the corresponding `LhsValue` variant (`LhsValue::Ip(IpAddr)`). +// Second argument is the corresponding `RhsValue` variant (`RhsValue::Ip(IpAddr)`). +// Third argument is the corresponding `RhsValues` variant (`RhsValues::Ip(Vec)`) for the curly bracket syntax. eg `num in {1, 5}` +// +// ``` +// declare_types! { +// Ip(IpAddr | IpAddr | IpRange), +// } +// ``` macro_rules! declare_types { + // This is just to be used by the other arm. ($(# $attrs:tt)* enum $name:ident $(<$lt:tt>)* { $($(# $vattrs:tt)* $variant:ident ( $ty:ty ) , )* }) => { $(# $attrs)* #[repr(u8)] @@ -54,7 +176,7 @@ macro_rules! declare_types { impl $(<$lt>)* GetType for $name $(<$lt>)* { fn get_type(&self) -> Type { match self { - $($name::$variant(_) => Type::$variant,)* + $($name::$variant(_value) => specialized_get_type!($variant, _value),)* } } } @@ -68,24 +190,12 @@ macro_rules! declare_types { } }; - ($($(# $attrs:tt)* $name:ident ( $(# $lhs_attrs:tt)* $lhs_ty:ty | $rhs_ty:ty | $multi_rhs_ty:ty ) , )*) => { + // This is the entry point for the macro. + ($($(# $attrs:tt)* $name:ident $([$val_ty:ty])? ( $(# $lhs_attrs:tt)* $lhs_ty:ty | $rhs_ty:ty | $multi_rhs_ty:ty ) , )*) => { /// Enumeration of supported types for field values. - #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] - #[repr(C)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Hash, PartialOrd, Ord)] pub enum Type { - $($(# $attrs)* $name,)* - } - - /// Provides a way to get a [`Type`] of the implementor. - pub trait GetType { - /// Returns a type. - fn get_type(&self) -> Type; - } - - impl GetType for Type { - fn get_type(&self) -> Type { - *self - } + $($(# $attrs)* $name$(($val_ty))?,)* } declare_types! { @@ -94,27 +204,35 @@ macro_rules! declare_types { /// These are passed to the [execution context](::ExecutionContext) /// and are used by [filters](::Filter) /// for execution and comparisons. - #[derive(PartialEq, Eq, Clone, Deserialize)] + #[derive(PartialEq, Eq, Clone, Deserialize, Hash)] #[serde(untagged)] enum LhsValue<'a> { $($(# $attrs)* $(# $lhs_attrs)* $name($lhs_ty),)* } } - $(impl<'a> From<$lhs_ty> for LhsValue<'a> { - fn from(value: $lhs_ty) -> Self { - LhsValue::$name(value) + $(impl<'a> TryFrom> for $lhs_ty { + type Error = TypeMismatchError; + + fn try_from(value: LhsValue<'a>) -> Result<$lhs_ty, TypeMismatchError> { + match value { + LhsValue::$name(value) => Ok(value), + _ => Err(TypeMismatchError { + expected: specialized_try_from!($name).into(), + actual: value.get_type(), + }), + } } })* - $(impl<'a> TryFrom> for $lhs_ty { + $(impl<'a> TryFrom<&'a LhsValue<'a>> for &'a $lhs_ty { type Error = TypeMismatchError; - fn try_from(value: LhsValue<'a>) -> Result<$lhs_ty, TypeMismatchError> { + fn try_from(value: &'a LhsValue<'a>) -> Result<&'a $lhs_ty, TypeMismatchError> { match value { LhsValue::$name(value) => Ok(value), _ => Err(TypeMismatchError { - expected: Type::$name, + expected: specialized_try_from!($name).into(), actual: value.get_type(), }), } @@ -123,7 +241,7 @@ macro_rules! declare_types { declare_types! { /// An RHS value parsed from a filter string. - #[derive(PartialEq, Eq, Clone, Serialize)] + #[derive(PartialEq, Eq, Clone, Hash, Serialize)] #[serde(untagged)] enum RhsValue { $($(# $attrs)* $name($rhs_ty),)* @@ -133,7 +251,7 @@ macro_rules! declare_types { impl<'i> LexWith<'i, Type> for RhsValue { fn lex_with(input: &str, ty: Type) -> LexResult<'_, Self> { Ok(match ty { - $(Type::$name => { + $(replace_underscore!($name $(($val_ty))?) => { let (value, input) = <$rhs_ty>::lex(input)?; (RhsValue::$name(value), input) })* @@ -152,30 +270,99 @@ macro_rules! declare_types { } } - impl<'a> StrictPartialOrd for LhsValue<'a> {} + $(impl<'a> TryFrom for $rhs_ty { + type Error = TypeMismatchError; - impl<'a> PartialEq for LhsValue<'a> { - fn eq(&self, other: &RhsValue) -> bool { - self.strict_partial_cmp(other) == Some(Ordering::Equal) + fn try_from(value: RhsValue) -> Result<$rhs_ty, TypeMismatchError> { + match value { + RhsValue::$name(value) => Ok(value), + _ => Err(TypeMismatchError { + expected: specialized_try_from!($name).into(), + actual: value.get_type(), + }), + } } - } + })* + + $(impl<'a> TryFrom<&'a RhsValue> for &'a $rhs_ty { + type Error = TypeMismatchError; + + fn try_from(value: &'a RhsValue) -> Result<&'a $rhs_ty, TypeMismatchError> { + match value { + RhsValue::$name(value) => Ok(value), + _ => Err(TypeMismatchError { + expected: specialized_try_from!($name).into(), + actual: value.get_type(), + }), + } + } + })* declare_types! { /// A typed group of a list of values. /// /// This is used for `field in { ... }` operation that allows /// only same-typed values in a list. - #[derive(PartialEq, Eq, Clone, Serialize)] + #[derive(PartialEq, Eq, Clone, Hash, Serialize)] #[serde(untagged)] enum RhsValues { $($(# $attrs)* $name(Vec<$multi_rhs_ty>),)* } } + impl From for RhsValues { + fn from(rhs: RhsValue) -> Self { + match rhs { + $(RhsValue::$name(rhs) => RhsValues::$name(vec![rhs.into()]),)* + } + } + } + + impl RhsValues { + /// Appends a value to the back of the collection. + pub fn push(&mut self, rhs: RhsValue) -> Result<(), TypeMismatchError> { + match self { + $(RhsValues::$name(vec) => match rhs { + RhsValue::$name(rhs) => Ok(vec.push(rhs.into())), + _ => Err(TypeMismatchError { + expected: self.get_type().into(), + actual: rhs.get_type(), + }), + },)* + } + } + + /// Moves all the values of `other` into `self`, leaving `other` empty. + pub fn append(&mut self, other: &mut Self) -> Result<(), TypeMismatchError> { + match self { + $(RhsValues::$name(vec) => match other { + RhsValues::$name(other) => Ok(vec.append(other)), + _ => Err(TypeMismatchError { + expected: self.get_type().into(), + actual: other.get_type(), + }), + },)* + } + } + + /// Extends the collection with the values of another collection. + pub fn extend(&mut self, other: Self) -> Result<(), TypeMismatchError> { + match self { + $(RhsValues::$name(vec) => match other { + RhsValues::$name(other) => Ok(vec.extend(other)), + _ => Err(TypeMismatchError { + expected: self.get_type().into(), + actual: other.get_type(), + }), + },)* + } + } + } + impl<'i> LexWith<'i, Type> for RhsValues { fn lex_with(input: &str, ty: Type) -> LexResult<'_, Self> { Ok(match ty { - $(Type::$name => { + $(replace_underscore!($name $(($val_ty))?) => { let (value, input) = lex_rhs_values(input)?; (RhsValues::$name(value), input) })* @@ -185,33 +372,333 @@ macro_rules! declare_types { }; } -// special cases for simply passing owned and borrowed bytes -impl<'a> From<&'a [u8]> for LhsValue<'a> { +impl Type { + /// Returns the inner type when available (e.g: for a Map) + pub fn next(&self) -> Option { + match self { + Type::Array(ty) => Some((*ty).into()), + Type::Map(ty) => Some((*ty).into()), + _ => None, + } + } + + /// Creates a new array type. + pub fn array(ty: impl Into) -> Self { + Self::Array(ty.into()) + } + + /// Creates a new map type. + pub fn map(ty: impl Into) -> Self { + Self::Map(ty.into()) + } +} + +impl std::fmt::Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Bool => write!(f, "Bool"), + Self::Bytes => write!(f, "Bytes"), + Self::Int => write!(f, "Int"), + Self::Ip => write!(f, "Ip"), + Self::Array(ty) => write!(f, "Array({})", Type::from(*ty)), + Self::Map(ty) => write!(f, "Map({})", Type::from(*ty)), + } + } +} + +/// Provides a way to get a [`Type`] of the implementor. +pub trait GetType { + /// Returns a type. + fn get_type(&self) -> Type; +} + +impl GetType for Type { + fn get_type(&self) -> Type { + *self + } +} + +impl GetType for CompoundType { + fn get_type(&self) -> Type { + (*self).into() + } +} + +impl<'a> StrictPartialOrd for LhsValue<'a> {} + +impl<'a> PartialEq for LhsValue<'a> { + fn eq(&self, other: &RhsValue) -> bool { + self.strict_partial_cmp(other) == Some(Ordering::Equal) + } +} + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum BytesOrString<'a> { + BorrowedBytes(#[serde(borrow)] &'a [u8]), + OwnedBytes(Vec), + BorrowedString(#[serde(borrow)] &'a str), + OwnedString(String), +} + +impl<'a> BytesOrString<'a> { + pub fn into_bytes(self) -> Cow<'a, [u8]> { + match self { + BytesOrString::BorrowedBytes(slice) => (*slice).into(), + BytesOrString::OwnedBytes(vec) => vec.into(), + BytesOrString::BorrowedString(str) => str.as_bytes().into(), + BytesOrString::OwnedString(str) => str.into_bytes().into(), + } + } +} + +mod private { + use super::IntoValue; + use crate::TypedArray; + use std::borrow::Cow; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + pub trait SealedIntoValue {} + + impl SealedIntoValue for bool {} + + impl SealedIntoValue for i8 {} + impl SealedIntoValue for u8 {} + impl SealedIntoValue for i16 {} + impl SealedIntoValue for u16 {} + impl SealedIntoValue for i32 {} + impl SealedIntoValue for u32 {} + impl SealedIntoValue for i64 {} + + impl SealedIntoValue for &[u8] {} + impl SealedIntoValue for Box<[u8]> {} + impl SealedIntoValue for Vec {} + impl SealedIntoValue for Cow<'_, [u8]> {} + impl SealedIntoValue for &str {} + impl SealedIntoValue for Box {} + impl SealedIntoValue for String {} + impl SealedIntoValue for Cow<'_, str> {} + + impl SealedIntoValue for IpAddr {} + impl SealedIntoValue for Ipv4Addr {} + impl SealedIntoValue for Ipv6Addr {} + + impl<'a, V: IntoValue<'a>> SealedIntoValue for TypedArray<'a, V> {} +} + +/// Converts a value into an `LhsValue`. +/// It is a stronger version of `Into>` in that +/// any value of the input type will *always* convert to the +/// same statically known `LhsValue` variant. +pub trait IntoValue<'a>: private::SealedIntoValue { + const TYPE: Type; + + fn into_value(self) -> LhsValue<'a>; +} + +impl<'a> IntoValue<'a> for bool { + const TYPE: Type = Type::Bool; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bool(self) + } +} + +impl<'a> IntoValue<'a> for i64 { + const TYPE: Type = Type::Int; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Int(self) + } +} + +impl<'a> IntoValue<'a> for i32 { + const TYPE: Type = Type::Int; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Int(i64::from(self)) + } +} + +impl<'a> IntoValue<'a> for i16 { + const TYPE: Type = Type::Int; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Int(i64::from(self)) + } +} + +impl<'a> IntoValue<'a> for u16 { + const TYPE: Type = Type::Int; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Int(i64::from(self)) + } +} + +impl<'a> IntoValue<'a> for i8 { + const TYPE: Type = Type::Int; + #[inline] - fn from(b: &'a [u8]) -> Self { - LhsValue::Bytes(Cow::Borrowed(b)) + fn into_value(self) -> LhsValue<'a> { + LhsValue::Int(i64::from(self)) } } -impl From> for LhsValue<'static> { +impl<'a> IntoValue<'a> for u8 { + const TYPE: Type = Type::Int; + #[inline] - fn from(b: Vec) -> Self { - LhsValue::Bytes(Cow::Owned(b)) + fn into_value(self) -> LhsValue<'a> { + LhsValue::Int(i64::from(self)) } } -// special cases for simply passing strings and string slices -impl<'a> From<&'a str> for LhsValue<'a> { +impl<'a> IntoValue<'a> for &'a [u8] { + const TYPE: Type = Type::Bytes; + #[inline] - fn from(s: &'a str) -> Self { - s.as_bytes().into() + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(Cow::Borrowed(self)) } } -impl From for LhsValue<'static> { +impl<'a> IntoValue<'a> for Box<[u8]> { + const TYPE: Type = Type::Bytes; + #[inline] - fn from(s: String) -> Self { - s.into_bytes().into() + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(Cow::Owned(Vec::from(self))) + } +} + +impl<'a> IntoValue<'a> for Vec { + const TYPE: Type = Type::Bytes; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(Cow::Owned(self)) + } +} + +impl<'a> IntoValue<'a> for Cow<'a, [u8]> { + const TYPE: Type = Type::Bytes; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(self) + } +} + +impl<'a> IntoValue<'a> for &'a str { + const TYPE: Type = Type::Bytes; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(Cow::Borrowed(self.as_bytes())) + } +} + +impl<'a> IntoValue<'a> for Box { + const TYPE: Type = Type::Bytes; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(Cow::Owned(Vec::from(Box::<[u8]>::from(self)))) + } +} + +impl<'a> IntoValue<'a> for String { + const TYPE: Type = Type::Bytes; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(Cow::Owned(self.into_bytes())) + } +} + +impl<'a> IntoValue<'a> for Cow<'a, str> { + const TYPE: Type = Type::Bytes; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Bytes(match self { + Cow::Borrowed(slice) => Cow::Borrowed(slice.as_bytes()), + Cow::Owned(vec) => Cow::Owned(vec.into()), + }) + } +} + +impl<'a> IntoValue<'a> for IpAddr { + const TYPE: Type = Type::Ip; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Ip(self) + } +} + +impl<'a> IntoValue<'a> for Ipv4Addr { + const TYPE: Type = Type::Ip; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Ip(IpAddr::V4(self)) + } +} + +impl<'a> IntoValue<'a> for Ipv6Addr { + const TYPE: Type = Type::Ip; + + #[inline] + fn into_value(self) -> LhsValue<'a> { + LhsValue::Ip(IpAddr::V6(self)) + } +} + +impl<'a, T: IntoValue<'a>> From for LhsValue<'a> { + #[inline] + fn from(value: T) -> Self { + value.into_value() + } +} + +// Array cannot implement `IntoValue` as the +// underlying element type is not statically +// known. +impl<'a> From> for LhsValue<'a> { + #[inline] + fn from(value: Array<'a>) -> LhsValue<'a> { + LhsValue::Array(value) + } +} + +// Map cannot implement `IntoValue` as the +// underlying element type is not statically +// known. +impl<'a> From> for LhsValue<'a> { + #[inline] + fn from(value: Map<'a>) -> LhsValue<'a> { + LhsValue::Map(value) + } +} + +impl<'a> TryFrom<&'a LhsValue<'a>> for &'a [u8] { + type Error = TypeMismatchError; + + fn try_from(value: &'a LhsValue<'_>) -> Result { + match value { + LhsValue::Bytes(value) => Ok(value), + _ => Err(TypeMismatchError { + expected: Type::Bytes.into(), + actual: value.get_type(), + }), + } } } @@ -222,6 +709,21 @@ impl<'a> From<&'a RhsValue> for LhsValue<'a> { RhsValue::Bytes(bytes) => LhsValue::Bytes(Cow::Borrowed(bytes)), RhsValue::Int(integer) => LhsValue::Int(*integer), RhsValue::Bool(b) => match *b {}, + RhsValue::Array(a) => match *a {}, + RhsValue::Map(m) => match *m {}, + } + } +} + +impl<'a> From for LhsValue<'a> { + fn from(rhs_value: RhsValue) -> Self { + match rhs_value { + RhsValue::Ip(ip) => LhsValue::Ip(ip), + RhsValue::Bytes(bytes) => LhsValue::Bytes(Cow::Owned(bytes.into())), + RhsValue::Int(integer) => LhsValue::Int(integer), + RhsValue::Bool(b) => match b {}, + RhsValue::Array(a) => match a {}, + RhsValue::Map(m) => match m {}, } } } @@ -235,12 +737,377 @@ impl<'a> LhsValue<'a> { LhsValue::Bytes(bytes) => LhsValue::Bytes(Cow::Borrowed(bytes)), LhsValue::Int(integer) => LhsValue::Int(*integer), LhsValue::Bool(b) => LhsValue::Bool(*b), + LhsValue::Array(a) => LhsValue::Array(a.as_ref()), + LhsValue::Map(m) => LhsValue::Map(m.as_ref()), + } + } + + /// Converts an `LhsValue` with borrowed data to a fully owned `LhsValue`. + pub fn into_owned(self) -> LhsValue<'static> { + match self { + LhsValue::Ip(ip) => LhsValue::Ip(ip), + LhsValue::Bytes(bytes) => LhsValue::Bytes(Cow::Owned(bytes.into_owned())), + LhsValue::Int(i) => LhsValue::Int(i), + LhsValue::Bool(b) => LhsValue::Bool(b), + LhsValue::Array(arr) => LhsValue::Array(arr.into_owned()), + LhsValue::Map(map) => LhsValue::Map(map.into_owned()), + } + } + + /// Retrieve an element from an LhsValue given a path item and a specified + /// type. + /// Returns a TypeMismatchError error if current type does not support it + /// nested element. + /// + /// Both LhsValue::Array and LhsValue::Map support nested elements. + pub fn get(&'a self, item: &FieldIndex) -> Result>, IndexAccessError> { + match (self, item) { + (LhsValue::Array(arr), FieldIndex::ArrayIndex(ref idx)) => Ok(arr.get(*idx as usize)), + (_, FieldIndex::ArrayIndex(_)) => Err(IndexAccessError { + index: item.clone(), + actual: self.get_type(), + }), + (LhsValue::Map(map), FieldIndex::MapKey(ref key)) => Ok(map.get(key.as_bytes())), + (_, FieldIndex::MapKey(_)) => Err(IndexAccessError { + index: item.clone(), + actual: self.get_type(), + }), + (_, FieldIndex::MapEach) => Err(IndexAccessError { + index: item.clone(), + actual: self.get_type(), + }), + } + } + + pub(crate) fn extract( + self, + item: &FieldIndex, + ) -> Result>, IndexAccessError> { + match item { + FieldIndex::ArrayIndex(idx) => match self { + LhsValue::Array(arr) => Ok(arr.extract(*idx as usize)), + _ => Err(IndexAccessError { + index: item.clone(), + actual: self.get_type(), + }), + }, + FieldIndex::MapKey(key) => match self { + LhsValue::Map(map) => Ok(map.extract(key.as_bytes())), + _ => Err(IndexAccessError { + index: item.clone(), + actual: self.get_type(), + }), + }, + FieldIndex::MapEach => Err(IndexAccessError { + index: item.clone(), + actual: self.get_type(), + }), + } + } + + /// Set an element in an LhsValue given a path item and a specified value. + /// Returns a TypeMismatchError error if current type does not support + /// nested element or if value type is invalid. + /// Only LhsValyue::Map supports nested elements for now. + pub fn set>>( + &mut self, + item: FieldIndex, + value: V, + ) -> Result<(), SetValueError> { + let value = value.into(); + match item { + FieldIndex::ArrayIndex(idx) => match self { + LhsValue::Array(ref mut arr) => arr + .insert(idx as usize, value) + .map_err(SetValueError::TypeMismatch), + _ => Err(SetValueError::IndexAccess(IndexAccessError { + index: item, + actual: self.get_type(), + })), + }, + FieldIndex::MapKey(name) => match self { + LhsValue::Map(ref mut map) => map + .insert(name.as_bytes(), value) + .map_err(SetValueError::TypeMismatch), + _ => Err(SetValueError::IndexAccess(IndexAccessError { + index: FieldIndex::MapKey(name), + actual: self.get_type(), + })), + }, + FieldIndex::MapEach => Err(SetValueError::IndexAccess(IndexAccessError { + index: item, + actual: self.get_type(), + })), + } + } + + /// Returns an iterator over the Map or Array + pub fn iter(&'a self) -> Option> { + match self { + LhsValue::Array(array) => Some(Iter::IterArray(array.as_slice().iter())), + LhsValue::Map(map) => Some(Iter::IterMap(map.iter())), + _ => None, + } + } +} + +impl<'a> Serialize for LhsValue<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + LhsValue::Ip(ip) => ip.serialize(serializer), + LhsValue::Bytes(bytes) => { + if let Ok(s) = std::str::from_utf8(bytes) { + s.serialize(serializer) + } else { + bytes.serialize(serializer) + } + } + LhsValue::Int(num) => num.serialize(serializer), + LhsValue::Bool(b) => b.serialize(serializer), + LhsValue::Array(arr) => arr.serialize(serializer), + LhsValue::Map(map) => map.serialize(serializer), + } + } +} + +pub(crate) struct LhsValueSeed<'a>(pub &'a Type); + +impl<'de, 'a> DeserializeSeed<'de> for LhsValueSeed<'a> { + type Value = LhsValue<'de>; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + match self.0 { + Type::Ip => Ok(LhsValue::Ip(std::net::IpAddr::deserialize(deserializer)?)), + Type::Int => Ok(LhsValue::Int(i64::deserialize(deserializer)?)), + Type::Bool => Ok(LhsValue::Bool(bool::deserialize(deserializer)?)), + Type::Bytes => Ok(LhsValue::Bytes( + BytesOrString::deserialize(deserializer)?.into_bytes(), + )), + Type::Array(ty) => Ok(LhsValue::Array({ + let mut arr = Array::new(*ty); + arr.deserialize(deserializer)?; + arr + })), + Type::Map(ty) => Ok(LhsValue::Map({ + let mut map = Map::new(*ty); + map.deserialize(deserializer)?; + map + })), + } + } +} + +pub enum IntoIter<'a> { + IntoArray(ArrayIterator<'a>), + IntoMap(MapValuesIntoIter<'a>), +} + +impl<'a> Iterator for IntoIter<'a> { + type Item = LhsValue<'a>; + + fn next(&mut self) -> Option> { + match self { + IntoIter::IntoArray(array) => array.next(), + IntoIter::IntoMap(map) => map.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } +} + +impl<'a> ExactSizeIterator for IntoIter<'a> { + fn len(&self) -> usize { + match self { + IntoIter::IntoArray(array) => array.len(), + IntoIter::IntoMap(map) => map.len(), + } + } +} + +impl<'a> IntoIterator for LhsValue<'a> { + type Item = LhsValue<'a>; + type IntoIter = IntoIter<'a>; + fn into_iter(self) -> Self::IntoIter { + match self { + LhsValue::Array(array) => IntoIter::IntoArray(array.into_iter()), + LhsValue::Map(map) => IntoIter::IntoMap(map.values_into_iter()), + _ => unreachable!(), + } + } +} + +pub enum Iter<'a> { + IterArray(std::slice::Iter<'a, LhsValue<'a>>), + IterMap(MapIter<'a, 'a>), +} + +impl<'a> Iterator for Iter<'a> { + type Item = &'a LhsValue<'a>; + + fn next(&mut self) -> Option<&'a LhsValue<'a>> { + match self { + Iter::IterArray(array) => array.next(), + Iter::IterMap(map) => map.next().map(|(_, v)| v), + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } +} + +impl<'a> ExactSizeIterator for Iter<'a> { + fn len(&self) -> usize { + match self { + Iter::IterArray(array) => array.len(), + Iter::IterMap(map) => map.len(), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize, Hash, PartialOrd, Ord)] +enum PrimitiveType { + Bool, + Bytes, + Int, + Ip, +} + +#[derive(Clone, Copy, Debug)] +enum Layer { + Array, + Map, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct CompoundType { + layers: u32, + len: u8, + primitive: PrimitiveType, +} + +impl Serialize for CompoundType { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + Type::from(*self).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for CompoundType { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Type::deserialize(deserializer).map(Self::from) + } +} + +impl CompoundType { + #[inline] + const fn new(ty: PrimitiveType) -> Self { + Self { + layers: 0, + len: 0, + primitive: ty, + } + } + + #[inline] + pub(crate) const fn from_type(ty: Type) -> Option { + match ty { + Type::Bool => Some(Self::new(PrimitiveType::Bool)), + Type::Bytes => Some(Self::new(PrimitiveType::Bytes)), + Type::Int => Some(Self::new(PrimitiveType::Int)), + Type::Ip => Some(Self::new(PrimitiveType::Ip)), + Type::Array(ty) => ty.push(Layer::Array), + Type::Map(ty) => ty.push(Layer::Map), + } + } + + #[inline] + const fn pop(mut self) -> (Self, Option) { + if self.len > 0 { + // Maybe use (trailing/leading)_(ones/zeros) instead + let is_array = (self.layers & 1) == 0; + self.layers >>= 1; + self.len -= 1; + if is_array { + (self, Some(Layer::Array)) + } else { + (self, Some(Layer::Map)) + } + } else { + (self, None) + } + } + + #[inline] + const fn push(mut self, layer: Layer) -> Option { + if self.len >= 32 { + None + } else { + let layer = match layer { + Layer::Array => 0, + Layer::Map => 1, + }; + self.layers = (self.layers << 1) | layer; + self.len += 1; + Some(self) + } + } +} + +impl From for CompoundType { + #[inline] + fn from(ty: PrimitiveType) -> Self { + Self::new(ty) + } +} + +impl From for CompoundType { + #[inline] + fn from(ty: Type) -> Self { + Self::from_type(ty).unwrap() + } +} + +impl From for Type { + #[inline] + fn from(ty: CompoundType) -> Self { + let (ty, layer) = ty.pop(); + match layer { + Some(Layer::Array) => Type::Array(ty), + Some(Layer::Map) => Type::Map(ty), + None => match ty.primitive { + PrimitiveType::Bool => Type::Bool, + PrimitiveType::Bytes => Type::Bytes, + PrimitiveType::Int => Type::Int, + PrimitiveType::Ip => Type::Ip, + }, } } } declare_types!( - /// An IPv4 or IPv6 field. + /// A boolean. + Bool(bool | UninhabitedBool | UninhabitedBool), + + /// A 64-bit integer number. + Int(i64 | i64 | IntRange), + + /// An IPv4 or IPv6 address. /// /// These are represented as a single type to allow interop comparisons. Ip(IpAddr | IpAddr | IpRange), @@ -251,13 +1118,45 @@ declare_types!( /// syntax representation, so we represent them as a single type. Bytes(#[serde(borrow)] Cow<'a, [u8]> | Bytes | Bytes), - /// A 32-bit integer number. - Int(i32 | i32 | RangeInclusive), + /// An Array of [`Type`]. + Array[CompoundType](#[serde(skip_deserializing)] Array<'a> | UninhabitedArray | UninhabitedArray), - /// A boolean. - Bool(bool | UninhabitedBool | UninhabitedBool), + /// A Map of string to [`Type`]. + Map[CompoundType](#[serde(skip_deserializing)] Map<'a> | UninhabitedMap | UninhabitedMap), ); +/// Wrapper type around mutable `LhsValue` to prevent +/// illegal operations like changing the type of values +/// in an `Array` or a `Map`. +pub enum LhsValueMut<'a, 'b> { + /// A mutable boolean. + Bool(&'a mut bool), + /// A mutable 32-bit integer number. + Int(&'a mut i64), + /// A mutable IPv4 or IPv6 address. + Ip(&'a mut IpAddr), + /// A mutable byte string. + Bytes(&'a mut Cow<'b, [u8]>), + /// A mutable array. + Array(ArrayMut<'a, 'b>), + /// A mutable map. + Map(MapMut<'a, 'b>), +} + +impl<'a, 'b> From<&'a mut LhsValue<'b>> for LhsValueMut<'a, 'b> { + #[inline] + fn from(value: &'a mut LhsValue<'b>) -> Self { + match value { + LhsValue::Bool(b) => LhsValueMut::Bool(b), + LhsValue::Int(i) => LhsValueMut::Int(i), + LhsValue::Ip(ip) => LhsValueMut::Ip(ip), + LhsValue::Bytes(b) => LhsValueMut::Bytes(b), + LhsValue::Array(arr) => LhsValueMut::Array(arr.into()), + LhsValue::Map(map) => LhsValueMut::Map(map.into()), + } + } +} + #[test] fn test_lhs_value_deserialize() { use std::str::FromStr; @@ -291,3 +1190,84 @@ fn test_lhs_value_deserialize() { let b: LhsValue<'_> = serde_json::from_str("false").unwrap(); assert_eq!(b, LhsValue::Bool(false)); } + +#[test] +fn test_type_serialize() { + let ty = Type::Bool; + assert_eq!(serde_json::to_string(&ty).unwrap(), "\"Bool\""); + + let ty = Type::Bytes; + assert_eq!(serde_json::to_string(&ty).unwrap(), "\"Bytes\""); + + let ty = Type::Int; + assert_eq!(serde_json::to_string(&ty).unwrap(), "\"Int\""); + + let ty = Type::Ip; + assert_eq!(serde_json::to_string(&ty).unwrap(), "\"Ip\""); + + let ty = Type::Array(Type::Bytes.into()); + assert_eq!(serde_json::to_string(&ty).unwrap(), "{\"Array\":\"Bytes\"}"); + + let ty = Type::Map(Type::Bytes.into()); + assert_eq!(serde_json::to_string(&ty).unwrap(), "{\"Map\":\"Bytes\"}"); + + let ty = Type::Map(Type::Array(Type::Bytes.into()).into()); + assert_eq!( + serde_json::to_string(&ty).unwrap(), + "{\"Map\":{\"Array\":\"Bytes\"}}" + ); + + let ty = Type::Array(Type::Map(Type::Bytes.into()).into()); + assert_eq!( + serde_json::to_string(&ty).unwrap(), + "{\"Array\":{\"Map\":\"Bytes\"}}" + ); +} + +#[test] +fn test_type_deserialize() { + assert_eq!( + serde_json::from_str::<'_, Type>("\"Bool\"").unwrap(), + Type::Bool, + ); + + assert_eq!( + serde_json::from_str::<'_, Type>("\"Bytes\"").unwrap(), + Type::Bytes, + ); + + assert_eq!( + serde_json::from_str::<'_, Type>("\"Int\"").unwrap(), + Type::Int, + ); + + assert_eq!( + serde_json::from_str::<'_, Type>("\"Ip\"").unwrap(), + Type::Ip, + ); + + assert_eq!( + serde_json::from_str::<'_, Type>("{\"Array\":\"Bytes\"}").unwrap(), + Type::Array(Type::Bytes.into()), + ); + + assert_eq!( + serde_json::from_str::<'_, Type>("{\"Map\":\"Bytes\"}").unwrap(), + Type::Map(Type::Bytes.into()), + ); + + assert_eq!( + serde_json::from_str::<'_, Type>("{\"Map\":{\"Array\":\"Bytes\"}}").unwrap(), + Type::Map(Type::Array(Type::Bytes.into()).into()), + ); + + assert_eq!( + serde_json::from_str::<'_, Type>("{\"Array\":{\"Map\":\"Bytes\"}}").unwrap(), + Type::Array(Type::Map(Type::Bytes.into()).into()), + ); +} + +#[test] +fn test_size_of_lhs_value() { + assert_eq!(std::mem::size_of::>(), 48); +} diff --git a/ffi/Cargo.toml b/ffi/Cargo.toml index 5dec5bc5..3f2209cc 100644 --- a/ffi/Cargo.toml +++ b/ffi/Cargo.toml @@ -1,29 +1,31 @@ [package] -authors = ["Ingvar Stepanyan "] +authors = [ "Ingvar Stepanyan " ] name = "wirefilter-ffi" version = "0.7.0" description = "FFI bindings for the Wirefilter engine" publish = false -edition = "2018" +edition = "2021" [package.metadata.deb] -assets = [["target/release/libwirefilter_ffi.so", "usr/local/lib/libwirefilter.so", "644"]] +assets = [ [ "target/release/libwirefilter_ffi.so", "usr/local/lib/libwirefilter.so", "644" ] ] [lib] -crate-type = ["cdylib", "rlib"] +crate-type = [ "cdylib", "rlib" ] # Avoid duplicate compilation error messages as we don't have doctests anyway doctest = false bench = false [dependencies] -fnv = "1.0.6" -libc = "0.2.42" -serde_json = "1.0.27" - -[dependencies.wirefilter-engine] -path = "../engine" +fnv.workspace = true +libc.workspace = true +num_enum.workspace = true +serde.workspace = true +serde_json.workspace = true +wirefilter.workspace = true [dev-dependencies] -regex = "1.1.5" -indoc = "0.3.0" -wirefilter-ffi-ctests = {path = "tests/ctests"} +indoc.workspace = true +regex.workspace = true + +[target.'cfg(unix)'.dev-dependencies] +wirefilter-ffi-ctests = { path = "tests/ctests" } diff --git a/ffi/include/wirefilter.h b/ffi/include/wirefilter.h index a54d7a4e..7ffea460 100644 --- a/ffi/include/wirefilter.h +++ b/ffi/include/wirefilter.h @@ -12,6 +12,9 @@ typedef struct wirefilter_scheme wirefilter_scheme_t; typedef struct wirefilter_execution_context wirefilter_execution_context_t; typedef struct wirefilter_filter_ast wirefilter_filter_ast_t; typedef struct wirefilter_filter wirefilter_filter_t; +typedef struct wirefilter_map wirefilter_map_t; +typedef struct wirefilter_array wirefilter_array_t; +typedef struct wirefilter_list wirefilter_list_t; typedef struct { const char *data; @@ -45,30 +48,125 @@ typedef union { } ok; } wirefilter_parsing_result_t; +typedef union { + uint8_t success; + struct { + uint8_t _res1; + wirefilter_rust_allocated_str_t msg; + } err; + struct { + uint8_t _res2; + bool value; + } ok; +} wirefilter_boolean_result_t; + +typedef wirefilter_boolean_result_t wirefilter_using_result_t; + +typedef wirefilter_boolean_result_t wirefilter_deserializing_result_t; + +typedef union { + uint8_t success; + struct { + uint8_t _res1; + wirefilter_rust_allocated_str_t msg; + } err; + struct { + uint8_t _res2; + wirefilter_filter_t *filter; + } ok; +} wirefilter_compiling_result_t; + +typedef wirefilter_boolean_result_t wirefilter_matching_result_t; + +typedef union { + uint8_t success; + struct { + uint8_t _res1; + wirefilter_rust_allocated_str_t msg; + } err; + struct { + uint8_t _res2; + wirefilter_rust_allocated_str_t json; + } ok; +} wirefilter_serializing_result_t; + +typedef union { + uint8_t success; + struct { + uint8_t _res1; + wirefilter_rust_allocated_str_t msg; + } err; + struct { + uint8_t _res2; + uint64_t hash; + } ok; +} wirefilter_hashing_result_t; + typedef enum { - WIREFILTER_TYPE_IP, - WIREFILTER_TYPE_BYTES, - WIREFILTER_TYPE_INT, - WIREFILTER_TYPE_BOOL, + WIREFILTER_PRIMITIVE_TYPE_UNKNOWN = 0, + WIREFILTER_PRIMITIVE_TYPE_IP = 1, + WIREFILTER_PRIMITIVE_TYPE_BYTES = 2, + WIREFILTER_PRIMITIVE_TYPE_INT = 3, + WIREFILTER_PRIMITIVE_TYPE_BOOL = 4, +} wirefilter_primitive_type_t; + +typedef struct { + uint32_t layers; + uint8_t len; + uint8_t primitive; } wirefilter_type_t; +static const wirefilter_type_t WIREFILTER_TYPE_IP = {.layers = 0, .len = 0, .primitive = WIREFILTER_PRIMITIVE_TYPE_IP}; +static const wirefilter_type_t WIREFILTER_TYPE_BYTES = {.layers = 0, .len = 0, .primitive = WIREFILTER_PRIMITIVE_TYPE_BYTES}; +static const wirefilter_type_t WIREFILTER_TYPE_INT = {.layers = 0, .len = 0, .primitive = WIREFILTER_PRIMITIVE_TYPE_INT}; +static const wirefilter_type_t WIREFILTER_TYPE_BOOL = {.layers = 0, .len = 0, .primitive = WIREFILTER_PRIMITIVE_TYPE_BOOL}; + +typedef enum { + WIREFILTER_PANIC_CATCHER_FALLBACK_MODE_CONTINUE = 0, + WIREFILTER_PANIC_CATCHER_FALLBACK_MODE_ABORT = 1, +} wirefilter_panic_catcher_fallback_mode_t; + +void wirefilter_set_panic_catcher_hook(); +wirefilter_boolean_result_t wirefilter_set_panic_catcher_fallback_mode(uint8_t mode); +void wirefilter_enable_panic_catcher(); +void wirefilter_disable_panic_catcher(); + wirefilter_scheme_t *wirefilter_create_scheme(); void wirefilter_free_scheme(wirefilter_scheme_t *scheme); -void wirefilter_add_type_field_to_scheme( +wirefilter_type_t wirefilter_create_map_type(wirefilter_type_t type); + +wirefilter_type_t wirefilter_create_array_type(wirefilter_type_t type); + +bool wirefilter_add_type_field_to_scheme( wirefilter_scheme_t *scheme, wirefilter_externally_allocated_str_t name, wirefilter_type_t type ); +wirefilter_list_t *wirefilter_create_always_list(); + +wirefilter_list_t *wirefilter_create_never_list(); + +bool wirefilter_add_type_list_to_scheme( + wirefilter_scheme_t *scheme, + wirefilter_type_t type, + wirefilter_list_t *list +); + wirefilter_parsing_result_t wirefilter_parse_filter( const wirefilter_scheme_t *scheme, wirefilter_externally_allocated_str_t input ); +void wirefilter_free_filter_ast(wirefilter_filter_ast_t *ast); + void wirefilter_free_parsing_result(wirefilter_parsing_result_t result); -wirefilter_filter_t *wirefilter_compile_filter(wirefilter_filter_ast_t *ast); +wirefilter_compiling_result_t wirefilter_compile_filter(wirefilter_filter_ast_t *ast); + +void wirefilter_free_compiling_result(wirefilter_compiling_result_t result); + void wirefilter_free_compiled_filter(wirefilter_filter_t *filter); wirefilter_execution_context_t *wirefilter_create_execution_context( @@ -78,52 +176,184 @@ void wirefilter_free_execution_context( wirefilter_execution_context_t *exec_ctx ); -void wirefilter_add_int_value_to_execution_context( +bool wirefilter_add_int_value_to_execution_context( wirefilter_execution_context_t *exec_ctx, wirefilter_externally_allocated_str_t name, - int32_t value + int64_t value ); -void wirefilter_add_bytes_value_to_execution_context( +bool wirefilter_add_bytes_value_to_execution_context( wirefilter_execution_context_t *exec_ctx, wirefilter_externally_allocated_str_t name, wirefilter_externally_allocated_byte_arr_t value ); -void wirefilter_add_ipv6_value_to_execution_context( +bool wirefilter_add_ipv6_value_to_execution_context( wirefilter_execution_context_t *exec_ctx, wirefilter_externally_allocated_str_t name, uint8_t value[16] ); -void wirefilter_add_ipv4_value_to_execution_context( +bool wirefilter_add_ipv4_value_to_execution_context( wirefilter_execution_context_t *exec_ctx, wirefilter_externally_allocated_str_t name, uint8_t value[4] ); -void wirefilter_add_bool_value_to_execution_context( +bool wirefilter_add_bool_value_to_execution_context( wirefilter_execution_context_t *exec_ctx, wirefilter_externally_allocated_str_t name, bool value ); -bool wirefilter_match( +bool wirefilter_add_map_value_to_execution_context( + wirefilter_execution_context_t *exec_ctx, + wirefilter_externally_allocated_str_t name, + wirefilter_map_t *map +); + +bool wirefilter_add_array_value_to_execution_context( + wirefilter_execution_context_t *exec_ctx, + wirefilter_externally_allocated_str_t name, + wirefilter_array_t *array +); + +wirefilter_map_t *wirefilter_create_map(wirefilter_type_t type); + +bool wirefilter_add_int_value_to_map( + wirefilter_map_t *map, + wirefilter_externally_allocated_str_t name, + int64_t value +); + +bool wirefilter_add_bytes_value_to_map( + wirefilter_map_t *map, + wirefilter_externally_allocated_str_t name, + wirefilter_externally_allocated_byte_arr_t value +); + +bool wirefilter_add_ipv6_value_to_map( + wirefilter_map_t *map, + wirefilter_externally_allocated_str_t name, + uint8_t value[16] +); + +bool wirefilter_add_ipv4_value_to_map( + wirefilter_map_t *map, + wirefilter_externally_allocated_str_t name, + uint8_t value[4] +); + +bool wirefilter_add_bool_value_to_map( + wirefilter_map_t *map, + wirefilter_externally_allocated_str_t name, + bool value +); + +bool wirefilter_add_map_value_to_map( + wirefilter_map_t *map, + wirefilter_externally_allocated_str_t name, + wirefilter_map_t *value +); + +bool wirefilter_add_array_value_to_map( + wirefilter_map_t *map, + wirefilter_externally_allocated_str_t name, + wirefilter_array_t *value +); + +void wirefilter_free_map(wirefilter_map_t *map); + +wirefilter_array_t *wirefilter_create_array(wirefilter_type_t type); + +bool wirefilter_add_int_value_to_array( + wirefilter_array_t *array, + uint32_t index, + int64_t value +); + +bool wirefilter_add_bytes_value_to_array( + wirefilter_array_t *array, + uint32_t index, + wirefilter_externally_allocated_byte_arr_t value +); + +bool wirefilter_add_ipv6_value_to_array( + wirefilter_array_t *array, + uint32_t index, + uint8_t value[16] +); + +bool wirefilter_add_ipv4_value_to_array( + wirefilter_array_t *array, + uint32_t index, + uint8_t value[4] +); + +bool wirefilter_add_bool_value_to_array( + wirefilter_array_t *array, + uint32_t index, + bool value +); + +bool wirefilter_add_map_value_to_array( + wirefilter_array_t *array, + uint32_t index, + wirefilter_map_t *value +); + +bool wirefilter_add_array_value_to_array( + wirefilter_array_t *array, + uint32_t index, + wirefilter_array_t *value +); + +void wirefilter_free_array(wirefilter_array_t *array); + +wirefilter_matching_result_t wirefilter_match( const wirefilter_filter_t *filter, const wirefilter_execution_context_t *exec_ctx ); -bool wirefilter_filter_uses( +void wirefilter_free_matching_result(wirefilter_matching_result_t result); + +wirefilter_using_result_t wirefilter_filter_uses( + const wirefilter_filter_ast_t *ast, + wirefilter_externally_allocated_str_t field_name +); + +wirefilter_using_result_t wirefilter_filter_uses_list( const wirefilter_filter_ast_t *ast, wirefilter_externally_allocated_str_t field_name ); -uint64_t wirefilter_get_filter_hash(const wirefilter_filter_ast_t *ast); +wirefilter_hashing_result_t wirefilter_get_filter_hash(const wirefilter_filter_ast_t *ast); -wirefilter_rust_allocated_str_t wirefilter_serialize_filter_to_json( +void wirefilter_free_hashing_result(wirefilter_hashing_result_t result); + +wirefilter_serializing_result_t wirefilter_serialize_filter_to_json( const wirefilter_filter_ast_t *ast ); +wirefilter_serializing_result_t wirefilter_serialize_scheme_to_json( + const wirefilter_scheme_t *scheme +); + +wirefilter_serializing_result_t wirefilter_serialize_type_to_json( + const wirefilter_type_t *type +); + +wirefilter_deserializing_result_t wirefilter_deserialize_json_to_execution_context( + wirefilter_execution_context_t *ctx, + wirefilter_externally_allocated_byte_arr_t serialized_exec_context +); + +wirefilter_serializing_result_t wirefilter_serialize_execution_context_to_json( + const wirefilter_execution_context_t *exec_ctx +); + +void wirefilter_free_serializing_result(wirefilter_serializing_result_t result); + void wirefilter_free_string(wirefilter_rust_allocated_str_t str); wirefilter_static_rust_allocated_str_t wirefilter_get_version(); diff --git a/ffi/src/lib.rs b/ffi/src/lib.rs index 84b5307a..824dff54 100644 --- a/ffi/src/lib.rs +++ b/ffi/src/lib.rs @@ -1,25 +1,148 @@ +#![warn(rust_2018_idioms)] + +pub mod panic; pub mod transfer_types; +use crate::panic::catch_panic; use crate::transfer_types::{ - ExternallyAllocatedByteArr, ExternallyAllocatedStr, RustAllocatedString, RustBox, + ExternallyAllocatedByteArr, ExternallyAllocatedStr, RustAllocatedString, StaticRustAllocatedString, }; use fnv::FnvHasher; +use num_enum::{IntoPrimitive, TryFromPrimitive}; +use serde::de::DeserializeSeed; use std::{ + convert::TryFrom, hash::Hasher, io::{self, Write}, net::IpAddr, }; -use wirefilter::{ExecutionContext, Filter, FilterAst, ParseError, Scheme, Type}; +use wirefilter::{ + AlwaysList, Array, ExecutionContext, FieldIndex, Filter, FilterAst, LhsValue, ListDefinition, + Map, NeverList, ParseError, Scheme, Type, +}; const VERSION: &str = env!("CARGO_PKG_VERSION"); +#[derive(Debug, Eq, PartialEq, IntoPrimitive, TryFromPrimitive)] +#[repr(u8)] +pub enum CPrimitiveType { + Ip = 1u8, + Bytes = 2u8, + Int = 3u8, + Bool = 4u8, +} + +enum Layer { + Array, + Map, +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub struct CType { + pub layers: u32, + pub len: u8, + pub primitive: u8, +} + +impl CType { + const fn push(mut self, layer: Layer) -> CType { + let layer = match layer { + Layer::Array => 0, + Layer::Map => 1, + }; + self.layers = (self.layers << 1) | layer; + self.len += 1; + self + } + + const fn pop(mut self) -> (Self, Option) { + if self.len > 0 { + // Maybe use (trailing/leading)_(ones/zeros) instead + let layer = (self.layers & 1) == 0; + self.layers >>= 1; + self.len -= 1; + if layer { + (self, Some(Layer::Array)) + } else { + (self, Some(Layer::Map)) + } + } else { + (self, None) + } + } +} + +impl From for Type { + fn from(cty: CType) -> Self { + let (ty, layer) = cty.pop(); + match layer { + Some(Layer::Array) => Type::Array(Type::from(ty).into()), + Some(Layer::Map) => Type::Map(Type::from(ty).into()), + None => match CPrimitiveType::try_from(cty.primitive).unwrap() { + CPrimitiveType::Bool => Type::Bool, + CPrimitiveType::Bytes => Type::Bytes, + CPrimitiveType::Int => Type::Int, + CPrimitiveType::Ip => Type::Ip, + }, + } + } +} + +impl From for CType { + fn from(ty: Type) -> Self { + match ty { + Type::Ip => CType { + len: 0, + layers: 0, + primitive: CPrimitiveType::Ip.into(), + }, + Type::Bytes => CType { + len: 0, + layers: 0, + primitive: CPrimitiveType::Bytes.into(), + }, + Type::Int => CType { + len: 0, + layers: 0, + primitive: CPrimitiveType::Int.into(), + }, + Type::Bool => CType { + len: 0, + layers: 0, + primitive: CPrimitiveType::Bool.into(), + }, + Type::Array(arr) => Self::from(Type::from(arr)).push(Layer::Array), + Type::Map(map) => Self::from(Type::from(map)).push(Layer::Map), + } + } +} + #[repr(u8)] -pub enum ParsingResult<'s> { +pub enum CResult { Err(RustAllocatedString), - Ok(RustBox>), + Ok(T), +} + +impl CResult { + pub fn unwrap(self) -> T { + match self { + CResult::Err(err) => panic!("{}", &err as &str), + CResult::Ok(ok) => ok, + } + } + + pub fn into_result(self) -> Result { + match self { + CResult::Ok(ok) => Ok(ok), + CResult::Err(err) => Err(err), + } + } } +pub type ParsingResult<'s> = CResult>>; + impl<'s> From> for ParsingResult<'s> { fn from(filter_ast: FilterAst<'s>) -> Self { ParsingResult::Ok(filter_ast.into()) @@ -32,36 +155,71 @@ impl<'s, 'a> From> for ParsingResult<'s> { } } -impl<'s> ParsingResult<'s> { - pub fn unwrap(self) -> RustBox> { - match self { - ParsingResult::Err(err) => panic!("{}", &err as &str), - ParsingResult::Ok(filter) => filter, - } - } -} +pub type UsingResult = CResult; + +pub type CompilingResult<'s, 'e> = CResult>>; + +pub type MatchingResult = CResult; + +pub type SerializingResult = CResult; + +pub type DeserializingResult = CResult; + +pub type HashingResult = CResult; #[no_mangle] -pub extern "C" fn wirefilter_create_scheme() -> RustBox { +pub extern "C" fn wirefilter_create_scheme() -> Box { Default::default() } #[no_mangle] -pub extern "C" fn wirefilter_free_scheme(scheme: RustBox) { +pub extern "C" fn wirefilter_free_scheme(scheme: Box) { drop(scheme); } +#[no_mangle] +pub extern "C" fn wirefilter_create_map_type(ty: CType) -> CType { + ty.push(Layer::Map) +} + +#[no_mangle] +pub extern "C" fn wirefilter_create_array_type(ty: CType) -> CType { + ty.push(Layer::Array) +} + #[no_mangle] pub extern "C" fn wirefilter_add_type_field_to_scheme( scheme: &mut Scheme, name: ExternallyAllocatedStr<'_>, - ty: Type, -) { - scheme.add_field(name.into_ref().to_owned(), ty).unwrap(); + ty: CType, +) -> bool { + scheme.add_field(name.into_ref(), ty.into()).is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_create_always_list() -> *mut Box { + Box::into_raw(Box::new(Box::new(AlwaysList {}))) +} + +#[no_mangle] +pub extern "C" fn wirefilter_create_never_list() -> *mut Box { + Box::into_raw(Box::new(Box::new(NeverList {}))) +} + +#[allow(clippy::not_unsafe_ptr_arg_deref)] +#[no_mangle] +pub extern "C" fn wirefilter_add_type_list_to_scheme( + scheme: &mut Scheme, + ty: CType, + list: *mut Box, +) -> bool { + scheme + .add_list(ty.into(), *unsafe { Box::from_raw(list) }) + .is_ok() } #[no_mangle] -pub extern "C" fn wirefilter_free_parsed_filter(filter_ast: RustBox>) { +pub extern "C" fn wirefilter_free_parsed_filter(filter_ast: Box>) { drop(filter_ast); } @@ -71,14 +229,21 @@ pub extern "C" fn wirefilter_free_string(s: RustAllocatedString) { } #[no_mangle] -pub extern "C" fn wirefilter_parse_filter<'s, 'i>( +pub extern "C" fn wirefilter_parse_filter<'s>( scheme: &'s Scheme, - input: ExternallyAllocatedStr<'i>, + input: ExternallyAllocatedStr<'_>, ) -> ParsingResult<'s> { - match scheme.parse(input.into_ref()) { - Ok(filter) => ParsingResult::from(filter), - Err(err) => ParsingResult::from(err), - } + catch_panic(std::panic::AssertUnwindSafe(|| { + scheme + .parse(input.into_ref()) + .map(Box::new) + .map_err(|err| err.to_string()) + })) +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_filter_ast(ast: Box>) { + drop(ast); } #[no_mangle] @@ -106,52 +271,110 @@ impl Write for HasherWrite { } } -fn unwrap_json_result(filter_ast: &FilterAst<'_>, result: serde_json::Result) -> T { - // Filter serialisation must never fail. - result.unwrap_or_else(|err| panic!("{} while serializing filter {:#?}", err, filter_ast)) -} - #[no_mangle] -pub extern "C" fn wirefilter_get_filter_hash(filter_ast: &FilterAst<'_>) -> u64 { +pub extern "C" fn wirefilter_get_filter_hash(filter_ast: &FilterAst<'_>) -> HashingResult { let mut hasher = FnvHasher::default(); // Serialize JSON to our Write-compatible wrapper around FnvHasher, // effectively calculating a hash for our filter in a streaming fashion // that is as stable as the JSON representation itself // (instead of relying on #[derive(Hash)] which would be tied to impl details). - let result = serde_json::to_writer(HasherWrite(&mut hasher), filter_ast); - unwrap_json_result(filter_ast, result); - hasher.finish() + match serde_json::to_writer(HasherWrite(&mut hasher), filter_ast) { + Ok(_) => HashingResult::Ok(hasher.finish()), + Err(err) => { + HashingResult::Err(format!("{err} while serializing filter {filter_ast:#?}").into()) + } + } +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_hashing_result(r: HashingResult) { + drop(r); } #[no_mangle] pub extern "C" fn wirefilter_serialize_filter_to_json( filter_ast: &FilterAst<'_>, -) -> RustAllocatedString { - let result = serde_json::to_string(filter_ast); - unwrap_json_result(filter_ast, result).into() +) -> SerializingResult { + match serde_json::to_string(filter_ast) { + Ok(ok) => SerializingResult::Ok(ok.into()), + Err(err) => { + SerializingResult::Err(format!("{err} while serializing filter {filter_ast:#?}").into()) + } + } +} + +#[no_mangle] +pub extern "C" fn wirefilter_serialize_scheme_to_json(scheme: &Scheme) -> SerializingResult { + match serde_json::to_string(scheme) { + Ok(ok) => SerializingResult::Ok(ok.into()), + Err(err) => { + SerializingResult::Err(format!("{err} while serializing scheme {scheme:#?}").into()) + } + } +} + +#[no_mangle] +pub extern "C" fn wirefilter_serialize_type_to_json(ty: &CType) -> SerializingResult { + match serde_json::to_string(&Type::from(*ty)) { + Ok(ok) => SerializingResult::Ok(ok.into()), + Err(err) => SerializingResult::Err(format!("{err} while serializing type {ty:#?}").into()), + } +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_serializing_result(r: SerializingResult) { + drop(r); } #[no_mangle] pub extern "C" fn wirefilter_create_execution_context<'e, 's: 'e>( scheme: &'s Scheme, -) -> RustBox> { +) -> Box> { ExecutionContext::new(scheme).into() } #[no_mangle] -pub extern "C" fn wirefilter_free_execution_context(exec_context: RustBox>) { +pub extern "C" fn wirefilter_serialize_execution_context_to_json( + exec_context: &mut ExecutionContext<'_>, +) -> SerializingResult { + match serde_json::to_string(exec_context) { + Ok(ok) => SerializingResult::Ok(ok.into()), + Err(err) => SerializingResult::Err( + format!("{err} while serializing execution context {exec_context:#?}").into(), + ), + } +} + +#[no_mangle] +pub extern "C" fn wirefilter_deserialize_json_to_execution_context<'e>( + exec_context: &mut ExecutionContext<'e>, + serialized_context: ExternallyAllocatedByteArr<'e>, +) -> DeserializingResult { + let mut deserializer = serde_json::Deserializer::from_slice(serialized_context.into_ref()); + match exec_context.deserialize(&mut deserializer) { + Ok(_) => DeserializingResult::Ok(true), + Err(err) => { + DeserializingResult::Err(format!("{err} while deserializing execution context").into()) + } + } +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_execution_context(exec_context: Box>) { drop(exec_context); } #[no_mangle] -pub extern "C" fn wirefilter_add_int_value_to_execution_context<'a>( - exec_context: &mut ExecutionContext<'a>, +pub extern "C" fn wirefilter_add_int_value_to_execution_context( + exec_context: &mut ExecutionContext<'_>, name: ExternallyAllocatedStr<'_>, - value: i32, -) { - exec_context - .set_field_value(name.into_ref(), value) - .unwrap(); + value: i64, +) -> bool { + let field = match exec_context.scheme().get_field(name.into_ref()) { + Ok(f) => f, + Err(_) => return false, + }; + exec_context.set_field_value(field, value).is_ok() } #[no_mangle] @@ -159,11 +382,13 @@ pub extern "C" fn wirefilter_add_bytes_value_to_execution_context<'a>( exec_context: &mut ExecutionContext<'a>, name: ExternallyAllocatedStr<'_>, value: ExternallyAllocatedByteArr<'a>, -) { +) -> bool { let slice: &[u8] = value.into_ref(); - exec_context - .set_field_value(name.into_ref(), slice) - .unwrap(); + let field = match exec_context.scheme().get_field(name.into_ref()) { + Ok(f) => f, + Err(_) => return false, + }; + exec_context.set_field_value(field, slice).is_ok() } #[no_mangle] @@ -171,10 +396,14 @@ pub extern "C" fn wirefilter_add_ipv6_value_to_execution_context( exec_context: &mut ExecutionContext<'_>, name: ExternallyAllocatedStr<'_>, value: &[u8; 16], -) { +) -> bool { + let field = match exec_context.scheme().get_field(name.into_ref()) { + Ok(f) => f, + Err(_) => return false, + }; exec_context - .set_field_value(name.into_ref(), IpAddr::from(*value)) - .unwrap(); + .set_field_value(field, IpAddr::from(*value)) + .is_ok() } #[no_mangle] @@ -182,10 +411,14 @@ pub extern "C" fn wirefilter_add_ipv4_value_to_execution_context( exec_context: &mut ExecutionContext<'_>, name: ExternallyAllocatedStr<'_>, value: &[u8; 4], -) { +) -> bool { + let field = match exec_context.scheme().get_field(name.into_ref()) { + Ok(f) => f, + Err(_) => return false, + }; exec_context - .set_field_value(name.into_ref(), IpAddr::from(*value)) - .unwrap(); + .set_field_value(field, IpAddr::from(*value)) + .is_ok() } #[no_mangle] @@ -193,30 +426,238 @@ pub extern "C" fn wirefilter_add_bool_value_to_execution_context( exec_context: &mut ExecutionContext<'_>, name: ExternallyAllocatedStr<'_>, value: bool, -) { - exec_context - .set_field_value(name.into_ref(), value) - .unwrap(); +) -> bool { + let field = match exec_context.scheme().get_field(name.into_ref()) { + Ok(f) => f, + Err(_) => return false, + }; + exec_context.set_field_value(field, value).is_ok() } #[no_mangle] -pub extern "C" fn wirefilter_compile_filter<'s>( - filter_ast: RustBox>, -) -> RustBox> { - let filter_ast = filter_ast.into_real_box(); - filter_ast.compile().into() +pub extern "C" fn wirefilter_add_map_value_to_execution_context<'a>( + exec_context: &mut ExecutionContext<'a>, + name: ExternallyAllocatedStr<'_>, + value: Box>, +) -> bool { + let field = match exec_context.scheme().get_field(name.into_ref()) { + Ok(f) => f, + Err(_) => return false, + }; + exec_context.set_field_value(field, *value).is_ok() } #[no_mangle] -pub extern "C" fn wirefilter_match<'s>( - filter: &Filter<'s>, - exec_context: &ExecutionContext<'s>, +pub extern "C" fn wirefilter_add_array_value_to_execution_context<'a>( + exec_context: &mut ExecutionContext<'a>, + name: ExternallyAllocatedStr<'_>, + value: Box>, ) -> bool { - filter.execute(exec_context).unwrap() + let field = match exec_context.scheme().get_field(name.into_ref()) { + Ok(f) => f, + Err(_) => return false, + }; + exec_context.set_field_value(field, *value).is_ok() } #[no_mangle] -pub extern "C" fn wirefilter_free_compiled_filter(filter: RustBox>) { +pub extern "C" fn wirefilter_create_map<'a>(ty: CType) -> Box> { + Box::new(LhsValue::Map(Map::new(Type::from(ty)))) +} + +// TODO: store a Box<[u8] inside FieldIndex::MapKey instead of String +// and call map.set(FieldIndex::MapKey(key), value.into()) directly +macro_rules! map_insert { + ($map:ident, $name:ident, $value:expr) => { + match $map { + LhsValue::Map(map) => map.insert($name.into_ref(), $value).is_ok(), + _ => unreachable!(), + } + }; +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_int_value_to_map( + map: &mut LhsValue<'_>, + name: ExternallyAllocatedByteArr<'_>, + value: i64, +) -> bool { + map_insert!(map, name, value) +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_bytes_value_to_map<'a>( + map: &mut LhsValue<'a>, + name: ExternallyAllocatedByteArr<'_>, + value: ExternallyAllocatedByteArr<'a>, +) -> bool { + let slice: &[u8] = value.into_ref(); + map_insert!(map, name, slice) +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_ipv6_value_to_map( + map: &mut LhsValue<'_>, + name: ExternallyAllocatedByteArr<'_>, + value: &[u8; 16], +) -> bool { + let value = IpAddr::from(*value); + map_insert!(map, name, value) +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_ipv4_value_to_map( + map: &mut LhsValue<'_>, + name: ExternallyAllocatedByteArr<'_>, + value: &[u8; 4], +) -> bool { + let value = IpAddr::from(*value); + map_insert!(map, name, value) +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_bool_value_to_map( + map: &mut LhsValue<'_>, + name: ExternallyAllocatedByteArr<'_>, + value: bool, +) -> bool { + map_insert!(map, name, value) +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_map_value_to_map<'a>( + map: &mut LhsValue<'a>, + name: ExternallyAllocatedByteArr<'_>, + value: Box>, +) -> bool { + let value = value; + map_insert!(map, name, *value) +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_array_value_to_map<'a>( + map: &mut LhsValue<'a>, + name: ExternallyAllocatedByteArr<'_>, + value: Box>, +) -> bool { + let value = value; + map_insert!(map, name, *value) +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_map(map: Box>) { + drop(map) +} + +#[no_mangle] +pub extern "C" fn wirefilter_create_array<'a>(ty: CType) -> Box> { + Box::new(LhsValue::Array(Array::new(Type::from(ty)))) +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_int_value_to_array( + array: &mut LhsValue<'_>, + index: u32, + value: i64, +) -> bool { + array.set(FieldIndex::ArrayIndex(index), value).is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_bytes_value_to_array<'a>( + array: &mut LhsValue<'a>, + index: u32, + value: ExternallyAllocatedByteArr<'a>, +) -> bool { + let slice: &[u8] = value.into_ref(); + array.set(FieldIndex::ArrayIndex(index), slice).is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_ipv6_value_to_array( + array: &mut LhsValue<'_>, + index: u32, + value: &[u8; 16], +) -> bool { + array + .set(FieldIndex::ArrayIndex(index), IpAddr::from(*value)) + .is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_ipv4_value_to_array( + array: &mut LhsValue<'_>, + index: u32, + value: &[u8; 4], +) -> bool { + array + .set(FieldIndex::ArrayIndex(index), IpAddr::from(*value)) + .is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_bool_value_to_array( + array: &mut LhsValue<'_>, + index: u32, + value: bool, +) -> bool { + array.set(FieldIndex::ArrayIndex(index), value).is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_map_value_to_array<'a>( + array: &mut LhsValue<'a>, + index: u32, + value: Box>, +) -> bool { + array.set(FieldIndex::ArrayIndex(index), *value).is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_add_array_value_to_array<'a>( + array: &mut LhsValue<'a>, + index: u32, + value: Box>, +) -> bool { + array.set(FieldIndex::ArrayIndex(index), *value).is_ok() +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_array(array: Box>) { + drop(array) +} + +#[no_mangle] +pub extern "C" fn wirefilter_compile_filter( + filter_ast: Box>, +) -> CompilingResult<'_, '_> { + catch_panic(std::panic::AssertUnwindSafe(|| { + Ok(Box::new(filter_ast.compile())) + })) +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_compiling_result(r: CompilingResult<'_, '_>) { + drop(r); +} + +#[no_mangle] +pub extern "C" fn wirefilter_match<'e, 's: 'e>( + filter: &Filter<'s>, + exec_context: &ExecutionContext<'e>, +) -> MatchingResult { + catch_panic(std::panic::AssertUnwindSafe(|| { + filter.execute(exec_context).map_err(|err| err.to_string()) + })) +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_matching_result(r: MatchingResult) { + drop(r); +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_compiled_filter(filter: Box>) { drop(filter); } @@ -224,8 +665,29 @@ pub extern "C" fn wirefilter_free_compiled_filter(filter: RustBox>) { pub extern "C" fn wirefilter_filter_uses( filter_ast: &FilterAst<'_>, field_name: ExternallyAllocatedStr<'_>, -) -> bool { - filter_ast.uses(field_name.into_ref()).unwrap() +) -> UsingResult { + catch_panic(std::panic::AssertUnwindSafe(|| { + filter_ast + .uses(field_name.into_ref()) + .map_err(|err| err.to_string()) + })) +} + +#[no_mangle] +pub extern "C" fn wirefilter_filter_uses_list( + filter_ast: &FilterAst<'_>, + field_name: ExternallyAllocatedStr<'_>, +) -> UsingResult { + catch_panic(std::panic::AssertUnwindSafe(|| { + filter_ast + .uses_list(field_name.into_ref()) + .map_err(|err| err.to_string()) + })) +} + +#[no_mangle] +pub extern "C" fn wirefilter_free_using_result(r: UsingResult) { + drop(r); } #[no_mangle] @@ -234,51 +696,71 @@ pub extern "C" fn wirefilter_get_version() -> StaticRustAllocatedString { } #[cfg(test)] +#[allow(clippy::bool_assert_comparison)] mod ffi_test { use super::*; use regex::Regex; - fn create_scheme() -> RustBox { + fn create_scheme() -> Box { let mut scheme = wirefilter_create_scheme(); wirefilter_add_type_field_to_scheme( &mut scheme, ExternallyAllocatedStr::from("ip1"), - Type::Ip, + Type::Ip.into(), ); wirefilter_add_type_field_to_scheme( &mut scheme, ExternallyAllocatedStr::from("ip2"), - Type::Ip, + Type::Ip.into(), ); wirefilter_add_type_field_to_scheme( &mut scheme, ExternallyAllocatedStr::from("str1"), - Type::Bytes, + Type::Bytes.into(), ); wirefilter_add_type_field_to_scheme( &mut scheme, ExternallyAllocatedStr::from("str2"), - Type::Bytes, + Type::Bytes.into(), ); wirefilter_add_type_field_to_scheme( &mut scheme, ExternallyAllocatedStr::from("num1"), - Type::Int, + Type::Int.into(), ); wirefilter_add_type_field_to_scheme( &mut scheme, ExternallyAllocatedStr::from("num2"), - Type::Int, + Type::Int.into(), + ); + wirefilter_add_type_field_to_scheme( + &mut scheme, + ExternallyAllocatedStr::from("map1"), + wirefilter_create_map_type(Type::Int.into()), + ); + wirefilter_add_type_field_to_scheme( + &mut scheme, + ExternallyAllocatedStr::from("map2"), + wirefilter_create_map_type(Type::Bytes.into()), + ); + + wirefilter_add_type_list_to_scheme( + &mut scheme, + Type::Int.into(), + wirefilter_create_always_list(), ); scheme } - fn create_execution_context<'e, 's: 'e>(scheme: &'s Scheme) -> RustBox> { + fn create_execution_context<'e, 's: 'e>(scheme: &'s Scheme) -> Box> { let mut exec_context = wirefilter_create_execution_context(scheme); + let invalid_key = &b"\xc3\x28"[..]; + + assert!(std::str::from_utf8(invalid_key).is_err()); wirefilter_add_ipv4_value_to_execution_context( &mut exec_context, @@ -316,6 +798,42 @@ mod ffi_test { 1337, ); + let mut map1 = wirefilter_create_map(Type::Int.into()); + + wirefilter_add_int_value_to_map(&mut map1, ExternallyAllocatedByteArr::from("key"), 42); + + wirefilter_add_int_value_to_map( + &mut map1, + ExternallyAllocatedByteArr::from(invalid_key), + 42, + ); + + wirefilter_add_map_value_to_execution_context( + &mut exec_context, + ExternallyAllocatedStr::from("map1"), + map1, + ); + + let mut map2 = wirefilter_create_map(Type::Bytes.into()); + + wirefilter_add_bytes_value_to_map( + &mut map2, + ExternallyAllocatedByteArr::from("key"), + ExternallyAllocatedByteArr::from("value"), + ); + + wirefilter_add_bytes_value_to_map( + &mut map2, + ExternallyAllocatedByteArr::from(invalid_key), + ExternallyAllocatedByteArr::from("value"), + ); + + wirefilter_add_map_value_to_execution_context( + &mut exec_context, + ExternallyAllocatedStr::from("map2"), + map2, + ); + exec_context } @@ -329,13 +847,13 @@ mod ffi_test { exec_context: &ExecutionContext<'_>, ) -> bool { let filter = parse_filter(scheme, input).unwrap(); - let filter = wirefilter_compile_filter(filter); + let filter = wirefilter_compile_filter(filter).unwrap(); let result = wirefilter_match(&filter, exec_context); wirefilter_free_compiled_filter(filter); - result + result.unwrap() } #[test] @@ -385,7 +903,7 @@ mod ffi_test { { let filter = parse_filter(&scheme, r#"num1 > 3 && str2 == "abc""#).unwrap(); - let json = wirefilter_serialize_filter_to_json(&filter); + let json = wirefilter_serialize_filter_to_json(&filter).unwrap(); assert_eq!( &json as &str, @@ -400,6 +918,19 @@ mod ffi_test { wirefilter_free_scheme(scheme); } + #[test] + fn scheme_serialize() { + let scheme = create_scheme(); + let json = wirefilter_serialize_scheme_to_json(&scheme).unwrap(); + + let expected: String = serde_json::to_string(&*scheme).unwrap(); + assert_eq!(&json as &str, expected); + + wirefilter_free_string(json); + + wirefilter_free_scheme(scheme); + } + #[test] fn filter_matching() { let scheme = create_scheme(); @@ -408,13 +939,13 @@ mod ffi_test { let exec_context = create_execution_context(&scheme); assert!(match_filter( - r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 && str2 ~ "yo\d+""#, + r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 && str2 ~ "yo\d+" && map2["key"] == "value""#, &scheme, &exec_context )); assert!(match_filter( - r#"ip2 == 0:0:0:0:0:ffff:c0a8:1 && (str1 == "Hey" || str2 == "ya")"#, + r#"ip2 == 0:0:0:0:0:ffff:c0a8:1 && (str1 == "Hey" || str2 == "ya") && (map1["key"] == 42 || map2["key2"] == "value")"#, &scheme, &exec_context )); @@ -438,21 +969,21 @@ mod ffi_test { { let filter1 = parse_filter( &scheme, - r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 && str2 ~ "yo\d+""#, + r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 && str2 ~ "yo\d+" && map1["key"] == 42"#, ) .unwrap(); let filter2 = parse_filter( &scheme, - r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 and str2 ~ "yo\d+""#, + r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 and str2 ~ "yo\d+" && map1["key"] == 42 "#, ) .unwrap(); let filter3 = parse_filter(&scheme, r#"num1 > 41 && num2 == 1337"#).unwrap(); - let hash1 = wirefilter_get_filter_hash(&filter1); - let hash2 = wirefilter_get_filter_hash(&filter2); - let hash3 = wirefilter_get_filter_hash(&filter3); + let hash1 = wirefilter_get_filter_hash(&filter1).unwrap(); + let hash2 = wirefilter_get_filter_hash(&filter2).unwrap(); + let hash3 = wirefilter_get_filter_hash(&filter3).unwrap(); assert_eq!(hash1, hash2); assert_ne!(hash2, hash3); @@ -480,38 +1011,122 @@ mod ffi_test { { let filter = parse_filter( &scheme, - r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 && str2 ~ "yo\d+""#, + r#"num1 > 41 && num2 == 1337 && ip1 != 192.168.0.1 && str2 ~ "yo\d+" && map1["key"] == 42"#, ) .unwrap(); - assert!(wirefilter_filter_uses( - &filter, - ExternallyAllocatedStr::from("num1") - )); + assert!(wirefilter_filter_uses(&filter, ExternallyAllocatedStr::from("num1")).unwrap()); - assert!(wirefilter_filter_uses( - &filter, - ExternallyAllocatedStr::from("ip1") - )); + assert!(wirefilter_filter_uses(&filter, ExternallyAllocatedStr::from("ip1")).unwrap()); - assert!(wirefilter_filter_uses( - &filter, - ExternallyAllocatedStr::from("str2") - )); + assert!(wirefilter_filter_uses(&filter, ExternallyAllocatedStr::from("str2")).unwrap()); - assert!(!wirefilter_filter_uses( - &filter, - ExternallyAllocatedStr::from("str1") - )); + assert!( + !wirefilter_filter_uses(&filter, ExternallyAllocatedStr::from("str1")).unwrap() + ); - assert!(!wirefilter_filter_uses( - &filter, - ExternallyAllocatedStr::from("ip2") - )); + assert!(!wirefilter_filter_uses(&filter, ExternallyAllocatedStr::from("ip2")).unwrap()); + + assert!(wirefilter_filter_uses(&filter, ExternallyAllocatedStr::from("map1")).unwrap()); + + assert!( + !wirefilter_filter_uses(&filter, ExternallyAllocatedStr::from("map2")).unwrap() + ); wirefilter_free_parsed_filter(filter); } wirefilter_free_scheme(scheme); } + + #[test] + fn filter_uses_list() { + let scheme = create_scheme(); + + { + let filter = parse_filter( + &scheme, + r#"num1 in $numbers && num2 == 1337 && str2 != "hi" && ip2 == 10.10.10.10"#, + ) + .unwrap(); + + assert_eq!( + wirefilter_filter_uses_list(&filter, ExternallyAllocatedStr::from("num1")).unwrap(), + true, + ); + + assert_eq!( + wirefilter_filter_uses_list(&filter, ExternallyAllocatedStr::from("num2")).unwrap(), + false, + ); + + assert_eq!( + wirefilter_filter_uses_list(&filter, ExternallyAllocatedStr::from("str1")).unwrap(), + false + ); + + assert_eq!( + wirefilter_filter_uses_list(&filter, ExternallyAllocatedStr::from("str2")).unwrap(), + false, + ); + + assert_eq!( + wirefilter_filter_uses_list(&filter, ExternallyAllocatedStr::from("ip1")).unwrap(), + false, + ); + + assert_eq!( + wirefilter_filter_uses_list(&filter, ExternallyAllocatedStr::from("ip2")).unwrap(), + false, + ); + + wirefilter_free_parsed_filter(filter); + } + + wirefilter_free_scheme(scheme); + } + + #[test] + fn execution_context_deserialize() { + let scheme = create_scheme(); + let exec_context = create_execution_context(&scheme); + + let expected: String = serde_json::to_string(&*exec_context).unwrap(); + assert!(expected.len() > 3); + + let mut exec_context_c = wirefilter_create_execution_context(&scheme); + let res = wirefilter_deserialize_json_to_execution_context( + &mut exec_context_c, + ExternallyAllocatedByteArr::from(&expected[..]), + ); + assert_eq!(res.unwrap(), true); + + let expected_c: String = serde_json::to_string(&*exec_context_c).unwrap(); + assert_eq!(expected, expected_c); + } + + #[test] + fn ctype_convertion() { + let cty = CType::from(Type::Bytes); + + assert_eq!(Type::from(cty), Type::Bytes); + + let cty = wirefilter_create_array_type(cty); + + assert_eq!(cty, CType::from(Type::Array(Type::Bytes.into()))); + + assert_eq!(Type::from(cty), Type::Array(Type::Bytes.into())); + + let cty = wirefilter_create_map_type(cty); + + assert_eq!( + cty, + CType::from(Type::Map(Type::Array(Type::Bytes.into()).into())) + ); + + assert_eq!( + Type::from(cty), + Type::Map(Type::Array(Type::Bytes.into()).into()) + ); + } } diff --git a/ffi/src/panic.rs b/ffi/src/panic.rs new file mode 100644 index 00000000..c3554d1a --- /dev/null +++ b/ffi/src/panic.rs @@ -0,0 +1,89 @@ +use crate::CResult; +use num_enum::{IntoPrimitive, TryFromPrimitive}; +use std::panic::UnwindSafe; +use wirefilter::{ + panic_catcher_disable, panic_catcher_enable, panic_catcher_set_fallback_mode, + panic_catcher_set_hook, PanicCatcherFallbackMode, +}; + +#[repr(u8)] +#[derive(Clone, Copy, IntoPrimitive, TryFromPrimitive)] +pub enum CPanicCatcherFallbackMode { + Continue = 0u8, + Abort = 1u8, +} + +#[inline(always)] +pub(crate) fn catch_panic(f: F) -> CResult +where + F: FnOnce() -> Result + UnwindSafe, +{ + match wirefilter::catch_panic(f) { + Ok(ok) => CResult::Ok(ok), + Err(msg) => CResult::Err(msg.into()), + } +} + +#[no_mangle] +pub extern "C" fn wirefilter_set_panic_catcher_hook() { + panic_catcher_set_hook() +} + +#[no_mangle] +pub extern "C" fn wirefilter_set_panic_catcher_fallback_mode(fallback_mode: u8) -> CResult { + let fallback_mode = match fallback_mode { + 0 => PanicCatcherFallbackMode::Continue, + 1 => PanicCatcherFallbackMode::Abort, + _ => return CResult::Err(format!("Invalid fallback mode {fallback_mode}").into()), + }; + + panic_catcher_set_fallback_mode(fallback_mode); + CResult::Ok(true) +} + +#[no_mangle] +pub extern "C" fn wirefilter_enable_panic_catcher() { + panic_catcher_enable() +} + +#[no_mangle] +pub extern "C" fn wirefilter_disable_panic_catcher() { + panic_catcher_disable() +} + +#[cfg(test)] +mod panic_test { + use super::*; + use crate::CResult; + + #[test] + #[cfg_attr(miri, ignore)] + #[should_panic(expected = r#"Hello World!"#)] + fn test_panic_catcher_set_panic_hook_can_still_panic() { + wirefilter_set_panic_catcher_hook(); + panic!("Hello World!"); + } + + #[test] + #[cfg_attr(miri, ignore)] + #[should_panic(expected = r#"Hello World!"#)] + fn test_panic_catcher_enabled_disabled_can_still_panic() { + wirefilter_set_panic_catcher_hook(); + wirefilter_enable_panic_catcher(); + wirefilter_disable_panic_catcher(); + panic!("Hello World!"); + } + + #[test] + fn test_panic_catcher_can_catch_panic() { + wirefilter_set_panic_catcher_hook(); + wirefilter_set_panic_catcher_fallback_mode(1).unwrap(); + wirefilter_enable_panic_catcher(); + let result: CResult<()> = catch_panic(|| panic!("Halt and Catch Panic")); + match result { + CResult::Ok(_) => unreachable!(), + CResult::Err(msg) => assert!(msg.contains("Halt and Catch Panic")), + } + wirefilter_disable_panic_catcher(); + } +} diff --git a/ffi/src/transfer_types/ownership_repr/reference.rs b/ffi/src/transfer_types/ownership_repr/reference.rs index cd656bd4..e544a15d 100644 --- a/ffi/src/transfer_types/ownership_repr/reference.rs +++ b/ffi/src/transfer_types/ownership_repr/reference.rs @@ -11,7 +11,7 @@ pub struct Ref<'a, T: ?Sized + ExternPtrRepr> { // potentially expensive checks and consume unsafe wrapper once. impl<'a, T: ?Sized + ExternPtrRepr> Ref<'a, T> { pub fn into_ref(self) -> &'a T { - let slice: *mut T = ExternPtrRepr::from_extern_repr(self.ptr); + let slice: *const T = ExternPtrRepr::from_extern_repr(self.ptr); unsafe { &*slice } } } @@ -19,7 +19,7 @@ impl<'a, T: ?Sized + ExternPtrRepr> Ref<'a, T> { impl<'a, T: ?Sized + ExternPtrRepr> From<&'a T> for Ref<'a, T> { fn from(ptr: &'a T) -> Self { Ref { - ptr: (ptr as *const T as *mut T).into(), + ptr: (ptr as *const T).into(), ownership_marker: PhantomData, } } diff --git a/ffi/src/transfer_types/ownership_repr/rust_box.rs b/ffi/src/transfer_types/ownership_repr/rust_box.rs index 9c759a62..5c5b0282 100644 --- a/ffi/src/transfer_types/ownership_repr/rust_box.rs +++ b/ffi/src/transfer_types/ownership_repr/rust_box.rs @@ -1,12 +1,13 @@ -use crate::transfer_types::raw_ptr_repr::ExternPtrRepr; +use crate::transfer_types::raw_ptr_repr::ExternPtrReprMut; use std::{ marker::PhantomData, mem, ops::{Deref, DerefMut}, }; +#[derive(Debug)] #[repr(transparent)] -pub struct RustBox { +pub struct RustBox { ptr: T::Repr, ownership_marker: PhantomData, } @@ -15,7 +16,7 @@ pub struct RustBox { // Rust compiler doesn't allow for custom types. However, it does allow this // for real `Box` by treating it in a special manner, so we want to provide // conversion to that real `Box` to unlock these features. -impl RustBox { +impl RustBox { // This needs to accept a reference not an owned version in order to work // inside of `Drop` implementation (and is highly unsafe otherwise). unsafe fn to_real_box_impl(&self) -> Box { @@ -31,21 +32,21 @@ impl RustBox { } } -impl Deref for RustBox { +impl Deref for RustBox { type Target = T; fn deref(&self) -> &T { - unsafe { &*ExternPtrRepr::from_extern_repr_unchecked(self.ptr) } + unsafe { &*ExternPtrReprMut::from_extern_repr_unchecked(self.ptr) } } } -impl DerefMut for RustBox { +impl DerefMut for RustBox { fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *ExternPtrRepr::from_extern_repr_unchecked(self.ptr) } + unsafe { &mut *ExternPtrReprMut::from_extern_repr_unchecked(self.ptr) } } } -impl From> for RustBox { +impl From> for RustBox { fn from(b: Box) -> Self { RustBox { ptr: Box::into_raw(b).into(), @@ -54,19 +55,13 @@ impl From> for RustBox { } } -impl From for RustBox { - fn from(value: T) -> Self { - Box::new(value).into() - } -} - -impl Drop for RustBox { +impl Drop for RustBox { fn drop(&mut self) { drop(unsafe { self.to_real_box_impl() }); } } -impl Default for RustBox +impl Default for RustBox where Box: Default, { diff --git a/ffi/src/transfer_types/raw_ptr_repr/mod.rs b/ffi/src/transfer_types/raw_ptr_repr/mod.rs index 42e3fd71..91e570c9 100644 --- a/ffi/src/transfer_types/raw_ptr_repr/mod.rs +++ b/ffi/src/transfer_types/raw_ptr_repr/mod.rs @@ -1,7 +1,10 @@ mod slice; mod str; -pub use self::{slice::ExternSliceRepr, str::ExternStrRepr}; +pub use self::{ + slice::{ExternSliceRepr, ExternSliceReprMut}, + str::{ExternStrRepr, ExternStrReprMut}, +}; /// This trait allows to define FFI-safe representation for fat pointers /// with corresponding conversions. @@ -10,8 +13,10 @@ pub use self::{slice::ExternSliceRepr, str::ExternStrRepr}; /// [`::transfer_types::RustBox`] and [`::transfer_types::Ref`] to add required /// ownership semantics while preserving FFI compatibility. pub trait ExternPtrRepr { - type Repr: Copy + From<*mut Self>; + type Repr: Copy + From<*const Self>; + /// # Safety + /// /// This method will be used in places where data behind the pointer /// was allocated by Rust side, so implementors may omit potentially /// expensive safety checks. @@ -19,25 +24,33 @@ pub trait ExternPtrRepr { /// # Safety /// /// This function should not be called for objects allocated outside of Rust code. - unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *mut Self; + unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *const Self; /// This method will be used for pointers to data allocated by the FFI /// caller, and, when converting to a Rust pointer, must make sure that /// such conversion won't lead to Undefined Behaviour (e.g. check that /// slices don't have nullable data part and strings are valid UTF-8). - fn from_extern_repr(repr: Self::Repr) -> *mut Self; + fn from_extern_repr(repr: Self::Repr) -> *const Self; } -/// This is a blanket implementation for pointer to regular sized types. -/// They are already guaranteed to be thin and FFI-safe. -impl ExternPtrRepr for T { - type Repr = *mut T; +/// Mutable equivalent of `ExternPtrRepr`. +pub trait ExternPtrReprMut { + type Repr: Copy + From<*mut Self>; - unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *mut Self { - repr - } + /// # Safety + /// + /// This method will be used in places where data behind the pointer + /// was allocated by Rust side, so implementors may omit potentially + /// expensive safety checks. + /// + /// # Safety + /// + /// This function should not be called for objects allocated outside of Rust code. + unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *mut Self; - fn from_extern_repr(repr: Self::Repr) -> *mut Self { - repr - } + /// This method will be used for pointers to data allocated by the FFI + /// caller, and, when converting to a Rust pointer, must make sure that + /// such conversion won't lead to Undefined Behaviour (e.g. check that + /// slices don't have nullable data part and strings are valid UTF-8). + fn from_extern_repr(repr: Self::Repr) -> *mut Self; } diff --git a/ffi/src/transfer_types/raw_ptr_repr/slice.rs b/ffi/src/transfer_types/raw_ptr_repr/slice.rs index 41cc3490..d45af640 100644 --- a/ffi/src/transfer_types/raw_ptr_repr/slice.rs +++ b/ffi/src/transfer_types/raw_ptr_repr/slice.rs @@ -1,4 +1,4 @@ -use super::ExternPtrRepr; +use super::{ExternPtrRepr, ExternPtrReprMut}; use libc::size_t; use std::slice; @@ -14,28 +14,25 @@ use std::slice; /// ``` #[repr(C)] pub struct ExternSliceRepr { - data: *mut T, + data: *const T, length: size_t, } // Can't be derived without bound on `T: Clone`. impl Clone for ExternSliceRepr { fn clone(&self) -> Self { - ExternSliceRepr { - data: self.data, - length: self.length, - } + *self } } impl Copy for ExternSliceRepr {} -impl From<*mut [T]> for ExternSliceRepr { +impl From<*const [T]> for ExternSliceRepr { #[allow(clippy::not_unsafe_ptr_arg_deref)] - fn from(ptr: *mut [T]) -> Self { + fn from(ptr: *const [T]) -> Self { unsafe { ExternSliceRepr { - data: (*ptr).as_mut_ptr(), + data: (*ptr).as_ptr(), length: (*ptr).len(), } } @@ -45,6 +42,50 @@ impl From<*mut [T]> for ExternSliceRepr { impl ExternPtrRepr for [T] { type Repr = ExternSliceRepr; + unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *const [T] { + slice::from_raw_parts(repr.data, repr.length) + } + + fn from_extern_repr(repr: Self::Repr) -> *const [T] { + // `slice::from_raw_parts{_mut}` require data part to be non-null. + if repr.data.is_null() { + &[] + } else { + unsafe { ExternPtrRepr::from_extern_repr_unchecked(repr) } + } + } +} + +#[repr(C)] +pub struct ExternSliceReprMut { + data: *mut T, + length: size_t, +} + +// Can't be derived without bound on `T: Clone`. +impl Clone for ExternSliceReprMut { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for ExternSliceReprMut {} + +impl From<*mut [T]> for ExternSliceReprMut { + #[allow(clippy::not_unsafe_ptr_arg_deref)] + fn from(ptr: *mut [T]) -> Self { + unsafe { + ExternSliceReprMut { + data: (*ptr).as_mut_ptr(), + length: (*ptr).len(), + } + } + } +} + +impl ExternPtrReprMut for [T] { + type Repr = ExternSliceReprMut; + unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *mut [T] { slice::from_raw_parts_mut(repr.data, repr.length) } @@ -54,7 +95,7 @@ impl ExternPtrRepr for [T] { if repr.data.is_null() { &mut [] } else { - unsafe { Self::from_extern_repr_unchecked(repr) } + unsafe { ExternPtrReprMut::from_extern_repr_unchecked(repr) } } } } diff --git a/ffi/src/transfer_types/raw_ptr_repr/str.rs b/ffi/src/transfer_types/raw_ptr_repr/str.rs index a8570123..278d1f2e 100644 --- a/ffi/src/transfer_types/raw_ptr_repr/str.rs +++ b/ffi/src/transfer_types/raw_ptr_repr/str.rs @@ -1,4 +1,4 @@ -use super::{ExternPtrRepr, ExternSliceRepr}; +use super::{ExternPtrRepr, ExternPtrReprMut, ExternSliceRepr, ExternSliceReprMut}; use std::str; /// This structure provides FFI-safe representation for Rust string slice @@ -18,10 +18,10 @@ use std::str; #[derive(Clone, Copy)] pub struct ExternStrRepr(ExternSliceRepr); -impl From<*mut str> for ExternStrRepr { +impl From<*const str> for ExternStrRepr { #[allow(clippy::not_unsafe_ptr_arg_deref)] - fn from(ptr: *mut str) -> Self { - let bytes: *mut [u8] = unsafe { (*ptr).as_bytes_mut() }; + fn from(ptr: *const str) -> Self { + let bytes: *const [u8] = unsafe { (*ptr).as_bytes() }; ExternStrRepr(bytes.into()) } } @@ -29,13 +29,40 @@ impl From<*mut str> for ExternStrRepr { impl ExternPtrRepr for str { type Repr = ExternStrRepr; - unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *mut str { + unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *const str { let bytes = ExternPtrRepr::from_extern_repr_unchecked(repr.0); + str::from_utf8_unchecked(&*bytes) + } + + fn from_extern_repr(repr: Self::Repr) -> *const str { + let bytes = ExternPtrRepr::from_extern_repr(repr.0); + // Make sure that strings coming via FFI are UTF-8 compatible. + str::from_utf8(unsafe { &*bytes }).unwrap() + } +} + +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct ExternStrReprMut(ExternSliceReprMut); + +impl From<*mut str> for ExternStrReprMut { + #[allow(clippy::not_unsafe_ptr_arg_deref)] + fn from(ptr: *mut str) -> Self { + let bytes: *mut [u8] = unsafe { (*ptr).as_bytes_mut() }; + ExternStrReprMut(bytes.into()) + } +} + +impl ExternPtrReprMut for str { + type Repr = ExternStrReprMut; + + unsafe fn from_extern_repr_unchecked(repr: Self::Repr) -> *mut str { + let bytes = ExternPtrReprMut::from_extern_repr_unchecked(repr.0); str::from_utf8_unchecked_mut(&mut *bytes) } fn from_extern_repr(repr: Self::Repr) -> *mut str { - let bytes = ExternPtrRepr::from_extern_repr(repr.0); + let bytes = ExternPtrReprMut::from_extern_repr(repr.0); // Make sure that strings coming via FFI are UTF-8 compatible. str::from_utf8_mut(unsafe { &mut *bytes }).unwrap() } diff --git a/ffi/tests/ctests/Cargo.toml b/ffi/tests/ctests/Cargo.toml index 6f07963d..d5b526dd 100644 --- a/ffi/tests/ctests/Cargo.toml +++ b/ffi/tests/ctests/Cargo.toml @@ -4,7 +4,7 @@ name = "wirefilter-ffi-ctests" version = "0.1.0" description = "C based tests for FFI bindings of the Wirefilter engine" publish = false -edition = "2018" +edition = "2021" [dependencies] wirefilter-ffi = {path = "../.."} diff --git a/ffi/tests/ctests/build.rs b/ffi/tests/ctests/build.rs index f80e5dbb..0f12ab38 100644 --- a/ffi/tests/ctests/build.rs +++ b/ffi/tests/ctests/build.rs @@ -1,6 +1,6 @@ -use cc; - fn main() { + println!("cargo::rerun-if-changed=src/tests.c"); + #[cfg(unix)] cc::Build::new() .include("../../include") .file("src/tests.c") diff --git a/ffi/tests/ctests/src/lib.rs b/ffi/tests/ctests/src/lib.rs index ea0fdbbf..26c09c64 100644 --- a/ffi/tests/ctests/src/lib.rs +++ b/ffi/tests/ctests/src/lib.rs @@ -9,6 +9,7 @@ macro_rules! ffi_ctest { (@inner $($name:ident => $link_name:expr,)*) => { $( #[test] + #[cfg_attr(miri, ignore)] pub fn $name() { extern "C" { #[link_name = $link_name] @@ -29,16 +30,28 @@ macro_rules! ffi_ctest { mod ffi_ctest { ffi_ctest!( + create_array_type, + create_map_type, + create_complex_type, create_scheme, add_fields_to_scheme, + add_malloced_type_field_to_scheme, parse_good_filter, parse_bad_filter, filter_uses_field, + filter_uses_list_field, filter_hash, filter_serialize, + scheme_serialize, + type_serialize, compile_filter, create_execution_context, add_values_to_execution_context, + add_values_to_execution_context_errors, + execution_context_serialize, + execution_context_deserialize, match_filter, + match_map, + match_array, ); } diff --git a/ffi/tests/ctests/src/tests.c b/ffi/tests/ctests/src/tests.c index 9cc16d8f..e192c3ed 100644 --- a/ffi/tests/ctests/src/tests.c +++ b/ffi/tests/ctests/src/tests.c @@ -14,26 +14,71 @@ static wirefilter_externally_allocated_str_t wirefilter_string(const char *s) { } void initialize_scheme(wirefilter_scheme_t *scheme) { - wirefilter_add_type_field_to_scheme( + rust_assert(wirefilter_add_type_field_to_scheme( scheme, wirefilter_string("http.host"), WIREFILTER_TYPE_BYTES - ); - wirefilter_add_type_field_to_scheme( + ), "could not add field http.host of type \"Bytes\" to scheme"); + rust_assert(wirefilter_add_type_field_to_scheme( scheme, wirefilter_string("ip.addr"), WIREFILTER_TYPE_IP - ); - wirefilter_add_type_field_to_scheme( + ), "could not add field ip.addr of type \"Ip\" to scheme"); + rust_assert(wirefilter_add_type_field_to_scheme( scheme, wirefilter_string("ssl"), WIREFILTER_TYPE_BOOL - ); - wirefilter_add_type_field_to_scheme( + ), "could not add field ssl of type \"Bool\" to scheme"); + rust_assert(wirefilter_add_type_field_to_scheme( scheme, wirefilter_string("tcp.port"), WIREFILTER_TYPE_INT + ), "could not add field tcp.port of type \"Int\" to scheme"); + wirefilter_add_type_field_to_scheme( + scheme, + wirefilter_string("http.headers"), + wirefilter_create_map_type(WIREFILTER_TYPE_BYTES) ); + rust_assert(wirefilter_add_type_field_to_scheme( + scheme, + wirefilter_string("http.cookies"), + wirefilter_create_array_type(WIREFILTER_TYPE_BYTES) + ), "could not add field http.cookies of type \"Array\" to scheme"); + rust_assert(wirefilter_add_type_list_to_scheme( + scheme, + WIREFILTER_TYPE_IP, + wirefilter_create_always_list() + ), "could not add list for type \"Ip\" to scheme"); +} + +void wirefilter_ffi_ctest_create_array_type() { + wirefilter_type_t array_type = wirefilter_create_array_type(WIREFILTER_TYPE_BYTES); + rust_assert(array_type.layers == 0, "could not create valid array type"); + rust_assert(array_type.len == 1, "could not create valid array type"); + rust_assert(array_type.primitive == WIREFILTER_PRIMITIVE_TYPE_BYTES, "could not create valid array type"); +} + +void wirefilter_ffi_ctest_create_map_type() { + wirefilter_type_t map_type = wirefilter_create_map_type(WIREFILTER_TYPE_BYTES); + rust_assert(map_type.layers == 1, "could not create valid map type"); + rust_assert(map_type.len == 1, "could not create valid map type"); + rust_assert(map_type.primitive == WIREFILTER_PRIMITIVE_TYPE_BYTES, "could not create valid map type"); +} + +void wirefilter_ffi_ctest_create_complex_type() { + wirefilter_type_t type = WIREFILTER_TYPE_BYTES; + type = wirefilter_create_map_type(type); + type = wirefilter_create_array_type(type); + rust_assert(type.layers == 2, "could not create valid type"); + rust_assert(type.len == 2, "could not create valid type"); + rust_assert(type.primitive == WIREFILTER_PRIMITIVE_TYPE_BYTES, "could not create valid type"); + + type = WIREFILTER_TYPE_BYTES; + type = wirefilter_create_array_type(type); + type = wirefilter_create_map_type(type); + rust_assert(type.layers == 1, "could not create valid type"); + rust_assert(type.len == 2, "could not create valid type"); + rust_assert(type.primitive == WIREFILTER_PRIMITIVE_TYPE_BYTES, "could not create valid type"); } void wirefilter_ffi_ctest_create_scheme() { @@ -51,6 +96,25 @@ void wirefilter_ffi_ctest_add_fields_to_scheme() { wirefilter_free_scheme(scheme); } +void wirefilter_ffi_ctest_add_malloced_type_field_to_scheme() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + wirefilter_type_t *byte_type = (wirefilter_type_t *)malloc(sizeof(wirefilter_type_t)); + rust_assert(byte_type != NULL, "could not allocate type"); + *byte_type = WIREFILTER_TYPE_BYTES; + + rust_assert(wirefilter_add_type_field_to_scheme( + scheme, + wirefilter_string("http.host"), + *byte_type + ), "could not add field http.host of type \"Bytes\" to scheme"); + + free(byte_type); + + wirefilter_free_scheme(scheme); +} + void wirefilter_ffi_ctest_parse_good_filter() { wirefilter_scheme_t *scheme = wirefilter_create_scheme(); rust_assert(scheme != NULL, "could not create scheme"); @@ -93,24 +157,68 @@ void wirefilter_ffi_ctest_filter_uses_field() { initialize_scheme(scheme); - wirefilter_parsing_result_t result = wirefilter_parse_filter( + wirefilter_parsing_result_t parsing_result = wirefilter_parse_filter( scheme, wirefilter_string("tcp.port == 80") ); - rust_assert(result.success == 1, "could not parse good filter"); - rust_assert(result.ok.ast != NULL, "could not parse good filter"); + rust_assert(parsing_result.success == 1, "could not parse good filter"); + rust_assert(parsing_result.ok.ast != NULL, "could not parse good filter"); - rust_assert( - wirefilter_filter_uses(result.ok.ast, wirefilter_string("tcp.port")) == true, - "filter should be using field tcp.port" + wirefilter_using_result_t using_result; + + using_result = wirefilter_filter_uses( + parsing_result.ok.ast, + wirefilter_string("tcp.port") ); - rust_assert( - wirefilter_filter_uses(result.ok.ast, wirefilter_string("ip.addr")) == false, - "filter should not be using field ip.addr" + rust_assert(using_result.success == 1, "could not check if filter uses tcp.port field"); + rust_assert(using_result.ok.value == true, "filter should be using field tcp.port"); + + using_result = wirefilter_filter_uses( + parsing_result.ok.ast, + wirefilter_string("ip.addr") ); - wirefilter_free_parsing_result(result); + rust_assert(using_result.success == 1, "could not check if filter uses ip.addr field"); + rust_assert(using_result.ok.value == false, "filter should not be using field ip.addr"); + + wirefilter_free_parsing_result(parsing_result); + + wirefilter_free_scheme(scheme); +} + +void wirefilter_ffi_ctest_filter_uses_list_field() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + initialize_scheme(scheme); + + wirefilter_parsing_result_t parsing_result = wirefilter_parse_filter( + scheme, + wirefilter_string("ip.addr in $bad") + ); + rust_assert(parsing_result.success == 1, "could not parse good filter"); + rust_assert(parsing_result.ok.ast != NULL, "could not parse good filter"); + + wirefilter_using_result_t using_result; + + using_result = wirefilter_filter_uses_list( + parsing_result.ok.ast, + wirefilter_string("ip.addr") + ); + + rust_assert(using_result.success == 1, "could not check if filter uses tcp.port field"); + rust_assert(using_result.ok.value == true, "filter should be using field ip.addr"); + + using_result = wirefilter_filter_uses_list( + parsing_result.ok.ast, + wirefilter_string("tcp.port") + ); + + rust_assert(using_result.success == 1, "could not check if filter uses tcp.port field"); + rust_assert(using_result.ok.value == false, "filter should not be using field tcp.port"); + + wirefilter_free_parsing_result(parsing_result); wirefilter_free_scheme(scheme); } @@ -135,11 +243,15 @@ void wirefilter_ffi_ctest_filter_hash() { rust_assert(result2.success == 1, "could not parse good filter"); rust_assert(result2.ok.ast != NULL, "could not parse good filter"); - uint64_t hash1 = wirefilter_get_filter_hash(result1.ok.ast); + wirefilter_hashing_result_t hashing_result; - uint64_t hash2 = wirefilter_get_filter_hash(result2.ok.ast); + hashing_result = wirefilter_get_filter_hash(result1.ok.ast); + rust_assert(hashing_result.success == 1, "could not compute hash"); + uint64_t hash1 = hashing_result.ok.hash; - rust_assert(hash1 != 0, "could not compute hash"); + hashing_result = wirefilter_get_filter_hash(result2.ok.ast); + rust_assert(hashing_result.success == 1, "could not compute hash"); + uint64_t hash2 = hashing_result.ok.hash; rust_assert(hash1 == hash2, "both filters should have the same hash"); @@ -163,8 +275,10 @@ void wirefilter_ffi_ctest_filter_serialize() { rust_assert(result.success == 1, "could not parse good filter"); rust_assert(result.ok.ast != NULL, "could not parse good filter"); - wirefilter_rust_allocated_str_t json = wirefilter_serialize_filter_to_json(result.ok.ast); + wirefilter_serializing_result_t serializing_result = wirefilter_serialize_filter_to_json(result.ok.ast); + rust_assert(serializing_result.success == 1, "could not serialize filter to JSON"); + wirefilter_rust_allocated_str_t json = serializing_result.ok.json; rust_assert(json.data != NULL && json.length > 0, "could not serialize filter to JSON"); rust_assert( @@ -179,6 +293,60 @@ void wirefilter_ffi_ctest_filter_serialize() { wirefilter_free_scheme(scheme); } +void wirefilter_ffi_ctest_scheme_serialize() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + initialize_scheme(scheme); + + wirefilter_serializing_result_t serializing_result = wirefilter_serialize_scheme_to_json(scheme); + rust_assert(serializing_result.success == 1, "could not serialize scheme to JSON"); + + wirefilter_rust_allocated_str_t json = serializing_result.ok.json; + rust_assert(json.data != NULL && json.length > 0, "could not serialize scheme to JSON"); + + rust_assert( + strncmp(json.data, "{\"http.host\":\"Bytes\",\"ip.addr\":\"Ip\",\"ssl\":\"Bool\",\"tcp.port\":\"Int\",\"http.headers\":{\"Map\":\"Bytes\"},\"http.cookies\":{\"Array\":\"Bytes\"}}", json.length) == 0, + "invalid JSON serialization" + ); + + wirefilter_free_string(json); + + wirefilter_free_scheme(scheme); +} + +void wirefilter_ffi_ctest_type_serialize() { + wirefilter_serializing_result_t serializing_result = wirefilter_serialize_type_to_json(&WIREFILTER_TYPE_BYTES); + rust_assert(serializing_result.success == 1, "could not serialize type to JSON"); + + wirefilter_rust_allocated_str_t json = serializing_result.ok.json; + rust_assert(json.data != NULL && json.length > 0, "could not serialize type to JSON"); + + rust_assert( + strncmp(json.data, "\"Bytes\"", json.length) == 0, + "invalid JSON serialization" + ); + + wirefilter_free_string(json); + + wirefilter_type_t type = wirefilter_create_map_type( + wirefilter_create_array_type(WIREFILTER_TYPE_BYTES) + ); + + serializing_result = wirefilter_serialize_type_to_json(&type); + rust_assert(serializing_result.success == 1, "could not serialize type to JSON"); + + json = serializing_result.ok.json; + rust_assert(json.data != NULL && json.length > 0, "could not serialize type to JSON"); + + rust_assert( + strncmp(json.data, "{\"Map\":{\"Array\":\"Bytes\"}}", json.length) == 0, + "invalid JSON serialization" + ); + + wirefilter_free_string(json); +} + void wirefilter_ffi_ctest_compile_filter() { wirefilter_scheme_t *scheme = wirefilter_create_scheme(); rust_assert(scheme != NULL, "could not create scheme"); @@ -192,10 +360,11 @@ void wirefilter_ffi_ctest_compile_filter() { rust_assert(result.success == true, "could not parse good filter"); rust_assert(result.ok.ast != NULL, "could not parse good filter"); - wirefilter_filter_t *filter = wirefilter_compile_filter(result.ok.ast); - rust_assert(filter != NULL, "could not compile filter"); + wirefilter_compiling_result_t compiling_result = wirefilter_compile_filter(result.ok.ast); + rust_assert(compiling_result.success == true, "could not compile filter"); + rust_assert(compiling_result.ok.filter != NULL, "could not compile filter"); - wirefilter_free_compiled_filter(filter); + wirefilter_free_compiled_filter(compiling_result.ok.filter); wirefilter_free_scheme(scheme); } @@ -221,6 +390,260 @@ void wirefilter_ffi_ctest_add_values_to_execution_context() { wirefilter_execution_context_t *exec_ctx = wirefilter_create_execution_context(scheme); rust_assert(exec_ctx != NULL, "could not create execution context"); + wirefilter_externally_allocated_byte_arr_t http_host; + http_host.data = (unsigned char *)"www.cloudflare.com"; + http_host.length = strlen((char *)http_host.data); + rust_assert(wirefilter_add_bytes_value_to_execution_context( + exec_ctx, + wirefilter_string("http.host"), + http_host + ) == true, "could not set value for field http.host"); + + uint8_t ipv4_addr[4] = {192, 168, 0, 1}; + rust_assert(wirefilter_add_ipv4_value_to_execution_context( + exec_ctx, + wirefilter_string("ip.addr"), + ipv4_addr + ) == true, "could not set value for field ip.addr"); + + uint8_t ipv6_addr[16] = {20, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + rust_assert(wirefilter_add_ipv4_value_to_execution_context( + exec_ctx, + wirefilter_string("ip.addr"), + ipv6_addr + ) == true, "could not set value for field ip.addr"); + + rust_assert(wirefilter_add_bool_value_to_execution_context( + exec_ctx, + wirefilter_string("ssl"), + false + ) == true, "could not set value for field ssl"); + + rust_assert(wirefilter_add_int_value_to_execution_context( + exec_ctx, + wirefilter_string("tcp.port"), + 80 + ) == true, "could not set value for field tcp.port"); + + wirefilter_free_execution_context(exec_ctx); + + wirefilter_free_scheme(scheme); +} + +void wirefilter_ffi_ctest_add_values_to_execution_context_errors() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + initialize_scheme(scheme); + + wirefilter_execution_context_t *exec_ctx = wirefilter_create_execution_context(scheme); + rust_assert(exec_ctx != NULL, "could not create execution context"); + + wirefilter_externally_allocated_byte_arr_t http_host; + http_host.data = (unsigned char *)"www.cloudflare.com"; + http_host.length = strlen((char *)http_host.data); + rust_assert(wirefilter_add_bytes_value_to_execution_context( + exec_ctx, + wirefilter_string("doesnotexist"), + http_host + ) == false, "managed to set value for non-existent bytes field"); + + uint8_t ipv4_addr[4] = {192, 168, 0, 1}; + rust_assert(wirefilter_add_ipv4_value_to_execution_context( + exec_ctx, + wirefilter_string("doesnotexist"), + ipv4_addr + ) == false, "managed to set value for non-existent ipv4 field"); + + uint8_t ipv6_addr[16] = {20, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + rust_assert(wirefilter_add_ipv6_value_to_execution_context( + exec_ctx, + wirefilter_string("doesnotexist"), + ipv6_addr + ) == false, "managed to set value for non-existent ipv6 field"); + + rust_assert(wirefilter_add_bool_value_to_execution_context( + exec_ctx, + wirefilter_string("doesnotexist"), + false + ) == false, "managed to set value for non-existent bool field"); + + rust_assert(wirefilter_add_int_value_to_execution_context( + exec_ctx, + wirefilter_string("doesnotexist"), + 80 + ) == false, "managed to set value for non-existent int field"); + + wirefilter_map_t *more_http_headers = wirefilter_create_map( + WIREFILTER_TYPE_BYTES + ); + rust_assert(wirefilter_add_map_value_to_execution_context( + exec_ctx, + wirefilter_string("doesnotexist"), + more_http_headers + ) == false, "managed to set value for non-existent map field"); + + wirefilter_array_t *http_cookies = wirefilter_create_array( + WIREFILTER_TYPE_BYTES + ); + rust_assert(wirefilter_add_array_value_to_execution_context( + exec_ctx, + wirefilter_string("doesnotexist"), + http_cookies + ) == false, "managed to set value for non-existent array field"); + + wirefilter_free_execution_context(exec_ctx); + + wirefilter_free_scheme(scheme); +} + +void wirefilter_ffi_ctest_execution_context_serialize() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + initialize_scheme(scheme); + + wirefilter_execution_context_t *exec_ctx = wirefilter_create_execution_context(scheme); + rust_assert(exec_ctx != NULL, "could not create execution context"); + + wirefilter_externally_allocated_byte_arr_t http_host; + http_host.data = (unsigned char *)"www.cloudflare.com"; + http_host.length = strlen((char *)http_host.data); + rust_assert(wirefilter_add_bytes_value_to_execution_context( + exec_ctx, + wirefilter_string("http.host"), + http_host + ) == true, "could not set value for field http.host"); + + uint8_t ipv4_addr[4] = {192, 168, 0, 1}; + rust_assert(wirefilter_add_ipv4_value_to_execution_context( + exec_ctx, + wirefilter_string("ip.addr"), + ipv4_addr + ) == true, "could not set value for field ip.addr"); + + uint8_t ipv6_addr[16] = {20, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + rust_assert(wirefilter_add_ipv4_value_to_execution_context( + exec_ctx, + wirefilter_string("ip.addr"), + ipv6_addr + ) == true, "could not set value for field ip.addr"); + + rust_assert(wirefilter_add_bool_value_to_execution_context( + exec_ctx, + wirefilter_string("ssl"), + false + ) == true, "could not set value for field ssl"); + + rust_assert(wirefilter_add_int_value_to_execution_context( + exec_ctx, + wirefilter_string("tcp.port"), + 80 + ) == true, "could not set value for field tcp.port"); + + wirefilter_serializing_result_t serializing_result = wirefilter_serialize_execution_context_to_json(exec_ctx); + rust_assert(serializing_result.success == 1, "could not serialize execution context to JSON"); + + wirefilter_rust_allocated_str_t json = serializing_result.ok.json; + rust_assert(json.data != NULL && json.length > 0, "could not serialize execution context to JSON"); + + rust_assert( + strncmp(json.data, "{\"http.host\":\"www.cloudflare.com\",\"ip.addr\":\"20.20.0.0\",\"ssl\":false,\"tcp.port\":80,\"$lists\":[]}", json.length) == 0, + "invalid JSON serialization" + ); + + wirefilter_free_string(json); + + wirefilter_free_execution_context(exec_ctx); + + wirefilter_free_scheme(scheme); +} + +void wirefilter_ffi_ctest_execution_context_deserialize() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + initialize_scheme(scheme); + + wirefilter_execution_context_t *exec_ctx = wirefilter_create_execution_context(scheme); + rust_assert(exec_ctx != NULL, "could not create execution context"); + + wirefilter_externally_allocated_byte_arr_t http_host; + http_host.data = (unsigned char *)"www.cloudflare.com"; + http_host.length = strlen((char *)http_host.data); + rust_assert(wirefilter_add_bytes_value_to_execution_context( + exec_ctx, + wirefilter_string("http.host"), + http_host + ) == true, "could not set value for field http.host"); + + wirefilter_serializing_result_t serializing_result = wirefilter_serialize_execution_context_to_json(exec_ctx); + rust_assert(serializing_result.success == 1, "could not serialize execution context to JSON"); + + wirefilter_rust_allocated_str_t json = serializing_result.ok.json; + rust_assert(json.data != NULL && json.length > 0, "could not serialize execution context to JSON"); + + rust_assert( + strncmp(json.data, "{\"http.host\":\"www.cloudflare.com\",\"$lists\":[]}", json.length) == 0, + "invalid JSON serialization" + ); + + wirefilter_execution_context_t *conv_exec_ctx = wirefilter_create_execution_context(scheme); + rust_assert(conv_exec_ctx != NULL, "could not create execution context"); + + wirefilter_externally_allocated_byte_arr_t serialized_exec_ctx; + serialized_exec_ctx.data = (const unsigned char*)json.data; + serialized_exec_ctx.length = json.length; + + wirefilter_boolean_result_t deserialize_result = wirefilter_deserialize_json_to_execution_context( + conv_exec_ctx, serialized_exec_ctx + ); + rust_assert(deserialize_result.success == 1, "could not deserialize execution context from JSON"); + + wirefilter_serializing_result_t conv_serializing_result = wirefilter_serialize_execution_context_to_json(conv_exec_ctx); + rust_assert(conv_serializing_result.success == 1, "could not serialize execution context to JSON"); + + wirefilter_rust_allocated_str_t conv_json = conv_serializing_result.ok.json; + rust_assert(conv_json.data != NULL && conv_json.length > 0, "could not serialize execution context to JSON"); + + rust_assert( + strncmp(conv_json.data, "{\"http.host\":\"www.cloudflare.com\",\"$lists\":[]}", conv_json.length) == 0, + "invalid JSON serialization" + ); + + wirefilter_free_serializing_result(conv_serializing_result); + + wirefilter_free_string(json); + + wirefilter_free_execution_context(conv_exec_ctx); + + wirefilter_free_execution_context(exec_ctx); + + wirefilter_free_scheme(scheme); +} + + +void wirefilter_ffi_ctest_match_filter() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + initialize_scheme(scheme); + + wirefilter_parsing_result_t result = wirefilter_parse_filter( + scheme, + wirefilter_string("tcp.port == 80") + ); + rust_assert(result.success == true, "could not parse good filter"); + rust_assert(result.ok.ast != NULL, "could not parse good filter"); + + wirefilter_compiling_result_t compiling_result = wirefilter_compile_filter(result.ok.ast); + rust_assert(compiling_result.success == true, "could not compile filter"); + rust_assert(compiling_result.ok.filter != NULL, "could not compile filter"); + wirefilter_filter_t *filter = compiling_result.ok.filter; + + wirefilter_execution_context_t *exec_ctx = wirefilter_create_execution_context(scheme); + rust_assert(exec_ctx != NULL, "could not create execution context"); + wirefilter_externally_allocated_byte_arr_t http_host; http_host.data = (unsigned char *)"www.cloudflare.com"; http_host.length = strlen((char *)http_host.data); @@ -249,12 +672,19 @@ void wirefilter_ffi_ctest_add_values_to_execution_context() { 80 ); + wirefilter_matching_result_t matching_result = wirefilter_match(filter, exec_ctx); + rust_assert(matching_result.success == 1, "could not match filter"); + + rust_assert(matching_result.ok.value == true, "filter should match"); + wirefilter_free_execution_context(exec_ctx); + wirefilter_free_compiled_filter(filter); + wirefilter_free_scheme(scheme); } -void wirefilter_ffi_ctest_match_filter() { +void wirefilter_ffi_ctest_match_map() { wirefilter_scheme_t *scheme = wirefilter_create_scheme(); rust_assert(scheme != NULL, "could not create scheme"); @@ -262,13 +692,15 @@ void wirefilter_ffi_ctest_match_filter() { wirefilter_parsing_result_t result = wirefilter_parse_filter( scheme, - wirefilter_string("tcp.port == 80") + wirefilter_string("http.headers[\"host\"] == \"www.cloudflare.com\"") ); rust_assert(result.success == true, "could not parse good filter"); rust_assert(result.ok.ast != NULL, "could not parse good filter"); - wirefilter_filter_t *filter = wirefilter_compile_filter(result.ok.ast); - rust_assert(filter != NULL, "could not compile filter"); + wirefilter_compiling_result_t compiling_result = wirefilter_compile_filter(result.ok.ast); + rust_assert(compiling_result.success == true, "could not compile filter"); + rust_assert(compiling_result.ok.filter != NULL, "could not compile filter"); + wirefilter_filter_t *filter = compiling_result.ok.filter; wirefilter_execution_context_t *exec_ctx = wirefilter_create_execution_context(scheme); rust_assert(exec_ctx != NULL, "could not create execution context"); @@ -301,7 +733,121 @@ void wirefilter_ffi_ctest_match_filter() { 80 ); - rust_assert(wirefilter_match(filter, exec_ctx) == true, "could not match filter"); + wirefilter_map_t *http_headers = wirefilter_create_map( + WIREFILTER_TYPE_BYTES + ); + + rust_assert(wirefilter_add_bytes_value_to_map( + http_headers, + wirefilter_string("host"), + http_host + ), "could not add bytes value to map"); + + rust_assert(wirefilter_add_map_value_to_execution_context( + exec_ctx, + wirefilter_string("http.headers"), + http_headers + ) == true, "could not set value for map field http.headers"); + + wirefilter_matching_result_t matching_result = wirefilter_match(filter, exec_ctx); + rust_assert(matching_result.success == 1, "could not match filter"); + + rust_assert(matching_result.ok.value == true, "filter should match"); + + wirefilter_free_execution_context(exec_ctx); + + wirefilter_free_compiled_filter(filter); + + wirefilter_free_scheme(scheme); +} + +void wirefilter_ffi_ctest_match_array() { + wirefilter_scheme_t *scheme = wirefilter_create_scheme(); + rust_assert(scheme != NULL, "could not create scheme"); + + initialize_scheme(scheme); + + wirefilter_parsing_result_t result = wirefilter_parse_filter( + scheme, + wirefilter_string("http.cookies[2] == \"www.cloudflare.com\"") + ); + rust_assert(result.success == true, "could not parse good filter"); + rust_assert(result.ok.ast != NULL, "could not parse good filter"); + + wirefilter_compiling_result_t compiling_result = wirefilter_compile_filter(result.ok.ast); + rust_assert(compiling_result.success == true, "could not compile filter"); + rust_assert(compiling_result.ok.filter != NULL, "could not compile filter"); + wirefilter_filter_t *filter = compiling_result.ok.filter; + + wirefilter_execution_context_t *exec_ctx = wirefilter_create_execution_context(scheme); + rust_assert(exec_ctx != NULL, "could not create execution context"); + + wirefilter_externally_allocated_byte_arr_t http_host; + http_host.data = (unsigned char *)"www.cloudflare.com"; + http_host.length = strlen((char *)http_host.data); + wirefilter_add_bytes_value_to_execution_context( + exec_ctx, + wirefilter_string("http.host"), + http_host + ); + + uint8_t ip_addr[4] = {192, 168, 0, 1}; + wirefilter_add_ipv4_value_to_execution_context( + exec_ctx, + wirefilter_string("ip.addr"), + ip_addr + ); + + wirefilter_add_bool_value_to_execution_context( + exec_ctx, + wirefilter_string("ssl"), + false + ); + + wirefilter_add_int_value_to_execution_context( + exec_ctx, + wirefilter_string("tcp.port"), + 80 + ); + + wirefilter_array_t *http_cookies = wirefilter_create_array( + WIREFILTER_TYPE_BYTES + ); + + wirefilter_externally_allocated_byte_arr_t http_cookie_one; + http_cookie_one.data = (unsigned char *)"one"; + http_cookie_one.length = strlen((char *)http_cookie_one.data); + rust_assert(wirefilter_add_bytes_value_to_array( + http_cookies, + 0, + http_cookie_one + ), "could not add bytes value to array"); + + wirefilter_externally_allocated_byte_arr_t http_cookie_two; + http_cookie_two.data = (unsigned char *)"two"; + http_cookie_two.length = strlen((char *)http_cookie_two.data); + rust_assert(wirefilter_add_bytes_value_to_array( + http_cookies, + 1, + http_cookie_two + ), "could not add bytes value to array"); + + rust_assert(wirefilter_add_bytes_value_to_array( + http_cookies, + 2, + http_host + ), "could not add bytes value to array"); + + rust_assert(wirefilter_add_array_value_to_execution_context( + exec_ctx, + wirefilter_string("http.cookies"), + http_cookies + ) == true, "could not set value for map field http.cookies"); + + wirefilter_matching_result_t matching_result = wirefilter_match(filter, exec_ctx); + rust_assert(matching_result.success == 1, "could not match filter"); + + rust_assert(matching_result.ok.value == true, "filter should match"); wirefilter_free_execution_context(exec_ctx); diff --git a/fuzz/bytes/Cargo.toml b/fuzz/bytes/Cargo.toml new file mode 100644 index 00000000..f4cb3641 --- /dev/null +++ b/fuzz/bytes/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "fuzz-bytes" +version = "0.1.0" +edition = "2021" + +[dependencies] +afl = "0.14" + +[dependencies.wirefilter-engine] +path = "../../engine" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } diff --git a/fuzz/bytes/in/10px-Malapera_monstro.svg.png b/fuzz/bytes/in/10px-Malapera_monstro.svg.png new file mode 100644 index 00000000..9ad9cd8d Binary files /dev/null and b/fuzz/bytes/in/10px-Malapera_monstro.svg.png differ diff --git a/fuzz/bytes/in/aaaa b/fuzz/bytes/in/aaaa new file mode 100644 index 00000000..5d308e1d --- /dev/null +++ b/fuzz/bytes/in/aaaa @@ -0,0 +1 @@ +aaaa diff --git a/fuzz/bytes/in/empty b/fuzz/bytes/in/empty new file mode 100644 index 00000000..e69de29b diff --git a/fuzz/bytes/in/invalid-utf8 b/fuzz/bytes/in/invalid-utf8 new file mode 100644 index 00000000..12b20f5f --- /dev/null +++ b/fuzz/bytes/in/invalid-utf8 @@ -0,0 +1 @@ +invalid_(_utf8 \ No newline at end of file diff --git a/fuzz/bytes/in/multi-byte b/fuzz/bytes/in/multi-byte new file mode 100644 index 00000000..49685946 --- /dev/null +++ b/fuzz/bytes/in/multi-byte @@ -0,0 +1 @@ +針盗肉前子全静属教歴的闘所殿新厚夢縮上。賞局解長暮高本側版暮断話曲。心地非法傷図物足演得芸社必西動時嶋飛。唆新四助載真宿日商辺洋巨経十。口禁般便題果供供座訂武全国助載信初。北町新坊連提遂三首無動設全情橋少部。学南結第育手内内現細盗仮死奈索等手向安彩。人精本辞打央救日計毎世裏約面無住。寄明角手携意供知要応晴権役秀男動。 \ No newline at end of file diff --git a/fuzz/bytes/src/main.rs b/fuzz/bytes/src/main.rs new file mode 100644 index 00000000..23a9d7e0 --- /dev/null +++ b/fuzz/bytes/src/main.rs @@ -0,0 +1,21 @@ +// This is up here to make the Scheme macro happy +#[cfg(fuzzing)] +use wirefilter::{ExecutionContext, Scheme}; + +#[cfg(fuzzing)] +fn main() { + fuzz!(|data: &[u8]| { + let scheme = Scheme! { foo: Bytes }; + let filter = scheme.parse("foo == \"\"").unwrap().compile(); + let mut ctx = ExecutionContext::new(&scheme); + ctx.set_field_value(scheme.get_field("foo").unwrap(), data) + .unwrap(); + + filter.execute(&ctx).unwrap(); + }); +} + +#[cfg(not(fuzzing))] +fn main() { + panic!("must compile with `cargo afl build`, not `cargo build`") +} diff --git a/fuzz/map-keys/Cargo.toml b/fuzz/map-keys/Cargo.toml new file mode 100644 index 00000000..1e2db6ac --- /dev/null +++ b/fuzz/map-keys/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "fuzz-map-keys" +version = "0.1.0" +edition = "2021" + +[dependencies] +afl = "0.14" + +[dependencies.wirefilter-engine] +path = "../../engine" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } diff --git a/fuzz/map-keys/in/10px-Malapera_monstro.svg.png b/fuzz/map-keys/in/10px-Malapera_monstro.svg.png new file mode 100644 index 00000000..9ad9cd8d Binary files /dev/null and b/fuzz/map-keys/in/10px-Malapera_monstro.svg.png differ diff --git a/fuzz/map-keys/in/aaaa b/fuzz/map-keys/in/aaaa new file mode 100644 index 00000000..5d308e1d --- /dev/null +++ b/fuzz/map-keys/in/aaaa @@ -0,0 +1 @@ +aaaa diff --git a/fuzz/map-keys/in/empty b/fuzz/map-keys/in/empty new file mode 100644 index 00000000..e69de29b diff --git a/fuzz/map-keys/in/invalid-utf8 b/fuzz/map-keys/in/invalid-utf8 new file mode 100644 index 00000000..12b20f5f --- /dev/null +++ b/fuzz/map-keys/in/invalid-utf8 @@ -0,0 +1 @@ +invalid_(_utf8 \ No newline at end of file diff --git a/fuzz/map-keys/in/multi-byte b/fuzz/map-keys/in/multi-byte new file mode 100644 index 00000000..49685946 --- /dev/null +++ b/fuzz/map-keys/in/multi-byte @@ -0,0 +1 @@ +針盗肉前子全静属教歴的闘所殿新厚夢縮上。賞局解長暮高本側版暮断話曲。心地非法傷図物足演得芸社必西動時嶋飛。唆新四助載真宿日商辺洋巨経十。口禁般便題果供供座訂武全国助載信初。北町新坊連提遂三首無動設全情橋少部。学南結第育手内内現細盗仮死奈索等手向安彩。人精本辞打央救日計毎世裏約面無住。寄明角手携意供知要応晴権役秀男動。 \ No newline at end of file diff --git a/fuzz/map-keys/src/main.rs b/fuzz/map-keys/src/main.rs new file mode 100644 index 00000000..29000333 --- /dev/null +++ b/fuzz/map-keys/src/main.rs @@ -0,0 +1,67 @@ +use std::sync::LazyLock; + +use wirefilter::{ + FunctionArgKind, FunctionArgs, LhsValue, SimpleFunctionDefinition, SimpleFunctionImpl, + SimpleFunctionParam, Type, +}; + +#[cfg(fuzzing)] +fn main() { + use wirefilter::{ExecutionContext, Map, Scheme}; + + fuzz!(|key: &[u8]| { + let mut scheme = Scheme! { foo: Map(Bytes) }; + scheme + .add_function("first".to_string(), FIRST_FN.clone()) + .unwrap(); + + let filter = scheme.parse("first(foo) == \"abc\"").unwrap().compile(); + + let value: &[u8] = b"abc"; + let mut map = Map::new(Type::Bytes); + map.insert(key, LhsValue::Bytes(value.into())).unwrap(); + + let mut ctx = ExecutionContext::new(&scheme); + ctx.set_field_value("foo", map).unwrap(); + + assert!(filter.execute(&ctx).unwrap()); + }); +} + +#[cfg(not(fuzzing))] +fn main() { + panic!("must compile with `cargo afl build`, not `cargo build`") +} + +/// A function which, given an array of bool, returns true if any one of the +/// arguments is true, otherwise false. +/// +/// It expects one argument and will panic if given an incorrect number of +/// arguments or an incorrect LhsValue. +fn first_impl<'a>(args: FunctionArgs<'_, 'a>) -> Option> { + let arg = args.next().expect("expected 1 argument, got 0"); + if args.next().is_some() { + panic!("expected 1 argument, got {}", 2 + args.count()); + } + match arg { + Ok(LhsValue::Map(m)) => { + let bytes = m.into_iter().next().unwrap().1; + + Some(bytes) + } + _ => unreachable!(), + } +} + +// ANY_FN is a function which returns true if any arguments passed to the +// function are true. +pub static FIRST_FN: LazyLock = + LazyLock::new(|| SimpleFunctionDefinition { + params: vec![SimpleFunctionParam { + arg_kind: FunctionArgKind::Field, + val_type: Type::Map(Type::Bytes.into()), + }], + opt_params: vec![], + return_type: Type::Bytes, + implementation: SimpleFunctionImpl::new(first_impl), + }); diff --git a/fuzz/raw-string/Cargo.toml b/fuzz/raw-string/Cargo.toml new file mode 100644 index 00000000..cf039a3c --- /dev/null +++ b/fuzz/raw-string/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "fuzz-raw-string" +version = "0.1.0" +edition = "2021" + +[dependencies] +afl = "0.14" + +[dependencies.wirefilter-engine] +path = "../../engine" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } diff --git a/fuzz/raw-string/in/10px-Malapera_monstro.svg.png b/fuzz/raw-string/in/10px-Malapera_monstro.svg.png new file mode 100644 index 00000000..9ad9cd8d Binary files /dev/null and b/fuzz/raw-string/in/10px-Malapera_monstro.svg.png differ diff --git a/fuzz/raw-string/in/aaaa b/fuzz/raw-string/in/aaaa new file mode 100644 index 00000000..5d308e1d --- /dev/null +++ b/fuzz/raw-string/in/aaaa @@ -0,0 +1 @@ +aaaa diff --git a/fuzz/raw-string/in/empty b/fuzz/raw-string/in/empty new file mode 100644 index 00000000..e69de29b diff --git a/fuzz/raw-string/in/invalid-utf8 b/fuzz/raw-string/in/invalid-utf8 new file mode 100644 index 00000000..12b20f5f --- /dev/null +++ b/fuzz/raw-string/in/invalid-utf8 @@ -0,0 +1 @@ +invalid_(_utf8 \ No newline at end of file diff --git a/fuzz/raw-string/in/multi-byte b/fuzz/raw-string/in/multi-byte new file mode 100644 index 00000000..49685946 --- /dev/null +++ b/fuzz/raw-string/in/multi-byte @@ -0,0 +1 @@ +針盗肉前子全静属教歴的闘所殿新厚夢縮上。賞局解長暮高本側版暮断話曲。心地非法傷図物足演得芸社必西動時嶋飛。唆新四助載真宿日商辺洋巨経十。口禁般便題果供供座訂武全国助載信初。北町新坊連提遂三首無動設全情橋少部。学南結第育手内内現細盗仮死奈索等手向安彩。人精本辞打央救日計毎世裏約面無住。寄明角手携意供知要応晴権役秀男動。 \ No newline at end of file diff --git a/fuzz/raw-string/src/main.rs b/fuzz/raw-string/src/main.rs new file mode 100644 index 00000000..ed229ece --- /dev/null +++ b/fuzz/raw-string/src/main.rs @@ -0,0 +1,32 @@ +// This is up here to make the Scheme macro happy +#[cfg(fuzzing)] +use wirefilter::Scheme; + +#[cfg(fuzzing)] +fn main() { + fuzz!(|data: &[u8]| { + let scheme = Scheme! { foo: Bytes }; + let mut input = String::from(String::from_utf8_lossy(data)); + + input = input.replace("\"###", "x"); + let f = format!("foo == r###\"{}\"###", &input); + scheme.parse(&f).unwrap().compile(); + + input = input.replace("\"##", "x"); + let f = format!("foo == r##\"{}\"##", &input); + scheme.parse(&f).unwrap().compile(); + + input = input.replace("\"#", "x"); + let f = format!("foo == r#\"{}\"#", &input); + scheme.parse(&f).unwrap().compile(); + + input = input.replace("\"", "x"); + let f = format!("foo == r\"{}\"", &input); + scheme.parse(&f).unwrap().compile(); + }); +} + +#[cfg(not(fuzzing))] +fn main() { + panic!("must compile with `cargo afl build`, not `cargo build`") +} diff --git a/wasm/Cargo.toml b/wasm/Cargo.toml index 2c4c9f0e..146c391d 100644 --- a/wasm/Cargo.toml +++ b/wasm/Cargo.toml @@ -4,7 +4,7 @@ name = "wirefilter-wasm" version = "0.7.0" description = "WebAssembly bindings for the Wirefilter engine" publish = false -edition = "2018" +edition = "2021" [lib] crate-type = ["cdylib"] @@ -12,6 +12,8 @@ crate-type = ["cdylib"] doctest = false [dependencies] -js-sys = "0.3.5" -wasm-bindgen = { version = "0.2.28", features = ["serde-serialize"] } +getrandom = { version = "0.2", features = ["js"] } +js-sys = "0.3.41" +wasm-bindgen = { version = "0.2", features = ["serde-serialize"] } wirefilter-engine = { path = "../engine", default-features = false } +serde-wasm-bindgen = "0.5.0" diff --git a/wasm/src/lib.rs b/wasm/src/lib.rs index f764e5c4..283bbd83 100644 --- a/wasm/src/lib.rs +++ b/wasm/src/lib.rs @@ -11,12 +11,14 @@ fn into_js_error(err: impl std::error::Error) -> JsValue { #[wasm_bindgen] impl Scheme { #[wasm_bindgen(constructor)] - pub fn try_from(fields: &JsValue) -> Result { - fields.into_serde().map(Scheme).map_err(into_js_error) + pub fn try_from(fields: JsValue) -> Result { + serde_wasm_bindgen::from_value(fields) + .map(Scheme) + .map_err(into_js_error) } pub fn parse(&self, s: &str) -> Result { let filter = self.0.parse(s).map_err(into_js_error)?; - JsValue::from_serde(&filter).map_err(into_js_error) + serde_wasm_bindgen::to_value(&filter).map_err(into_js_error) } } diff --git a/wirefilter-parser/Cargo.toml b/wirefilter-parser/Cargo.toml deleted file mode 100644 index fa58faea..00000000 --- a/wirefilter-parser/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "wirefilter-parser" -version = "0.1.0" -authors = ["Ivan Nikulin "] -edition = "2018" - -[dependencies] -cidr = "0.1.0" -pest_consume = "1.0.4" -pest = "2.1.3" -regex = { version = "1.3.7", default-features = false, features = ["std", "perf"] } - -[dev-dependencies] -indoc = "0.3.5" \ No newline at end of file diff --git a/wirefilter-parser/src/ast.rs b/wirefilter-parser/src/ast.rs deleted file mode 100644 index fc11cfff..00000000 --- a/wirefilter-parser/src/ast.rs +++ /dev/null @@ -1,74 +0,0 @@ -use cidr::{Ipv4Cidr, Ipv6Cidr}; -use regex::bytes::RegexBuilder; -use std::borrow::Cow; -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::ops::{Deref, RangeInclusive}; -use std::str::FromStr; - -#[derive(Debug, PartialEq)] -pub struct Var<'i>(pub Cow<'i, str>); - -#[derive(Debug)] -pub struct Regex(regex::bytes::Regex); - -#[derive(Debug, PartialEq)] -pub enum Rhs<'i> { - Int(i32), - IntRange(RangeInclusive), - String(Cow<'i, [u8]>), - Ipv4(Ipv4Addr), - Ipv6(Ipv6Addr), - Ipv4Range(RangeInclusive), - Ipv6Range(RangeInclusive), - Ipv4Cidr(Ipv4Cidr), - Ipv6Cidr(Ipv6Cidr), - Regex(Regex), -} - -#[derive(Debug, PartialEq)] -pub enum BinOp { - Eq, - NotEq, - GreaterOrEq, - LessOrEq, - Greater, - Less, - BitwiseAnd, - Contains, - Matches, - In, -} - -#[derive(Debug, PartialEq)] -pub enum Expr<'i> { - Unary(Var<'i>), - Binary { - lhs: Var<'i>, - op: BinOp, - rhs: Rhs<'i>, - }, -} - -impl PartialEq for Regex { - #[inline] - fn eq(&self, other: &Self) -> bool { - self.0.as_str() == other.0.as_str() - } -} - -impl Deref for Regex { - type Target = regex::bytes::Regex; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl FromStr for Regex { - type Err = regex::Error; - - fn from_str(s: &str) -> Result { - RegexBuilder::new(s).unicode(false).build().map(Regex) - } -} diff --git a/wirefilter-parser/src/grammar.pest b/wirefilter-parser/src/grammar.pest deleted file mode 100644 index 1d9b8811..00000000 --- a/wirefilter-parser/src/grammar.pest +++ /dev/null @@ -1,96 +0,0 @@ -// Identifiers -//============================================================ -ident = _{ ASCII_ALPHA ~ ASCII_ALPHANUMERIC* } -var = @{ ident ~ ("." ~ ident)* } - - -// Rhs -//============================================================ -// NOTE: unfortunately there is an ambiguity between IP literals and int literal. -// Though, in the worst case we'll backtrack only 4 characters. -rhs = { - ipv4_range | ipv4_cidr | ipv4_lit | ipv6_range | ipv6_cidr | ipv6_lit | - int_range | int_lit | str_lit | re_lit -} - -// Int literal -int_lit = ${ "-"? ~ digits } -digits = _{ oct_digits | ( "0x" ~ hex_digits ) | dec_digits } -hex_digits = { ASCII_HEX_DIGIT+ } -// NOTE: we need to include 0, so i32::from_str_radix can parse it properly -oct_digits = { "0" ~ ASCII_OCT_DIGIT+ } -dec_digits = { ASCII_DIGIT+ } - -// Int range -int_range = ${ int_lit ~ ".." ~ int_lit } - -// String -str_lit = ${ "\"" ~ str_content ~ "\"" } -str_content = _{ ( text | ( esc ~ text? ) )* } -text = { (!("\"" | "\\") ~ ANY)+ } -esc = _{ "\\" ~ ( esc_alias | ( "x" ~ esc_hex_byte ) ) } -esc_alias = { "\"" | "\\" | "n" | "r" | "t" } -esc_hex_byte = { ASCII_HEX_DIGIT{2} } - -// IP -ipv4_lit = @{ ASCII_DIGIT{1,3} ~ ( "." ~ ASCII_DIGIT{1,3} ){3} } -// NOTE: this is far from being precise IPv6 grammar, but it's not ambigious with IPv4 and -// int literal. Actual parsing is performed by Rust's std lib. -ipv6_lit = @{ - ( ":" | ASCII_ALPHANUMERIC{1,4} ) ~ ":" ~ ( ipv4_lit | ASCII_ALPHANUMERIC{1,4} | ":" )* -} - -// IP ranges -ipv4_cidr = @{ ipv4_lit ~ "/" ~ ASCII_DIGIT{1,2} } -ipv6_cidr = @{ ipv6_lit ~ "/" ~ ASCII_DIGIT{1,3} } -ipv4_range = ${ ipv4_lit ~ ".." ~ ipv4_lit } -ipv6_range = ${ ipv6_lit ~ ".." ~ ipv6_lit } - -// Regex -re_lit = ${ "/" ~ re_content ~ "/" } -re_content = { ( re_ch_gr | re_esc | re_unesc )+ } -re_unesc = _{ ( !( "/" | "\\" | "[" ) ~ ANY )+ } -re_esc = _{ "\\" ~ ANY } -re_ch_gr = _{ "[" ~ ( re_esc | re_ch_gr_unesc )* ~ "]" } -re_ch_gr_unesc = _{ ( !( "]" | "\\" ) ~ ANY )+ } - - -// Logical operators -//============================================================ -logical_op = { op_or | op_and | op_xor } - -op_or = { "||" | "or" } -op_and = { "&&" | "and" } -op_xor = { "^^" | "xor" } - - -// Binary operators -//============================================================ -bin_op = { - eq_op | ne_op | ge_op | le_op | gt_op | lt_op | band_op | contains_op | matches_op | in_op -} - -eq_op = { "==" | "eq" } -ne_op = { "!=" | "ne" } -ge_op = { ">=" | "ge" } -le_op = { "<=" | "le" } -gt_op = { ">" | "gt" } -lt_op = { "<" | "lt" } -band_op = { "&" | "bitwise_and" } -contains_op = { "contains" } -matches_op = { "~" | "matches" } -in_op = { "in" } - - -// Expression -//============================================================ -filter = { SOI ~ compound_expr ~ EOI } - -compound_expr = { term ~ ( logical_op ~ term )* } -term = _{ expr | "(" ~ compound_expr ~ ")" } -expr = { var ~ (bin_op ~ rhs)? } - - -// Trivia -//============================================================ -WHITESPACE = _{ " " | NEWLINE } diff --git a/wirefilter-parser/src/lib.rs b/wirefilter-parser/src/lib.rs deleted file mode 100644 index 0c5604a2..00000000 --- a/wirefilter-parser/src/lib.rs +++ /dev/null @@ -1,627 +0,0 @@ -pub mod ast; -mod semantics; - -use cidr::{Ipv4Cidr, Ipv6Cidr}; -use pest::error::ErrorVariant; -use pest_consume::{match_nodes, Error as ParseError, Parser as PestParser}; -use semantics::ValidateSemantics; -use std::borrow::Cow; -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::ops::RangeInclusive; - -#[derive(PestParser)] -#[grammar = "./grammar.pest"] -pub struct Parser; - -pub type ParseResult = Result>; -pub type Node<'i> = pest_consume::Node<'i, Rule, ()>; - -trait IntoParseResult { - fn into_parse_result(self, node: &Node) -> ParseResult; -} - -impl IntoParseResult for Result -where - E: ToString, -{ - fn into_parse_result(self, node: &Node) -> ParseResult { - self.map_err(|e| { - let span = node.as_span(); - - let err_var = ErrorVariant::CustomError { - message: e.to_string(), - }; - - ParseError::new_from_span(err_var, span) - }) - } -} - -macro_rules! parse_num { - ($node:expr, $ty:ident, $radix:expr) => { - $ty::from_str_radix($node.as_str(), $radix).into_parse_result(&$node) - }; -} - -macro_rules! parse_type { - ($node:expr, $ty:ident) => { - $node.as_str().parse::<$ty>().into_parse_result(&$node) - }; -} - -macro_rules! parse_range { - ($node:expr, $lit_ty:ident) => { - match_nodes! { - $node.children(); - [$lit_ty(l1), $lit_ty(l2)] => (l1..=l2).validate_semantics().into_parse_result(&$node) - } - }; -} - -#[pest_consume::parser] -impl Parser { - fn var(node: Node) -> ParseResult { - // TODO check in scheme - Ok(ast::Var(node.as_str().into())) - } - - fn int_lit(node: Node) -> ParseResult { - use Rule::*; - - let digits_node = node.children().single().unwrap(); - - let radix = match digits_node.as_rule() { - hex_digits => 16, - oct_digits => 8, - dec_digits => 10, - _ => unreachable!(), - }; - - let mut num = parse_num!(digits_node, i32, radix)?; - - if let Some('-') = node.as_str().chars().next() { - num = -num; - } - - Ok(num) - } - - fn esc_alias(node: Node) -> ParseResult { - Ok(match node.as_str() { - "\"" => b'"', - "\\" => b'\\', - "n" => b'\n', - "r" => b'\r', - "t" => b'\t', - _ => unreachable!(), - }) - } - - fn str_lit(node: Node) -> ParseResult> { - use Rule::*; - - let content = node.into_children().collect::>(); - - // NOTE: if there are no escapes then we can avoid allocating. - if content.len() == 1 && matches!(content[0].as_rule(), Rule::text) { - return Ok(content[0].as_str().as_bytes().into()); - } - - let mut s = Vec::new(); - - for node in content { - match node.as_rule() { - text => s.extend_from_slice(node.as_str().as_bytes()), - esc_alias => s.push(Parser::esc_alias(node)?), - esc_hex_byte => s.push(parse_num!(node, u8, 16)?), - _ => unreachable!(), - } - } - - Ok(s.into()) - } - - fn int_range(node: Node) -> ParseResult> { - parse_range!(node, int_lit) - } - - #[inline] - fn ipv4_lit(node: Node) -> ParseResult { - parse_type!(node, Ipv4Addr) - } - - #[inline] - fn ipv6_lit(node: Node) -> ParseResult { - parse_type!(node, Ipv6Addr) - } - - #[inline] - fn ipv4_cidr(node: Node) -> ParseResult { - parse_type!(node, Ipv4Cidr) - } - - #[inline] - fn ipv6_cidr(node: Node) -> ParseResult { - parse_type!(node, Ipv6Cidr) - } - - #[inline] - fn ipv4_range(node: Node) -> ParseResult> { - parse_range!(node, ipv4_lit) - } - - #[inline] - fn ipv6_range(node: Node) -> ParseResult> { - parse_range!(node, ipv6_lit) - } - - fn re_lit(node: Node) -> ParseResult { - node.children() - .single() - .unwrap() - .as_str() - .parse() - .into_parse_result(&node) - } - - fn rhs(node: Node) -> ParseResult { - Ok(match_nodes! { - node.children(); - [int_lit(i)] => ast::Rhs::Int(i), - [int_range(r)] => ast::Rhs::IntRange(r), - [str_lit(s)] => ast::Rhs::String(s), - [ipv4_lit(i)] => ast::Rhs::Ipv4(i), - [ipv6_lit(i)] => ast::Rhs::Ipv6(i), - [ipv4_cidr(c)] => ast::Rhs::Ipv4Cidr(c), - [ipv6_cidr(c)] => ast::Rhs::Ipv6Cidr(c), - [ipv4_range(r)] => ast::Rhs::Ipv4Range(r), - [ipv6_range(r)] => ast::Rhs::Ipv6Range(r), - [re_lit(r)] => ast::Rhs::Regex(r) - }) - } - - fn bin_op(node: Node) -> ParseResult { - use ast::BinOp::*; - use Rule::*; - - let op = node.children().single().unwrap().as_rule(); - - Ok(match op { - eq_op => Eq, - ne_op => NotEq, - ge_op => GreaterOrEq, - le_op => LessOrEq, - gt_op => Greater, - lt_op => Less, - band_op => BitwiseAnd, - contains_op => Contains, - matches_op => Matches, - in_op => In, - _ => unreachable!(), - }) - } - - fn expr(node: Node) -> ParseResult { - // TODO type checks - Ok(match_nodes! { - node.children(); - [var(var), bin_op(op), rhs(rhs)] => ast::Expr::Binary {lhs: var, op, rhs}, - [var(var)] => ast::Expr::Unary(var) - }) - } -} - -#[cfg(test)] -#[allow(clippy::string_lit_as_bytes)] -mod tests { - use super::*; - use cidr::Cidr as _; - use indoc::indoc; - - macro_rules! parse { - ($rule:ident, $input:expr) => { - Parser::parse(Rule::$rule, $input) - .and_then(|p| p.single()) - .and_then(Parser::$rule) - }; - } - - macro_rules! ok { - ($rule:ident $input:expr => $expected:expr) => { - assert_eq!(parse!($rule, $input), Ok($expected)); - }; - } - - macro_rules! err { - ($rule:ident $input:expr => $expected:expr) => { - assert_eq!( - parse!($rule, $input).unwrap_err().to_string(), - indoc!($expected) - ); - }; - } - - #[test] - fn parse_var() { - ok! { var "foo" => ast::Var("foo".into()) } - ok! { var "f1o2o3" => ast::Var("f1o2o3".into()) } - ok! { var "f1o2o3.bar321" => ast::Var("f1o2o3.bar321".into()) } - ok! { var "foo.bar.baz" => ast::Var("foo.bar.baz".into()) } - - err! { var "123foo" => - " --> 1:1 - | - 1 | 123foo - | ^--- - | - = expected var" - } - } - - #[test] - fn parse_int_lit() { - ok! { int_lit "42" => 42 } - ok! { int_lit "-42" => -42 } - ok! { int_lit "0x2A" => 42 } - ok! { int_lit "-0x2a" => -42 } - ok! { int_lit "052" => 42 } - ok! { int_lit "-052" => -42 } - - err! { int_lit "-abc" => - " --> 1:2 - | - 1 | -abc - | ^--- - | - = expected oct_digits or dec_digits" - } - - err! { int_lit "99999999999999999999999999999" => - " --> 1:1 - | - 1 | 99999999999999999999999999999 - | ^---------------------------^ - | - = number too large to fit in target type" - } - } - - #[test] - fn parse_int_range() { - ok! { int_range "42..0x2b" => 42..=43 } - ok! { int_range "-0x2a..0x2A" => -42..=42 } - ok! { int_range "42..42" => 42..=42 } - - err! { int_range "42.. 43" => - " --> 1:5 - | - 1 | 42.. 43 - | ^--- - | - = expected int_lit" - } - - err! { int_range "45..42" => - " --> 1:1 - | - 1 | 45..42 - | ^----^ - | - = start of the range is greater than the end" - } - - err! { int_range "42..z" => - " --> 1:5 - | - 1 | 42..z - | ^--- - | - = expected int_lit" - } - } - - #[test] - fn parse_str_lit() { - ok! { str_lit r#""""# => "".as_bytes().into() } - ok! { str_lit r#""foobar baz qux""# => "foobar baz qux".as_bytes().into() } - ok! { str_lit r#""foo \x41\x42 bar\x43""# => "foo AB barC".as_bytes().into() } - - ok! { - str_lit r#""\n foo \t\r \\ baz \" bar ""# => - "\n foo \t\r \\ baz \" bar ".as_bytes().into() - } - - err! { str_lit r#""foobar \i""# => - r#" --> 1:10 - | - 1 | "foobar \i" - | ^--- - | - = expected esc_alias"# - } - - err! { str_lit r#""foobar \x3z""# => - r#" --> 1:11 - | - 1 | "foobar \x3z" - | ^--- - | - = expected esc_hex_byte"# - } - } - - #[test] - fn parse_bin_op() { - ok! { bin_op "==" => ast::BinOp::Eq } - ok! { bin_op "eq" => ast::BinOp::Eq } - ok! { bin_op "!=" => ast::BinOp::NotEq } - ok! { bin_op "ne" => ast::BinOp::NotEq } - ok! { bin_op ">=" => ast::BinOp::GreaterOrEq } - ok! { bin_op "ge" => ast::BinOp::GreaterOrEq } - ok! { bin_op "<=" => ast::BinOp::LessOrEq } - ok! { bin_op "le" => ast::BinOp::LessOrEq } - ok! { bin_op ">" => ast::BinOp::Greater } - ok! { bin_op "gt" => ast::BinOp::Greater } - ok! { bin_op "<" => ast::BinOp::Less } - ok! { bin_op "lt" => ast::BinOp::Less } - ok! { bin_op "&" => ast::BinOp::BitwiseAnd } - ok! { bin_op "bitwise_and" => ast::BinOp::BitwiseAnd } - ok! { bin_op "contains" => ast::BinOp::Contains } - ok! { bin_op "~" => ast::BinOp::Matches } - ok! { bin_op "matches" => ast::BinOp::Matches } - ok! { bin_op "in" => ast::BinOp::In } - } - - #[test] - fn pare_expr() { - ok! { expr "foo.bar.baz" => ast::Expr::Unary(ast::Var("foo.bar.baz".into())) } - - ok! { - expr "foo.bar.baz > 42" => - ast::Expr::Binary { - lhs: ast::Var("foo.bar.baz".into()), - op: ast::BinOp::Greater, - rhs: ast::Rhs::Int(42) - } - } - - ok! { - expr "foo.bar.baz in 32..42" => - ast::Expr::Binary { - lhs: ast::Var("foo.bar.baz".into()), - op: ast::BinOp::In, - rhs: ast::Rhs::IntRange(32..=42) - } - } - - ok! { - expr "foo == 220.12.13.1" => - ast::Expr::Binary { - lhs: ast::Var("foo".into()), - op: ast::BinOp::Eq, - rhs: ast::Rhs::Ipv4(Ipv4Addr::new(220, 12, 13, 1)) - } - } - - ok! { - expr "foo in 220.12.13.1..220.12.13.2" => - ast::Expr::Binary { - lhs: ast::Var("foo".into()), - op: ast::BinOp::In, - rhs: ast::Rhs::Ipv4Range( - Ipv4Addr::new(220, 12, 13, 1)..=Ipv4Addr::new(220, 12, 13, 2) - ) - } - } - - ok! { - expr "foo in 192.0.0.0/16" => - ast::Expr::Binary { - lhs: ast::Var("foo".into()), - op: ast::BinOp::In, - rhs: ast::Rhs::Ipv4Cidr( - Ipv4Cidr::new(Ipv4Addr::new(192, 0, 0, 0), 16).unwrap() - ) - } - } - - ok! { - expr "foo in ::1/128" => - ast::Expr::Binary { - lhs: ast::Var("foo".into()), - op: ast::BinOp::In, - rhs: ast::Rhs::Ipv6Cidr( - Ipv6Cidr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 128).unwrap() - ) - } - } - - ok! { - expr "foo == 2001:db8::1" => - ast::Expr::Binary { - lhs: ast::Var("foo".into()), - op: ast::BinOp::Eq, - rhs: ast::Rhs::Ipv6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)) - } - } - - ok! { - expr r#"foo.bar == "test\n""# => - ast::Expr::Binary { - lhs: ast::Var("foo.bar".into()), - op: ast::BinOp::Eq, - rhs: ast::Rhs::String("test\n".as_bytes().into()) - } - } - } - - #[test] - fn parse_ipv4_lit() { - ok! { ipv4_lit "127.0.0.1" => Ipv4Addr::new(127, 0, 0, 1) } - ok! { ipv4_lit "192.0.2.235" => Ipv4Addr::new(192, 0, 2, 235) } - - err! { ipv4_lit "127.0.0.a" => - " --> 1:1 - | - 1 | 127.0.0.a - | ^--- - | - = expected ipv4_lit" - } - - err! { ipv4_lit "300.0.0.1" => - " --> 1:1 - | - 1 | 300.0.0.1 - | ^-------^ - | - = invalid IP address syntax" - } - } - - #[test] - fn parse_ipv6_lit() { - ok! { ipv6_lit "::" => Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0) } - ok! { ipv6_lit "::1" => Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1) } - ok! { ipv6_lit "2001:db8::1" => Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1) } - ok! { ipv6_lit "2001:db8::1" => Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1) } - - ok! { - ipv6_lit "::ffff:255.255.255.255" => - Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xffff, 0xffff) - } - - err! { ipv6_lit "2001:dz8::1" => - " --> 1:1 - | - 1 | 2001:dz8::1 - | ^---------^ - | - = invalid IP address syntax" - } - } - - #[test] - fn parse_ipv4_cidr() { - ok! { - ipv4_cidr "127.0.0.1/32" => - Ipv4Cidr::new(Ipv4Addr::new(127, 0, 0, 1), 32).unwrap() - } - - ok! { - ipv4_cidr "192.0.0.0/16" => - Ipv4Cidr::new(Ipv4Addr::new(192, 0, 0, 0), 16).unwrap() - } - - err! { ipv4_cidr "192.0.0.0/99" => - " --> 1:1 - | - 1 | 192.0.0.0/99 - | ^----------^ - | - = invalid length for network: Network length 99 is too long for Ipv4 (maximum: 32)" - } - - err! { ipv4_cidr "192.0.0.1/8" => - " --> 1:1 - | - 1 | 192.0.0.1/8 - | ^---------^ - | - = host part of address was not zero" - } - } - - #[test] - fn parse_ipv6_cidr() { - ok! { - ipv6_cidr "::1/128" => - Ipv6Cidr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 128).unwrap() - } - - ok! { - ipv6_cidr "::/10" => - Ipv6Cidr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), 10).unwrap() - } - - err! { ipv6_cidr "::1/560" => - " --> 1:1 - | - 1 | ::1/560 - | ^-----^ - | - = couldn't parse length in network: number too large to fit in target type" - } - } - - #[test] - fn parse_ipv4_range() { - ok! { - ipv4_range "127.0.0.1..127.0.0.128" => - Ipv4Addr::new(127, 0, 0, 1)..=Ipv4Addr::new(127, 0, 0, 128) - } - - ok! { - ipv4_range "192.0.2.235..192.1.2.235" => - Ipv4Addr::new(192, 0, 2, 235)..=Ipv4Addr::new(192, 1, 2, 235) - } - - err! { ipv4_range "192.0.2.235..192.1.2.a" => - " --> 1:14 - | - 1 | 192.0.2.235..192.1.2.a - | ^--- - | - = expected ipv4_lit" - } - - err! { ipv4_range "192.0.2.235..192.0.2.128" => - " --> 1:1 - | - 1 | 192.0.2.235..192.0.2.128 - | ^----------------------^ - | - = start of the range is greater than the end" - } - } - - #[test] - fn parse_ipv6_range() { - ok! { - ipv6_range "::1..::2" => - Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)..=Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2) - } - - ok! { - ipv6_range "2001:db8::1..2001:db8::ff" => - Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1) - ..=Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0xff) - } - - err! { ipv6_range "2001:db8::1..2001:dz8::ff" => - " --> 1:14 - | - 1 | 2001:db8::1..2001:dz8::ff - | ^----------^ - | - = invalid IP address syntax" - } - - err! { ipv6_range "2001:db8::ff..2001:db8::11" => - " --> 1:1 - | - 1 | 2001:db8::ff..2001:db8::11 - | ^------------------------^ - | - = start of the range is greater than the end" - } - } - - #[test] - fn parse_re_lit() { - ok! { - re_lit r#"/[-]?[0-9]+[,.]?[0-9]*([\/][0-9]+[,.]?[0-9]*)*/"# => - r#"[-]?[0-9]+[,.]?[0-9]*([\/][0-9]+[,.]?[0-9]*)*"#.parse().unwrap() - } - } -} diff --git a/wirefilter-parser/src/semantics.rs b/wirefilter-parser/src/semantics.rs deleted file mode 100644 index 47c43bc0..00000000 --- a/wirefilter-parser/src/semantics.rs +++ /dev/null @@ -1,15 +0,0 @@ -use std::ops::RangeInclusive; - -pub trait ValidateSemantics: Sized { - fn validate_semantics(self) -> Result; -} - -impl ValidateSemantics for RangeInclusive { - fn validate_semantics(self) -> Result { - if self.start() > self.end() { - Err("start of the range is greater than the end") - } else { - Ok(self) - } - } -}