diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9c4cda5d034d..ffde5378da93 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -6,10 +6,17 @@ updates: interval: daily open-pull-requests-limit: 10 target-branch: master - labels: [auto-dependencies] + labels: [ auto-dependencies, arrow ] + - package-ecosystem: cargo + directory: "/object_store" + schedule: + interval: daily + open-pull-requests-limit: 10 + target-branch: master + labels: [ auto-dependencies, object_store ] - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" open-pull-requests-limit: 10 - labels: [auto-dependencies] + labels: [ auto-dependencies ] diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index 279e276a7912..da56c23b5cd9 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -39,6 +39,7 @@ on: - arrow-integration-test/** - arrow-ipc/** - arrow-json/** + - arrow-avro/** - arrow-ord/** - arrow-row/** - arrow-schema/** @@ -55,7 +56,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -78,10 +79,12 @@ jobs: run: cargo test -p arrow-csv --all-features - name: Test arrow-json with all features run: cargo test -p arrow-json --all-features + - name: Test arrow-avro with all features + run: cargo test -p arrow-avro --all-features - name: Test arrow-string with all features run: cargo test -p arrow-string --all-features - - name: Test arrow-ord with all features except SIMD - run: cargo test -p arrow-ord --features dyn_cmp_dict + - name: Test arrow-ord with all features + run: cargo test -p arrow-ord --all-features - name: Test arrow-arith with all features except SIMD run: cargo test -p arrow-arith - name: Test arrow-row with all features @@ -91,7 +94,7 @@ jobs: - name: Test arrow with default features run: cargo test -p arrow - name: Test arrow with all features apart from simd - run: cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict,chrono-tz + run: cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,chrono-tz - name: Run examples run: | # Test arrow examples @@ -109,7 +112,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -136,7 +139,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -145,8 +148,6 @@ jobs: rust-version: nightly - name: Test arrow-array with SIMD run: cargo test -p arrow-array --features simd - - name: Test arrow-ord with SIMD - run: cargo test -p arrow-ord --features simd - name: Test arrow-arith with SIMD run: cargo test -p arrow-arith --features simd - name: Test arrow with SIMD @@ -162,7 +163,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -181,7 +182,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy @@ -204,16 +205,18 @@ jobs: run: cargo clippy -p arrow-csv --all-targets --all-features -- -D warnings - name: Clippy arrow-json with all features run: cargo clippy -p arrow-json --all-targets --all-features -- -D warnings + - name: Clippy arrow-avro with all features + run: cargo clippy -p arrow-avro --all-targets --all-features -- -D warnings - name: Clippy arrow-string with all features run: cargo clippy -p arrow-string --all-targets --all-features -- -D warnings - - name: Clippy arrow-ord with all features except SIMD - run: cargo clippy -p arrow-ord --all-targets --features dyn_cmp_dict -- -D warnings + - name: Clippy arrow-ord with all features + run: cargo clippy -p arrow-ord --all-targets --all-features -- -D warnings - name: Clippy arrow-arith with all features except SIMD run: cargo clippy -p arrow-arith --all-targets -- -D warnings - name: Clippy arrow-row with all features run: cargo clippy -p arrow-row --all-targets --all-features -- -D warnings - name: Clippy arrow with all features except SIMD - run: cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict,chrono-tz --all-targets -- -D warnings + run: cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,chrono-tz --all-targets -- -D warnings - name: Clippy arrow-integration-test with all features run: cargo clippy -p arrow-integration-test --all-targets --all-features -- -D warnings - name: Clippy arrow-integration-testing with all features diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml index 5301a3f8563f..242e0f2a3b0d 100644 --- a/.github/workflows/arrow_flight.yml +++ b/.github/workflows/arrow_flight.yml @@ -47,7 +47,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -68,7 +68,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Run gen @@ -82,7 +82,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 3fa254142dbe..64b2ca437067 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -36,7 +36,7 @@ jobs: # otherwise we get this error: # Failed to run tests: ASLR disable failed: EPERM: Operation not permitted steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 0eb2d024f352..9871f8b7d295 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -38,7 +38,7 @@ jobs: name: Release Audit Tool (RAT) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python uses: actions/setup-python@v4 with: @@ -50,7 +50,7 @@ jobs: name: Markdown format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-node@v3 with: node-version: "14" diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index daa5d6a76c52..5f3d9e54c8db 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -37,14 +37,14 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v4.2.0 + uses: actions/labeler@v4.3.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index e5b86e8bcdf0..ea5873081f18 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -27,6 +27,7 @@ arrow: - arrow-integration-testing/**/* - arrow-ipc/**/* - arrow-json/**/* + - arrow-avro/**/* - arrow-ord/**/* - arrow-row/**/* - arrow-schema/**/* diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 7e80aea6b978..721260892402 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -43,13 +43,13 @@ jobs: env: RUSTDOCFLAGS: "-Dwarnings --enable-index-page -Zunstable-options" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install python dev run: | apt update - apt install -y libpython3.9-dev + apt install -y libpython3.11-dev - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -64,7 +64,7 @@ jobs: echo "::warning title=Invalid file permissions automatically fixed::$line" done - name: Upload artifacts - uses: actions/upload-pages-artifact@v1 + uses: actions/upload-pages-artifact@v2 with: name: crate-docs path: target/doc @@ -77,7 +77,7 @@ jobs: contents: write runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Download crate docs uses: actions/download-artifact@v3 with: diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 9b2e7797d5ff..62d2d2cb1a06 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -38,6 +38,7 @@ on: - arrow-integration-testing/** - arrow-ipc/** - arrow-json/** + - arrow-avro/** - arrow-ord/** - arrow-pyarrow-integration-testing/** - arrow-schema/** @@ -56,6 +57,7 @@ jobs: env: ARROW_USE_CCACHE: OFF ARROW_CPP_EXE_PATH: /build/cpp/debug + ARROW_GO_INTEGRATION: 1 BUILD_DOCS_CPP: OFF # These are necessary because the github runner overrides $HOME # https://github.com/actions/runner/issues/863 @@ -76,16 +78,20 @@ jobs: - name: Check cmake run: which cmake - name: Checkout Arrow - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: apache/arrow submodules: true fetch-depth: 0 - name: Checkout Arrow Rust - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: path: rust fetch-depth: 0 + - name: Install pythonnet + run: conda run --no-capture-output pip install pythonnet + - name: Install archery + run: conda run --no-capture-output pip install -e dev/archery[integration] - name: Make build directory run: mkdir /build - name: Build Rust @@ -100,12 +106,12 @@ jobs: run: conda run --no-capture-output ci/scripts/java_build.sh $PWD /build - name: Build JS run: conda run --no-capture-output ci/scripts/js_build.sh $PWD /build - - name: Install archery - run: conda run --no-capture-output pip install -e dev/archery - name: Run integration tests run: | conda run --no-capture-output archery integration \ --run-flight \ + --run-c-data \ + --run-ipc \ --with-cpp=1 \ --with-csharp=1 \ --with-java=1 \ @@ -127,7 +133,7 @@ jobs: matrix: rust: [ stable ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain diff --git a/.github/workflows/miri.sh b/.github/workflows/miri.sh index 3323bd0996bf..ec8712660c74 100755 --- a/.github/workflows/miri.sh +++ b/.github/workflows/miri.sh @@ -5,11 +5,7 @@ # Must be run with nightly rust for example # rustup default nightly - -# stacked borrows checking uses too much memory to run successfully in github actions -# re-enable if the CI is migrated to something more powerful (https://github.com/apache/arrow-rs/issues/1833) -# see also https://github.com/rust-lang/miri/issues/1367 -export MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-stacked-borrows" +export MIRIFLAGS="-Zmiri-disable-isolation" cargo miri setup cargo clean @@ -18,3 +14,5 @@ cargo miri test -p arrow-buffer cargo miri test -p arrow-data --features ffi cargo miri test -p arrow-schema --features ffi cargo miri test -p arrow-array +cargo miri test -p arrow-arith --features simd +cargo miri test -p arrow-ord diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index 0c1f8069cd40..19b432121b6f 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -36,6 +36,7 @@ on: - arrow-data/** - arrow-ipc/** - arrow-json/** + - arrow-avro/** - arrow-schema/** - arrow-select/** - arrow-string/** @@ -46,7 +47,7 @@ jobs: name: MIRI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index 5ae9d2d9c83f..1b991e33c097 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -43,7 +43,7 @@ jobs: run: working-directory: object_store steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy @@ -60,11 +60,31 @@ jobs: run: cargo clippy --features gcp -- -D warnings - name: Run clippy with azure feature run: cargo clippy --features azure -- -D warnings + - name: Run clippy with http feature + run: cargo clippy --features http -- -D warnings - name: Run clippy with all features run: cargo clippy --all-features -- -D warnings - name: Run clippy with all features and all targets run: cargo clippy --all-features --all-targets -- -D warnings + # test doc links still work + # + # Note that since object_store is not part of the main workspace, + # this needs a separate docs job as it is not covered by + # `cargo doc --workspace` + docs: + name: Rustdocs + runs-on: ubuntu-latest + defaults: + run: + working-directory: object_store + env: + RUSTDOCFLAGS: "-Dwarnings" + steps: + - uses: actions/checkout@v4 + - name: Run cargo doc + run: cargo doc --document-private-items --no-deps --workspace --all-features + # test the crate # This runs outside a container to workaround lack of support for passing arguments # to service containers - https://github.com/orgs/community/discussions/26688 @@ -82,15 +102,22 @@ jobs: # Run integration tests TEST_INTEGRATION: 1 EC2_METADATA_ENDPOINT: http://localhost:1338 - AZURE_USE_EMULATOR: "1" + AZURE_CONTAINER_NAME: test-bucket + AZURE_STORAGE_USE_EMULATOR: "1" AZURITE_BLOB_STORAGE_URL: "http://localhost:10000" AZURITE_QUEUE_STORAGE_URL: "http://localhost:10001" + AWS_BUCKET: test-bucket + AWS_DEFAULT_REGION: "us-east-1" + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test + AWS_ENDPOINT: http://localhost:4566 + AWS_ALLOW_HTTP: true HTTP_URL: "http://localhost:8080" + GOOGLE_BUCKET: test-bucket GOOGLE_SERVICE_ACCOUNT: "/tmp/gcs.json" - OBJECT_STORE_BUCKET: test-bucket steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Configure Fake GCS Server (GCP emulation) # Custom image - see fsouza/fake-gcs-server#1164 @@ -99,17 +126,12 @@ jobs: # Give the container a moment to start up prior to configuring it sleep 1 curl -v -X POST --data-binary '{"name":"test-bucket"}' -H "Content-Type: application/json" "http://localhost:4443/storage/v1/b" - echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": ""}' > "$GOOGLE_SERVICE_ACCOUNT" + echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": "", "private_key_id": ""}' > "$GOOGLE_SERVICE_ACCOUNT" - name: Setup WebDav run: docker run -d -p 8080:80 rclone/rclone serve webdav /data --addr :80 - name: Setup LocalStack (AWS emulation) - env: - AWS_DEFAULT_REGION: "us-east-1" - AWS_ACCESS_KEY_ID: test - AWS_SECRET_ACCESS_KEY: test - AWS_ENDPOINT: http://localhost:4566 run: | docker run -d -p 4566:4566 localstack/localstack:2.0 docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2 @@ -128,11 +150,6 @@ jobs: rustup default stable - name: Run object_store tests - env: - OBJECT_STORE_AWS_DEFAULT_REGION: "us-east-1" - OBJECT_STORE_AWS_ACCESS_KEY_ID: test - OBJECT_STORE_AWS_SECRET_ACCESS_KEY: test - OBJECT_STORE_AWS_ENDPOINT: http://localhost:4566 run: cargo test --features=aws,azure,gcp,http # test the object_store crate builds against wasm32 in stable rust @@ -145,7 +162,7 @@ jobs: run: working-directory: object_store steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -155,4 +172,4 @@ jobs: - name: Build wasm32-unknown-unknown run: cargo build --target wasm32-unknown-unknown - name: Build wasm32-wasi - run: cargo build --target wasm32-wasi \ No newline at end of file + run: cargo build --target wasm32-wasi diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml index 55599b776c32..d664a0dc0730 100644 --- a/.github/workflows/parquet.yml +++ b/.github/workflows/parquet.yml @@ -40,6 +40,7 @@ on: - arrow-ipc/** - arrow-csv/** - arrow-json/** + - arrow-avro/** - parquet/** - .github/** @@ -51,7 +52,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -74,7 +75,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -116,17 +117,19 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: target: wasm32-unknown-unknown,wasm32-wasi + - name: Install clang # Needed for zlib compilation + run: apt-get update && apt-get install -y clang gcc-multilib - name: Build wasm32-unknown-unknown - run: cargo build -p parquet --no-default-features --features cli,snap,flate2,brotli --target wasm32-unknown-unknown + run: cargo build -p parquet --target wasm32-unknown-unknown - name: Build wasm32-wasi - run: cargo build -p parquet --no-default-features --features cli,snap,flate2,brotli --target wasm32-wasi + run: cargo build -p parquet --target wasm32-wasi pyspark-integration-test: name: PySpark Integration Test @@ -135,7 +138,7 @@ jobs: matrix: rust: [ stable ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python uses: actions/setup-python@v4 with: @@ -168,7 +171,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml index 72b90ecfd81a..d8b02f73a8aa 100644 --- a/.github/workflows/parquet_derive.yml +++ b/.github/workflows/parquet_derive.yml @@ -43,7 +43,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -57,7 +57,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e09e898fe160..9c4b28b691b7 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -37,7 +37,7 @@ jobs: name: Test on Mac runs-on: macos-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protoc with brew @@ -60,7 +60,7 @@ jobs: name: Test on Windows runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protobuf compiler in /d/protoc @@ -93,10 +93,41 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup rustfmt run: rustup component add rustfmt - - name: Run + - name: Format arrow run: cargo fmt --all -- --check + - name: Format object_store + working-directory: object_store + run: cargo fmt --all -- --check + + msrv: + name: Verify MSRV + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Install cargo-msrv + run: cargo install cargo-msrv + - name: Check arrow + working-directory: arrow + run: cargo msrv verify + - name: Check parquet + working-directory: parquet + run: cargo msrv verify + - name: Check arrow-flight + working-directory: arrow-flight + run: cargo msrv verify + - name: Downgrade object_store dependencies + working-directory: object_store + # Necessary because 1.30.0 updates MSRV to 1.63 + run: cargo update -p tokio --precise 1.29.1 + - name: Check object_store + working-directory: object_store + run: cargo msrv verify diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index 295728a67d3a..cde9b8f3b521 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -19,6 +19,324 @@ # Historical Changelog +## [47.0.0](https://github.com/apache/arrow-rs/tree/47.0.0) (2023-09-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/46.0.0...47.0.0) + +**Breaking changes:** + +- Make FixedSizeBinaryArray value\_data return a reference [\#4820](https://github.com/apache/arrow-rs/issues/4820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update prost to v0.12.1 [\#4825](https://github.com/apache/arrow-rs/pull/4825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: FixedSizeBinaryArray::value\_data return reference [\#4821](https://github.com/apache/arrow-rs/pull/4821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Stateless Row Encoding / Don't Preserve Dictionaries in `RowConverter` \(\#4811\) [\#4819](https://github.com/apache/arrow-rs/pull/4819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- fix: entries field is non-nullable [\#4808](https://github.com/apache/arrow-rs/pull/4808) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Fix flight sql do put handling, add bind parameter support to FlightSQL cli client [\#4797](https://github.com/apache/arrow-rs/pull/4797) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([suremarc](https://github.com/suremarc)) +- Remove unused dyn\_cmp\_dict feature [\#4766](https://github.com/apache/arrow-rs/pull/4766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add underlying `std::io::Error` to `IoError` and add `IpcError` variant [\#4726](https://github.com/apache/arrow-rs/pull/4726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexandreyc](https://github.com/alexandreyc)) + +**Implemented enhancements:** + +- Row Format Adapative Block Size [\#4812](https://github.com/apache/arrow-rs/issues/4812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Stateless Row Conversion [\#4811](https://github.com/apache/arrow-rs/issues/4811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add option to specify custom null values for CSV reader [\#4794](https://github.com/apache/arrow-rs/issues/4794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet::record::RowIter cannot be customized with batch\_size and defaults to 1024 [\#4782](https://github.com/apache/arrow-rs/issues/4782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `DynScalar` abstraction \(something that makes it easy to create scalar `Datum`s\) [\#4781](https://github.com/apache/arrow-rs/issues/4781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Datum` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4780](https://github.com/apache/arrow-rs/issues/4780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Scalar` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4779](https://github.com/apache/arrow-rs/issues/4779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support IntoPyArrow for impl RecordBatchReader [\#4730](https://github.com/apache/arrow-rs/issues/4730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Datum Based String Kernels [\#4595](https://github.com/apache/arrow-rs/issues/4595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- MapArray::new\_from\_strings creates nullable entries field [\#4807](https://github.com/apache/arrow-rs/issues/4807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- pyarrow module can't roundtrip tensor arrays [\#4805](https://github.com/apache/arrow-rs/issues/4805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `concat_batches` errors with "schema mismatch" error when only metadata differs [\#4799](https://github.com/apache/arrow-rs/issues/4799) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- panic in `cmp` kernels with DictionaryArrays: `Option::unwrap()` on a `None` value' [\#4788](https://github.com/apache/arrow-rs/issues/4788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- stream ffi panics if schema metadata values aren't valid utf8 [\#4750](https://github.com/apache/arrow-rs/issues/4750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression: Incorrect Sorting of `*ListArray` in 46.0.0 [\#4746](https://github.com/apache/arrow-rs/issues/4746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row is no longer comparable after reuse [\#4741](https://github.com/apache/arrow-rs/issues/4741) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- DoPut FlightSQL handler inadvertently consumes schema at start of Request\\> [\#4658](https://github.com/apache/arrow-rs/issues/4658) +- Return error when converting schema [\#4752](https://github.com/apache/arrow-rs/pull/4752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Implement PyArrowType for `Box` [\#4751](https://github.com/apache/arrow-rs/pull/4751) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) + +**Closed issues:** + +- Building arrow-rust for target wasm32-wasi falied to compile packed\_simd\_2 [\#4717](https://github.com/apache/arrow-rs/issues/4717) + +**Merged pull requests:** + +- Respect FormatOption::nulls for NullArray [\#4836](https://github.com/apache/arrow-rs/pull/4836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix merge\_dictionary\_values in selection kernels [\#4833](https://github.com/apache/arrow-rs/pull/4833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix like scalar null [\#4832](https://github.com/apache/arrow-rs/pull/4832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More chrono deprecations [\#4822](https://github.com/apache/arrow-rs/pull/4822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Adaptive Row Block Size \(\#4812\) [\#4818](https://github.com/apache/arrow-rs/pull/4818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.66 to =1.0.67 [\#4816](https://github.com/apache/arrow-rs/pull/4816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Do not check schema for equality in concat\_batches [\#4815](https://github.com/apache/arrow-rs/pull/4815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: export record batch through stream [\#4806](https://github.com/apache/arrow-rs/pull/4806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Improve CSV Reader Benchmark Coverage of Small Primitives [\#4803](https://github.com/apache/arrow-rs/pull/4803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- csv: Add option to specify custom null values [\#4795](https://github.com/apache/arrow-rs/pull/4795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vrongmeal](https://github.com/vrongmeal)) +- Expand docstring and add example to `Scalar` [\#4793](https://github.com/apache/arrow-rs/pull/4793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Re-export array crate root \(\#4780\) \(\#4779\) [\#4791](https://github.com/apache/arrow-rs/pull/4791) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix DictionaryArray::normalized\_keys \(\#4788\) [\#4789](https://github.com/apache/arrow-rs/pull/4789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow custom tree builder for parquet::record::RowIter [\#4783](https://github.com/apache/arrow-rs/pull/4783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([YuraKotov](https://github.com/YuraKotov)) +- Bump actions/checkout from 3 to 4 [\#4767](https://github.com/apache/arrow-rs/pull/4767) ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: avoid panic if offset index not exists. [\#4761](https://github.com/apache/arrow-rs/pull/4761) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Relax constraints on PyArrowType [\#4757](https://github.com/apache/arrow-rs/pull/4757) ([tustvold](https://github.com/tustvold)) +- Chrono deprecations [\#4748](https://github.com/apache/arrow-rs/pull/4748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix List Sorting, Revert Removal of Rank Kernels [\#4747](https://github.com/apache/arrow-rs/pull/4747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clear row buffer before reuse [\#4742](https://github.com/apache/arrow-rs/pull/4742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- Datum based like kernels \(\#4595\) [\#4732](https://github.com/apache/arrow-rs/pull/4732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: expose DoGet response headers & trailers [\#4727](https://github.com/apache/arrow-rs/pull/4727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Cleanup length and bit\_length kernels [\#4718](https://github.com/apache/arrow-rs/pull/4718) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [46.0.0](https://github.com/apache/arrow-rs/tree/46.0.0) (2023-08-21) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/45.0.0...46.0.0) + +**Breaking changes:** + +- API improvement: `batches_to_flight_data` forces clone [\#4656](https://github.com/apache/arrow-rs/issues/4656) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add AnyDictionary Abstraction and Take ArrayRef in DictionaryArray::with\_values [\#4707](https://github.com/apache/arrow-rs/pull/4707) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup parquet type builders [\#4706](https://github.com/apache/arrow-rs/pull/4706) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Take kernel dyn Array [\#4705](https://github.com/apache/arrow-rs/pull/4705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve ergonomics of Scalar [\#4704](https://github.com/apache/arrow-rs/pull/4704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Datum based comparison kernels \(\#4596\) [\#4701](https://github.com/apache/arrow-rs/pull/4701) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Improve `Array` Logical Nullability [\#4691](https://github.com/apache/arrow-rs/pull/4691) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Validate ArrayData Buffer Alignment and Automatically Align IPC buffers \(\#4255\) [\#4681](https://github.com/apache/arrow-rs/pull/4681) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More intuitive bool-to-string casting [\#4666](https://github.com/apache/arrow-rs/pull/4666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fsdvh](https://github.com/fsdvh)) +- enhancement: batches\_to\_flight\_data use a schema ref as param. [\#4665](https://github.com/apache/arrow-rs/pull/4665) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([jackwener](https://github.com/jackwener)) +- fix: from\_thrift avoid panic when stats in invalid. [\#4642](https://github.com/apache/arrow-rs/pull/4642) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jackwener](https://github.com/jackwener)) +- bug: Add some missing field in row group metadata: ordinal, total co… [\#4636](https://github.com/apache/arrow-rs/pull/4636) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liurenjie1024](https://github.com/liurenjie1024)) +- Remove deprecated limit kernel [\#4597](https://github.com/apache/arrow-rs/pull/4597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- parquet: support setting the field\_id with an ArrowWriter [\#4702](https://github.com/apache/arrow-rs/issues/4702) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support references in i256 arithmetic ops [\#4694](https://github.com/apache/arrow-rs/issues/4694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Precision-Loss Decimal Arithmetic [\#4664](https://github.com/apache/arrow-rs/issues/4664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Faster i256 Division [\#4663](https://github.com/apache/arrow-rs/issues/4663) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `concat_batches` for 0 columns [\#4661](https://github.com/apache/arrow-rs/issues/4661) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `filter_record_batch` should support filtering record batch without columns [\#4647](https://github.com/apache/arrow-rs/issues/4647) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve speed of `lexicographical_partition_ranges` [\#4614](https://github.com/apache/arrow-rs/issues/4614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- object\_store: multipart ranges for HTTP [\#4612](https://github.com/apache/arrow-rs/issues/4612) +- Add Rank Function [\#4606](https://github.com/apache/arrow-rs/issues/4606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Datum Based Comparison Kernels [\#4596](https://github.com/apache/arrow-rs/issues/4596) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Convenience method to create `DataType::List` correctly [\#4544](https://github.com/apache/arrow-rs/issues/4544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove Deprecated Arithmetic Kernels [\#4481](https://github.com/apache/arrow-rs/issues/4481) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Equality kernel where null==null gives true [\#4438](https://github.com/apache/arrow-rs/issues/4438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Parquet ArrowWriter Ignores Nulls in Dictionary Values [\#4690](https://github.com/apache/arrow-rs/issues/4690) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Schema Nullability Validation Fails to Account for Dictionary Nulls [\#4689](https://github.com/apache/arrow-rs/issues/4689) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Comparison Kernels Ignore Nulls in Dictionary Values [\#4688](https://github.com/apache/arrow-rs/issues/4688) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting List to String Ignores Format Options [\#4669](https://github.com/apache/arrow-rs/issues/4669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Double free in C Stream Interface [\#4659](https://github.com/apache/arrow-rs/issues/4659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CI Failing On Packed SIMD [\#4651](https://github.com/apache/arrow-rs/issues/4651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `RowInterner::size()` much too low for high cardinality dictionary columns [\#4645](https://github.com/apache/arrow-rs/issues/4645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Decimal PrimitiveArray change datatype after try\_unary [\#4644](https://github.com/apache/arrow-rs/issues/4644) +- Better explanation in docs for Dictionary field encoding using RowConverter [\#4639](https://github.com/apache/arrow-rs/issues/4639) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `List(FixedSizeBinary)` array equality check may return wrong result [\#4637](https://github.com/apache/arrow-rs/issues/4637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow::compute::nullif` panics if `NullArray` is provided [\#4634](https://github.com/apache/arrow-rs/issues/4634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Empty lists in FixedSizeListArray::try\_new is not handled [\#4623](https://github.com/apache/arrow-rs/issues/4623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Bounds checking in `MutableBuffer::set_null_bits` can be bypassed [\#4620](https://github.com/apache/arrow-rs/issues/4620) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- TypedDictionaryArray Misleading Null Behaviour [\#4616](https://github.com/apache/arrow-rs/issues/4616) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- bug: Parquet writer missing row group metadata fields such as `compressed_size`, `file offset`. [\#4610](https://github.com/apache/arrow-rs/issues/4610) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `new_null_array` generates an invalid union array [\#4600](https://github.com/apache/arrow-rs/issues/4600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Footer parsing fails for very large parquet file. [\#4592](https://github.com/apache/arrow-rs/issues/4592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- bug\(parquet\): Disabling global statistics but enabling for particular column breaks reading [\#4587](https://github.com/apache/arrow-rs/issues/4587) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `arrow::compute::concat` panics for dense union arrays with non-trivial type IDs [\#4578](https://github.com/apache/arrow-rs/issues/4578) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- \[object\_store\] when Create a AmazonS3 instance work with MinIO without set endpoint got error MissingRegion [\#4617](https://github.com/apache/arrow-rs/issues/4617) + +**Merged pull requests:** + +- Add distinct kernels \(\#960\) \(\#4438\) [\#4716](https://github.com/apache/arrow-rs/pull/4716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update parquet object\_store 0.7 [\#4715](https://github.com/apache/arrow-rs/pull/4715) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support Field ID in ArrowWriter \(\#4702\) [\#4710](https://github.com/apache/arrow-rs/pull/4710) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove rank kernels [\#4703](https://github.com/apache/arrow-rs/pull/4703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support references in i256 arithmetic ops [\#4692](https://github.com/apache/arrow-rs/pull/4692) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cleanup DynComparator \(\#2654\) [\#4687](https://github.com/apache/arrow-rs/pull/4687) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Separate metadata fetch from `ArrowReaderBuilder` construction \(\#4674\) [\#4676](https://github.com/apache/arrow-rs/pull/4676) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- cleanup some assert\(\) with error propagation [\#4673](https://github.com/apache/arrow-rs/pull/4673) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Faster i256 Division \(2-100x\) \(\#4663\) [\#4672](https://github.com/apache/arrow-rs/pull/4672) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix MSRV CI [\#4671](https://github.com/apache/arrow-rs/pull/4671) ([tustvold](https://github.com/tustvold)) +- Fix equality of nested nullable FixedSizeBinary \(\#4637\) [\#4670](https://github.com/apache/arrow-rs/pull/4670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use ArrayFormatter in cast kernel [\#4668](https://github.com/apache/arrow-rs/pull/4668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Improve API docs for FlightSQL metadata builders [\#4667](https://github.com/apache/arrow-rs/pull/4667) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Support `concat_batches` for 0 columns [\#4662](https://github.com/apache/arrow-rs/pull/4662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- fix ownership of c stream error [\#4660](https://github.com/apache/arrow-rs/pull/4660) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Minor: Fix illustration for dict encoding [\#4657](https://github.com/apache/arrow-rs/pull/4657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JayjeetAtGithub](https://github.com/JayjeetAtGithub)) +- minor: move comment to the correct location [\#4655](https://github.com/apache/arrow-rs/pull/4655) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Update packed\_simd and run miri tests on simd code [\#4654](https://github.com/apache/arrow-rs/pull/4654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- impl `From>` for `BufferBuilder` and `MutableBuffer` [\#4650](https://github.com/apache/arrow-rs/pull/4650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Filter record batch with 0 columns [\#4648](https://github.com/apache/arrow-rs/pull/4648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Account for child `Bucket` size in OrderPreservingInterner [\#4646](https://github.com/apache/arrow-rs/pull/4646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Implement `Default`,`Extend` and `FromIterator` for `BufferBuilder` [\#4638](https://github.com/apache/arrow-rs/pull/4638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- fix\(select\): handle `NullArray` in `nullif` [\#4635](https://github.com/apache/arrow-rs/pull/4635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Move `BufferBuilder` to `arrow-buffer` [\#4630](https://github.com/apache/arrow-rs/pull/4630) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- allow zero sized empty fixed [\#4626](https://github.com/apache/arrow-rs/pull/4626) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- fix: compute\_dictionary\_mapping use wrong offsetSize [\#4625](https://github.com/apache/arrow-rs/pull/4625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- impl `FromIterator` for `MutableBuffer` [\#4624](https://github.com/apache/arrow-rs/pull/4624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- expand docs for FixedSizeListArray [\#4622](https://github.com/apache/arrow-rs/pull/4622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- fix\(buffer\): panic on end index overflow in `MutableBuffer::set_null_bits` [\#4621](https://github.com/apache/arrow-rs/pull/4621) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- impl `Default` for `arrow_buffer::buffer::MutableBuffer` [\#4619](https://github.com/apache/arrow-rs/pull/4619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Minor: improve docs and add example for lexicographical\_partition\_ranges [\#4615](https://github.com/apache/arrow-rs/pull/4615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Cleanup sort [\#4613](https://github.com/apache/arrow-rs/pull/4613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add rank function \(\#4606\) [\#4609](https://github.com/apache/arrow-rs/pull/4609) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add more docs and examples for ListArray and OffsetsBuffer [\#4607](https://github.com/apache/arrow-rs/pull/4607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Simplify dictionary sort [\#4605](https://github.com/apache/arrow-rs/pull/4605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Consolidate sort benchmarks [\#4604](https://github.com/apache/arrow-rs/pull/4604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Don't Reorder Nulls in sort\_to\_indices \(\#4545\) [\#4603](https://github.com/apache/arrow-rs/pull/4603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix\(data\): create child arrays of correct length when building a sparse union null array [\#4601](https://github.com/apache/arrow-rs/pull/4601) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Use u32 metadata\_len when parsing footer of parquet. [\#4599](https://github.com/apache/arrow-rs/pull/4599) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Berrysoft](https://github.com/Berrysoft)) +- fix\(data\): map type ID to child index before indexing a union child array [\#4598](https://github.com/apache/arrow-rs/pull/4598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Remove deprecated arithmetic kernels \(\#4481\) [\#4594](https://github.com/apache/arrow-rs/pull/4594) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Test Disabled Page Statistics \(\#4587\) [\#4589](https://github.com/apache/arrow-rs/pull/4589) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup ArrayData::buffers [\#4583](https://github.com/apache/arrow-rs/pull/4583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use contains\_nulls in ArrayData equality of byte arrays [\#4582](https://github.com/apache/arrow-rs/pull/4582) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Vectorized lexicographical\_partition\_ranges \(~80% faster\) [\#4575](https://github.com/apache/arrow-rs/pull/4575) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- chore: add datatype new\_list [\#4561](https://github.com/apache/arrow-rs/pull/4561) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +## [45.0.0](https://github.com/apache/arrow-rs/tree/45.0.0) (2023-07-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/44.0.0...45.0.0) + +**Breaking changes:** + +- Fix timezoned timestamp arithmetic [\#4546](https://github.com/apache/arrow-rs/pull/4546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) + +**Implemented enhancements:** + +- Use FormatOptions in Const Contexts [\#4580](https://github.com/apache/arrow-rs/issues/4580) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Human Readable Duration Display [\#4554](https://github.com/apache/arrow-rs/issues/4554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `BooleanBuilder`: Add `validity_slice` method for accessing validity bits [\#4535](https://github.com/apache/arrow-rs/issues/4535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `FixedSizedListArray` for `length` kernel [\#4517](https://github.com/apache/arrow-rs/issues/4517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `RowCoverter::convert` that targets an existing `Rows` [\#4479](https://github.com/apache/arrow-rs/issues/4479) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Panic `assertion failed: idx < self.len` when casting DictionaryArrays with nulls [\#4576](https://github.com/apache/arrow-rs/issues/4576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-arith is\_null is buggy with NullArray [\#4565](https://github.com/apache/arrow-rs/issues/4565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect Interval to Duration Casting [\#4553](https://github.com/apache/arrow-rs/issues/4553) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Too large validity buffer pre-allocation in `FixedSizeListBuilder::new` [\#4549](https://github.com/apache/arrow-rs/issues/4549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Like with wildcards fail to match fields with new lines. [\#4547](https://github.com/apache/arrow-rs/issues/4547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Timestamp Interval Arithmetic Ignores Timezone [\#4457](https://github.com/apache/arrow-rs/issues/4457) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- refactor: simplify hour\_dyn\(\) with time\_fraction\_dyn\(\) [\#4588](https://github.com/apache/arrow-rs/pull/4588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Move from\_iter\_values to GenericByteArray [\#4586](https://github.com/apache/arrow-rs/pull/4586) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Mark GenericByteArray::new\_unchecked unsafe [\#4584](https://github.com/apache/arrow-rs/pull/4584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Configurable Duration Display [\#4581](https://github.com/apache/arrow-rs/pull/4581) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix take\_bytes Null and Overflow Handling \(\#4576\) [\#4579](https://github.com/apache/arrow-rs/pull/4579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move chrono-tz arithmetic tests to integration [\#4571](https://github.com/apache/arrow-rs/pull/4571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Write Page Offset Index For All-Nan Pages [\#4567](https://github.com/apache/arrow-rs/pull/4567) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([MachaelLee](https://github.com/MachaelLee)) +- support NullArray un arith/boolean kernel [\#4566](https://github.com/apache/arrow-rs/pull/4566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- Remove Sync from arrow-flight example [\#4564](https://github.com/apache/arrow-rs/pull/4564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Fix interval to duration casting \(\#4553\) [\#4562](https://github.com/apache/arrow-rs/pull/4562) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- docs: fix wrong parameter name [\#4559](https://github.com/apache/arrow-rs/pull/4559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +- Fix FixedSizeListBuilder capacity \(\#4549\) [\#4552](https://github.com/apache/arrow-rs/pull/4552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- docs: fix wrong inline code snippet in parquet document [\#4550](https://github.com/apache/arrow-rs/pull/4550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +- fix multiline wildcard likes \(fixes \#4547\) [\#4548](https://github.com/apache/arrow-rs/pull/4548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nl5887](https://github.com/nl5887)) +- Provide default `is_empty` impl for `arrow::array::ArrayBuilder` [\#4543](https://github.com/apache/arrow-rs/pull/4543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Add RowConverter::append \(\#4479\) [\#4541](https://github.com/apache/arrow-rs/pull/4541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clarify GenericColumnReader::read\_records [\#4540](https://github.com/apache/arrow-rs/pull/4540) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Initial loongarch port [\#4538](https://github.com/apache/arrow-rs/pull/4538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xiangzhai](https://github.com/xiangzhai)) +- Update proc-macro2 requirement from =1.0.64 to =1.0.66 [\#4537](https://github.com/apache/arrow-rs/pull/4537) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- add a validity slice access for boolean array builders [\#4536](https://github.com/apache/arrow-rs/pull/4536) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ChristianBeilschmidt](https://github.com/ChristianBeilschmidt)) +- use new num version instead of explicit num-complex dependency [\#4532](https://github.com/apache/arrow-rs/pull/4532) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mwlon](https://github.com/mwlon)) +- feat: Support `FixedSizedListArray` for `length` kernel [\#4520](https://github.com/apache/arrow-rs/pull/4520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +## [44.0.0](https://github.com/apache/arrow-rs/tree/44.0.0) (2023-07-14) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/43.0.0...44.0.0) + +**Breaking changes:** + +- Use Parser for cast kernel \(\#4512\) [\#4513](https://github.com/apache/arrow-rs/pull/4513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Datum based arithmetic kernels \(\#3999\) [\#4465](https://github.com/apache/arrow-rs/pull/4465) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- eq\_dyn\_binary\_scalar should support FixedSizeBinary types [\#4491](https://github.com/apache/arrow-rs/issues/4491) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Port Tests from Deprecated Arithmetic Kernels [\#4480](https://github.com/apache/arrow-rs/issues/4480) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement RecordBatchReader for Boxed trait object [\#4474](https://github.com/apache/arrow-rs/issues/4474) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Date` - `Date` kernel [\#4383](https://github.com/apache/arrow-rs/issues/4383) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Default FlightSqlService Implementations [\#4372](https://github.com/apache/arrow-rs/issues/4372) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Parquet: `AsyncArrowWriter` to a file corrupts the footer for large columns [\#4526](https://github.com/apache/arrow-rs/issues/4526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[object\_store\] Failure to send bytes to azure [\#4522](https://github.com/apache/arrow-rs/issues/4522) +- Cannot cast string '2021-01-02' to value of Date64 type [\#4512](https://github.com/apache/arrow-rs/issues/4512) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect Interval Subtraction [\#4489](https://github.com/apache/arrow-rs/issues/4489) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Interval Negation Incorrect [\#4488](https://github.com/apache/arrow-rs/issues/4488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet: AsyncArrowWriter inner buffer is not correctly limited and causes OOM [\#4477](https://github.com/apache/arrow-rs/issues/4477) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Fix AsyncArrowWriter flush for large buffer sizes \(\#4526\) [\#4527](https://github.com/apache/arrow-rs/pull/4527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup cast\_primitive\_to\_list [\#4511](https://github.com/apache/arrow-rs/pull/4511) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Bump actions/upload-pages-artifact from 1 to 2 [\#4508](https://github.com/apache/arrow-rs/pull/4508) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support Date - Date \(\#4383\) [\#4504](https://github.com/apache/arrow-rs/pull/4504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Bump actions/labeler from 4.2.0 to 4.3.0 [\#4501](https://github.com/apache/arrow-rs/pull/4501) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.63 to =1.0.64 [\#4500](https://github.com/apache/arrow-rs/pull/4500) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add negate kernels \(\#4488\) [\#4494](https://github.com/apache/arrow-rs/pull/4494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Datum Arithmetic tests, Fix Interval Substraction \(\#4480\) [\#4493](https://github.com/apache/arrow-rs/pull/4493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- support FixedSizeBinary types in eq\_dyn\_binary\_scalar/neq\_dyn\_binary\_scalar [\#4492](https://github.com/apache/arrow-rs/pull/4492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +- Add default implementations to the FlightSqlService trait [\#4485](https://github.com/apache/arrow-rs/pull/4485) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([rossjones](https://github.com/rossjones)) +- add num-complex requirement [\#4482](https://github.com/apache/arrow-rs/pull/4482) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mwlon](https://github.com/mwlon)) +- fix incorrect buffer size limiting in parquet async writer [\#4478](https://github.com/apache/arrow-rs/pull/4478) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([richox](https://github.com/richox)) +- feat: support RecordBatchReader on boxed trait objects [\#4475](https://github.com/apache/arrow-rs/pull/4475) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Improve in-place primitive sorts by 13-67% [\#4473](https://github.com/apache/arrow-rs/pull/4473) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Add Scalar/Datum abstraction \(\#1047\) [\#4393](https://github.com/apache/arrow-rs/pull/4393) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [43.0.0](https://github.com/apache/arrow-rs/tree/43.0.0) (2023-06-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/42.0.0...43.0.0) + +**Breaking changes:** + +- Simplify ffi import/export [\#4447](https://github.com/apache/arrow-rs/pull/4447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Virgiel](https://github.com/Virgiel)) +- Return Result from Parquet Row APIs [\#4428](https://github.com/apache/arrow-rs/pull/4428) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Remove Binary Dictionary Arithmetic Support [\#4407](https://github.com/apache/arrow-rs/pull/4407) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Request: a way to copy a `Row` to `Rows` [\#4466](https://github.com/apache/arrow-rs/issues/4466) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Reuse schema when importing from FFI [\#4444](https://github.com/apache/arrow-rs/issues/4444) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] Allow implementations of `FlightSqlService` to handle custom actions and commands [\#4439](https://github.com/apache/arrow-rs/issues/4439) +- Support `NullBuilder` [\#4429](https://github.com/apache/arrow-rs/issues/4429) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Regression in in parquet `42.0.0` : Bad parquet column indexes for All Null Columns, resulting in `Parquet error: StructArrayReader out of sync` on read [\#4459](https://github.com/apache/arrow-rs/issues/4459) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Regression in 42.0.0: Parsing fractional intervals without leading 0 is not supported [\#4424](https://github.com/apache/arrow-rs/issues/4424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- doc: deploy crate docs to GitHub pages [\#4436](https://github.com/apache/arrow-rs/pull/4436) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) + +**Merged pull requests:** + +- Append Row to Rows \(\#4466\) [\#4470](https://github.com/apache/arrow-rs/pull/4470) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat\(flight-sql\): Allow implementations of FlightSqlService to handle custom actions and commands [\#4463](https://github.com/apache/arrow-rs/pull/4463) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- Docs: Add clearer API doc links [\#4461](https://github.com/apache/arrow-rs/pull/4461) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Fix empty offset index for all null columns \(\#4459\) [\#4460](https://github.com/apache/arrow-rs/pull/4460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Bump peaceiris/actions-gh-pages from 3.9.2 to 3.9.3 [\#4455](https://github.com/apache/arrow-rs/pull/4455) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Convince the compiler to auto-vectorize the range check in parquet DictionaryBuffer [\#4453](https://github.com/apache/arrow-rs/pull/4453) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- fix docs deployment [\#4452](https://github.com/apache/arrow-rs/pull/4452) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) +- Update indexmap requirement from 1.9 to 2.0 [\#4451](https://github.com/apache/arrow-rs/pull/4451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.60 to =1.0.63 [\#4450](https://github.com/apache/arrow-rs/pull/4450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/deploy-pages from 1 to 2 [\#4449](https://github.com/apache/arrow-rs/pull/4449) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Revise error message in From\ for ScalarBuffer [\#4446](https://github.com/apache/arrow-rs/pull/4446) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- minor: remove useless mut [\#4443](https://github.com/apache/arrow-rs/pull/4443) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- unify substring for binary&utf8 [\#4442](https://github.com/apache/arrow-rs/pull/4442) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Casting fixedsizelist to list/largelist [\#4433](https://github.com/apache/arrow-rs/pull/4433) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jayzhan211](https://github.com/jayzhan211)) +- feat: support `NullBuilder` [\#4430](https://github.com/apache/arrow-rs/pull/4430) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Remove Float64 -\> Float32 cast in IPC Reader [\#4427](https://github.com/apache/arrow-rs/pull/4427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) +- Parse intervals like `.5` the same as `0.5` [\#4425](https://github.com/apache/arrow-rs/pull/4425) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add strict mode to json reader [\#4421](https://github.com/apache/arrow-rs/pull/4421) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([blinkseb](https://github.com/blinkseb)) +- Add DictionaryArray::occupancy [\#4415](https://github.com/apache/arrow-rs/pull/4415) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) ## [42.0.0](https://github.com/apache/arrow-rs/tree/42.0.0) (2023-06-16) [Full Changelog](https://github.com/apache/arrow-rs/compare/41.0.0...42.0.0) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ed2f1420684..8c5351708c0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,53 +19,86 @@ # Changelog -## [43.0.0](https://github.com/apache/arrow-rs/tree/43.0.0) (2023-06-30) +## [48.0.0](https://github.com/apache/arrow-rs/tree/48.0.0) (2023-10-18) -[Full Changelog](https://github.com/apache/arrow-rs/compare/42.0.0...43.0.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/47.0.0...48.0.0) **Breaking changes:** -- Simplify ffi import/export [\#4447](https://github.com/apache/arrow-rs/pull/4447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Virgiel](https://github.com/Virgiel)) -- Return Result from Parquet Row APIs [\#4428](https://github.com/apache/arrow-rs/pull/4428) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) -- Remove Binary Dictionary Arithmetic Support [\#4407](https://github.com/apache/arrow-rs/pull/4407) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Evaluate null\_regex for string type in csv \(now such values will be parsed as `Null` rather than `""`\) [\#4942](https://github.com/apache/arrow-rs/pull/4942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haohuaijin](https://github.com/haohuaijin)) +- fix\(csv\)!: infer null for empty column. [\#4910](https://github.com/apache/arrow-rs/pull/4910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- feat: log headers/trailers in flight CLI \(+ minor fixes\) [\#4898](https://github.com/apache/arrow-rs/pull/4898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- fix\(arrow-json\)!: include null fields in schema inference with a type of Null [\#4894](https://github.com/apache/arrow-rs/pull/4894) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Mark OnCloseRowGroup Send [\#4893](https://github.com/apache/arrow-rs/pull/4893) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) +- Specialize Thrift Decoding \(~40% Faster\) \(\#4891\) [\#4892](https://github.com/apache/arrow-rs/pull/4892) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make ArrowRowGroupWriter Public and SerializedRowGroupWriter Send [\#4850](https://github.com/apache/arrow-rs/pull/4850) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) **Implemented enhancements:** -- Request: a way to copy a `Row` to `Rows` [\#4466](https://github.com/apache/arrow-rs/issues/4466) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Reuse schema when importing from FFI [\#4444](https://github.com/apache/arrow-rs/issues/4444) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[FlightSQL\] Allow implementations of `FlightSqlService` to handle custom actions and commands [\#4439](https://github.com/apache/arrow-rs/issues/4439) -- Support `NullBuilder` [\#4429](https://github.com/apache/arrow-rs/issues/4429) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow schema fields to merge with `Null` datatype [\#4901](https://github.com/apache/arrow-rs/issues/4901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add option to FlightDataEncoder to always send dictionaries [\#4895](https://github.com/apache/arrow-rs/issues/4895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Rework Thrift Encoding / Decoding of Parquet Metadata [\#4891](https://github.com/apache/arrow-rs/issues/4891) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Plans for supporting Extension Array to support Fixed shape tensor Array [\#4890](https://github.com/apache/arrow-rs/issues/4890) +- Implement Take for UnionArray [\#4882](https://github.com/apache/arrow-rs/issues/4882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check precision overflow for casting floating to decimal [\#4865](https://github.com/apache/arrow-rs/issues/4865) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace lexical [\#4774](https://github.com/apache/arrow-rs/issues/4774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add read access to settings in `csv::WriterBuilder` [\#4735](https://github.com/apache/arrow-rs/issues/4735) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve the performance of "DictionaryValue" row encoding [\#4712](https://github.com/apache/arrow-rs/issues/4712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] **Fixed bugs:** -- Regression in in parquet `42.0.0` : Bad parquet column indexes for All Null Columns, resulting in `Parquet error: StructArrayReader out of sync` on read [\#4459](https://github.com/apache/arrow-rs/issues/4459) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Regression in 42.0.0: Parsing fractional intervals without leading 0 is not supported [\#4424](https://github.com/apache/arrow-rs/issues/4424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Should we make blank values and empty string to `None` in csv? [\#4939](https://github.com/apache/arrow-rs/issues/4939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] SubstraitPlan structure is not exported [\#4932](https://github.com/apache/arrow-rs/issues/4932) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Loading page index breaks skipping of pages with nested types [\#4921](https://github.com/apache/arrow-rs/issues/4921) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV schema inference assumes `Utf8` for empty columns [\#4903](https://github.com/apache/arrow-rs/issues/4903) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: Field Ids are not read from a Parquet file without serialized arrow schema [\#4877](https://github.com/apache/arrow-rs/issues/4877) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- make\_primitive\_scalar function loses DataType Internal information [\#4851](https://github.com/apache/arrow-rs/issues/4851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- StructBuilder doesn't handle nulls correctly for empty structs [\#4842](https://github.com/apache/arrow-rs/issues/4842) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `NullArray::is_null()` returns `false` incorrectly [\#4835](https://github.com/apache/arrow-rs/issues/4835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- cast\_string\_to\_decimal should check precision overflow [\#4829](https://github.com/apache/arrow-rs/issues/4829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Null fields are omitted by `infer_json_schema_from_seekable` [\#4814](https://github.com/apache/arrow-rs/issues/4814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -**Documentation updates:** +**Closed issues:** -- doc: deploy crate docs to GitHub pages [\#4436](https://github.com/apache/arrow-rs/pull/4436) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) +- Support for reading JSON Array to Arrow [\#4905](https://github.com/apache/arrow-rs/issues/4905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Merged pull requests:** -- Append Row to Rows \(\#4466\) [\#4470](https://github.com/apache/arrow-rs/pull/4470) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- feat\(flight-sql\): Allow implementations of FlightSqlService to handle custom actions and commands [\#4463](https://github.com/apache/arrow-rs/pull/4463) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) -- Docs: Add clearer API doc links [\#4461](https://github.com/apache/arrow-rs/pull/4461) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) -- Fix empty offset index for all null columns \(\#4459\) [\#4460](https://github.com/apache/arrow-rs/pull/4460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Bump peaceiris/actions-gh-pages from 3.9.2 to 3.9.3 [\#4455](https://github.com/apache/arrow-rs/pull/4455) ([dependabot[bot]](https://github.com/apps/dependabot)) -- Convince the compiler to auto-vectorize the range check in parquet DictionaryBuffer [\#4453](https://github.com/apache/arrow-rs/pull/4453) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) -- fix docs deployment [\#4452](https://github.com/apache/arrow-rs/pull/4452) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) -- Update indexmap requirement from 1.9 to 2.0 [\#4451](https://github.com/apache/arrow-rs/pull/4451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Update proc-macro2 requirement from =1.0.60 to =1.0.63 [\#4450](https://github.com/apache/arrow-rs/pull/4450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) -- Bump actions/deploy-pages from 1 to 2 [\#4449](https://github.com/apache/arrow-rs/pull/4449) ([dependabot[bot]](https://github.com/apps/dependabot)) -- Revise error message in From\ for ScalarBuffer [\#4446](https://github.com/apache/arrow-rs/pull/4446) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- minor: remove useless mut [\#4443](https://github.com/apache/arrow-rs/pull/4443) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) -- unify substring for binary&utf8 [\#4442](https://github.com/apache/arrow-rs/pull/4442) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) -- Casting fixedsizelist to list/largelist [\#4433](https://github.com/apache/arrow-rs/pull/4433) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jayzhan211](https://github.com/jayzhan211)) -- feat: support `NullBuilder` [\#4430](https://github.com/apache/arrow-rs/pull/4430) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) -- Remove Float64 -\> Float32 cast in IPC Reader [\#4427](https://github.com/apache/arrow-rs/pull/4427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) -- Parse intervals like `.5` the same as `0.5` [\#4425](https://github.com/apache/arrow-rs/pull/4425) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- feat: add strict mode to json reader [\#4421](https://github.com/apache/arrow-rs/pull/4421) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([blinkseb](https://github.com/blinkseb)) -- Add DictionaryArray::occupancy [\#4415](https://github.com/apache/arrow-rs/pull/4415) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Assume Pages Delimit Records When Offset Index Loaded \(\#4921\) [\#4943](https://github.com/apache/arrow-rs/pull/4943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update pyo3 requirement from 0.19 to 0.20 [\#4941](https://github.com/apache/arrow-rs/pull/4941) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add `FileWriter` schema getter [\#4940](https://github.com/apache/arrow-rs/pull/4940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haixuanTao](https://github.com/haixuanTao)) +- feat: support parsing for parquet writer option [\#4938](https://github.com/apache/arrow-rs/pull/4938) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([fansehep](https://github.com/fansehep)) +- Export `SubstraitPlan` structure in arrow\_flight::sql \(\#4932\) [\#4933](https://github.com/apache/arrow-rs/pull/4933) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- Update zstd requirement from 0.12.0 to 0.13.0 [\#4923](https://github.com/apache/arrow-rs/pull/4923) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: add method for async read bloom filter [\#4917](https://github.com/apache/arrow-rs/pull/4917) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hengfeiyang](https://github.com/hengfeiyang)) +- Minor: Clarify rationale for `FlightDataEncoder` API, add examples [\#4916](https://github.com/apache/arrow-rs/pull/4916) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Update regex-syntax requirement from 0.7.1 to 0.8.0 [\#4914](https://github.com/apache/arrow-rs/pull/4914) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: document & streamline flight SQL CLI [\#4912](https://github.com/apache/arrow-rs/pull/4912) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Arbitrary JSON values in JSON Reader \(\#4905\) [\#4911](https://github.com/apache/arrow-rs/pull/4911) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup CSV WriterBuilder, Default to AutoSI Second Precision \(\#4735\) [\#4909](https://github.com/apache/arrow-rs/pull/4909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.68 to =1.0.69 [\#4907](https://github.com/apache/arrow-rs/pull/4907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- chore: add csv example [\#4904](https://github.com/apache/arrow-rs/pull/4904) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +- feat\(schema\): allow null fields to be merged with other datatypes [\#4902](https://github.com/apache/arrow-rs/pull/4902) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Update proc-macro2 requirement from =1.0.67 to =1.0.68 [\#4900](https://github.com/apache/arrow-rs/pull/4900) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add option to `FlightDataEncoder` to always resend batch dictionaries [\#4896](https://github.com/apache/arrow-rs/pull/4896) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- Fix integration tests [\#4889](https://github.com/apache/arrow-rs/pull/4889) ([tustvold](https://github.com/tustvold)) +- Support Parsing Avro File Headers [\#4888](https://github.com/apache/arrow-rs/pull/4888) ([tustvold](https://github.com/tustvold)) +- Support parquet bloom filter length [\#4885](https://github.com/apache/arrow-rs/pull/4885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([letian-jiang](https://github.com/letian-jiang)) +- Replace lz4 with lz4\_flex Allowing Compilation for WASM [\#4884](https://github.com/apache/arrow-rs/pull/4884) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Take for UnionArray [\#4883](https://github.com/apache/arrow-rs/pull/4883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Update tonic-build requirement from =0.10.1 to =0.10.2 [\#4881](https://github.com/apache/arrow-rs/pull/4881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- parquet: Read field IDs from Parquet Schema [\#4878](https://github.com/apache/arrow-rs/pull/4878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Samrose-Ahmed](https://github.com/Samrose-Ahmed)) +- feat: improve flight CLI error handling [\#4873](https://github.com/apache/arrow-rs/pull/4873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Encoding Parquet Columns in Parallel [\#4871](https://github.com/apache/arrow-rs/pull/4871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Check precision overflow for casting floating to decimal [\#4866](https://github.com/apache/arrow-rs/pull/4866) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make align\_buffers as public API [\#4863](https://github.com/apache/arrow-rs/pull/4863) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable new integration tests \(\#4828\) [\#4862](https://github.com/apache/arrow-rs/pull/4862) ([tustvold](https://github.com/tustvold)) +- Faster Serde Integration \(~80% faster\) [\#4861](https://github.com/apache/arrow-rs/pull/4861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: make\_primitive\_scalar bug [\#4852](https://github.com/apache/arrow-rs/pull/4852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JasonLi-cn](https://github.com/JasonLi-cn)) +- Update tonic-build requirement from =0.10.0 to =0.10.1 [\#4846](https://github.com/apache/arrow-rs/pull/4846) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Allow Constructing Non-Empty StructArray with no Fields \(\#4842\) [\#4845](https://github.com/apache/arrow-rs/pull/4845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Refine documentation to `Array::is_null` [\#4838](https://github.com/apache/arrow-rs/pull/4838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: add missing precision overflow checking for `cast_string_to_decimal` [\#4830](https://github.com/apache/arrow-rs/pull/4830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonahgao](https://github.com/jonahgao)) diff --git a/Cargo.toml b/Cargo.toml index 173bafc6e08a..d59a5af68a19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ members = [ "arrow", "arrow-arith", "arrow-array", + "arrow-avro", "arrow-buffer", "arrow-cast", "arrow-csv", @@ -61,7 +62,7 @@ exclude = [ ] [workspace.package] -version = "43.0.0" +version = "48.0.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] @@ -76,18 +77,20 @@ edition = "2021" rust-version = "1.62" [workspace.dependencies] -arrow = { version = "43.0.0", path = "./arrow", default-features = false } -arrow-arith = { version = "43.0.0", path = "./arrow-arith" } -arrow-array = { version = "43.0.0", path = "./arrow-array" } -arrow-buffer = { version = "43.0.0", path = "./arrow-buffer" } -arrow-cast = { version = "43.0.0", path = "./arrow-cast" } -arrow-csv = { version = "43.0.0", path = "./arrow-csv" } -arrow-data = { version = "43.0.0", path = "./arrow-data" } -arrow-ipc = { version = "43.0.0", path = "./arrow-ipc" } -arrow-json = { version = "43.0.0", path = "./arrow-json" } -arrow-ord = { version = "43.0.0", path = "./arrow-ord" } -arrow-row = { version = "43.0.0", path = "./arrow-row" } -arrow-schema = { version = "43.0.0", path = "./arrow-schema" } -arrow-select = { version = "43.0.0", path = "./arrow-select" } -arrow-string = { version = "43.0.0", path = "./arrow-string" } -parquet = { version = "43.0.0", path = "./parquet", default-features = false } +arrow = { version = "48.0.0", path = "./arrow", default-features = false } +arrow-arith = { version = "48.0.0", path = "./arrow-arith" } +arrow-array = { version = "48.0.0", path = "./arrow-array" } +arrow-buffer = { version = "48.0.0", path = "./arrow-buffer" } +arrow-cast = { version = "48.0.0", path = "./arrow-cast" } +arrow-csv = { version = "48.0.0", path = "./arrow-csv" } +arrow-data = { version = "48.0.0", path = "./arrow-data" } +arrow-ipc = { version = "48.0.0", path = "./arrow-ipc" } +arrow-json = { version = "48.0.0", path = "./arrow-json" } +arrow-ord = { version = "48.0.0", path = "./arrow-ord" } +arrow-row = { version = "48.0.0", path = "./arrow-row" } +arrow-schema = { version = "48.0.0", path = "./arrow-schema" } +arrow-select = { version = "48.0.0", path = "./arrow-select" } +arrow-string = { version = "48.0.0", path = "./arrow-string" } +parquet = { version = "48.0.0", path = "./parquet", default-features = false } + +chrono = { version = "0.4.31", default-features = false, features = ["clock"] } diff --git a/arrow-arith/Cargo.toml b/arrow-arith/Cargo.toml index b5ea2e3c4354..57dc033e9645 100644 --- a/arrow-arith/Cargo.toml +++ b/arrow-arith/Cargo.toml @@ -38,7 +38,7 @@ arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } -chrono = { version = "0.4.23", default-features = false } +chrono = { workspace = true } half = { version = "2.1", default-features = false } num = { version = "0.4", default-features = false, features = ["std"] } diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 4961d7efc0f2..04417c666c85 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -867,8 +867,8 @@ where #[cfg(test)] mod tests { use super::*; - use crate::arithmetic::add; use arrow_array::types::*; + use arrow_buffer::NullBuffer; use std::sync::Arc; #[test] @@ -897,54 +897,35 @@ mod tests { #[test] fn test_primitive_array_sum_large_64() { - let a: Int64Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(i) } else { None }) - .collect(); - let b: Int64Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int64Array::new((1..=100).collect(), Some(validity)); + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); } #[test] fn test_primitive_array_sum_large_32() { - let a: Int32Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(i) } else { None }) - .collect(); - let b: Int32Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int32Array::new((1..=100).collect(), Some(validity)); assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); } #[test] fn test_primitive_array_sum_large_16() { - let a: Int16Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(i) } else { None }) - .collect(); - let b: Int16Array = (1..=100) - .map(|i| if i % 3 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int16Array::new((1..=100).collect(), Some(validity)); assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); } #[test] fn test_primitive_array_sum_large_8() { // include fewer values than other large tests so the result does not overflow the u8 - let a: UInt8Array = (1..=100) - .map(|i| if i % 33 == 0 { Some(i) } else { None }) - .collect(); - let b: UInt8Array = (1..=100) - .map(|i| if i % 33 == 0 { Some(0) } else { Some(i) }) - .collect(); // create an array that actually has non-zero values at the invalid indices - let c = add(&a, &b).unwrap(); + let validity = NullBuffer::new((1..=100).map(|x| x % 33 == 0).collect()); + let c = UInt8Array::new((1..=100).collect(), Some(validity)); assert_eq!(Some((1..=100).filter(|i| i % 33 == 0).sum()), sum(&c)); } diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 8e7ab44042cf..8635ce0ddd80 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -23,7 +23,6 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. use crate::arity::*; -use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::i256; @@ -32,1102 +31,6 @@ use arrow_schema::*; use std::cmp::min; use std::sync::Arc; -/// Helper function to perform math lambda function on values from two arrays. If either -/// left or right value is null then the output value is also null, so `1 + null` is -/// `null`. -/// -/// # Errors -/// -/// This function errors if the arrays have different lengths -pub fn math_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result, ArrowError> -where - LT: ArrowNumericType, - RT: ArrowNumericType, - F: Fn(LT::Native, RT::Native) -> LT::Native, -{ - binary(left, right, op) -} - -/// This is similar to `math_op` as it performs given operation between two input primitive arrays. -/// But the given operation can return `Err` if overflow is detected. For the case, this function -/// returns an `Err`. -fn math_checked_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result, ArrowError> -where - LT: ArrowNumericType, - RT: ArrowNumericType, - F: Fn(LT::Native, RT::Native) -> Result, -{ - try_binary(left, right, op) -} - -/// Helper function for operations where a valid `0` on the right array should -/// result in an [ArrowError::DivideByZero], namely the division and modulo operations -/// -/// # Errors -/// -/// This function errors if: -/// * the arrays have different lengths -/// * there is an element where both left and right values are valid and the right value is `0` -fn math_checked_divide_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result, ArrowError> -where - LT: ArrowNumericType, - RT: ArrowNumericType, - F: Fn(LT::Native, RT::Native) -> Result, -{ - math_checked_op(left, right, op) -} - -/// Calculates the modulus operation `left % right` on two SIMD inputs. -/// The lower-most bits of `valid_mask` specify which vector lanes are considered as valid. -/// -/// # Errors -/// -/// This function returns a [`ArrowError::DivideByZero`] if a valid element in `right` is `0` -#[cfg(feature = "simd")] -#[inline] -fn simd_checked_modulus( - valid_mask: Option, - left: T::Simd, - right: T::Simd, -) -> Result { - let zero = T::init(T::Native::ZERO); - let one = T::init(T::Native::ONE); - - let right_no_invalid_zeros = match valid_mask { - Some(mask) => { - let simd_mask = T::mask_from_u64(mask); - // select `1` for invalid lanes, which will be a no-op during division later - T::mask_select(simd_mask, right, one) - } - None => right, - }; - - let zero_mask = T::eq(right_no_invalid_zeros, zero); - - if T::mask_any(zero_mask) { - Err(ArrowError::DivideByZero) - } else { - Ok(T::bin_op(left, right_no_invalid_zeros, |a, b| a % b)) - } -} - -/// Calculates the division operation `left / right` on two SIMD inputs. -/// The lower-most bits of `valid_mask` specify which vector lanes are considered as valid. -/// -/// # Errors -/// -/// This function returns a [`ArrowError::DivideByZero`] if a valid element in `right` is `0` -#[cfg(feature = "simd")] -#[inline] -fn simd_checked_divide( - valid_mask: Option, - left: T::Simd, - right: T::Simd, -) -> Result { - let zero = T::init(T::Native::ZERO); - let one = T::init(T::Native::ONE); - - let right_no_invalid_zeros = match valid_mask { - Some(mask) => { - let simd_mask = T::mask_from_u64(mask); - // select `1` for invalid lanes, which will be a no-op during division later - T::mask_select(simd_mask, right, one) - } - None => right, - }; - - let zero_mask = T::eq(right_no_invalid_zeros, zero); - - if T::mask_any(zero_mask) { - Err(ArrowError::DivideByZero) - } else { - Ok(T::bin_op(left, right_no_invalid_zeros, |a, b| a / b)) - } -} - -/// Applies `op` on the remainder elements of two input chunks and writes the result into -/// the remainder elements of `result_chunks`. -/// The lower-most bits of `valid_mask` specify which elements are considered as valid. -/// -/// # Errors -/// -/// This function returns a [`ArrowError::DivideByZero`] if a valid element in `right` is `0` -#[cfg(feature = "simd")] -#[inline] -fn simd_checked_divide_op_remainder( - valid_mask: Option, - left_chunks: std::slice::ChunksExact, - right_chunks: std::slice::ChunksExact, - result_chunks: std::slice::ChunksExactMut, - op: F, -) -> Result<(), ArrowError> -where - T: ArrowNumericType, - F: Fn(T::Native, T::Native) -> T::Native, -{ - let result_remainder = result_chunks.into_remainder(); - let left_remainder = left_chunks.remainder(); - let right_remainder = right_chunks.remainder(); - - result_remainder - .iter_mut() - .zip(left_remainder.iter().zip(right_remainder.iter())) - .enumerate() - .try_for_each(|(i, (result_scalar, (left_scalar, right_scalar)))| { - if valid_mask.map(|mask| mask & (1 << i) != 0).unwrap_or(true) { - if right_scalar.is_zero() { - return Err(ArrowError::DivideByZero); - } - *result_scalar = op(*left_scalar, *right_scalar); - } else { - *result_scalar = T::default_value(); - } - Ok(()) - })?; - - Ok(()) -} - -/// Creates a new PrimitiveArray by applying `simd_op` to the `left` and `right` input array. -/// If the length of the arrays is not multiple of the number of vector lanes -/// then the remainder of the array will be calculated using `scalar_op`. -/// Any operation on a `NULL` value will result in a `NULL` value in the output. -/// -/// # Errors -/// -/// This function errors if: -/// * the arrays have different lengths -/// * there is an element where both left and right values are valid and the right value is `0` -#[cfg(feature = "simd")] -fn simd_checked_divide_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - simd_op: SI, - scalar_op: SC, -) -> Result, ArrowError> -where - T: ArrowNumericType, - SI: Fn(Option, T::Simd, T::Simd) -> Result, - SC: Fn(T::Native, T::Native) -> T::Native, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - - // Create the combined `Bitmap` - let nulls = arrow_buffer::NullBuffer::union(left.nulls(), right.nulls()); - - let lanes = T::lanes(); - let buffer_size = left.len() * std::mem::size_of::(); - let mut result = - arrow_buffer::MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - match &nulls { - Some(b) => { - let valid_chunks = b.inner().bit_chunks(); - - // process data in chunks of 64 elements since we also get 64 bits of validity information at a time - - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64); - let mut left_chunks = left.values().chunks_exact(64); - let mut right_chunks = right.values().chunks_exact(64); - - valid_chunks - .iter() - .zip((&mut result_chunks).zip((&mut left_chunks).zip(&mut right_chunks))) - .try_for_each( - |(mut mask, (result_slice, (left_slice, right_slice)))| { - // split chunks further into slices corresponding to the vector length - // the compiler is able to unroll this inner loop and remove bounds checks - // since the outer chunk size (64) is always a multiple of the number of lanes - result_slice - .chunks_exact_mut(lanes) - .zip(left_slice.chunks_exact(lanes).zip(right_slice.chunks_exact(lanes))) - .try_for_each(|(result_slice, (left_slice, right_slice))| -> Result<(), ArrowError> { - let simd_left = T::load(left_slice); - let simd_right = T::load(right_slice); - - let simd_result = simd_op(Some(mask), simd_left, simd_right)?; - - T::write(simd_result, result_slice); - - // skip the shift and avoid overflow for u8 type, which uses 64 lanes. - mask >>= T::lanes() % 64; - - Ok(()) - }) - }, - )?; - - let valid_remainder = valid_chunks.remainder_bits(); - - simd_checked_divide_op_remainder::( - Some(valid_remainder), - left_chunks, - right_chunks, - result_chunks, - scalar_op, - )?; - } - None => { - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut left_chunks = left.values().chunks_exact(lanes); - let mut right_chunks = right.values().chunks_exact(lanes); - - (&mut result_chunks) - .zip((&mut left_chunks).zip(&mut right_chunks)) - .try_for_each( - |(result_slice, (left_slice, right_slice))| -> Result<(), ArrowError> { - let simd_left = T::load(left_slice); - let simd_right = T::load(right_slice); - - let simd_result = simd_op(None, simd_left, simd_right)?; - - T::write(simd_result, result_slice); - - Ok(()) - }, - )?; - - simd_checked_divide_op_remainder::( - None, - left_chunks, - right_chunks, - result_chunks, - scalar_op, - )?; - } - } - - Ok(PrimitiveArray::new(result.into(), nulls)) -} - -fn math_safe_divide_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result -where - LT: ArrowNumericType, - RT: ArrowNumericType, - F: Fn(LT::Native, RT::Native) -> Option, -{ - let array: PrimitiveArray = binary_opt::<_, _, _, LT>(left, right, op)?; - Ok(Arc::new(array) as ArrayRef) -} - -/// Perform `left + right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `add_checked` instead. -pub fn add( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - math_op(left, right, |a, b| a.add_wrapping(b)) -} - -/// Perform `left + right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `add` instead. -pub fn add_checked( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - math_checked_op(left, right, |a, b| a.add_checked(b)) -} - -/// Perform `left + right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `add_dyn_checked` instead. -pub fn add_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Second, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::add_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::add_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::add_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - - DataType::Interval(_) - if matches!( - right.data_type(), - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) - ) => - { - add_dyn(right, left) - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_op(left, right, |a, b| a.add_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } -} - -/// Perform `left + right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `add_dyn` instead. -pub fn add_dyn_checked( - left: &dyn Array, - right: &dyn Array, -) -> Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::add_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_checked_op(left, right, |a, b| a.add_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } -} - -/// Add every value in an array by a scalar. If any value in the array is null then the -/// result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `add_scalar_checked` instead. -pub fn add_scalar( - array: &PrimitiveArray, - scalar: T::Native, -) -> Result, ArrowError> { - Ok(unary(array, |value| value.add_wrapping(scalar))) -} - -/// Add every value in an array by a scalar. If any value in the array is null then the -/// result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `add_scalar` instead. -pub fn add_scalar_checked( - array: &PrimitiveArray, - scalar: T::Native, -) -> Result, ArrowError> { - try_unary(array, |value| value.add_checked(scalar)) -} - -/// Add every value in an array by a scalar. If any value in the array is null then the -/// result is also null. The given array must be a `PrimitiveArray` of the type same as -/// the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead. -/// -/// This returns an `Err` when the input array is not supported for adding operation. -pub fn add_scalar_dyn( - array: &dyn Array, - scalar: T::Native, -) -> Result { - unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar)) -} - -/// Add every value in an array by a scalar. If any value in the array is null then the -/// result is also null. The given array must be a `PrimitiveArray` of the type same as -/// the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `add_scalar_dyn` instead. -/// -/// As this kernel has the branching costs and also prevents LLVM from vectorising it correctly, -/// it is usually much slower than non-checking variant. -pub fn add_scalar_checked_dyn( - array: &dyn Array, - scalar: T::Native, -) -> Result { - try_unary_dyn::<_, T>(array, |value| value.add_checked(scalar)) - .map(|a| Arc::new(a) as ArrayRef) -} - -/// Perform `left - right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `subtract_checked` instead. -pub fn subtract( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - math_op(left, right, |a, b| a.sub_wrapping(b)) -} - -/// Perform `left - right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `subtract` instead. -pub fn subtract_checked( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - math_checked_op(left, right, |a, b| a.sub_checked(b)) -} - -/// Perform `left - right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `subtract_dyn_checked` instead. -pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Second, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampSecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Second, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMicrosecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampMillisecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::subtract_year_months)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::subtract_day_time)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_checked_op(l, r, TimestampNanosecondType::subtract_month_day_nano)?; - Ok(Arc::new(res.with_timezone_opt(l.timezone()))) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = binary(l, r, |a, b| a.wrapping_sub(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_op(left, right, |a, b| a.sub_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } -} - -/// Perform `left - right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `subtract_dyn` instead. -pub fn subtract_dyn_checked( - left: &dyn Array, - right: &dyn Array, -) -> Result { - match left.data_type() { - DataType::Date32 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date32Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Date64 => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_year_months)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_day_time)?; - Ok(Arc::new(res)) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let r = right.as_primitive::(); - let res = math_op(l, r, Date64Type::subtract_month_day_nano)?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Second, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Second, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let l = left.as_primitive::(); - match right.data_type() { - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let r = right.as_primitive::(); - let res: PrimitiveArray = try_binary(l, r, |a, b| a.sub_checked(b))?; - Ok(Arc::new(res)) - } - _ => Err(ArrowError::CastError(format!( - "Cannot perform arithmetic operation between array of type {} and array of type {}", - left.data_type(), right.data_type() - ))), - } - } - _ => { - downcast_primitive_array!( - (left, right) => { - math_checked_op(left, right, |a, b| a.sub_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) - } - } -} - -/// Subtract every value in an array by a scalar. If any value in the array is null then the -/// result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `subtract_scalar_checked` instead. -pub fn subtract_scalar( - array: &PrimitiveArray, - scalar: T::Native, -) -> Result, ArrowError> { - Ok(unary(array, |value| value.sub_wrapping(scalar))) -} - -/// Subtract every value in an array by a scalar. If any value in the array is null then the -/// result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `subtract_scalar` instead. -pub fn subtract_scalar_checked( - array: &PrimitiveArray, - scalar: T::Native, -) -> Result, ArrowError> { - try_unary(array, |value| value.sub_checked(scalar)) -} - -/// Subtract every value in an array by a scalar. If any value in the array is null then the -/// result is also null. The given array must be a `PrimitiveArray` of the type same as -/// the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead. -pub fn subtract_scalar_dyn( - array: &dyn Array, - scalar: T::Native, -) -> Result { - unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar)) -} - -/// Subtract every value in an array by a scalar. If any value in the array is null then the -/// result is also null. The given array must be a `PrimitiveArray` of the type same as -/// the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `subtract_scalar_dyn` instead. -pub fn subtract_scalar_checked_dyn( - array: &dyn Array, - scalar: T::Native, -) -> Result { - try_unary_dyn::<_, T>(array, |value| value.sub_checked(scalar)) - .map(|a| Arc::new(a) as ArrayRef) -} - -/// Perform `-` operation on an array. If value is null then the result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `negate_checked` instead. -pub fn negate( - array: &PrimitiveArray, -) -> Result, ArrowError> { - Ok(unary(array, |x| x.neg_wrapping())) -} - -/// Perform `-` operation on an array. If value is null then the result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `negate` instead. -pub fn negate_checked( - array: &PrimitiveArray, -) -> Result, ArrowError> { - try_unary(array, |value| value.neg_checked()) -} - -/// Perform `left * right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `multiply_check` instead. -pub fn multiply( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - math_op(left, right, |a, b| a.mul_wrapping(b)) -} - -/// Perform `left * right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `multiply` instead. -pub fn multiply_checked( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - math_checked_op(left, right, |a, b| a.mul_checked(b)) -} - -/// Perform `left * right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `multiply_dyn_checked` instead. -pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result { - downcast_primitive_array!( - (left, right) => { - math_op(left, right, |a, b| a.mul_wrapping(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) -} - -/// Perform `left * right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `multiply_dyn` instead. -pub fn multiply_dyn_checked( - left: &dyn Array, - right: &dyn Array, -) -> Result { - downcast_primitive_array!( - (left, right) => { - math_checked_op(left, right, |a, b| a.mul_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) -} - /// Returns the precision and scale of the result of a multiplication of two decimal types, /// and the divisor for fixed point multiplication. fn get_fixed_point_info( @@ -1210,8 +113,10 @@ pub fn multiply_fixed_point_checked( )?; if required_scale == product_scale { - return multiply_checked(left, right)? - .with_precision_and_scale(precision, required_scale); + return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { + a.mul_checked(b) + })? + .with_precision_and_scale(precision, required_scale); } try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { @@ -1254,7 +159,7 @@ pub fn multiply_fixed_point( )?; if required_scale == product_scale { - return multiply(left, right)? + return binary(left, right, |a, b| a.mul_wrapping(b))? .with_precision_and_scale(precision, required_scale); } @@ -1289,1829 +194,42 @@ where } } -/// Multiply every value in an array by a scalar. If any value in the array is null then the -/// result is also null. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `multiply_scalar_checked` instead. -pub fn multiply_scalar( - array: &PrimitiveArray, - scalar: T::Native, -) -> Result, ArrowError> { - Ok(unary(array, |value| value.mul_wrapping(scalar))) -} - -/// Multiply every value in an array by a scalar. If any value in the array is null then the -/// result is also null. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `multiply_scalar` instead. -pub fn multiply_scalar_checked( - array: &PrimitiveArray, - scalar: T::Native, -) -> Result, ArrowError> { - try_unary(array, |value| value.mul_checked(scalar)) -} +#[cfg(test)] +mod tests { + use super::*; + use crate::numeric::mul; -/// Multiply every value in an array by a scalar. If any value in the array is null then the -/// result is also null. The given array must be a `PrimitiveArray` of the type same as -/// the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead. -pub fn multiply_scalar_dyn( - array: &dyn Array, - scalar: T::Native, -) -> Result { - unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar)) -} + #[test] + fn test_decimal_multiply_allow_precision_loss() { + // Overflow happening as i128 cannot hold multiplying result. + // [123456789] + let a = Decimal128Array::from(vec![123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); -/// Subtract every value in an array by a scalar. If any value in the array is null then the -/// result is also null. The given array must be a `PrimitiveArray` of the type same as -/// the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `multiply_scalar_dyn` instead. -pub fn multiply_scalar_checked_dyn( - array: &dyn Array, - scalar: T::Native, -) -> Result { - try_unary_dyn::<_, T>(array, |value| value.mul_checked(scalar)) - .map(|a| Arc::new(a) as ArrayRef) -} + // [10] + let b = Decimal128Array::from(vec![10000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); -/// Perform `left % right` operation on two arrays. If either left or right value is null -/// then the result is also null. If any right hand value is zero then the result of this -/// operation will be `Err(ArrowError::DivideByZero)`. -pub fn modulus( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - #[cfg(feature = "simd")] - return simd_checked_divide_op(&left, &right, simd_checked_modulus::, |a, b| { - a.mod_wrapping(b) - }); - #[cfg(not(feature = "simd"))] - return try_binary(left, right, |a, b| { - if b.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(a.mod_wrapping(b)) - } - }); -} + let err = mul(&a, &b).unwrap_err(); + assert!(err.to_string().contains( + "Overflow happened on: 123456789000000000000000000 * 10000000000000000000" + )); -/// Perform `left % right` operation on two arrays. If either left or right value is null -/// then the result is also null. If any right hand value is zero then the result of this -/// operation will be `Err(ArrowError::DivideByZero)`. -pub fn modulus_dyn(left: &dyn Array, right: &dyn Array) -> Result { - downcast_primitive_array!( - (left, right) => { - math_checked_divide_op(left, right, |a, b| { - if b.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(a.mod_wrapping(b)) - } - }).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) -} + // Allow precision loss. + let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); + // [1234567890] + let expected = + Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); -/// Perform `left / right` operation on two arrays. If either left or right value is null -/// then the result is also null. If any right hand value is zero then the result of this -/// operation will be `Err(ArrowError::DivideByZero)`. -/// -/// When `simd` feature is not enabled. This detects overflow and returns an `Err` for that. -/// For an non-overflow-checking variant, use `divide` instead. -pub fn divide_checked( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - #[cfg(feature = "simd")] - return simd_checked_divide_op(&left, &right, simd_checked_divide::, |a, b| { - a.div_wrapping(b) - }); - #[cfg(not(feature = "simd"))] - return math_checked_divide_op(left, right, |a, b| a.div_checked(b)); -} - -/// Perform `left / right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// If any right hand value is zero, the operation value will be replaced with null in the -/// result. -/// -/// Unlike [`divide`] or [`divide_checked`], division by zero will yield a null value in the -/// result instead of returning an `Err`. -/// -/// For floating point types overflow will saturate at INF or -INF -/// preserving the expected sign value. -/// -/// For integer types overflow will wrap around. -/// -pub fn divide_opt( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - binary_opt(left, right, |a, b| { - if b.is_zero() { - None - } else { - Some(a.div_wrapping(b)) - } - }) -} - -/// Perform `left / right` operation on two arrays. If either left or right value is null -/// then the result is also null. If any right hand value is zero then the result of this -/// operation will be `Err(ArrowError::DivideByZero)`. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `divide_dyn_checked` instead. -pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { - downcast_primitive_array!( - (left, right) => { - math_checked_divide_op(left, right, |a, b| { - if b.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(a.div_wrapping(b)) - } - }).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) -} - -/// Perform `left / right` operation on two arrays. If either left or right value is null -/// then the result is also null. If any right hand value is zero then the result of this -/// operation will be `Err(ArrowError::DivideByZero)`. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `divide_dyn` instead. -pub fn divide_dyn_checked( - left: &dyn Array, - right: &dyn Array, -) -> Result { - downcast_primitive_array!( - (left, right) => { - math_checked_divide_op(left, right, |a, b| a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) -} - -/// Perform `left / right` operation on two arrays. If either left or right value is null -/// then the result is also null. -/// -/// If any right hand value is zero, the operation value will be replaced with null in the -/// result. -/// -/// Unlike `divide_dyn` or `divide_dyn_checked`, division by zero will get a null value instead -/// returning an `Err`, this also doesn't check overflowing, overflowing will just wrap -/// the result around. -pub fn divide_dyn_opt( - left: &dyn Array, - right: &dyn Array, -) -> Result { - downcast_primitive_array!( - (left, right) => { - math_safe_divide_op(left, right, |a, b| { - if b.is_zero() { - None - } else { - Some(a.div_wrapping(b)) - } - }) - } - _ => Err(ArrowError::CastError(format!( - "Unsupported data type {}, {}", - left.data_type(), right.data_type() - ))) - ) -} - -/// Perform `left / right` operation on two arrays without checking for -/// division by zero or overflow. -/// -/// For floating point types, overflow and division by zero follows normal floating point rules -/// -/// For integer types overflow will wrap around. Division by zero will currently panic, although -/// this may be subject to change see -/// -/// If either left or right value is null then the result is also null. -/// -/// For an overflow-checking variant, use `divide_checked` instead. -pub fn divide( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result, ArrowError> { - // TODO: This is incorrect as div_wrapping has side-effects for integer types - // and so may panic on null values (#2647) - math_op(left, right, |a, b| a.div_wrapping(b)) -} - -/// Modulus every value in an array by a scalar. If any value in the array is null then the -/// result is also null. If the scalar is zero then the result of this operation will be -/// `Err(ArrowError::DivideByZero)`. -pub fn modulus_scalar( - array: &PrimitiveArray, - modulo: T::Native, -) -> Result, ArrowError> { - if modulo.is_zero() { - return Err(ArrowError::DivideByZero); - } - - Ok(unary(array, |a| a.mod_wrapping(modulo))) -} - -/// Modulus every value in an array by a scalar. If any value in the array is null then the -/// result is also null. If the scalar is zero then the result of this operation will be -/// `Err(ArrowError::DivideByZero)`. -pub fn modulus_scalar_dyn( - array: &dyn Array, - modulo: T::Native, -) -> Result { - if modulo.is_zero() { - return Err(ArrowError::DivideByZero); - } - unary_dyn::<_, T>(array, |value| value.mod_wrapping(modulo)) -} - -/// Divide every value in an array by a scalar. If any value in the array is null then the -/// result is also null. If the scalar is zero then the result of this operation will be -/// `Err(ArrowError::DivideByZero)`. -pub fn divide_scalar( - array: &PrimitiveArray, - divisor: T::Native, -) -> Result, ArrowError> { - if divisor.is_zero() { - return Err(ArrowError::DivideByZero); - } - Ok(unary(array, |a| a.div_wrapping(divisor))) -} - -/// Divide every value in an array by a scalar. If any value in the array is null then the -/// result is also null. If the scalar is zero then the result of this operation will be -/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type -/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead. -pub fn divide_scalar_dyn( - array: &dyn Array, - divisor: T::Native, -) -> Result { - if divisor.is_zero() { - return Err(ArrowError::DivideByZero); - } - unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor)) -} - -/// Divide every value in an array by a scalar. If any value in the array is null then the -/// result is also null. If the scalar is zero then the result of this operation will be -/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type -/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `divide_scalar_dyn` instead. -pub fn divide_scalar_checked_dyn( - array: &dyn Array, - divisor: T::Native, -) -> Result { - if divisor.is_zero() { - return Err(ArrowError::DivideByZero); - } - - try_unary_dyn::<_, T>(array, |value| value.div_checked(divisor)) - .map(|a| Arc::new(a) as ArrayRef) -} - -/// Divide every value in an array by a scalar. If any value in the array is null then the -/// result is also null. The given array must be a `PrimitiveArray` of the type -/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. -/// -/// If any right hand value is zero, the operation value will be replaced with null in the -/// result. -/// -/// Unlike `divide_scalar_dyn` or `divide_scalar_checked_dyn`, division by zero will get a -/// null value instead returning an `Err`, this also doesn't check overflowing, overflowing -/// will just wrap the result around. -pub fn divide_scalar_opt_dyn( - array: &dyn Array, - divisor: T::Native, -) -> Result { - if divisor.is_zero() { - match array.data_type() { - DataType::Dictionary(_, value_type) => { - return Ok(new_null_array(value_type.as_ref(), array.len())) - } - _ => return Ok(new_null_array(array.data_type(), array.len())), - } - } - - unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor)) -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::builder::{ - BooleanBufferBuilder, BufferBuilder, PrimitiveDictionaryBuilder, - }; - use arrow_array::temporal_conversions::SECONDS_IN_DAY; - use arrow_buffer::buffer::NullBuffer; - use arrow_buffer::i256; - use arrow_data::ArrayDataBuilder; - use chrono::NaiveDate; - use half::f16; - - #[test] - fn test_primitive_array_add() { - let a = Int32Array::from(vec![5, 6, 7, 8, 9]); - let b = Int32Array::from(vec![6, 7, 8, 9, 8]); - let c = add(&a, &b).unwrap(); - assert_eq!(11, c.value(0)); - assert_eq!(13, c.value(1)); - assert_eq!(15, c.value(2)); - assert_eq!(17, c.value(3)); - assert_eq!(17, c.value(4)); - } - - #[test] - fn test_date32_month_add() { - let a = Date32Array::from(vec![Date32Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), - )]); - let b = - IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 2)]); - let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3, 1).unwrap()) - ); - - let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3, 1).unwrap()) - ); - } - - #[test] - fn test_date32_day_time_add() { - let a = Date32Array::from(vec![Date32Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), - )]); - let b = IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]); - let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) - ); - - let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) - ); - } - - #[test] - fn test_date32_month_day_nano_add() { - let a = Date32Array::from(vec![Date32Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), - )]); - let b = - IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNanoType::make_value( - 1, 2, 3, - )]); - let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) - ); - - let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) - ); - } - - #[test] - fn test_date64_month_add() { - let a = Date64Array::from(vec![Date64Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), - )]); - let b = - IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 2)]); - let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3, 1).unwrap()) - ); - - let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2001, 3, 1).unwrap()) - ); - } - - #[test] - fn test_date64_day_time_add() { - let a = Date64Array::from(vec![Date64Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), - )]); - let b = IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 2)]); - let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) - ); - - let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 1, 2).unwrap()) - ); - } - - #[test] - fn test_date64_month_day_nano_add() { - let a = Date64Array::from(vec![Date64Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), - )]); - let b = - IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNanoType::make_value( - 1, 2, 3, - )]); - let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) - ); - - let c = add_dyn(&b, &a).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2000, 2, 3).unwrap()) - ); - } - - #[test] - fn test_primitive_array_add_dyn() { - let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]); - let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); - let c = add_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(11, c.value(0)); - assert_eq!(13, c.value(1)); - assert_eq!(15, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(17, c.value(4)); - } - - #[test] - fn test_primitive_array_add_scalar_dyn() { - let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); - let b = 1_i32; - let c = add_scalar_dyn::(&a, b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(6, c.value(0)); - assert_eq!(7, c.value(1)); - assert_eq!(8, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(10, c.value(4)); - - let mut builder = PrimitiveDictionaryBuilder::::new(); - builder.append(5).unwrap(); - builder.append_null(); - builder.append(7).unwrap(); - builder.append(8).unwrap(); - builder.append(9).unwrap(); - let a = builder.finish(); - let b = -1_i32; - - let c = add_scalar_dyn::(&a, b).unwrap(); - let c = c - .as_any() - .downcast_ref::>() - .unwrap(); - let values = c - .values() - .as_any() - .downcast_ref::>() - .unwrap(); - assert_eq!(4, values.value(c.key(0).unwrap())); - assert!(c.is_null(1)); - assert_eq!(6, values.value(c.key(2).unwrap())); - assert_eq!(7, values.value(c.key(3).unwrap())); - assert_eq!(8, values.value(c.key(4).unwrap())); - } - - #[test] - fn test_primitive_array_subtract_dyn() { - let a = Int32Array::from(vec![Some(51), Some(6), Some(15), Some(8), Some(9)]); - let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); - let c = subtract_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(45, c.value(0)); - assert_eq!(-1, c.value(1)); - assert_eq!(7, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(1, c.value(4)); - } - - #[test] - fn test_date32_month_subtract() { - let a = Date32Array::from(vec![Date32Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 7, 1).unwrap(), - )]); - let b = - IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(6, 3)]); - let c = subtract_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(1994, 4, 1).unwrap()) - ); - } - - #[test] - fn test_date32_day_time_subtract() { - let a = Date32Array::from(vec![Date32Type::from_naive_date( - NaiveDate::from_ymd_opt(2023, 3, 29).unwrap(), - )]); - let b = - IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 86500)]); - let c = subtract_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2023, 3, 27).unwrap()) - ); - } - - #[test] - fn test_date32_month_day_nano_subtract() { - let a = Date32Array::from(vec![Date32Type::from_naive_date( - NaiveDate::from_ymd_opt(2023, 3, 15).unwrap(), - )]); - let b = - IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNanoType::make_value( - 1, 2, 0, - )]); - let c = subtract_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date32Type::from_naive_date(NaiveDate::from_ymd_opt(2023, 2, 13).unwrap()) - ); - } - - #[test] - fn test_date64_month_subtract() { - let a = Date64Array::from(vec![Date64Type::from_naive_date( - NaiveDate::from_ymd_opt(2000, 7, 1).unwrap(), - )]); - let b = - IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(6, 3)]); - let c = subtract_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(1994, 4, 1).unwrap()) - ); - } - - #[test] - fn test_date64_day_time_subtract() { - let a = Date64Array::from(vec![Date64Type::from_naive_date( - NaiveDate::from_ymd_opt(2023, 3, 29).unwrap(), - )]); - let b = - IntervalDayTimeArray::from(vec![IntervalDayTimeType::make_value(1, 86500)]); - let c = subtract_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2023, 3, 27).unwrap()) - ); - } - - #[test] - fn test_date64_month_day_nano_subtract() { - let a = Date64Array::from(vec![Date64Type::from_naive_date( - NaiveDate::from_ymd_opt(2023, 3, 15).unwrap(), - )]); - let b = - IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNanoType::make_value( - 1, 2, 0, - )]); - let c = subtract_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!( - c.value(0), - Date64Type::from_naive_date(NaiveDate::from_ymd_opt(2023, 2, 13).unwrap()) - ); - } - - #[test] - fn test_primitive_array_subtract_scalar_dyn() { - let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); - let b = 1_i32; - let c = subtract_scalar_dyn::(&a, b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(4, c.value(0)); - assert_eq!(5, c.value(1)); - assert_eq!(6, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(8, c.value(4)); - - let mut builder = PrimitiveDictionaryBuilder::::new(); - builder.append(5).unwrap(); - builder.append_null(); - builder.append(7).unwrap(); - builder.append(8).unwrap(); - builder.append(9).unwrap(); - let a = builder.finish(); - let b = -1_i32; - - let c = subtract_scalar_dyn::(&a, b).unwrap(); - let c = c - .as_any() - .downcast_ref::>() - .unwrap(); - let values = c - .values() - .as_any() - .downcast_ref::>() - .unwrap(); - assert_eq!(6, values.value(c.key(0).unwrap())); - assert!(c.is_null(1)); - assert_eq!(8, values.value(c.key(2).unwrap())); - assert_eq!(9, values.value(c.key(3).unwrap())); - assert_eq!(10, values.value(c.key(4).unwrap())); - } - - #[test] - fn test_primitive_array_multiply_dyn() { - let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]); - let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); - let c = multiply_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(30, c.value(0)); - assert_eq!(42, c.value(1)); - assert_eq!(56, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(72, c.value(4)); - } - - #[test] - fn test_primitive_array_divide_dyn() { - let a = Int32Array::from(vec![Some(15), Some(6), Some(1), Some(8), Some(9)]); - let b = Int32Array::from(vec![Some(5), Some(3), Some(1), None, Some(3)]); - let c = divide_dyn(&a, &b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(3, c.value(0)); - assert_eq!(2, c.value(1)); - assert_eq!(1, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(3, c.value(4)); - } - - #[test] - fn test_primitive_array_multiply_scalar_dyn() { - let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); - let b = 2_i32; - let c = multiply_scalar_dyn::(&a, b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(10, c.value(0)); - assert_eq!(12, c.value(1)); - assert_eq!(14, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(18, c.value(4)); - - let mut builder = PrimitiveDictionaryBuilder::::new(); - builder.append(5).unwrap(); - builder.append_null(); - builder.append(7).unwrap(); - builder.append(8).unwrap(); - builder.append(9).unwrap(); - let a = builder.finish(); - let b = -1_i32; - - let c = multiply_scalar_dyn::(&a, b).unwrap(); - let c = c - .as_any() - .downcast_ref::>() - .unwrap(); - let values = c - .values() - .as_any() - .downcast_ref::>() - .unwrap(); - assert_eq!(-5, values.value(c.key(0).unwrap())); - assert!(c.is_null(1)); - assert_eq!(-7, values.value(c.key(2).unwrap())); - assert_eq!(-8, values.value(c.key(3).unwrap())); - assert_eq!(-9, values.value(c.key(4).unwrap())); - } - - #[test] - fn test_primitive_array_add_sliced() { - let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]); - let b = Int32Array::from(vec![0, 0, 0, 6, 7, 8, 9, 8, 0]); - let a = a.slice(3, 5); - let b = b.slice(3, 5); - let a = a.as_any().downcast_ref::().unwrap(); - let b = b.as_any().downcast_ref::().unwrap(); - - assert_eq!(5, a.value(0)); - assert_eq!(6, b.value(0)); - - let c = add(a, b).unwrap(); - assert_eq!(5, c.len()); - assert_eq!(11, c.value(0)); - assert_eq!(13, c.value(1)); - assert_eq!(15, c.value(2)); - assert_eq!(17, c.value(3)); - assert_eq!(17, c.value(4)); - } - - #[test] - fn test_primitive_array_add_mismatched_length() { - let a = Int32Array::from(vec![5, 6, 7, 8, 9]); - let b = Int32Array::from(vec![6, 7, 8]); - let e = add(&a, &b).expect_err("should have failed due to different lengths"); - assert_eq!( - "ComputeError(\"Cannot perform binary operation on arrays of different length\")", - format!("{e:?}") - ); - } - - #[test] - fn test_primitive_array_add_scalar() { - let a = Int32Array::from(vec![15, 14, 9, 8, 1]); - let b = 3; - let c = add_scalar(&a, b).unwrap(); - let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_array_add_scalar_sliced() { - let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); - let a = a.slice(1, 4); - let actual = add_scalar(&a, 3).unwrap(); - let expected = Int32Array::from(vec![None, Some(12), Some(11), None]); - assert_eq!(actual, expected); - } - - #[test] - fn test_primitive_array_subtract() { - let a = Int32Array::from(vec![1, 2, 3, 4, 5]); - let b = Int32Array::from(vec![5, 4, 3, 2, 1]); - let c = subtract(&a, &b).unwrap(); - assert_eq!(-4, c.value(0)); - assert_eq!(-2, c.value(1)); - assert_eq!(0, c.value(2)); - assert_eq!(2, c.value(3)); - assert_eq!(4, c.value(4)); - } - - #[test] - fn test_primitive_array_subtract_scalar() { - let a = Int32Array::from(vec![15, 14, 9, 8, 1]); - let b = 3; - let c = subtract_scalar(&a, b).unwrap(); - let expected = Int32Array::from(vec![12, 11, 6, 5, -2]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_array_subtract_scalar_sliced() { - let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); - let a = a.slice(1, 4); - let actual = subtract_scalar(&a, 3).unwrap(); - let expected = Int32Array::from(vec![None, Some(6), Some(5), None]); - assert_eq!(actual, expected); - } - - #[test] - fn test_primitive_array_multiply() { - let a = Int32Array::from(vec![5, 6, 7, 8, 9]); - let b = Int32Array::from(vec![6, 7, 8, 9, 8]); - let c = multiply(&a, &b).unwrap(); - assert_eq!(30, c.value(0)); - assert_eq!(42, c.value(1)); - assert_eq!(56, c.value(2)); - assert_eq!(72, c.value(3)); - assert_eq!(72, c.value(4)); - } - - #[test] - fn test_primitive_array_multiply_scalar() { - let a = Int32Array::from(vec![15, 14, 9, 8, 1]); - let b = 3; - let c = multiply_scalar(&a, b).unwrap(); - let expected = Int32Array::from(vec![45, 42, 27, 24, 3]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_array_multiply_scalar_sliced() { - let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); - let a = a.slice(1, 4); - let actual = multiply_scalar(&a, 3).unwrap(); - let expected = Int32Array::from(vec![None, Some(27), Some(24), None]); - assert_eq!(actual, expected); - } - - #[test] - fn test_primitive_array_divide() { - let a = Int32Array::from(vec![15, 15, 8, 1, 9]); - let b = Int32Array::from(vec![5, 6, 8, 9, 1]); - let c = divide(&a, &b).unwrap(); - assert_eq!(3, c.value(0)); - assert_eq!(2, c.value(1)); - assert_eq!(1, c.value(2)); - assert_eq!(0, c.value(3)); - assert_eq!(9, c.value(4)); - } - - #[test] - fn test_int_array_modulus() { - let a = Int32Array::from(vec![15, 15, 8, 1, 9]); - let b = Int32Array::from(vec![5, 6, 8, 9, 1]); - let c = modulus(&a, &b).unwrap(); - assert_eq!(0, c.value(0)); - assert_eq!(3, c.value(1)); - assert_eq!(0, c.value(2)); - assert_eq!(1, c.value(3)); - assert_eq!(0, c.value(4)); - - let c = modulus_dyn(&a, &b).unwrap(); - let c = c.as_primitive::(); - assert_eq!(0, c.value(0)); - assert_eq!(3, c.value(1)); - assert_eq!(0, c.value(2)); - assert_eq!(1, c.value(3)); - assert_eq!(0, c.value(4)); - } - - #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: DivideByZero" - )] - fn test_int_array_modulus_divide_by_zero() { - let a = Int32Array::from(vec![1]); - let b = Int32Array::from(vec![0]); - modulus(&a, &b).unwrap(); - } - - #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: DivideByZero" - )] - fn test_int_array_modulus_dyn_divide_by_zero() { - let a = Int32Array::from(vec![1]); - let b = Int32Array::from(vec![0]); - modulus_dyn(&a, &b).unwrap(); - } - - #[test] - fn test_int_array_modulus_overflow_wrapping() { - let a = Int32Array::from(vec![i32::MIN]); - let b = Int32Array::from(vec![-1]); - let result = modulus(&a, &b).unwrap(); - assert_eq!(0, result.value(0)) - } - - #[test] - fn test_primitive_array_divide_scalar() { - let a = Int32Array::from(vec![15, 14, 9, 8, 1]); - let b = 3; - let c = divide_scalar(&a, b).unwrap(); - let expected = Int32Array::from(vec![5, 4, 3, 2, 0]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_array_divide_scalar_dyn() { - let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); - let b = 2_i32; - let c = divide_scalar_dyn::(&a, b).unwrap(); - let c = c.as_any().downcast_ref::().unwrap(); - assert_eq!(2, c.value(0)); - assert_eq!(3, c.value(1)); - assert_eq!(3, c.value(2)); - assert!(c.is_null(3)); - assert_eq!(4, c.value(4)); - - let mut builder = PrimitiveDictionaryBuilder::::new(); - builder.append(5).unwrap(); - builder.append_null(); - builder.append(7).unwrap(); - builder.append(8).unwrap(); - builder.append(9).unwrap(); - let a = builder.finish(); - let b = -2_i32; - - let c = divide_scalar_dyn::(&a, b).unwrap(); - let c = c - .as_any() - .downcast_ref::>() - .unwrap(); - let values = c - .values() - .as_any() - .downcast_ref::>() - .unwrap(); - assert_eq!(-2, values.value(c.key(0).unwrap())); - assert!(c.is_null(1)); - assert_eq!(-3, values.value(c.key(2).unwrap())); - assert_eq!(-4, values.value(c.key(3).unwrap())); - assert_eq!(-4, values.value(c.key(4).unwrap())); - - let e = divide_scalar_dyn::(&a, 0_i32) - .expect_err("should have failed due to divide by zero"); - assert_eq!("DivideByZero", format!("{e:?}")); - } - - #[test] - fn test_primitive_array_divide_scalar_sliced() { - let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); - let a = a.slice(1, 4); - let actual = divide_scalar(&a, 3).unwrap(); - let expected = Int32Array::from(vec![None, Some(3), Some(2), None]); - assert_eq!(actual, expected); - } - - #[test] - fn test_int_array_modulus_scalar() { - let a = Int32Array::from(vec![15, 14, 9, 8, 1]); - let b = 3; - let c = modulus_scalar(&a, b).unwrap(); - let expected = Int32Array::from(vec![0, 2, 0, 2, 1]); - assert_eq!(c, expected); - - let c = modulus_scalar_dyn::(&a, b).unwrap(); - let c = c.as_primitive::(); - let expected = Int32Array::from(vec![0, 2, 0, 2, 1]); - assert_eq!(c, &expected); - } - - #[test] - fn test_int_array_modulus_scalar_sliced() { - let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); - let a = a.slice(1, 4); - let actual = modulus_scalar(&a, 3).unwrap(); - let expected = Int32Array::from(vec![None, Some(0), Some(2), None]); - assert_eq!(actual, expected); - - let actual = modulus_scalar_dyn::(&a, 3).unwrap(); - let actual = actual.as_primitive::(); - let expected = Int32Array::from(vec![None, Some(0), Some(2), None]); - assert_eq!(actual, &expected); - } - - #[test] - #[should_panic( - expected = "called `Result::unwrap()` on an `Err` value: DivideByZero" - )] - fn test_int_array_modulus_scalar_divide_by_zero() { - let a = Int32Array::from(vec![1]); - modulus_scalar(&a, 0).unwrap(); - } - - #[test] - fn test_int_array_modulus_scalar_overflow_wrapping() { - let a = Int32Array::from(vec![i32::MIN]); - let result = modulus_scalar(&a, -1).unwrap(); - assert_eq!(0, result.value(0)); - - let result = modulus_scalar_dyn::(&a, -1).unwrap(); - let result = result.as_primitive::(); - assert_eq!(0, result.value(0)); - } - - #[test] - fn test_primitive_array_divide_sliced() { - let a = Int32Array::from(vec![0, 0, 0, 15, 15, 8, 1, 9, 0]); - let b = Int32Array::from(vec![0, 0, 0, 5, 6, 8, 9, 1, 0]); - let a = a.slice(3, 5); - let b = b.slice(3, 5); - let a = a.as_any().downcast_ref::().unwrap(); - let b = b.as_any().downcast_ref::().unwrap(); - - let c = divide(a, b).unwrap(); - assert_eq!(5, c.len()); - assert_eq!(3, c.value(0)); - assert_eq!(2, c.value(1)); - assert_eq!(1, c.value(2)); - assert_eq!(0, c.value(3)); - assert_eq!(9, c.value(4)); - } - - #[test] - fn test_primitive_array_modulus_sliced() { - let a = Int32Array::from(vec![0, 0, 0, 15, 15, 8, 1, 9, 0]); - let b = Int32Array::from(vec![0, 0, 0, 5, 6, 8, 9, 1, 0]); - let a = a.slice(3, 5); - let b = b.slice(3, 5); - let a = a.as_any().downcast_ref::().unwrap(); - let b = b.as_any().downcast_ref::().unwrap(); - - let c = modulus(a, b).unwrap(); - assert_eq!(5, c.len()); - assert_eq!(0, c.value(0)); - assert_eq!(3, c.value(1)); - assert_eq!(0, c.value(2)); - assert_eq!(1, c.value(3)); - assert_eq!(0, c.value(4)); - } - - #[test] - fn test_primitive_array_divide_with_nulls() { - let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]); - let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), None, None]); - let c = divide_checked(&a, &b).unwrap(); - assert_eq!(3, c.value(0)); - assert!(c.is_null(1)); - assert_eq!(1, c.value(2)); - assert_eq!(0, c.value(3)); - assert!(c.is_null(4)); - assert!(c.is_null(5)); - } - - #[test] - fn test_primitive_array_modulus_with_nulls() { - let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]); - let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), None, None]); - let c = modulus(&a, &b).unwrap(); - assert_eq!(0, c.value(0)); - assert!(c.is_null(1)); - assert_eq!(0, c.value(2)); - assert_eq!(1, c.value(3)); - assert!(c.is_null(4)); - assert!(c.is_null(5)); - } - - #[test] - fn test_primitive_array_divide_scalar_with_nulls() { - let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]); - let b = 3; - let c = divide_scalar(&a, b).unwrap(); - let expected = - Int32Array::from(vec![Some(5), None, Some(2), Some(0), Some(3), None]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_array_modulus_scalar_with_nulls() { - let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]); - let b = 3; - let c = modulus_scalar(&a, b).unwrap(); - let expected = - Int32Array::from(vec![Some(0), None, Some(2), Some(1), Some(0), None]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_array_divide_with_nulls_sliced() { - let a = Int32Array::from(vec![ - None, - None, - None, - None, - None, - None, - None, - None, - Some(15), - None, - Some(8), - Some(1), - Some(9), - None, - None, - ]); - let b = Int32Array::from(vec![ - None, - None, - None, - None, - None, - None, - None, - None, - Some(5), - Some(6), - Some(8), - Some(9), - None, - None, - None, - ]); - - let a = a.slice(8, 6); - let a = a.as_any().downcast_ref::().unwrap(); - - let b = b.slice(8, 6); - let b = b.as_any().downcast_ref::().unwrap(); - - let c = divide_checked(a, b).unwrap(); - assert_eq!(6, c.len()); - assert_eq!(3, c.value(0)); - assert!(c.is_null(1)); - assert_eq!(1, c.value(2)); - assert_eq!(0, c.value(3)); - assert!(c.is_null(4)); - assert!(c.is_null(5)); - } - - #[test] - fn test_primitive_array_modulus_with_nulls_sliced() { - let a = Int32Array::from(vec![ - None, - None, - None, - None, - None, - None, - None, - None, - Some(15), - None, - Some(8), - Some(1), - Some(9), - None, - None, - ]); - let b = Int32Array::from(vec![ - None, - None, - None, - None, - None, - None, - None, - None, - Some(5), - Some(6), - Some(8), - Some(9), - None, - None, - None, - ]); - - let a = a.slice(8, 6); - let a = a.as_any().downcast_ref::().unwrap(); - - let b = b.slice(8, 6); - let b = b.as_any().downcast_ref::().unwrap(); - - let c = modulus(a, b).unwrap(); - assert_eq!(6, c.len()); - assert_eq!(0, c.value(0)); - assert!(c.is_null(1)); - assert_eq!(0, c.value(2)); - assert_eq!(1, c.value(3)); - assert!(c.is_null(4)); - assert!(c.is_null(5)); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_int_array_divide_by_zero_with_checked() { - let a = Int32Array::from(vec![15]); - let b = Int32Array::from(vec![0]); - divide_checked(&a, &b).unwrap(); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_f32_array_divide_by_zero_with_checked() { - let a = Float32Array::from(vec![15.0]); - let b = Float32Array::from(vec![0.0]); - divide_checked(&a, &b).unwrap(); - } - - #[test] - #[should_panic(expected = "attempt to divide by zero")] - fn test_int_array_divide_by_zero() { - let a = Int32Array::from(vec![15]); - let b = Int32Array::from(vec![0]); - divide(&a, &b).unwrap(); - } - - #[test] - fn test_f32_array_divide_by_zero() { - let a = Float32Array::from(vec![1.5, 0.0, -1.5]); - let b = Float32Array::from(vec![0.0, 0.0, 0.0]); - let result = divide(&a, &b).unwrap(); - assert_eq!(result.value(0), f32::INFINITY); - assert!(result.value(1).is_nan()); - assert_eq!(result.value(2), f32::NEG_INFINITY); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_int_array_divide_dyn_by_zero() { - let a = Int32Array::from(vec![15]); - let b = Int32Array::from(vec![0]); - divide_dyn(&a, &b).unwrap(); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_f32_array_divide_dyn_by_zero() { - let a = Float32Array::from(vec![1.5]); - let b = Float32Array::from(vec![0.0]); - divide_dyn(&a, &b).unwrap(); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_i32_array_modulus_by_zero() { - let a = Int32Array::from(vec![15]); - let b = Int32Array::from(vec![0]); - modulus(&a, &b).unwrap(); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_i32_array_modulus_dyn_by_zero() { - let a = Int32Array::from(vec![15]); - let b = Int32Array::from(vec![0]); - modulus_dyn(&a, &b).unwrap(); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_f32_array_modulus_by_zero() { - let a = Float32Array::from(vec![1.5]); - let b = Float32Array::from(vec![0.0]); - modulus(&a, &b).unwrap(); - } - - #[test] - #[should_panic(expected = "DivideByZero")] - fn test_f32_array_modulus_dyn_by_zero() { - let a = Float32Array::from(vec![1.5]); - let b = Float32Array::from(vec![0.0]); - modulus_dyn(&a, &b).unwrap(); - } - - #[test] - fn test_f64_array_divide() { - let a = Float64Array::from(vec![15.0, 15.0, 8.0]); - let b = Float64Array::from(vec![5.0, 6.0, 8.0]); - let c = divide(&a, &b).unwrap(); - assert_eq!(3.0, c.value(0)); - assert_eq!(2.5, c.value(1)); - assert_eq!(1.0, c.value(2)); - } - - #[test] - fn test_primitive_array_add_with_nulls() { - let a = Int32Array::from(vec![Some(5), None, Some(7), None]); - let b = Int32Array::from(vec![None, None, Some(6), Some(7)]); - let c = add(&a, &b).unwrap(); - assert!(c.is_null(0)); - assert!(c.is_null(1)); - assert!(!c.is_null(2)); - assert!(c.is_null(3)); - assert_eq!(13, c.value(2)); - } - - #[test] - fn test_primitive_array_negate() { - let a: Int64Array = (0..100).map(Some).collect(); - let actual = negate(&a).unwrap(); - let expected: Int64Array = (0..100).map(|i| Some(-i)).collect(); - assert_eq!(expected, actual); - } - - #[test] - fn test_primitive_array_negate_checked_overflow() { - let a = Int32Array::from(vec![i32::MIN]); - let actual = negate(&a).unwrap(); - let expected = Int32Array::from(vec![i32::MIN]); - assert_eq!(expected, actual); - - let err = negate_checked(&a); - err.expect_err("negate_checked should detect overflow"); - } - - #[test] - fn test_arithmetic_kernel_should_not_rely_on_padding() { - let a: UInt8Array = (0..128_u8).map(Some).collect(); - let a = a.slice(63, 65); - let a = a.as_any().downcast_ref::().unwrap(); - - let b: UInt8Array = (0..128_u8).map(Some).collect(); - let b = b.slice(63, 65); - let b = b.as_any().downcast_ref::().unwrap(); - - let actual = add(a, b).unwrap(); - let actual: Vec> = actual.iter().collect(); - let expected: Vec> = - (63..63_u8 + 65_u8).map(|i| Some(i + i)).collect(); - assert_eq!(expected, actual); - } - - #[test] - fn test_primitive_add_wrapping_overflow() { - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - let b = Int32Array::from(vec![1, 1]); - - let wrapped = add(&a, &b); - let expected = Int32Array::from(vec![-2147483648, -2147483647]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = add_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_subtract_wrapping_overflow() { - let a = Int32Array::from(vec![-2]); - let b = Int32Array::from(vec![i32::MAX]); - - let wrapped = subtract(&a, &b); - let expected = Int32Array::from(vec![i32::MAX]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = subtract_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_mul_wrapping_overflow() { - let a = Int32Array::from(vec![10]); - let b = Int32Array::from(vec![i32::MAX]); - - let wrapped = multiply(&a, &b); - let expected = Int32Array::from(vec![-10]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = multiply_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - #[cfg(not(feature = "simd"))] - fn test_primitive_div_wrapping_overflow() { - let a = Int32Array::from(vec![i32::MIN]); - let b = Int32Array::from(vec![-1]); - - let wrapped = divide(&a, &b); - let expected = Int32Array::from(vec![-2147483648]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = divide_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_add_scalar_wrapping_overflow() { - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - - let wrapped = add_scalar(&a, 1); - let expected = Int32Array::from(vec![-2147483648, -2147483647]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = add_scalar_checked(&a, 1); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_subtract_scalar_wrapping_overflow() { - let a = Int32Array::from(vec![-2]); - - let wrapped = subtract_scalar(&a, i32::MAX); - let expected = Int32Array::from(vec![i32::MAX]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = subtract_scalar_checked(&a, i32::MAX); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_mul_scalar_wrapping_overflow() { - let a = Int32Array::from(vec![10]); - - let wrapped = multiply_scalar(&a, i32::MAX); - let expected = Int32Array::from(vec![-10]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = multiply_scalar_checked(&a, i32::MAX); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_add_scalar_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - - let wrapped = add_scalar_dyn::(&a, 1).unwrap(); - let expected = - Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = add_scalar_checked_dyn::(&a, 1); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_subtract_scalar_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![-2]); - - let wrapped = subtract_scalar_dyn::(&a, i32::MAX).unwrap(); - let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = subtract_scalar_checked_dyn::(&a, i32::MAX); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_mul_scalar_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![10]); - - let wrapped = multiply_scalar_dyn::(&a, i32::MAX).unwrap(); - let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = multiply_scalar_checked_dyn::(&a, i32::MAX); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_div_scalar_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![i32::MIN]); - - let wrapped = divide_scalar_dyn::(&a, -1).unwrap(); - let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = divide_scalar_checked_dyn::(&a, -1); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_div_opt_overflow_division_by_zero() { - let a = Int32Array::from(vec![i32::MIN]); - let b = Int32Array::from(vec![-1]); - - let wrapped = divide(&a, &b); - let expected = Int32Array::from(vec![-2147483648]); - assert_eq!(expected, wrapped.unwrap()); - - let overflow = divide_opt(&a, &b); - let expected = Int32Array::from(vec![-2147483648]); - assert_eq!(expected, overflow.unwrap()); - - let b = Int32Array::from(vec![0]); - let overflow = divide_opt(&a, &b); - let expected = Int32Array::from(vec![None]); - assert_eq!(expected, overflow.unwrap()); - } - - #[test] - fn test_primitive_add_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - let b = Int32Array::from(vec![1, 1]); - - let wrapped = add_dyn(&a, &b).unwrap(); - let expected = - Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = add_dyn_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_subtract_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![-2]); - let b = Int32Array::from(vec![i32::MAX]); - - let wrapped = subtract_dyn(&a, &b).unwrap(); - let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = subtract_dyn_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_mul_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![10]); - let b = Int32Array::from(vec![i32::MAX]); - - let wrapped = multiply_dyn(&a, &b).unwrap(); - let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = multiply_dyn_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_div_dyn_wrapping_overflow() { - let a = Int32Array::from(vec![i32::MIN]); - let b = Int32Array::from(vec![-1]); - - let wrapped = divide_dyn(&a, &b).unwrap(); - let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef; - assert_eq!(&expected, &wrapped); - - let overflow = divide_dyn_checked(&a, &b); - overflow.expect_err("overflow should be detected"); - } - - #[test] - fn test_decimal128() { - let a = Decimal128Array::from_iter_values([1, 2, 4, 5]); - let b = Decimal128Array::from_iter_values([7, -3, 6, 3]); - let e = Decimal128Array::from_iter_values([8, -1, 10, 8]); - let r = add(&a, &b).unwrap(); - assert_eq!(e, r); - - let e = Decimal128Array::from_iter_values([-6, 5, -2, 2]); - let r = subtract(&a, &b).unwrap(); - assert_eq!(e, r); - - let e = Decimal128Array::from_iter_values([7, -6, 24, 15]); - let r = multiply(&a, &b).unwrap(); - assert_eq!(e, r); - - let a = Decimal128Array::from_iter_values([23, 56, 32, 55]); - let b = Decimal128Array::from_iter_values([1, -2, 4, 5]); - let e = Decimal128Array::from_iter_values([23, -28, 8, 11]); - let r = divide(&a, &b).unwrap(); - assert_eq!(e, r); - } - - #[test] - fn test_decimal256() { - let a = Decimal256Array::from_iter_values( - [1, 2, 4, 5].into_iter().map(i256::from_i128), - ); - let b = Decimal256Array::from_iter_values( - [7, -3, 6, 3].into_iter().map(i256::from_i128), - ); - let e = Decimal256Array::from_iter_values( - [8, -1, 10, 8].into_iter().map(i256::from_i128), - ); - let r = add(&a, &b).unwrap(); - assert_eq!(e, r); - - let e = Decimal256Array::from_iter_values( - [-6, 5, -2, 2].into_iter().map(i256::from_i128), - ); - let r = subtract(&a, &b).unwrap(); - assert_eq!(e, r); - - let e = Decimal256Array::from_iter_values( - [7, -6, 24, 15].into_iter().map(i256::from_i128), - ); - let r = multiply(&a, &b).unwrap(); - assert_eq!(e, r); - - let a = Decimal256Array::from_iter_values( - [23, 56, 32, 55].into_iter().map(i256::from_i128), - ); - let b = Decimal256Array::from_iter_values( - [1, -2, 4, 5].into_iter().map(i256::from_i128), - ); - let e = Decimal256Array::from_iter_values( - [23, -28, 8, 11].into_iter().map(i256::from_i128), - ); - let r = divide(&a, &b).unwrap(); - assert_eq!(e, r); - } - - #[test] - fn test_div_scalar_dyn_opt_overflow_division_by_zero() { - let a = Int32Array::from(vec![i32::MIN]); - - let division_by_zero = divide_scalar_opt_dyn::(&a, 0); - let expected = Arc::new(Int32Array::from(vec![None])) as ArrayRef; - assert_eq!(&expected, &division_by_zero.unwrap()); - - let mut builder = - PrimitiveDictionaryBuilder::::with_capacity(1, 1); - builder.append(i32::MIN).unwrap(); - let a = builder.finish(); - - let division_by_zero = divide_scalar_opt_dyn::(&a, 0); - assert_eq!(&expected, &division_by_zero.unwrap()); - } - - #[test] - fn test_sum_f16() { - let a = Float16Array::from_iter_values([ - f16::from_f32(0.1), - f16::from_f32(0.2), - f16::from_f32(1.5), - f16::from_f32(-0.1), - ]); - let b = Float16Array::from_iter_values([ - f16::from_f32(5.1), - f16::from_f32(6.2), - f16::from_f32(-1.), - f16::from_f32(-2.1), - ]); - let expected = Float16Array::from_iter_values( - a.values().iter().zip(b.values()).map(|(a, b)| a + b), - ); - - let c = add(&a, &b).unwrap(); - assert_eq!(c, expected); - } - - #[test] - fn test_resize_builder() { - let mut null_buffer_builder = BooleanBufferBuilder::new(16); - null_buffer_builder.append_slice(&[ - false, false, false, false, false, false, false, false, false, false, false, - false, false, true, true, true, - ]); - // `resize` resizes the buffer length to the ceil of byte numbers. - // So the underlying buffer is not changed. - null_buffer_builder.resize(13); - assert_eq!(null_buffer_builder.len(), 13); - - let nulls = null_buffer_builder.finish(); - assert_eq!(nulls.count_set_bits(), 0); - let nulls = NullBuffer::new(nulls); - assert_eq!(nulls.null_count(), 13); - - let mut data_buffer_builder = BufferBuilder::::new(13); - data_buffer_builder.append_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - let data_buffer = data_buffer_builder.finish(); - - let arg1: Int32Array = ArrayDataBuilder::new(DataType::Int32) - .len(13) - .nulls(Some(nulls)) - .buffers(vec![data_buffer]) - .build() - .unwrap() - .into(); - - assert_eq!(arg1.null_count(), 13); - - let mut data_buffer_builder = BufferBuilder::::new(13); - data_buffer_builder.append_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - let data_buffer = data_buffer_builder.finish(); - - let arg2: Int32Array = ArrayDataBuilder::new(DataType::Int32) - .len(13) - .buffers(vec![data_buffer]) - .build() - .unwrap() - .into(); - - assert_eq!(arg2.null_count(), 0); - - let result_dyn = add_dyn(&arg1, &arg2).unwrap(); - let result = result_dyn.as_any().downcast_ref::().unwrap(); - - assert_eq!(result.len(), 13); - assert_eq!(result.null_count(), 13); - } - - #[test] - fn test_primitive_array_add_mut_by_binary_mut() { - let a = Int32Array::from(vec![15, 14, 9, 8, 1]); - let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); - - let c = binary_mut(a, &b, |a, b| a.add_wrapping(b)) - .unwrap() - .unwrap(); - let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_add_mut_wrapping_overflow_by_try_binary_mut() { - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - let b = Int32Array::from(vec![1, 1]); - - let wrapped = binary_mut(a, &b, |a, b| a.add_wrapping(b)) - .unwrap() - .unwrap(); - let expected = Int32Array::from(vec![-2147483648, -2147483647]); - assert_eq!(expected, wrapped); - - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - let b = Int32Array::from(vec![1, 1]); - let overflow = try_binary_mut(a, &b, |a, b| a.add_checked(b)); - let _ = overflow.unwrap().expect_err("overflow should be detected"); - } - - #[test] - fn test_primitive_add_scalar_by_unary_mut() { - let a = Int32Array::from(vec![15, 14, 9, 8, 1]); - let b = 3; - let c = unary_mut(a, |value| value.add_wrapping(b)).unwrap(); - let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); - assert_eq!(c, expected); - } - - #[test] - fn test_primitive_add_scalar_overflow_by_try_unary_mut() { - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - - let wrapped = unary_mut(a, |value| value.add_wrapping(1)).unwrap(); - let expected = Int32Array::from(vec![-2147483648, -2147483647]); - assert_eq!(expected, wrapped); - - let a = Int32Array::from(vec![i32::MAX, i32::MIN]); - let overflow = try_unary_mut(a, |value| value.add_checked(1)); - let _ = overflow.unwrap().expect_err("overflow should be detected"); - } - - #[test] - fn test_decimal_add_scalar_dyn() { - let a = Decimal128Array::from(vec![100, 210, 320]) - .with_precision_and_scale(38, 2) - .unwrap(); - - let result = add_scalar_dyn::(&a, 1).unwrap(); - let result = result - .as_primitive::() - .clone() - .with_precision_and_scale(38, 2) - .unwrap(); - let expected = Decimal128Array::from(vec![101, 211, 321]) - .with_precision_and_scale(38, 2) - .unwrap(); - - assert_eq!(&expected, &result); - } - - #[test] - fn test_decimal_multiply_allow_precision_loss() { - // Overflow happening as i128 cannot hold multiplying result. - // [123456789] - let a = Decimal128Array::from(vec![123456789000000000000000000]) - .with_precision_and_scale(38, 18) - .unwrap(); - - // [10] - let b = Decimal128Array::from(vec![10000000000000000000]) - .with_precision_and_scale(38, 18) - .unwrap(); - - let err = multiply_dyn_checked(&a, &b).unwrap_err(); - assert!(err.to_string().contains( - "Overflow happened on: 123456789000000000000000000 * 10000000000000000000" - )); - - // Allow precision loss. - let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); - // [1234567890] - let expected = - Decimal128Array::from(vec![12345678900000000000000000000000000000]) - .with_precision_and_scale(38, 28) - .unwrap(); - - assert_eq!(&expected, &result); - assert_eq!( - result.value_as_string(0), - "1234567890.0000000000000000000000000000" - ); + assert_eq!(&expected, &result); + assert_eq!( + result.value_as_string(0), + "1234567890.0000000000000000000000000000" + ); // Rounding case // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555] @@ -3165,11 +283,8 @@ mod tests { assert_eq!(result.precision(), 9); assert_eq!(result.scale(), 4); - let expected = multiply_checked(&a, &b) - .unwrap() - .with_precision_and_scale(9, 4) - .unwrap(); - assert_eq!(&expected, &result); + let expected = mul(&a, &b).unwrap(); + assert_eq!(expected.as_ref(), &result); // Required scale cannot be larger than the product of the input scales. let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err(); @@ -3217,12 +332,8 @@ mod tests { .unwrap(); // `multiply` overflows on this case. - let result = multiply(&a, &b).unwrap(); - let expected = - Decimal128Array::from(vec![-16672482290199102048610367863168958464]) - .with_precision_and_scale(38, 10) - .unwrap(); - assert_eq!(&expected, &result); + let err = mul(&a, &b).unwrap_err(); + assert_eq!(err.to_string(), "Compute error: Overflow happened on: 123456789000000000000000000 * 10000000000000000000"); // Avoid overflow by reducing the scale. let result = multiply_fixed_point(&a, &b, 28).unwrap(); @@ -3238,717 +349,4 @@ mod tests { "1234567890.0000000000000000000000000000" ); } - - #[test] - fn test_timestamp_second_add_interval() { - // timestamp second + interval year month - let a = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampSecondArray::from(vec![ - 1 + SECONDS_IN_DAY * (365 + 31 + 28), - 2 + SECONDS_IN_DAY * (365 + 31 + 28), - 3 + SECONDS_IN_DAY * (365 + 31 + 28), - 4 + SECONDS_IN_DAY * (365 + 31 + 28), - 5 + SECONDS_IN_DAY * (365 + 31 + 28), - ]); - assert_eq!(result, &expected); - - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp second + interval day time - let a = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampSecondArray::from(vec![ - 1 + SECONDS_IN_DAY, - 2 + SECONDS_IN_DAY, - 3 + SECONDS_IN_DAY, - 4 + SECONDS_IN_DAY, - 5 + SECONDS_IN_DAY, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp second + interval month day nanosecond - let a = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampSecondArray::from(vec![ - 1 + SECONDS_IN_DAY, - 2 + SECONDS_IN_DAY, - 3 + SECONDS_IN_DAY, - 4 + SECONDS_IN_DAY, - 5 + SECONDS_IN_DAY, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - } - - #[test] - fn test_timestamp_second_subtract_interval() { - // timestamp second + interval year month - let a = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampSecondArray::from(vec![ - 1 - SECONDS_IN_DAY * (31 + 30 + 365), - 2 - SECONDS_IN_DAY * (31 + 30 + 365), - 3 - SECONDS_IN_DAY * (31 + 30 + 365), - 4 - SECONDS_IN_DAY * (31 + 30 + 365), - 5 - SECONDS_IN_DAY * (31 + 30 + 365), - ]); - assert_eq!(&expected, result); - - // timestamp second + interval day time - let a = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampSecondArray::from(vec![ - 1 - SECONDS_IN_DAY, - 2 - SECONDS_IN_DAY, - 3 - SECONDS_IN_DAY, - 4 - SECONDS_IN_DAY, - 5 - SECONDS_IN_DAY, - ]); - assert_eq!(&expected, result); - - // timestamp second + interval month day nanosecond - let a = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampSecondArray::from(vec![ - 1 - SECONDS_IN_DAY, - 2 - SECONDS_IN_DAY, - 3 - SECONDS_IN_DAY, - 4 - SECONDS_IN_DAY, - 5 - SECONDS_IN_DAY, - ]); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_millisecond_add_interval() { - // timestamp millisecond + interval year month - let a = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMillisecondArray::from(vec![ - 1 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000, - 2 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000, - 3 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000, - 4 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000, - 5 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000, - ]); - assert_eq!(result, &expected); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp millisecond + interval day time - let a = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMillisecondArray::from(vec![ - 1 + SECONDS_IN_DAY * 1_000, - 2 + SECONDS_IN_DAY * 1_000, - 3 + SECONDS_IN_DAY * 1_000, - 4 + SECONDS_IN_DAY * 1_000, - 5 + SECONDS_IN_DAY * 1_000, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp millisecond + interval month day nanosecond - let a = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMillisecondArray::from(vec![ - 1 + SECONDS_IN_DAY * 1_000, - 2 + SECONDS_IN_DAY * 1_000, - 3 + SECONDS_IN_DAY * 1_000, - 4 + SECONDS_IN_DAY * 1_000, - 5 + SECONDS_IN_DAY * 1_000, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - } - - #[test] - fn test_timestamp_millisecond_subtract_interval() { - // timestamp millisecond + interval year month - let a = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMillisecondArray::from(vec![ - 1 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000, - 2 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000, - 3 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000, - 4 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000, - 5 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000, - ]); - assert_eq!(&expected, result); - - // timestamp millisecond + interval day time - let a = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMillisecondArray::from(vec![ - 1 - SECONDS_IN_DAY * 1_000, - 2 - SECONDS_IN_DAY * 1_000, - 3 - SECONDS_IN_DAY * 1_000, - 4 - SECONDS_IN_DAY * 1_000, - 5 - SECONDS_IN_DAY * 1_000, - ]); - assert_eq!(&expected, result); - - // timestamp millisecond + interval month day nanosecond - let a = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMillisecondArray::from(vec![ - 1 - SECONDS_IN_DAY * 1_000, - 2 - SECONDS_IN_DAY * 1_000, - 3 - SECONDS_IN_DAY * 1_000, - 4 - SECONDS_IN_DAY * 1_000, - 5 - SECONDS_IN_DAY * 1_000, - ]); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_microsecond_add_interval() { - // timestamp microsecond + interval year month - let a = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMicrosecondArray::from(vec![ - 1 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000, - 2 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000, - 3 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000, - 4 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000, - 5 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000, - ]); - assert_eq!(result, &expected); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp microsecond + interval day time - let a = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMicrosecondArray::from(vec![ - 1 + SECONDS_IN_DAY * 1_000_000, - 2 + SECONDS_IN_DAY * 1_000_000, - 3 + SECONDS_IN_DAY * 1_000_000, - 4 + SECONDS_IN_DAY * 1_000_000, - 5 + SECONDS_IN_DAY * 1_000_000, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp microsecond + interval month day nanosecond - let a = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMicrosecondArray::from(vec![ - 1 + SECONDS_IN_DAY * 1_000_000, - 2 + SECONDS_IN_DAY * 1_000_000, - 3 + SECONDS_IN_DAY * 1_000_000, - 4 + SECONDS_IN_DAY * 1_000_000, - 5 + SECONDS_IN_DAY * 1_000_000, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - } - - #[test] - fn test_timestamp_microsecond_subtract_interval() { - // timestamp microsecond + interval year month - let a = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMicrosecondArray::from(vec![ - 1 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000, - 2 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000, - 3 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000, - 4 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000, - 5 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000, - ]); - assert_eq!(&expected, result); - - // timestamp microsecond + interval day time - let a = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMicrosecondArray::from(vec![ - 1 - SECONDS_IN_DAY * 1_000_000, - 2 - SECONDS_IN_DAY * 1_000_000, - 3 - SECONDS_IN_DAY * 1_000_000, - 4 - SECONDS_IN_DAY * 1_000_000, - 5 - SECONDS_IN_DAY * 1_000_000, - ]); - assert_eq!(&expected, result); - - // timestamp microsecond + interval month day nanosecond - let a = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampMicrosecondArray::from(vec![ - 1 - SECONDS_IN_DAY * 1_000_000, - 2 - SECONDS_IN_DAY * 1_000_000, - 3 - SECONDS_IN_DAY * 1_000_000, - 4 - SECONDS_IN_DAY * 1_000_000, - 5 - SECONDS_IN_DAY * 1_000_000, - ]); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_nanosecond_add_interval() { - // timestamp nanosecond + interval year month - let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampNanosecondArray::from(vec![ - 1 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000_000, - 2 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000_000, - 3 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000_000, - 4 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000_000, - 5 + SECONDS_IN_DAY * (31 + 28 + 365) * 1_000_000_000, - ]); - assert_eq!(result, &expected); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp nanosecond + interval day time - let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampNanosecondArray::from(vec![ - 1 + SECONDS_IN_DAY * 1_000_000_000, - 2 + SECONDS_IN_DAY * 1_000_000_000, - 3 + SECONDS_IN_DAY * 1_000_000_000, - 4 + SECONDS_IN_DAY * 1_000_000_000, - 5 + SECONDS_IN_DAY * 1_000_000_000, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - - // timestamp nanosecond + interval month day nanosecond - let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = add_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampNanosecondArray::from(vec![ - 1 + SECONDS_IN_DAY * 1_000_000_000, - 2 + SECONDS_IN_DAY * 1_000_000_000, - 3 + SECONDS_IN_DAY * 1_000_000_000, - 4 + SECONDS_IN_DAY * 1_000_000_000, - 5 + SECONDS_IN_DAY * 1_000_000_000, - ]); - assert_eq!(&expected, result); - let result = add_dyn(&b, &a).unwrap(); - let result = result.as_primitive::(); - assert_eq!(result, &expected); - } - - #[test] - fn test_timestamp_nanosecond_subtract_interval() { - // timestamp nanosecond + interval year month - let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - Some(IntervalYearMonthType::make_value(1, 2)), - ]); - - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampNanosecondArray::from(vec![ - 1 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000_000, - 2 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000_000, - 3 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000_000, - 4 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000_000, - 5 - SECONDS_IN_DAY * (31 + 30 + 365) * 1_000_000_000, - ]); - assert_eq!(&expected, result); - - // timestamp nanosecond + interval day time - let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - Some(IntervalDayTimeType::make_value(1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampNanosecondArray::from(vec![ - 1 - SECONDS_IN_DAY * 1_000_000_000, - 2 - SECONDS_IN_DAY * 1_000_000_000, - 3 - SECONDS_IN_DAY * 1_000_000_000, - 4 - SECONDS_IN_DAY * 1_000_000_000, - 5 - SECONDS_IN_DAY * 1_000_000_000, - ]); - assert_eq!(&expected, result); - - // timestamp nanosecond + interval month day nanosecond - let a = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]); - let b = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - Some(IntervalMonthDayNanoType::make_value(0, 1, 0)), - ]); - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - - let expected = TimestampNanosecondArray::from(vec![ - 1 - SECONDS_IN_DAY * 1_000_000_000, - 2 - SECONDS_IN_DAY * 1_000_000_000, - 3 - SECONDS_IN_DAY * 1_000_000_000, - 4 - SECONDS_IN_DAY * 1_000_000_000, - 5 - SECONDS_IN_DAY * 1_000_000_000, - ]); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_second_subtract_timestamp() { - let a = TimestampSecondArray::from(vec![0, 2, 4, 6, 8]); - let b = TimestampSecondArray::from(vec![1, 2, 3, 4, 5]); - let expected = DurationSecondArray::from(vec![-1, 0, 1, 2, 3]); - - // unchecked - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - - // checked - let result = subtract_dyn_checked(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_second_subtract_timestamp_overflow() { - let a = TimestampSecondArray::from(vec![ - ::Native::MAX, - ]); - let b = TimestampSecondArray::from(vec![ - ::Native::MIN, - ]); - - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); - - // checked - let result = subtract_dyn_checked(&a, &b); - assert!(&result.is_err()); - } - - #[test] - fn test_timestamp_microsecond_subtract_timestamp() { - let a = TimestampMicrosecondArray::from(vec![0, 2, 4, 6, 8]); - let b = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]); - let expected = DurationMicrosecondArray::from(vec![-1, 0, 1, 2, 3]); - - // unchecked - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - - // checked - let result = subtract_dyn_checked(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_microsecond_subtract_timestamp_overflow() { - let a = TimestampMicrosecondArray::from(vec![ - ::Native::MAX, - ]); - let b = TimestampMicrosecondArray::from(vec![ - ::Native::MIN, - ]); - - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); - - // checked - let result = subtract_dyn_checked(&a, &b); - assert!(&result.is_err()); - } - - #[test] - fn test_timestamp_millisecond_subtract_timestamp() { - let a = TimestampMillisecondArray::from(vec![0, 2, 4, 6, 8]); - let b = TimestampMillisecondArray::from(vec![1, 2, 3, 4, 5]); - let expected = DurationMillisecondArray::from(vec![-1, 0, 1, 2, 3]); - - // unchecked - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - - // checked - let result = subtract_dyn_checked(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_millisecond_subtract_timestamp_overflow() { - let a = TimestampMillisecondArray::from(vec![ - ::Native::MAX, - ]); - let b = TimestampMillisecondArray::from(vec![ - ::Native::MIN, - ]); - - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); - - // checked - let result = subtract_dyn_checked(&a, &b); - assert!(&result.is_err()); - } - - #[test] - fn test_timestamp_nanosecond_subtract_timestamp() { - let a = TimestampNanosecondArray::from(vec![0, 2, 4, 6, 8]); - let b = TimestampNanosecondArray::from(vec![1, 2, 3, 4, 5]); - let expected = DurationNanosecondArray::from(vec![-1, 0, 1, 2, 3]); - - // unchecked - let result = subtract_dyn(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - - // checked - let result = subtract_dyn_checked(&a, &b).unwrap(); - let result = result.as_primitive::(); - assert_eq!(&expected, result); - } - - #[test] - fn test_timestamp_nanosecond_subtract_timestamp_overflow() { - let a = TimestampNanosecondArray::from(vec![ - ::Native::MAX, - ]); - let b = TimestampNanosecondArray::from(vec![ - ::Native::MIN, - ]); - - // unchecked - let result = subtract_dyn(&a, &b); - assert!(!&result.is_err()); - - // checked - let result = subtract_dyn_checked(&a, &b); - assert!(&result.is_err()); - } } diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index ce766aff66f7..f3118d104536 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -18,7 +18,6 @@ //! Defines kernels suitable to perform operations to primitive arrays. use arrow_array::builder::BufferBuilder; -use arrow_array::iterator::ArrayIter; use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::*; use arrow_buffer::buffer::NullBuffer; @@ -83,7 +82,7 @@ where { let dict_values = array.values().as_any().downcast_ref().unwrap(); let values = unary::(dict_values, op); - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(Arc::new(values)))) } /// A helper function that applies a fallible unary function to a dictionary array with primitive value type. @@ -106,10 +105,11 @@ where let dict_values = array.values().as_any().downcast_ref().unwrap(); let values = try_unary::(dict_values, op)?; - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(Arc::new(values)))) } /// Applies an infallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] pub fn unary_dyn(array: &dyn Array, op: F) -> Result where T: ArrowPrimitiveType, @@ -135,6 +135,7 @@ where } /// Applies a fallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result where T: ArrowPrimitiveType, @@ -199,7 +200,7 @@ where return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); } - let nulls = NullBuffer::union(a.nulls(), b.nulls()); + let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()); let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r)); // JUSTIFICATION @@ -249,7 +250,7 @@ where )))); } - let nulls = NullBuffer::union(a.nulls(), b.nulls()); + let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()); let mut builder = a.into_builder()?; @@ -297,7 +298,9 @@ where if a.null_count() == 0 && b.null_count() == 0 { try_binary_no_nulls(len, a, b, op) } else { - let nulls = NullBuffer::union(a.nulls(), b.nulls()).unwrap(); + let nulls = + NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()) + .unwrap(); let mut buffer = BufferBuilder::::new(len); buffer.append_n_zeroed(len); @@ -356,7 +359,10 @@ where if a.null_count() == 0 && b.null_count() == 0 { try_binary_no_nulls_mut(len, a, b, op) } else { - let nulls = NullBuffer::union(a.nulls(), b.nulls()).unwrap(); + let nulls = + NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()) + .unwrap(); + let mut builder = a.into_builder()?; let slice = builder.values_slice_mut(); @@ -425,76 +431,6 @@ where Ok(Ok(builder.finish())) } -#[inline(never)] -fn try_binary_opt_no_nulls( - len: usize, - a: A, - b: B, - op: F, -) -> Result, ArrowError> -where - O: ArrowPrimitiveType, - F: Fn(A::Item, B::Item) -> Option, -{ - let mut buffer = Vec::with_capacity(10); - for idx in 0..len { - unsafe { - buffer.push(op(a.value_unchecked(idx), b.value_unchecked(idx))); - }; - } - Ok(buffer.iter().collect()) -} - -/// Applies the provided binary operation across `a` and `b`, collecting the optional results -/// into a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the corresponding -/// index in the result will also be null. The binary operation could return `None` which -/// results in a new null in the collected [`PrimitiveArray`]. -/// -/// The function is only evaluated for non-null indices -/// -/// # Error -/// -/// This function gives error if the arrays have different lengths -pub(crate) fn binary_opt( - a: A, - b: B, - op: F, -) -> Result, ArrowError> -where - O: ArrowPrimitiveType, - F: Fn(A::Item, B::Item) -> Option, -{ - if a.len() != b.len() { - return Err(ArrowError::ComputeError( - "Cannot perform binary operation on arrays of different length".to_string(), - )); - } - - if a.is_empty() { - return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); - } - - if a.null_count() == 0 && b.null_count() == 0 { - return try_binary_opt_no_nulls(a.len(), a, b, op); - } - - let iter_a = ArrayIter::new(a); - let iter_b = ArrayIter::new(b); - - let values = iter_a - .into_iter() - .zip(iter_b.into_iter()) - .map(|(item_a, item_b)| { - if let (Some(a), Some(b)) = (item_a, item_b) { - op(a, b) - } else { - None - } - }); - - Ok(values.collect()) -} - #[cfg(test)] mod tests { use super::*; @@ -502,6 +438,7 @@ mod tests { use arrow_array::types::*; #[test] + #[allow(deprecated)] fn test_unary_f64_slice() { let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); @@ -521,6 +458,7 @@ mod tests { } #[test] + #[allow(deprecated)] fn test_unary_dict_and_unary_dyn() { let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(5).unwrap(); diff --git a/arrow-arith/src/boolean.rs b/arrow-arith/src/boolean.rs index 04c9fb229034..46e5998208f1 100644 --- a/arrow-arith/src/boolean.rs +++ b/arrow-arith/src/boolean.rs @@ -311,7 +311,7 @@ pub fn not(left: &BooleanArray) -> Result { /// assert_eq!(a_is_null, BooleanArray::from(vec![false, false, true])); /// ``` pub fn is_null(input: &dyn Array) -> Result { - let values = match input.nulls() { + let values = match input.logical_nulls() { None => BooleanBuffer::new_unset(input.len()), Some(nulls) => !nulls.inner(), }; @@ -331,7 +331,7 @@ pub fn is_null(input: &dyn Array) -> Result { /// assert_eq!(a_is_not_null, BooleanArray::from(vec![true, true, false])); /// ``` pub fn is_not_null(input: &dyn Array) -> Result { - let values = match input.nulls() { + let values = match input.logical_nulls() { None => BooleanBuffer::new_set(input.len()), Some(n) => n.inner().clone(), }; @@ -871,4 +871,28 @@ mod tests { assert_eq!(expected, res); assert!(res.nulls().is_none()); } + + #[test] + fn test_null_array_is_null() { + let a = NullArray::new(3); + + let res = is_null(&a).unwrap(); + + let expected = BooleanArray::from(vec![true, true, true]); + + assert_eq!(expected, res); + assert!(res.nulls().is_none()); + } + + #[test] + fn test_null_array_is_not_null() { + let a = NullArray::new(3); + + let res = is_not_null(&a).unwrap(); + + let expected = BooleanArray::from(vec![false, false, false]); + + assert_eq!(expected, res); + assert!(res.nulls().is_none()); + } } diff --git a/arrow-arith/src/lib.rs b/arrow-arith/src/lib.rs index 60d31c972b66..2d5451e04dd2 100644 --- a/arrow-arith/src/lib.rs +++ b/arrow-arith/src/lib.rs @@ -18,8 +18,10 @@ //! Arrow arithmetic and aggregation kernels pub mod aggregate; +#[doc(hidden)] // Kernels to be removed in a future release pub mod arithmetic; pub mod arity; pub mod bitwise; pub mod boolean; +pub mod numeric; pub mod temporal; diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs new file mode 100644 index 000000000000..c47731ed5125 --- /dev/null +++ b/arrow-arith/src/numeric.rs @@ -0,0 +1,1525 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines numeric arithmetic kernels on [`PrimitiveArray`], such as [`add`] + +use std::cmp::Ordering; +use std::fmt::Formatter; +use std::sync::Arc; + +use arrow_array::cast::AsArray; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; + +use crate::arity::{binary, try_binary}; + +/// Perform `lhs + rhs`, returning an error on overflow +pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Add, lhs, rhs) +} + +/// Perform `lhs + rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::AddWrapping, lhs, rhs) +} + +/// Perform `lhs - rhs`, returning an error on overflow +pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Sub, lhs, rhs) +} + +/// Perform `lhs - rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::SubWrapping, lhs, rhs) +} + +/// Perform `lhs * rhs`, returning an error on overflow +pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Mul, lhs, rhs) +} + +/// Perform `lhs * rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::MulWrapping, lhs, rhs) +} + +/// Perform `lhs / rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Div, lhs, rhs) +} + +/// Perform `lhs % rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Rem, lhs, rhs) +} + +macro_rules! neg_checked { + ($t:ty, $a:ident) => {{ + let array = $a + .as_primitive::<$t>() + .try_unary::<_, $t, _>(|x| x.neg_checked())?; + Ok(Arc::new(array)) + }}; +} + +macro_rules! neg_wrapping { + ($t:ty, $a:ident) => {{ + let array = $a.as_primitive::<$t>().unary::<_, $t>(|x| x.neg_wrapping()); + Ok(Arc::new(array)) + }}; +} + +/// Negates each element of `array`, returning an error on overflow +/// +/// Note: negation of unsigned arrays is not supported and will return in an error, +/// for wrapping unsigned negation consider using [`neg_wrapping`][neg_wrapping()] +pub fn neg(array: &dyn Array) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + + match array.data_type() { + Int8 => neg_checked!(Int8Type, array), + Int16 => neg_checked!(Int16Type, array), + Int32 => neg_checked!(Int32Type, array), + Int64 => neg_checked!(Int64Type, array), + Float16 => neg_wrapping!(Float16Type, array), + Float32 => neg_wrapping!(Float32Type, array), + Float64 => neg_wrapping!(Float64Type, array), + Decimal128(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } + Decimal256(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } + Duration(Second) => neg_checked!(DurationSecondType, array), + Duration(Millisecond) => neg_checked!(DurationMillisecondType, array), + Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array), + Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array), + Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array), + Interval(DayTime) => { + let a = array + .as_primitive::() + .try_unary::<_, IntervalDayTimeType, ArrowError>(|x| { + let (days, ms) = IntervalDayTimeType::to_parts(x); + Ok(IntervalDayTimeType::make_value( + days.neg_checked()?, + ms.neg_checked()?, + )) + })?; + Ok(Arc::new(a)) + } + Interval(MonthDayNano) => { + let a = array + .as_primitive::() + .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x); + Ok(IntervalMonthDayNanoType::make_value( + months.neg_checked()?, + days.neg_checked()?, + nanos.neg_checked()?, + )) + })?; + Ok(Arc::new(a)) + } + t => Err(ArrowError::InvalidArgumentError(format!( + "Invalid arithmetic operation: !{t}" + ))), + } +} + +/// Negates each element of `array`, wrapping on overflow for [`DataType::is_integer`] +pub fn neg_wrapping(array: &dyn Array) -> Result { + downcast_integer! { + array.data_type() => (neg_wrapping, array), + _ => neg(array), + } +} + +/// An enumeration of arithmetic operations +/// +/// This allows sharing the type dispatch logic across the various kernels +#[derive(Debug, Copy, Clone)] +enum Op { + AddWrapping, + Add, + SubWrapping, + Sub, + MulWrapping, + Mul, + Div, + Rem, +} + +impl std::fmt::Display for Op { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Op::AddWrapping | Op::Add => write!(f, "+"), + Op::SubWrapping | Op::Sub => write!(f, "-"), + Op::MulWrapping | Op::Mul => write!(f, "*"), + Op::Div => write!(f, "/"), + Op::Rem => write!(f, "%"), + } + } +} + +impl Op { + fn commutative(&self) -> bool { + matches!(self, Self::Add | Self::AddWrapping) + } +} + +/// Dispatch the given `op` to the appropriate specialized kernel +fn arithmetic_op( + op: Op, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + + macro_rules! integer_helper { + ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident, $r_scalar:ident) => { + integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar) + }; + } + + let (l, l_scalar) = lhs.get(); + let (r, r_scalar) = rhs.get(); + downcast_integer! { + l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r, r_scalar), + (Float16, Float16) => float_op::(op, l, l_scalar, r, r_scalar), + (Float32, Float32) => float_op::(op, l, l_scalar, r, r_scalar), + (Float64, Float64) => float_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Second, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Millisecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Microsecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Nanosecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Duration(Second), Duration(Second)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Millisecond), Duration(Millisecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Microsecond), Duration(Microsecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Nanosecond), Duration(Nanosecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Interval(YearMonth), Interval(YearMonth)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(DayTime), Interval(DayTime)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Decimal128(_, _), Decimal128(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (Decimal256(_, _), Decimal256(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (l_t, r_t) => match (l_t, r_t) { + (Duration(_) | Interval(_), Date32 | Date64 | Timestamp(_, _)) if op.commutative() => { + arithmetic_op(op, rhs, lhs) + } + _ => Err(ArrowError::InvalidArgumentError( + format!("Invalid arithmetic operation: {l_t} {op} {r_t}") + )) + } + } +} + +/// Perform an infallible binary operation on potentially scalar inputs +macro_rules! op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.unary(|$r| $op), + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.unary(|$l| $op), + }, + } + }; +} + +/// Same as `op` but with a type hint for the returned array +macro_rules! op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform a fallible binary operation on potentially scalar inputs +macro_rules! try_op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.try_unary(|$r| $op)?, + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.try_unary(|$l| $op)?, + }, + } + }; +} + +/// Same as `try_op` but with a type hint for the returned array +macro_rules! try_op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform an arithmetic operation on integers +fn integer_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)), + Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)), + Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)), + Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)), + Op::Rem => try_op!(l, l_s, r, r_s, l.mod_checked(r)), + }; + Ok(Arc::new(array)) +} + +/// Perform an arithmetic operation on floats +fn float_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)), + Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)), + }; + Ok(Arc::new(array)) +} + +/// Arithmetic trait for timestamp arrays +trait TimestampOp: ArrowTimestampType { + type Duration: ArrowPrimitiveType; + + fn add_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option; + fn add_day_time(timestamp: i64, delta: i64, tz: Tz) -> Option; + fn add_month_day_nano(timestamp: i64, delta: i128, tz: Tz) -> Option; + + fn sub_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option; + fn sub_day_time(timestamp: i64, delta: i64, tz: Tz) -> Option; + fn sub_month_day_nano(timestamp: i64, delta: i128, tz: Tz) -> Option; +} + +macro_rules! timestamp { + ($t:ty, $d:ty) => { + impl TimestampOp for $t { + type Duration = $d; + + fn add_year_month(left: i64, right: i32, tz: Tz) -> Option { + Self::add_year_months(left, right, tz) + } + + fn add_day_time(left: i64, right: i64, tz: Tz) -> Option { + Self::add_day_time(left, right, tz) + } + + fn add_month_day_nano(left: i64, right: i128, tz: Tz) -> Option { + Self::add_month_day_nano(left, right, tz) + } + + fn sub_year_month(left: i64, right: i32, tz: Tz) -> Option { + Self::subtract_year_months(left, right, tz) + } + + fn sub_day_time(left: i64, right: i64, tz: Tz) -> Option { + Self::subtract_day_time(left, right, tz) + } + + fn sub_month_day_nano(left: i64, right: i128, tz: Tz) -> Option { + Self::subtract_month_day_nano(left, right, tz) + } + } + }; +} +timestamp!(TimestampSecondType, DurationSecondType); +timestamp!(TimestampMillisecondType, DurationMillisecondType); +timestamp!(TimestampMicrosecondType, DurationMicrosecondType); +timestamp!(TimestampNanosecondType, DurationNanosecondType); + +/// Perform arithmetic operation on a timestamp array +fn timestamp_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + let l = l.as_primitive::(); + let l_tz: Tz = l.timezone().unwrap_or("+00:00").parse()?; + + let array: PrimitiveArray = match (op, r.data_type()) { + (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + return Ok(try_op_ref!(T::Duration, l, l_s, r, r_s, l.sub_checked(r))); + } + + (Op::Add | Op::AddWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.add_checked(r)) + } + (Op::Sub | Op::SubWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.sub_checked(r)) + } + + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid timestamp arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))) + } + }; + Ok(Arc::new(array.with_timezone_opt(l.timezone()))) +} + +/// Arithmetic trait for date arrays +/// +/// Note: these should be fallible (#4456) +trait DateOp: ArrowTemporalType { + fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn add_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn add_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; + + fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn sub_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn sub_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; +} + +macro_rules! date { + ($t:ty) => { + impl DateOp for $t { + fn add_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::add_year_months(left, right) + } + + fn add_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::add_day_time(left, right) + } + + fn add_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::add_month_day_nano(left, right) + } + + fn sub_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::subtract_year_months(left, right) + } + + fn sub_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::subtract_day_time(left, right) + } + + fn sub_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::subtract_month_day_nano(left, right) + } + } + }; +} +date!(Date32Type); +date!(Date64Type); + +/// Arithmetic trait for interval arrays +trait IntervalOp: ArrowPrimitiveType { + fn add(left: Self::Native, right: Self::Native) -> Result; + fn sub(left: Self::Native, right: Self::Native) -> Result; +} + +impl IntervalOp for IntervalYearMonthType { + fn add(left: Self::Native, right: Self::Native) -> Result { + left.add_checked(right) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + left.sub_checked(right) + } +} + +impl IntervalOp for IntervalDayTimeType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.add_checked(r_days)?; + let ms = l_ms.add_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.sub_checked(r_days)?; + let ms = l_ms.sub_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } +} + +impl IntervalOp for IntervalMonthDayNanoType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.add_checked(r_months)?; + let days = l_days.add_checked(r_days)?; + let nanos = l_nanos.add_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.sub_checked(r_months)?; + let days = l_days.sub_checked(r_days)?; + let nanos = l_nanos.sub_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } +} + +/// Perform arithmetic operation on an interval array +fn interval_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::add(l, r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub(l, r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +fn duration_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.add_checked(r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.sub_checked(r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid duration arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on a date array +fn date_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + const NUM_SECONDS_IN_DAY: i64 = 60 * 60 * 24; + + let r_t = r.data_type(); + match (T::DATA_TYPE, op, r_t) { + (Date32, Op::Sub | Op::SubWrapping, Date32) => { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + return Ok(op_ref!( + DurationSecondType, + l, + l_s, + r, + r_s, + ((l as i64) - (r as i64)) * NUM_SECONDS_IN_DAY + )); + } + (Date64, Op::Sub | Op::SubWrapping, Date64) => { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let result = + try_op_ref!(DurationMillisecondType, l, l_s, r, r_s, l.sub_checked(r)); + return Ok(result); + } + _ => {} + } + + let l = l.as_primitive::(); + match (op, r_t) { + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) + } + + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid date arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on decimal arrays +fn decimal_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + + let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) { + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1, s1, p2, s2), + (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1, s1, p2, s2), + _ => unreachable!(), + }; + + // Follow the Hive decimal arithmetic rules + // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + let array: PrimitiveArray = match op { + Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => { + // max(s1, s2) + let result_scale = *s1.max(s2); + + // max(s1, s2) + max(p1-s1, p2-s2) + 1 + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).max(*p2 as i8 - s2)) as u8) + .saturating_add(1) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_checked((result_scale - s1) as _)?; + let r_mul = T::Native::usize_as(10).pow_checked((result_scale - s2) as _)?; + + match op { + Op::Add | Op::AddWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.add_checked(r.mul_checked(r_mul)?) + ) + } + Op::Sub | Op::SubWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.sub_checked(r.mul_checked(r_mul)?) + ) + } + _ => unreachable!(), + } + .with_precision_and_scale(result_precision, result_scale)? + } + Op::Mul | Op::MulWrapping => { + let result_precision = p1.saturating_add(p2 + 1).min(T::MAX_PRECISION); + let result_scale = s1.saturating_add(*s2); + if result_scale > T::MAX_SCALE { + // SQL standard says that if the resulting scale of a multiply operation goes + // beyond the maximum, rounding is not acceptable and thus an error occurs + return Err(ArrowError::InvalidArgumentError(format!( + "Output scale of {} {op} {} would exceed max scale of {}", + l.data_type(), + r.data_type(), + T::MAX_SCALE + ))); + } + + try_op!(l, l_s, r, r_s, l.mul_checked(r)) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Div => { + // Follow postgres and MySQL adding a fixed scale increment of 4 + // s1 + 4 + let result_scale = s1.saturating_add(4).min(T::MAX_SCALE); + let mul_pow = result_scale - s1 + s2; + + // p1 - s1 + s2 + result_scale + let result_precision = + (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION); + + let (l_mul, r_mul) = match mul_pow.cmp(&0) { + Ordering::Greater => ( + T::Native::usize_as(10).pow_checked(mul_pow as _)?, + T::Native::ONE, + ), + Ordering::Equal => (T::Native::ONE, T::Native::ONE), + Ordering::Less => ( + T::Native::ONE, + T::Native::usize_as(10).pow_checked(mul_pow.neg_wrapping() as _)?, + ), + }; + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.div_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Rem => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // min(p1-s1, p2 -s2) + max( s1,s2 ) + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).min(*p2 as i8 - s2)) as u8) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s1) as _); + let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s2) as _); + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.mod_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + }; + + Ok(Arc::new(array)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::temporal_conversions::{as_date, as_datetime}; + use arrow_buffer::{i256, ScalarBuffer}; + use chrono::{DateTime, NaiveDate}; + + fn test_neg_primitive( + input: &[T::Native], + out: Result<&[T::Native], &str>, + ) { + let a = PrimitiveArray::::new(ScalarBuffer::from(input.to_vec()), None); + match out { + Ok(expected) => { + let result = neg(&a).unwrap(); + assert_eq!(result.as_primitive::().values(), expected); + } + Err(e) => { + let err = neg(&a).unwrap_err().to_string(); + assert_eq!(e, err); + } + } + } + + #[test] + fn test_neg() { + let input = &[1, -5, 2, 693, 3929]; + let output = &[-1, 5, -2, -693, -3929]; + test_neg_primitive::(input, Ok(output)); + + let input = &[1, -5, 2, 693, 3929]; + let output = &[-1, 5, -2, -693, -3929]; + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + + let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5]; + let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5]; + test_neg_primitive::(input, Ok(output)); + + test_neg_primitive::( + &[i32::MIN], + Err("Compute error: Overflow happened on: -2147483648"), + ); + test_neg_primitive::( + &[i64::MIN], + Err("Compute error: Overflow happened on: -9223372036854775808"), + ); + test_neg_primitive::( + &[i64::MIN], + Err("Compute error: Overflow happened on: -9223372036854775808"), + ); + + let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap(); + assert_eq!(r.as_primitive::().value(0), i32::MIN); + + let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap(); + assert_eq!(r.as_primitive::().value(0), i64::MIN); + + let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN])) + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Compute error: Overflow happened on: -9223372036854775808" + ); + + let a = Decimal128Array::from(vec![1, 3, -44, 2, 4]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[-1, -3, 44, -2, -4] + ); + + let a = Decimal256Array::from(vec![ + i256::from_i128(342), + i256::from_i128(-4949), + i256::from_i128(3), + ]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[ + i256::from_i128(-342), + i256::from_i128(4949), + i256::from_i128(-3), + ] + ); + + let a = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(2, 4), + IntervalYearMonthType::make_value(2, -4), + IntervalYearMonthType::make_value(-3, -5), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalYearMonthType::make_value(-2, -4), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(3, 5), + ] + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(2, 4), + IntervalDayTimeType::make_value(2, -4), + IntervalDayTimeType::make_value(-3, -5), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalDayTimeType::make_value(-2, -4), + IntervalDayTimeType::make_value(-2, 4), + IntervalDayTimeType::make_value(3, 5), + ] + ); + + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(2, 4, 5953394), + IntervalMonthDayNanoType::make_value(2, -4, -45839), + IntervalMonthDayNanoType::make_value(-3, -5, 6944), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalMonthDayNanoType::make_value(-2, -4, -5953394), + IntervalMonthDayNanoType::make_value(-2, 4, 45839), + IntervalMonthDayNanoType::make_value(3, 5, -6944), + ] + ); + } + + #[test] + fn test_integer() { + let a = Int32Array::from(vec![4, 3, 5, -6, 100]); + let b = Int32Array::from(vec![6, 2, 5, -7, 3]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Int32Array::from(vec![10, 5, 10, -13, 103]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![-2, 1, 0, 1, 97])); + let result = div(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![0, 1, 1, 0, 33])); + let result = mul(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![24, 6, 25, 42, 300])); + let result = rem(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![4, 1, 0, -6, 1])); + + let a = Int8Array::from(vec![Some(2), None, Some(45)]); + let b = Int8Array::from(vec![Some(5), Some(3), None]); + let result = add(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int8Array::from(vec![Some(7), None, None])); + + let a = UInt8Array::from(vec![56, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 56 + 200"); + let result = add_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![0, 7, 8])); + + let a = UInt8Array::from(vec![34, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = sub(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 34 - 200"); + let result = sub_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![90, 3, 254])); + + let a = UInt8Array::from(vec![34, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = mul(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 34 * 200"); + let result = mul_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![144, 10, 15])); + + let a = Int16Array::from(vec![i16::MIN]); + let b = Int16Array::from(vec![-1]); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: -32768 / -1"); + + let a = Int16Array::from(vec![21]); + let b = Int16Array::from(vec![0]); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + + let a = Int16Array::from(vec![21]); + let b = Int16Array::from(vec![0]); + let err = rem(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + } + + #[test] + fn test_float() { + let a = Float32Array::from(vec![1., f32::MAX, 6., -4., -1., 0.]); + let b = Float32Array::from(vec![1., f32::MAX, f32::MAX, -3., 45., 0.]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![2., f32::INFINITY, f32::MAX, -7., 44.0, 0.]) + ); + + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![0., 0., f32::MIN, -1., -46., 0.]) + ); + + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![1., f32::INFINITY, f32::INFINITY, 12., -45., 0.]) + ); + + let result = div(&a, &b).unwrap(); + let r = result.as_primitive::(); + assert_eq!(r.value(0), 1.); + assert_eq!(r.value(1), 1.); + assert!(r.value(2) < f32::EPSILON); + assert_eq!(r.value(3), -4. / -3.); + assert!(r.value(5).is_nan()); + + let result = rem(&a, &b).unwrap(); + let r = result.as_primitive::(); + assert_eq!(&r.values()[..5], &[0., 0., 6., -1., -1.]); + assert!(r.value(5).is_nan()); + } + + #[test] + fn test_decimal() { + // 0.015 7.842 -0.577 0.334 -0.078 0.003 + let a = Decimal128Array::from(vec![15, 0, -577, 334, -78, 3]) + .with_precision_and_scale(12, 3) + .unwrap(); + + // 5.4 0 -35.6 0.3 0.6 7.45 + let b = Decimal128Array::from(vec![54, 34, -356, 3, 6, 745]) + .with_precision_and_scale(12, 1) + .unwrap(); + + let result = add(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(15, 3)); + assert_eq!( + result.as_primitive::().values(), + &[5415, 3400, -36177, 634, 522, 74503] + ); + + let result = sub(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(15, 3)); + assert_eq!( + result.as_primitive::().values(), + &[-5385, -3400, 35023, 34, -678, -74497] + ); + + let result = mul(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(25, 4)); + assert_eq!( + result.as_primitive::().values(), + &[810, 0, 205412, 1002, -468, 2235] + ); + + let result = div(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(17, 7)); + assert_eq!( + result.as_primitive::().values(), + &[27777, 0, 162078, 11133333, -1300000, 402] + ); + + let result = rem(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(12, 3)); + assert_eq!( + result.as_primitive::().values(), + &[15, 0, -577, 34, -78, 3] + ); + + let a = Decimal128Array::from(vec![1]) + .with_precision_and_scale(3, 3) + .unwrap(); + let b = Decimal128Array::from(vec![1]) + .with_precision_and_scale(37, 37) + .unwrap(); + let err = mul(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Invalid argument error: Output scale of Decimal128(3, 3) * Decimal128(37, 37) would exceed max scale of 38"); + + let a = Decimal128Array::from(vec![1]) + .with_precision_and_scale(3, -2) + .unwrap(); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 10 ^ 39"); + + let a = Decimal128Array::from(vec![10]) + .with_precision_and_scale(3, -1) + .unwrap(); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 10 * 100000000000000000000000000000000000000"); + + let b = Decimal128Array::from(vec![0]) + .with_precision_and_scale(1, 1) + .unwrap(); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + let err = rem(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + } + + fn test_timestamp_impl() { + let a = PrimitiveArray::::new(vec![2000000, 434030324, 53943340].into(), None); + let b = PrimitiveArray::::new(vec![329593, 59349, 694994].into(), None); + + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[1670407, 433970975, 53248346] + ); + + let r2 = add(&b, &result.as_ref()).unwrap(); + assert_eq!(r2.as_ref(), &a); + + let r3 = add(&result.as_ref(), &b).unwrap(); + assert_eq!(r3.as_ref(), &a); + + let format_array = |x: &dyn Array| -> Vec { + x.as_primitive::() + .values() + .into_iter() + .map(|x| as_datetime::(*x).unwrap().to_string()) + .collect() + }; + + let values = vec![ + "1970-01-01T00:00:00Z", + "2010-04-01T04:00:20Z", + "1960-01-30T04:23:20Z", + ] + .into_iter() + .map(|x| { + T::make_value(DateTime::parse_from_rfc3339(x).unwrap().naive_utc()).unwrap() + }) + .collect(); + + let a = PrimitiveArray::::new(values, None); + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(5, 34), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(7, -4), + ]); + let r4 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r4.as_ref()), + &[ + "1977-11-01 00:00:00".to_string(), + "2008-08-01 04:00:20".to_string(), + "1966-09-30 04:23:20".to_string() + ] + ); + + let r5 = sub(&r4, &b).unwrap(); + assert_eq!(r5.as_ref(), &a); + + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(5, 454000), + IntervalDayTimeType::make_value(-34, 0), + IntervalDayTimeType::make_value(7, -4000), + ]); + let r6 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r6.as_ref()), + &[ + "1970-01-06 00:07:34".to_string(), + "2010-02-26 04:00:20".to_string(), + "1960-02-06 04:23:16".to_string() + ] + ); + + let r7 = sub(&r6, &b).unwrap(); + assert_eq!(r7.as_ref(), &a); + + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000), + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000), + IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000), + ]); + let r8 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r8.as_ref()), + &[ + "1998-10-04 23:59:17".to_string(), + "1960-09-29 04:00:33".to_string(), + "1960-07-02 04:31:33".to_string() + ] + ); + + let r9 = sub(&r8, &b).unwrap(); + // Note: subtraction is not the inverse of addition for intervals + assert_eq!( + &format_array(r9.as_ref()), + &[ + "1970-01-02 00:00:00".to_string(), + "2010-04-02 04:00:20".to_string(), + "1960-01-31 04:23:20".to_string() + ] + ); + } + + #[test] + fn test_timestamp() { + test_timestamp_impl::(); + test_timestamp_impl::(); + test_timestamp_impl::(); + test_timestamp_impl::(); + } + + #[test] + fn test_interval() { + let a = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(32, 4), + IntervalYearMonthType::make_value(32, 4), + ]); + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(-4, 6), + IntervalYearMonthType::make_value(-3, 23), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(28, 10), + IntervalYearMonthType::make_value(29, 27) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(36, -2), + IntervalYearMonthType::make_value(35, -19) + ]) + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(32, 4), + IntervalDayTimeType::make_value(32, 4), + ]); + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(-4, 6), + IntervalDayTimeType::make_value(-3, 23), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(28, 10), + IntervalDayTimeType::make_value(29, 27) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(36, -2), + IntervalDayTimeType::make_value(35, -19) + ]) + ); + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(32, 4, 4000000000000), + IntervalMonthDayNanoType::make_value(32, 4, 45463000000000000), + ]); + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(-4, 6, 46000000000000), + IntervalMonthDayNanoType::make_value(-3, 23, 3564000000000000), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(28, 10, 50000000000000), + IntervalMonthDayNanoType::make_value(29, 27, 49027000000000000) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(36, -2, -42000000000000), + IntervalMonthDayNanoType::make_value(35, -19, 41899000000000000) + ]) + ); + let a = IntervalMonthDayNanoArray::from(vec![i64::MAX as i128]); + let b = IntervalMonthDayNanoArray::from(vec![1]); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Compute error: Overflow happened on: 9223372036854775807 + 1" + ); + } + + fn test_duration_impl>() { + let a = PrimitiveArray::::new(vec![1000, 4394, -3944].into(), None); + let b = PrimitiveArray::::new(vec![4, -5, -243].into(), None); + + let result = add(&a, &b).unwrap(); + assert_eq!(result.as_primitive::().values(), &[1004, 4389, -4187]); + let result = sub(&a, &b).unwrap(); + assert_eq!(result.as_primitive::().values(), &[996, 4399, -3701]); + + let err = mul(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let err = div(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let err = rem(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let a = PrimitiveArray::::new(vec![i64::MAX].into(), None); + let b = PrimitiveArray::::new(vec![1].into(), None); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Compute error: Overflow happened on: 9223372036854775807 + 1" + ); + } + + #[test] + fn test_duration() { + test_duration_impl::(); + test_duration_impl::(); + test_duration_impl::(); + test_duration_impl::(); + } + + fn test_date_impl(f: F) + where + F: Fn(NaiveDate) -> T::Native, + T::Native: TryInto, + { + let a = PrimitiveArray::::new( + vec![ + f(NaiveDate::from_ymd_opt(1979, 1, 30).unwrap()), + f(NaiveDate::from_ymd_opt(2010, 4, 3).unwrap()), + f(NaiveDate::from_ymd_opt(2008, 2, 29).unwrap()), + ] + .into(), + None, + ); + + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(34, 2), + IntervalYearMonthType::make_value(3, -3), + IntervalYearMonthType::make_value(-12, 4), + ]); + + let format_array = |x: &dyn Array| -> Vec { + x.as_primitive::() + .values() + .into_iter() + .map(|x| { + as_date::((*x).try_into().ok().unwrap()) + .unwrap() + .to_string() + }) + .collect() + }; + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "2013-03-30".to_string(), + "2013-01-03".to_string(), + "1996-06-29".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!(result.as_ref(), &a); + + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(34, 2), + IntervalDayTimeType::make_value(3, -3), + IntervalDayTimeType::make_value(-12, 4), + ]); + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1979-03-05".to_string(), + "2010-04-06".to_string(), + "2008-02-17".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!(result.as_ref(), &a); + + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(34, 2, -34353534), + IntervalMonthDayNanoType::make_value(3, -3, 2443), + IntervalMonthDayNanoType::make_value(-12, 4, 2323242423232), + ]); + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1981-12-02".to_string(), + "2010-06-30".to_string(), + "2007-03-04".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1979-01-31".to_string(), + "2010-04-02".to_string(), + "2008-02-29".to_string(), + ] + ); + } + + #[test] + fn test_date() { + test_date_impl::(Date32Type::from_naive_date); + test_date_impl::(Date64Type::from_naive_date); + + let a = Date32Array::from(vec![i32::MIN, i32::MAX, 23, 7684]); + let b = Date32Array::from(vec![i32::MIN, i32::MIN, -2, 45]); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[0, 371085174288000, 2160000, 660009600] + ); + + let a = Date64Array::from(vec![4343, 76676, 3434]); + let b = Date64Array::from(vec![3, -5, 5]); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[4340, 76681, 3429] + ); + + let a = Date64Array::from(vec![i64::MAX]); + let b = Date64Array::from(vec![-1]); + let err = sub(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Compute error: Overflow happened on: 9223372036854775807 - -1" + ); + } +} diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs index 0a313718c907..7855b6fc6e46 100644 --- a/arrow-arith/src/temporal.rs +++ b/arrow-arith/src/temporal.rs @@ -181,26 +181,7 @@ pub fn using_chrono_tz_and_utc_naive_date_time( /// the range of [0, 23]. If the given array isn't temporal primitive or dictionary array, /// an `Err` will be returned. pub fn hour_dyn(array: &dyn Array) -> Result { - match array.data_type().clone() { - DataType::Dictionary(_, _) => { - downcast_dictionary_array!( - array => { - let hour_values = hour_dyn(array.values())?; - Ok(Arc::new(array.with_values(&hour_values))) - } - dt => return_compute_error_with!("hour does not support", dt), - ) - } - _ => { - downcast_temporal_array!( - array => { - hour(array) - .map(|a| Arc::new(a) as ArrayRef) - } - dt => return_compute_error_with!("hour does not support", dt), - ) - } - } + time_fraction_dyn(array, "hour", |t| t.hour() as i32) } /// Extracts the hours of a given temporal primitive array as an array of integers within @@ -481,7 +462,7 @@ where downcast_dictionary_array!( array => { let values = time_fraction_dyn(array.values(), name, op)?; - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(values))) } dt => return_compute_error_with!(format!("{name} does not support"), dt), ) @@ -940,37 +921,6 @@ mod tests { assert!(err.contains("Invalid timezone"), "{}", err); } - #[cfg(feature = "chrono-tz")] - #[test] - fn test_temporal_array_timestamp_hour_with_timezone_using_chrono_tz() { - let a = TimestampSecondArray::from(vec![60 * 60 * 10]) - .with_timezone("Asia/Kolkata".to_string()); - let b = hour(&a).unwrap(); - assert_eq!(15, b.value(0)); - } - - #[cfg(feature = "chrono-tz")] - #[test] - fn test_temporal_array_timestamp_hour_with_dst_timezone_using_chrono_tz() { - // - // 1635577147 converts to 2021-10-30 17:59:07 in time zone Australia/Sydney (AEDT) - // The offset (difference to UTC) is +11:00. Note that daylight savings is in effect on 2021-10-30. - // When daylight savings is not in effect, Australia/Sydney has an offset difference of +10:00. - - let a = TimestampMillisecondArray::from(vec![Some(1635577147000)]) - .with_timezone("Australia/Sydney".to_string()); - let b = hour(&a).unwrap(); - assert_eq!(17, b.value(0)); - } - - #[cfg(not(feature = "chrono-tz"))] - #[test] - fn test_temporal_array_timestamp_hour_with_timezone_using_chrono_tz() { - let a = TimestampSecondArray::from(vec![60 * 60 * 10]) - .with_timezone("Asia/Kolkatta".to_string()); - assert!(matches!(hour(&a), Err(ArrowError::ParseError(_)))) - } - #[test] fn test_temporal_array_timestamp_week_without_timezone() { // 1970-01-01T00:00:00 -> 1970-01-01T00:00:00 Thursday (week 1) diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml index d4f0f9fa0d47..4f7ab24f9708 100644 --- a/arrow-array/Cargo.toml +++ b/arrow-array/Cargo.toml @@ -44,12 +44,12 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow-buffer = { workspace = true } arrow-schema = { workspace = true } arrow-data = { workspace = true } -chrono = { version = "0.4.24", default-features = false, features = ["clock"] } +chrono = { workspace = true } chrono-tz = { version = "0.8", optional = true } -num = { version = "0.4", default-features = false, features = ["std"] } +num = { version = "0.4.1", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false, features = ["num-traits"] } hashbrown = { version = "0.14", default-features = false } -packed_simd = { version = "0.3", default-features = false, optional = true, package = "packed_simd_2" } +packed_simd = { version = "0.3.9", default-features = false, optional = true } [features] simd = ["packed_simd"] diff --git a/arrow-array/src/arithmetic.rs b/arrow-array/src/arithmetic.rs index abeb46b99688..b0ecef70ee19 100644 --- a/arrow-array/src/arithmetic.rs +++ b/arrow-array/src/arithmetic.rs @@ -229,7 +229,10 @@ macro_rules! native_type_op { #[inline] fn pow_checked(self, exp: u32) -> Result { self.checked_pow(exp).ok_or_else(|| { - ArrowError::ComputeError(format!("Overflow happened on: {:?}", self)) + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} ^ {exp:?}", + self + )) }) } diff --git a/arrow-array/src/array/binary_array.rs b/arrow-array/src/array/binary_array.rs index 54839604d192..75880bec30ce 100644 --- a/arrow-array/src/array/binary_array.rs +++ b/arrow-array/src/array/binary_array.rs @@ -19,7 +19,6 @@ use crate::types::{ByteArrayType, GenericBinaryType}; use crate::{ Array, GenericByteArray, GenericListArray, GenericStringArray, OffsetSizeTrait, }; -use arrow_buffer::MutableBuffer; use arrow_data::ArrayData; use arrow_schema::DataType; @@ -83,42 +82,6 @@ impl GenericBinaryArray { Self::from(data) } - /// Creates a [`GenericBinaryArray`] based on an iterator of values without nulls - pub fn from_iter_values(iter: I) -> Self - where - Ptr: AsRef<[u8]>, - I: IntoIterator, - { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let mut offsets = - MutableBuffer::new((data_len + 1) * std::mem::size_of::()); - let mut values = MutableBuffer::new(0); - - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - for s in iter { - let s = s.as_ref(); - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - offsets.push(length_so_far); - values.extend_from_slice(s); - } - - // iterator size hint may not be correct so compute the actual number of offsets - assert!(!offsets.is_empty()); // wrote at least one - let actual_len = (offsets.len() / std::mem::size_of::()) - 1; - - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(actual_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) - } - /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` pub fn take_iter<'a>( &'a self, @@ -584,9 +547,7 @@ mod tests { #[test] #[should_panic( - expected = "assertion failed: `(left == right)`\n left: `UInt32`,\n \ - right: `UInt8`: BinaryArray can only be created from List arrays, \ - mismatched data types." + expected = "BinaryArray can only be created from List arrays, mismatched data types." )] fn test_binary_array_from_incorrect_list_array() { let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs index 14fa87e138eb..4d19babe3e4b 100644 --- a/arrow-array/src/array/boolean_array.rs +++ b/arrow-array/src/array/boolean_array.rs @@ -18,7 +18,7 @@ use crate::array::print_long_array; use crate::builder::BooleanBuilder; use crate::iterator::BooleanIter; -use crate::{Array, ArrayAccessor, ArrayRef}; +use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::DataType; @@ -101,6 +101,15 @@ impl BooleanArray { } } + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: bool) -> Scalar { + let values = match value { + true => BooleanBuffer::new_set(1), + false => BooleanBuffer::new_unset(1), + }; + Scalar::new(Self::new(values, None)) + } + /// Returns the length of this array. pub fn len(&self) -> usize { self.values.len() @@ -205,7 +214,7 @@ impl BooleanArray { where F: FnMut(T::Item) -> bool, { - let nulls = left.nulls().cloned(); + let nulls = left.logical_nulls(); let values = BooleanBuffer::collect_bool(left.len(), |i| unsafe { // SAFETY: i in range 0..len op(left.value_unchecked(i)) @@ -239,7 +248,10 @@ impl BooleanArray { { assert_eq!(left.len(), right.len()); - let nulls = NullBuffer::union(left.nulls(), right.nulls()); + let nulls = NullBuffer::union( + left.logical_nulls().as_ref(), + right.logical_nulls().as_ref(), + ); let values = BooleanBuffer::collect_bool(left.len(), |i| unsafe { // SAFETY: i in range 0..len op(left.value_unchecked(i), right.value_unchecked(i)) @@ -425,6 +437,15 @@ impl>> FromIterator for BooleanArray } } +impl From for BooleanArray { + fn from(values: BooleanBuffer) -> Self { + Self { + values, + nulls: None, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs index 0a18062d9ae1..37d8de931e99 100644 --- a/arrow-array/src/array/byte_array.rs +++ b/arrow-array/src/array/byte_array.rs @@ -20,7 +20,7 @@ use crate::builder::GenericByteBuilder; use crate::iterator::ArrayIter; use crate::types::bytes::ByteArrayNativeType; use crate::types::ByteArrayType; -use crate::{Array, ArrayAccessor, ArrayRef, OffsetSizeTrait}; +use crate::{Array, ArrayAccessor, ArrayRef, OffsetSizeTrait, Scalar}; use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_buffer::{NullBuffer, OffsetBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; @@ -159,7 +159,7 @@ impl GenericByteArray { /// # Safety /// /// Safe if [`Self::try_new`] would not error - pub fn new_unchecked( + pub unsafe fn new_unchecked( offsets: OffsetBuffer, values: Buffer, nulls: Option, @@ -182,6 +182,46 @@ impl GenericByteArray { } } + /// Create a new [`Scalar`] from `v` + pub fn new_scalar(value: impl AsRef) -> Scalar { + Scalar::new(Self::from_iter_values(std::iter::once(value))) + } + + /// Creates a [`GenericByteArray`] based on an iterator of values without nulls + pub fn from_iter_values(iter: I) -> Self + where + Ptr: AsRef, + I: IntoIterator, + { + let iter = iter.into_iter(); + let (_, data_len) = iter.size_hint(); + let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. + + let mut offsets = + MutableBuffer::new((data_len + 1) * std::mem::size_of::()); + offsets.push(T::Offset::usize_as(0)); + + let mut values = MutableBuffer::new(0); + for s in iter { + let s: &[u8] = s.as_ref().as_ref(); + values.extend_from_slice(s); + offsets.push(T::Offset::usize_as(values.len())); + } + + T::Offset::from_usize(values.len()).expect("offset overflow"); + let offsets = Buffer::from(offsets); + + // Safety: valid by construction + let value_offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Self { + data_type: T::DATA_TYPE, + value_data: values.into(), + value_offsets, + nulls: None, + } + } + /// Deconstruct this array into its constituent parts pub fn into_parts(self) -> (OffsetBuffer, Buffer, Option) { (self.value_offsets, self.value_data, self.nulls) diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 5a2f439a8e0f..0cb00878929c 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -434,6 +434,7 @@ impl DictionaryArray { /// Panics if `values` has a length less than the current values /// /// ``` + /// # use std::sync::Arc; /// # use arrow_array::builder::PrimitiveDictionaryBuilder; /// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor}; /// # use arrow_array::types::{Int32Type, Int8Type}; @@ -451,7 +452,7 @@ impl DictionaryArray { /// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64); /// /// // Create a Dict(Int32, - /// let new = dictionary.with_values(&values); + /// let new = dictionary.with_values(Arc::new(values)); /// /// // Verify values are as expected /// let new_typed = new.downcast_dict::().unwrap(); @@ -460,21 +461,18 @@ impl DictionaryArray { /// } /// ``` /// - pub fn with_values(&self, values: &dyn Array) -> Self { + pub fn with_values(&self, values: ArrayRef) -> Self { assert!(values.len() >= self.values.len()); - - let builder = self - .to_data() - .into_builder() - .data_type(DataType::Dictionary( - Box::new(K::DATA_TYPE), - Box::new(values.data_type().clone()), - )) - .child_data(vec![values.to_data()]); - - // SAFETY: - // Offsets were valid before and verified length is greater than or equal - Self::from(unsafe { builder.build_unchecked() }) + let data_type = DataType::Dictionary( + Box::new(K::DATA_TYPE), + Box::new(values.data_type().clone()), + ); + Self { + data_type, + keys: self.keys.clone(), + values, + is_ordered: false, + } } /// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating @@ -729,6 +727,31 @@ impl Array for DictionaryArray { self.keys.nulls() } + fn logical_nulls(&self) -> Option { + match self.values.nulls() { + None => self.nulls().cloned(), + Some(value_nulls) => { + let mut builder = BooleanBufferBuilder::new(self.len()); + match self.keys.nulls() { + Some(n) => builder.append_buffer(n.inner()), + None => builder.append_n(self.len(), true), + } + for (idx, k) in self.keys.values().iter().enumerate() { + let k = k.as_usize(); + // Check range to allow for nulls + if k < value_nulls.len() && value_nulls.is_null(k) { + builder.set_bit(idx, false); + } + } + Some(builder.finish().into()) + } + } + } + + fn is_nullable(&self) -> bool { + !self.is_empty() && (self.nulls().is_some() || self.values.is_nullable()) + } + fn get_buffer_memory_size(&self) -> usize { self.keys.get_buffer_memory_size() + self.values.get_buffer_memory_size() } @@ -777,10 +800,7 @@ pub struct TypedDictionaryArray<'a, K: ArrowDictionaryKeyType, V> { // Manually implement `Clone` to avoid `V: Clone` type constraint impl<'a, K: ArrowDictionaryKeyType, V> Clone for TypedDictionaryArray<'a, K, V> { fn clone(&self) -> Self { - Self { - dictionary: self.dictionary, - values: self.values, - } + *self } } @@ -843,6 +863,14 @@ impl<'a, K: ArrowDictionaryKeyType, V: Sync> Array for TypedDictionaryArray<'a, self.dictionary.nulls() } + fn logical_nulls(&self) -> Option { + self.dictionary.logical_nulls() + } + + fn is_nullable(&self) -> bool { + self.dictionary.is_nullable() + } + fn get_buffer_memory_size(&self) -> usize { self.dictionary.get_buffer_memory_size() } @@ -897,6 +925,94 @@ where } } +/// A [`DictionaryArray`] with the key type erased +/// +/// This can be used to efficiently implement kernels for all possible dictionary +/// keys without needing to create specialized implementations for each key type +/// +/// For example +/// +/// ``` +/// # use arrow_array::*; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::builder::PrimitiveDictionaryBuilder; +/// # use arrow_array::types::*; +/// # use arrow_schema::ArrowError; +/// # use std::sync::Arc; +/// +/// fn to_string(a: &dyn Array) -> Result { +/// if let Some(d) = a.as_any_dictionary_opt() { +/// // Recursively handle dictionary input +/// let r = to_string(d.values().as_ref())?; +/// return Ok(d.with_values(r)); +/// } +/// downcast_primitive_array! { +/// a => Ok(Arc::new(a.iter().map(|x| x.map(|x| x.to_string())).collect::())), +/// d => Err(ArrowError::InvalidArgumentError(format!("{d:?} not supported"))) +/// } +/// } +/// +/// let result = to_string(&Int32Array::from(vec![1, 2, 3])).unwrap(); +/// let actual = result.as_string::().iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "2", "3"]); +/// +/// let mut dict = PrimitiveDictionaryBuilder::::new(); +/// dict.extend([Some(1), Some(1), Some(2), Some(3), Some(2)]); +/// let dict = dict.finish(); +/// +/// let r = to_string(&dict).unwrap(); +/// let r = r.as_dictionary::().downcast_dict::().unwrap(); +/// assert_eq!(r.keys(), dict.keys()); // Keys are the same +/// +/// let actual = r.into_iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "1", "2", "3", "2"]); +/// ``` +/// +/// See [`AsArray::as_any_dictionary_opt`] and [`AsArray::as_any_dictionary`] +pub trait AnyDictionaryArray: Array { + /// Returns the primitive keys of this dictionary as an [`Array`] + fn keys(&self) -> &dyn Array; + + /// Returns the values of this dictionary + fn values(&self) -> &ArrayRef; + + /// Returns the keys of this dictionary as usize + /// + /// The values for nulls will be arbitrary, but are guaranteed + /// to be in the range `0..self.values.len()` + /// + /// # Panic + /// + /// Panics if `values.len() == 0` + fn normalized_keys(&self) -> Vec; + + /// Create a new [`DictionaryArray`] replacing `values` with the new values + /// + /// See [`DictionaryArray::with_values`] + fn with_values(&self, values: ArrayRef) -> ArrayRef; +} + +impl AnyDictionaryArray for DictionaryArray { + fn keys(&self) -> &dyn Array { + &self.keys + } + + fn values(&self) -> &ArrayRef { + self.values() + } + + fn normalized_keys(&self) -> Vec { + let v_len = self.values().len(); + assert_ne!(v_len, 0); + let iter = self.keys().values().iter(); + iter.map(|x| x.as_usize().min(v_len - 1)).collect() + } + + fn with_values(&self, values: ArrayRef) -> ArrayRef { + Arc::new(self.with_values(values)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1253,4 +1369,29 @@ mod tests { assert_eq!(v, expected, "{idx}"); } } + + #[test] + fn test_iterator_nulls() { + let keys = Int32Array::new( + vec![0, 700, 1, 2].into(), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + let values = Int32Array::from(vec![Some(50), None, Some(2)]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + let values: Vec<_> = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect(); + assert_eq!(values, &[Some(50), None, None, Some(2)]) + } + + #[test] + fn test_normalized_keys() { + let values = vec![132, 0, 1].into(); + let nulls = NullBuffer::from(vec![false, true, true]); + let keys = Int32Array::new(values, Some(nulls)); + let dictionary = DictionaryArray::new(keys, Arc::new(Int32Array::new_null(2))); + assert_eq!(&dictionary.normalized_keys(), &[1, 0, 1]) + } } diff --git a/arrow-array/src/array/fixed_size_binary_array.rs b/arrow-array/src/array/fixed_size_binary_array.rs index 74a7c4c7a84a..f0b04c203ceb 100644 --- a/arrow-array/src/array/fixed_size_binary_array.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -179,9 +179,18 @@ impl FixedSizeBinaryArray { self.value_length } - /// Returns a clone of the value data buffer - pub fn value_data(&self) -> Buffer { - self.value_data.clone() + /// Returns the values of this array. + /// + /// Unlike [`Self::value_data`] this returns the [`Buffer`] + /// allowing for zero-copy cloning. + #[inline] + pub fn values(&self) -> &Buffer { + &self.value_data + } + + /// Returns the raw value data. + pub fn value_data(&self) -> &[u8] { + self.value_data.as_slice() } /// Returns a zero-copy slice of this array with the indicated offset and length. diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index 6c1598ce90df..db3ccbe0617b 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -26,7 +26,59 @@ use arrow_schema::{ArrowError, DataType, FieldRef}; use std::any::Any; use std::sync::Arc; -/// An array of [fixed size arrays](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-list-layout) +/// An array of [fixed length lists], similar to JSON arrays +/// (e.g. `["A", "B"]`). +/// +/// Lists are represented using a `values` child +/// array where each list has a fixed size of `value_length`. +/// +/// Use [`FixedSizeListBuilder`] to construct a [`FixedSizeListArray`]. +/// +/// # Representation +/// +/// A [`FixedSizeListArray`] can represent a list of values of any other +/// supported Arrow type. Each element of the `FixedSizeListArray` itself is +/// a list which may contain NULL and non-null values, +/// or may itself be NULL. +/// +/// For example, this `FixedSizeListArray` stores lists of strings: +/// +/// ```text +/// ┌─────────────┐ +/// │ [A,B] │ +/// ├─────────────┤ +/// │ NULL │ +/// ├─────────────┤ +/// │ [C,NULL] │ +/// └─────────────┘ +/// ``` +/// +/// The `values` of this `FixedSizeListArray`s are stored in a child +/// [`StringArray`] where logical null values take up `values_length` slots in the array +/// as shown in the following diagram. The logical values +/// are shown on the left, and the actual `FixedSizeListArray` encoding on the right +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─┐ +/// ┌─────────────┐ │ ┌───┐ ┌───┐ ┌──────┐ │ +/// │ [A,B] │ │ 1 │ │ │ 1 │ │ A │ │ 0 +/// ├─────────────┤ │ ├───┤ ├───┤ ├──────┤ │ +/// │ NULL │ │ 0 │ │ │ 1 │ │ B │ │ 1 +/// ├─────────────┤ │ ├───┤ ├───┤ ├──────┤ │ +/// │ [C,NULL] │ │ 1 │ │ │ 0 │ │ ???? │ │ 2 +/// └─────────────┘ │ └───┘ ├───┤ ├──────┤ │ +/// | │ 0 │ │ ???? │ │ 3 +/// Logical Values │ Validity ├───┤ ├──────┤ │ +/// (nulls) │ │ 1 │ │ C │ │ 4 +/// │ ├───┤ ├──────┤ │ +/// │ │ 0 │ │ ???? │ │ 5 +/// │ └───┘ └──────┘ │ +/// │ Values │ +/// │ FixedSizeListArray (Array) │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─┘ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ``` /// /// # Example /// @@ -60,6 +112,9 @@ use std::sync::Arc; /// assert_eq!( &[3, 4, 5], list1.as_any().downcast_ref::().unwrap().values()); /// assert_eq!( &[6, 7, 8], list2.as_any().downcast_ref::().unwrap().values()); /// ``` +/// +/// [`StringArray`]: crate::array::StringArray +/// [fixed size arrays](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-list-layout) #[derive(Clone)] pub struct FixedSizeListArray { data_type: DataType, // Must be DataType::FixedSizeList(value_length) @@ -91,7 +146,7 @@ impl FixedSizeListArray { /// * `size < 0` /// * `values.len() / size != nulls.len()` /// * `values.data_type() != field.data_type()` - /// * `!field.is_nullable() && !nulls.expand(size).contains(values.nulls())` + /// * `!field.is_nullable() && !nulls.expand(size).contains(values.logical_nulls())` pub fn try_new( field: FieldRef, size: i32, @@ -105,7 +160,7 @@ impl FixedSizeListArray { )) })?; - let len = values.len() / s; + let len = values.len() / s.max(1); if let Some(n) = nulls.as_ref() { if n.len() != len { return Err(ArrowError::InvalidArgumentError(format!( @@ -125,11 +180,11 @@ impl FixedSizeListArray { ))); } - if let Some(a) = values.nulls() { + if let Some(a) = values.logical_nulls() { let nulls_valid = field.is_nullable() || nulls .as_ref() - .map(|n| n.expand(size as _).contains(a)) + .map(|n| n.expand(size as _).contains(&a)) .unwrap_or_default(); if !nulls_valid { @@ -620,6 +675,9 @@ mod tests { "Invalid argument error: Size cannot be negative, got -1" ); + let list = FixedSizeListArray::new(field.clone(), 0, values.clone(), None); + assert_eq!(list.len(), 6); + let nulls = NullBuffer::new_null(2); let err = FixedSizeListArray::try_new(field, 2, values.clone(), Some(nulls)) .unwrap_err(); diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs index 0c1fea6f4161..e36d0ac4434f 100644 --- a/arrow-array/src/array/list_array.rs +++ b/arrow-array/src/array/list_array.rs @@ -54,11 +54,76 @@ impl OffsetSizeTrait for i64 { const PREFIX: &'static str = "Large"; } -/// An array of [variable length arrays](https://arrow.apache.org/docs/format/Columnar.html#variable-size-list-layout) +/// An array of [variable length lists], similar to JSON arrays +/// (e.g. `["A", "B", "C"]`). /// -/// See [`ListArray`] and [`LargeListArray`]` +/// Lists are represented using `offsets` into a `values` child +/// array. Offsets are stored in two adjacent entries of an +/// [`OffsetBuffer`]. /// -/// See [`GenericListBuilder`](crate::builder::GenericListBuilder) for how to construct a [`GenericListArray`] +/// Arrow defines [`ListArray`] with `i32` offsets and +/// [`LargeListArray`] with `i64` offsets. +/// +/// Use [`GenericListBuilder`] to construct a [`GenericListArray`]. +/// +/// # Representation +/// +/// A [`ListArray`] can represent a list of values of any other +/// supported Arrow type. Each element of the `ListArray` itself is +/// a list which may be empty, may contain NULL and non-null values, +/// or may itself be NULL. +/// +/// For example, the `ListArray` shown in the following diagram stores +/// lists of strings. Note that `[]` represents an empty (length +/// 0), but non NULL list. +/// +/// ```text +/// ┌─────────────┐ +/// │ [A,B,C] │ +/// ├─────────────┤ +/// │ [] │ +/// ├─────────────┤ +/// │ NULL │ +/// ├─────────────┤ +/// │ [D] │ +/// ├─────────────┤ +/// │ [NULL, F] │ +/// └─────────────┘ +/// ``` +/// +/// The `values` are stored in a child [`StringArray`] and the offsets +/// are stored in an [`OffsetBuffer`] as shown in the following +/// diagram. The logical values and offsets are shown on the left, and +/// the actual `ListArray` encoding on the right. +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌ ─ ─ ─ ─ ─ ─ ┐ │ +/// ┌─────────────┐ ┌───────┐ │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ +/// │ [A,B,C] │ │ (0,3) │ │ 1 │ │ 0 │ │ │ 1 │ │ A │ │ 0 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [] │ │ (3,3) │ │ 1 │ │ 3 │ │ │ 1 │ │ B │ │ 1 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ NULL │ │ (3,4) │ │ 0 │ │ 3 │ │ │ 1 │ │ C │ │ 2 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [D] │ │ (4,5) │ │ 1 │ │ 4 │ │ │ ? │ │ ? │ │ 3 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [NULL, F] │ │ (5,7) │ │ 1 │ │ 5 │ │ │ 1 │ │ D │ │ 4 │ +/// └─────────────┘ └───────┘ │ └───┘ ├───┤ ├───┤ ├───┤ +/// │ 7 │ │ │ 0 │ │ ? │ │ 5 │ +/// │ Validity └───┘ ├───┤ ├───┤ +/// Logical Logical (nulls) Offsets │ │ 1 │ │ F │ │ 6 │ +/// Values Offsets │ └───┘ └───┘ +/// │ Values │ │ +/// (offsets[i], │ ListArray (Array) +/// offsets[i+1]) └ ─ ─ ─ ─ ─ ─ ┘ │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// +/// ``` +/// +/// [`StringArray`]: crate::array::StringArray +/// [variable length lists]: https://arrow.apache.org/docs/format/Columnar.html#variable-size-list-layout pub struct GenericListArray { data_type: DataType, nulls: Option, @@ -95,7 +160,7 @@ impl GenericListArray { /// /// * `offsets.len() - 1 != nulls.len()` /// * `offsets.last() > values.len()` - /// * `!field.is_nullable() && values.null_count() != 0` + /// * `!field.is_nullable() && values.is_nullable()` /// * `field.data_type() != values.data_type()` pub fn try_new( field: FieldRef, @@ -123,7 +188,7 @@ impl GenericListArray { ))); } } - if !field.is_nullable() && values.null_count() != 0 { + if !field.is_nullable() && values.is_nullable() { return Err(ArrowError::InvalidArgumentError(format!( "Non-nullable field of {}ListArray {:?} cannot contain nulls", OffsetSize::PREFIX, @@ -971,13 +1036,17 @@ mod tests { #[should_panic( expected = "Memory pointer is not aligned with the specified scalar type" )] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_primitive_array_alignment() { let buf = Buffer::from_slice_ref([0_u64]); let buf2 = buf.slice(1); - let array_data = ArrayData::builder(DataType::Int32) - .add_buffer(buf2) - .build() - .unwrap(); + let array_data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(buf2) + .build_unchecked() + }; drop(Int32Array::from(array_data)); } diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index fca49cd7836f..77a7b9d4d547 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -330,7 +330,7 @@ impl MapArray { Arc::new(Field::new( "entries", entry_struct.data_type().clone(), - true, + false, )), false, ); @@ -477,7 +477,7 @@ mod tests { Arc::new(Field::new( "entries", entry_struct.data_type().clone(), - true, + false, )), false, ); @@ -523,7 +523,7 @@ mod tests { Arc::new(Field::new( "entries", entry_struct.data_type().clone(), - true, + false, )), false, ); @@ -645,7 +645,7 @@ mod tests { Arc::new(Field::new( "entries", entry_struct.data_type().clone(), - true, + false, )), false, ); diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 9312770644a3..9b66826f7584 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -69,7 +69,7 @@ pub use run_array::*; /// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html) pub trait Array: std::fmt::Debug + Send + Sync { - /// Returns the array as [`Any`](std::any::Any) so that it can be + /// Returns the array as [`Any`] so that it can be /// downcasted to a specific implementation. /// /// # Example: @@ -101,7 +101,7 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// Unlike [`Array::to_data`] this consumes self, allowing it avoid unnecessary clones fn into_data(self) -> ArrayData; - /// Returns a reference to the [`DataType`](arrow_schema::DataType) of this array. + /// Returns a reference to the [`DataType`] of this array. /// /// # Example: /// @@ -173,28 +173,56 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// ``` fn offset(&self) -> usize; - /// Returns the null buffers of this array if any + /// Returns the null buffer of this array if any. + /// + /// The null buffer encodes the "physical" nulls of an array. + /// However, some arrays can also encode nullability in their children, for example, + /// [`DictionaryArray::values`] values or [`RunArray::values`], or without a null buffer, + /// such as [`NullArray`]. To determine if each element of such an array is logically null, + /// you can use the slower [`Array::logical_nulls`] to obtain a computed mask . fn nulls(&self) -> Option<&NullBuffer>; - /// Returns whether the element at `index` is null. - /// When using this function on a slice, the index is relative to the slice. + /// Returns a potentially computed [`NullBuffer`] that represent the logical null values of this array, if any. + /// + /// In most cases this will be the same as [`Array::nulls`], except for: + /// + /// * [`DictionaryArray`] where [`DictionaryArray::values`] contains nulls + /// * [`RunArray`] where [`RunArray::values`] contains nulls + /// * [`NullArray`] where all indices are nulls + /// + /// In these cases a logical [`NullBuffer`] will be computed, encoding the logical nullability + /// of these arrays, beyond what is encoded in [`Array::nulls`] + fn logical_nulls(&self) -> Option { + self.nulls().cloned() + } + + /// Returns whether the element at `index` is null according to [`Array::nulls`] + /// + /// Note: For performance reasons, this method returns nullability solely as determined by the + /// null buffer. This difference can lead to surprising results, for example, [`NullArray::is_null`] always + /// returns `false` as the array lacks a null buffer. Similarly [`DictionaryArray`] and [`RunArray`] may + /// encode nullability in their children. See [`Self::logical_nulls`] for more information. /// /// # Example: /// /// ``` - /// use arrow_array::{Array, Int32Array}; + /// use arrow_array::{Array, Int32Array, NullArray}; /// /// let array = Int32Array::from(vec![Some(1), None]); - /// /// assert_eq!(array.is_null(0), false); /// assert_eq!(array.is_null(1), true); + /// + /// // NullArrays do not have a null buffer, and therefore always + /// // return false for is_null. + /// let array = NullArray::new(1); + /// assert_eq!(array.is_null(0), false); /// ``` fn is_null(&self, index: usize) -> bool { self.nulls().map(|n| n.is_null(index)).unwrap_or_default() } - /// Returns whether the element at `index` is not null. - /// When using this function on a slice, the index is relative to the slice. + /// Returns whether the element at `index` is *not* null, the + /// opposite of [`Self::is_null`]. /// /// # Example: /// @@ -210,7 +238,10 @@ pub trait Array: std::fmt::Debug + Send + Sync { !self.is_null(index) } - /// Returns the total number of null values in this array. + /// Returns the total number of physical null values in this array. + /// + /// Note: this method returns the physical null count, i.e. that encoded in [`Array::nulls`], + /// see [`Array::logical_nulls`] for logical nullability /// /// # Example: /// @@ -226,6 +257,19 @@ pub trait Array: std::fmt::Debug + Send + Sync { self.nulls().map(|n| n.null_count()).unwrap_or_default() } + /// Returns `false` if the array is guaranteed to not contain any logical nulls + /// + /// In general this will be equivalent to `Array::null_count() != 0` but may differ in the + /// presence of logical nullability, see [`Array::logical_nulls`]. + /// + /// Implementations will return `true` unless they can cheaply prove no logical nulls + /// are present. For example a [`DictionaryArray`] with nullable values will still return true, + /// even if the nulls present in [`DictionaryArray::values`] are not referenced by any key, + /// and therefore would not appear in [`Array::logical_nulls`]. + fn is_nullable(&self) -> bool { + self.null_count() != 0 + } + /// Returns the total number of bytes of memory pointed to by this array. /// The buffers store bytes in the Arrow memory format, and include the data as well as the validity map. fn get_buffer_memory_size(&self) -> usize; @@ -277,6 +321,10 @@ impl Array for ArrayRef { self.as_ref().nulls() } + fn logical_nulls(&self) -> Option { + self.as_ref().logical_nulls() + } + fn is_null(&self, index: usize) -> bool { self.as_ref().is_null(index) } @@ -289,6 +337,10 @@ impl Array for ArrayRef { self.as_ref().null_count() } + fn is_nullable(&self) -> bool { + self.as_ref().is_nullable() + } + fn get_buffer_memory_size(&self) -> usize { self.as_ref().get_buffer_memory_size() } @@ -335,6 +387,10 @@ impl<'a, T: Array> Array for &'a T { T::nulls(self) } + fn logical_nulls(&self) -> Option { + T::logical_nulls(self) + } + fn is_null(&self, index: usize) -> bool { T::is_null(self, index) } @@ -347,6 +403,10 @@ impl<'a, T: Array> Array for &'a T { T::null_count(self) } + fn is_nullable(&self) -> bool { + T::is_nullable(self) + } + fn get_buffer_memory_size(&self) -> usize { T::get_buffer_memory_size(self) } @@ -841,6 +901,8 @@ mod tests { assert_eq!(a.null_count(), 1); assert!(a.is_null(0)) } + + array.to_data().validate_full().unwrap(); } } diff --git a/arrow-array/src/array/null_array.rs b/arrow-array/src/array/null_array.rs index c054c890431b..af3ec0b57d27 100644 --- a/arrow-array/src/array/null_array.rs +++ b/arrow-array/src/array/null_array.rs @@ -36,8 +36,10 @@ use std::sync::Arc; /// /// let array = NullArray::new(10); /// +/// assert!(array.is_nullable()); /// assert_eq!(array.len(), 10); -/// assert_eq!(array.null_count(), 10); +/// assert_eq!(array.null_count(), 0); +/// assert_eq!(array.logical_nulls().unwrap().null_count(), 10); /// ``` #[derive(Clone)] pub struct NullArray { @@ -107,22 +109,12 @@ impl Array for NullArray { None } - /// Returns whether the element at `index` is null. - /// All elements of a `NullArray` are always null. - fn is_null(&self, _index: usize) -> bool { - true + fn logical_nulls(&self) -> Option { + (self.len != 0).then(|| NullBuffer::new_null(self.len)) } - /// Returns whether the element at `index` is valid. - /// All elements of a `NullArray` are always invalid. - fn is_valid(&self, _index: usize) -> bool { - false - } - - /// Returns the total number of null values in this array. - /// The null count of a `NullArray` always equals its length. - fn null_count(&self) -> usize { - self.len() + fn is_nullable(&self) -> bool { + !self.is_empty() } fn get_buffer_memory_size(&self) -> usize { @@ -176,8 +168,10 @@ mod tests { let null_arr = NullArray::new(32); assert_eq!(null_arr.len(), 32); - assert_eq!(null_arr.null_count(), 32); - assert!(!null_arr.is_valid(0)); + assert_eq!(null_arr.null_count(), 0); + assert_eq!(null_arr.logical_nulls().unwrap().null_count(), 32); + assert!(null_arr.is_valid(0)); + assert!(null_arr.is_nullable()); } #[test] @@ -186,7 +180,10 @@ mod tests { let array2 = array1.slice(8, 16); assert_eq!(array2.len(), 16); - assert_eq!(array2.null_count(), 16); + assert_eq!(array2.null_count(), 0); + assert_eq!(array2.logical_nulls().unwrap().null_count(), 16); + assert!(array2.is_valid(0)); + assert!(array2.is_nullable()); } #[test] diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 576f645b0375..4c07e81468aa 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -24,7 +24,7 @@ use crate::temporal_conversions::{ use crate::timezone::Tz; use crate::trusted_len::trusted_len_unzip; use crate::types::*; -use crate::{Array, ArrayAccessor, ArrayRef}; +use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; use arrow_buffer::{i256, ArrowNativeType, Buffer, NullBuffer, ScalarBuffer}; use arrow_data::bit_iterator::try_for_each_valid_idx; use arrow_data::{ArrayData, ArrayDataBuilder}; @@ -517,6 +517,15 @@ impl PrimitiveArray { Self::try_new(values, nulls).unwrap() } + /// Create a new [`PrimitiveArray`] of the given length where all values are null + pub fn new_null(length: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + values: vec![T::Native::usize_as(0); length].into(), + nulls: Some(NullBuffer::new_null(length)), + } + } + /// Create a new [`PrimitiveArray`] from the provided values and nulls /// /// # Errors @@ -544,6 +553,15 @@ impl PrimitiveArray { }) } + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: T::Native) -> Scalar { + Scalar::new(Self { + data_type: T::DATA_TYPE, + values: vec![value].into(), + nulls: None, + }) + } + /// Deconstruct this array into its constituent parts pub fn into_parts(self) -> (DataType, ScalarBuffer, Option) { (self.data_type, self.values, self.nulls) @@ -1562,7 +1580,7 @@ mod tests { assert_eq!(3, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); - let formatted = vec!["00:00:00.001", "10:30:00.005", "23:59:59.210"]; + let formatted = ["00:00:00.001", "10:30:00.005", "23:59:59.210"]; for (i, formatted) in formatted.iter().enumerate().take(3) { // check that we can't create dates or datetimes from time instances assert_eq!(None, arr.value_as_datetime(i)); @@ -1586,7 +1604,7 @@ mod tests { assert_eq!(3, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); - let formatted = vec!["00:00:00.001", "10:30:00.005", "23:59:59.210"]; + let formatted = ["00:00:00.001", "10:30:00.005", "23:59:59.210"]; for (i, item) in formatted.iter().enumerate().take(3) { // check that we can't create dates or datetimes from time instances assert_eq!(None, arr.value_as_datetime(i)); @@ -2201,7 +2219,7 @@ mod tests { #[test] fn test_decimal_from_iter_values() { - let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter()); + let array = Decimal128Array::from_iter_values(vec![-100, 0, 101]); assert_eq!(array.len(), 3); assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); assert_eq!(-100_i128, array.value(0)); @@ -2401,8 +2419,7 @@ mod tests { expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3" )] fn test_fixed_size_binary_array_get_value_index_out_of_bound() { - let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter()); - + let array = Decimal128Array::from(vec![-100, 0, 101]); array.value(4); } diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index 820d5c9ebfc1..ba6986c28463 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::sync::Arc; -use arrow_buffer::{ArrowNativeType, NullBuffer, RunEndBuffer}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, RunEndBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; @@ -349,6 +349,43 @@ impl Array for RunArray { None } + fn logical_nulls(&self) -> Option { + let len = self.len(); + let nulls = self.values.logical_nulls()?; + let mut out = BooleanBufferBuilder::new(len); + let offset = self.run_ends.offset(); + let mut valid_start = 0; + let mut last_end = 0; + for (idx, end) in self.run_ends.values().iter().enumerate() { + let end = end.as_usize(); + if end < offset { + continue; + } + let end = (end - offset).min(len); + if nulls.is_null(idx) { + if valid_start < last_end { + out.append_n(last_end - valid_start, true); + } + out.append_n(end - last_end, false); + valid_start = end; + } + last_end = end; + if end == len { + break; + } + } + if valid_start < len { + out.append_n(len - valid_start, true) + } + // Sanity check + assert_eq!(out.len(), len); + Some(out.finish().into()) + } + + fn is_nullable(&self) -> bool { + !self.is_empty() && self.values.is_nullable() + } + fn get_buffer_memory_size(&self) -> usize { self.run_ends.inner().inner().capacity() + self.values.get_buffer_memory_size() } @@ -500,10 +537,7 @@ pub struct TypedRunArray<'a, R: RunEndIndexType, V> { // Manually implement `Clone` to avoid `V: Clone` type constraint impl<'a, R: RunEndIndexType, V> Clone for TypedRunArray<'a, R, V> { fn clone(&self) -> Self { - Self { - run_array: self.run_array, - values: self.values, - } + *self } } @@ -569,6 +603,14 @@ impl<'a, R: RunEndIndexType, V: Sync> Array for TypedRunArray<'a, R, V> { self.run_array.nulls() } + fn logical_nulls(&self) -> Option { + self.run_array.logical_nulls() + } + + fn is_nullable(&self) -> bool { + self.run_array.is_nullable() + } + fn get_buffer_memory_size(&self) -> usize { self.run_array.get_buffer_memory_size() } @@ -1041,4 +1083,26 @@ mod tests { ); } } + + #[test] + fn test_logical_nulls() { + let run = Int32Array::from(vec![3, 6, 9, 12]); + let values = Int32Array::from(vec![Some(0), None, Some(1), None]); + let array = RunArray::try_new(&run, &values).unwrap(); + + let expected = [ + true, true, true, false, false, false, true, true, true, false, false, false, + ]; + + let n = array.logical_nulls().unwrap(); + assert_eq!(n.null_count(), 6); + + let slices = [(0, 12), (0, 2), (2, 5), (3, 0), (3, 3), (3, 4), (4, 8)]; + for (offset, length) in slices { + let a = array.slice(offset, length); + let n = a.logical_nulls().unwrap(); + let n = n.into_iter().collect::>(); + assert_eq!(&n, &expected[offset..offset + length], "{offset} {length}"); + } + } } diff --git a/arrow-array/src/array/string_array.rs b/arrow-array/src/array/string_array.rs index f9a3a5fbd095..cac4651f4496 100644 --- a/arrow-array/src/array/string_array.rs +++ b/arrow-array/src/array/string_array.rs @@ -17,8 +17,6 @@ use crate::types::GenericStringType; use crate::{GenericBinaryArray, GenericByteArray, GenericListArray, OffsetSizeTrait}; -use arrow_buffer::MutableBuffer; -use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType}; /// A [`GenericByteArray`] for storing `str` @@ -40,42 +38,6 @@ impl GenericStringArray { self.value(i).chars().count() } - /// Creates a [`GenericStringArray`] based on an iterator of values without nulls - pub fn from_iter_values(iter: I) -> Self - where - Ptr: AsRef, - I: IntoIterator, - { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let mut offsets = - MutableBuffer::new((data_len + 1) * std::mem::size_of::()); - let mut values = MutableBuffer::new(0); - - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - for i in iter { - let s = i.as_ref(); - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - offsets.push(length_so_far); - values.extend_from_slice(s.as_bytes()); - } - - // iterator size hint may not be correct so compute the actual number of offsets - assert!(!offsets.is_empty()); // wrote at least one - let actual_len = (offsets.len() / std::mem::size_of::()) - 1; - - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(actual_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) - } - /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` pub fn take_iter<'a>( &'a self, @@ -192,7 +154,7 @@ pub type StringArray = GenericStringArray; /// let arr: LargeStringArray = std::iter::repeat(Some("foo")).take(10).collect(); /// ``` /// -/// Constructon and Access +/// Construction and Access /// /// ``` /// use arrow_array::LargeStringArray; @@ -210,6 +172,7 @@ mod tests { use crate::types::UInt8Type; use crate::Array; use arrow_buffer::Buffer; + use arrow_data::ArrayData; use arrow_schema::Field; use std::sync::Arc; @@ -361,14 +324,14 @@ mod tests { #[test] fn test_string_array_from_iter_values() { - let data = vec!["hello", "hello2"]; + let data = ["hello", "hello2"]; let array1 = StringArray::from_iter_values(data.iter()); assert_eq!(array1.value(0), "hello"); assert_eq!(array1.value(1), "hello2"); // Also works with String types. - let data2: Vec = vec!["goodbye".into(), "goodbye2".into()]; + let data2 = ["goodbye".to_string(), "goodbye2".to_string()]; let array2 = StringArray::from_iter_values(data2.iter()); assert_eq!(array2.value(0), "goodbye"); diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 1a79ebd95f37..0e586ed1ef96 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -143,15 +143,14 @@ impl StructArray { ))); } - if let Some(a) = a.nulls() { - let nulls_valid = f.is_nullable() - || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default(); - - if !nulls_valid { - return Err(ArrowError::InvalidArgumentError(format!( - "Found unmasked nulls for non-nullable StructArray field {:?}", - f.name() - ))); + if !f.is_nullable() { + if let Some(a) = a.logical_nulls() { + if !nulls.as_ref().map(|n| n.contains(&a)).unwrap_or_default() { + return Err(ArrowError::InvalidArgumentError(format!( + "Found unmasked nulls for non-nullable StructArray field {:?}", + f.name() + ))); + } } } } @@ -198,6 +197,23 @@ impl StructArray { } } + /// Create a new [`StructArray`] containing no fields + /// + /// # Panics + /// + /// If `len != nulls.len()` + pub fn new_empty_fields(len: usize, nulls: Option) -> Self { + if let Some(n) = &nulls { + assert_eq!(len, n.len()) + } + Self { + len, + data_type: DataType::Struct(Fields::empty()), + fields: vec![], + nulls, + } + } + /// Deconstruct this array into its constituent parts pub fn into_parts(self) -> (Fields, Vec, Option) { let f = match self.data_type { @@ -314,7 +330,7 @@ impl TryFrom> for StructArray { .into_iter() .map(|(name, array)| { ( - Field::new(name, array.data_type().clone(), array.nulls().is_some()), + Field::new(name, array.data_type().clone(), array.is_nullable()), array, ) }) diff --git a/arrow-array/src/builder/boolean_builder.rs b/arrow-array/src/builder/boolean_builder.rs index 0def0ec48e3b..5f0013269677 100644 --- a/arrow-array/src/builder/boolean_builder.rs +++ b/arrow-array/src/builder/boolean_builder.rs @@ -169,6 +169,11 @@ impl BooleanBuilder { let array_data = unsafe { builder.build_unchecked() }; BooleanArray::from(array_data) } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } } impl ArrayBuilder for BooleanBuilder { @@ -192,11 +197,6 @@ impl ArrayBuilder for BooleanBuilder { self.values_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.values_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/buffer_builder.rs b/arrow-array/src/builder/buffer_builder.rs index f88a6392083e..01e4c1d4e217 100644 --- a/arrow-array/src/builder/buffer_builder.rs +++ b/arrow-array/src/builder/buffer_builder.rs @@ -16,9 +16,8 @@ // under the License. use crate::array::ArrowPrimitiveType; -use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +pub use arrow_buffer::BufferBuilder; use half::f16; -use std::marker::PhantomData; use crate::types::*; @@ -73,7 +72,7 @@ pub type Date64BufferBuilder = BufferBuilder<: /// Buffer builder for 32-bit elaspsed time since midnight of second unit. pub type Time32SecondBufferBuilder = BufferBuilder<::Native>; -/// Buffer builder for 32-bit elaspsed time since midnight of millisecond unit. +/// Buffer builder for 32-bit elaspsed time since midnight of millisecond unit. pub type Time32MillisecondBufferBuilder = BufferBuilder<::Native>; /// Buffer builder for 64-bit elaspsed time since midnight of microsecond unit. @@ -106,346 +105,6 @@ pub type DurationMicrosecondBufferBuilder = pub type DurationNanosecondBufferBuilder = BufferBuilder<::Native>; -/// Builder for creating a [`Buffer`](arrow_buffer::Buffer) object. -/// -/// A [`Buffer`](arrow_buffer::Buffer) is the underlying data -/// structure of Arrow's [`Arrays`](crate::Array). -/// -/// For all supported types, there are type definitions for the -/// generic version of `BufferBuilder`, e.g. `UInt8BufferBuilder`. -/// -/// # Example: -/// -/// ``` -/// # use arrow_array::builder::UInt8BufferBuilder; -/// -/// let mut builder = UInt8BufferBuilder::new(100); -/// builder.append_slice(&[42, 43, 44]); -/// builder.append(45); -/// let buffer = builder.finish(); -/// -/// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 43, 44, 45]); -/// ``` -#[derive(Debug)] -pub struct BufferBuilder { - buffer: MutableBuffer, - len: usize, - _marker: PhantomData, -} - -impl BufferBuilder { - /// Creates a new builder with initial capacity for _at least_ `capacity` - /// elements of type `T`. - /// - /// The capacity can later be manually adjusted with the - /// [`reserve()`](BufferBuilder::reserve) method. - /// Also the - /// [`append()`](BufferBuilder::append), - /// [`append_slice()`](BufferBuilder::append_slice) and - /// [`advance()`](BufferBuilder::advance) - /// methods automatically increase the capacity if needed. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// - /// assert!(builder.capacity() >= 10); - /// ``` - #[inline] - pub fn new(capacity: usize) -> Self { - let buffer = MutableBuffer::new(capacity * std::mem::size_of::()); - - Self { - buffer, - len: 0, - _marker: PhantomData, - } - } - - /// Creates a new builder from a [`MutableBuffer`] - pub fn new_from_buffer(buffer: MutableBuffer) -> Self { - let buffer_len = buffer.len(); - Self { - buffer, - len: buffer_len / std::mem::size_of::(), - _marker: PhantomData, - } - } - - /// Returns the current number of array elements in the internal buffer. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.append(42); - /// - /// assert_eq!(builder.len(), 1); - /// ``` - pub fn len(&self) -> usize { - self.len - } - - /// Returns whether the internal buffer is empty. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.append(42); - /// - /// assert_eq!(builder.is_empty(), false); - /// ``` - pub fn is_empty(&self) -> bool { - self.len == 0 - } - - /// Returns the actual capacity (number of elements) of the internal buffer. - /// - /// Note: the internal capacity returned by this method might be larger than - /// what you'd expect after setting the capacity in the `new()` or `reserve()` - /// functions. - pub fn capacity(&self) -> usize { - let byte_capacity = self.buffer.capacity(); - byte_capacity / std::mem::size_of::() - } - - /// Increases the number of elements in the internal buffer by `n` - /// and resizes the buffer as needed. - /// - /// The values of the newly added elements are 0. - /// This method is usually used when appending `NULL` values to the buffer - /// as they still require physical memory space. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.advance(2); - /// - /// assert_eq!(builder.len(), 2); - /// ``` - #[inline] - pub fn advance(&mut self, i: usize) { - self.buffer.extend_zeros(i * std::mem::size_of::()); - self.len += i; - } - - /// Reserves memory for _at least_ `n` more elements of type `T`. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.reserve(10); - /// - /// assert!(builder.capacity() >= 20); - /// ``` - #[inline] - pub fn reserve(&mut self, n: usize) { - self.buffer.reserve(n * std::mem::size_of::()); - } - - /// Appends a value of type `T` into the builder, - /// growing the internal buffer as needed. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.append(42); - /// - /// assert_eq!(builder.len(), 1); - /// ``` - #[inline] - pub fn append(&mut self, v: T) { - self.reserve(1); - self.buffer.push(v); - self.len += 1; - } - - /// Appends a value of type `T` into the builder N times, - /// growing the internal buffer as needed. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.append_n(10, 42); - /// - /// assert_eq!(builder.len(), 10); - /// ``` - #[inline] - pub fn append_n(&mut self, n: usize, v: T) { - self.reserve(n); - for _ in 0..n { - self.buffer.push(v); - } - self.len += n; - } - - /// Appends `n`, zero-initialized values - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt32BufferBuilder; - /// - /// let mut builder = UInt32BufferBuilder::new(10); - /// builder.append_n_zeroed(3); - /// - /// assert_eq!(builder.len(), 3); - /// assert_eq!(builder.as_slice(), &[0, 0, 0]) - #[inline] - pub fn append_n_zeroed(&mut self, n: usize) { - self.buffer.extend_zeros(n * std::mem::size_of::()); - self.len += n; - } - - /// Appends a slice of type `T`, growing the internal buffer as needed. - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.append_slice(&[42, 44, 46]); - /// - /// assert_eq!(builder.len(), 3); - /// ``` - #[inline] - pub fn append_slice(&mut self, slice: &[T]) { - self.buffer.extend_from_slice(slice); - self.len += slice.len(); - } - - /// View the contents of this buffer as a slice - /// - /// ``` - /// # use arrow_array::builder::Float64BufferBuilder; - /// - /// let mut builder = Float64BufferBuilder::new(10); - /// builder.append(1.3); - /// builder.append_n(2, 2.3); - /// - /// assert_eq!(builder.as_slice(), &[1.3, 2.3, 2.3]); - /// ``` - #[inline] - pub fn as_slice(&self) -> &[T] { - // SAFETY - // - // - MutableBuffer is aligned and initialized for len elements of T - // - MutableBuffer corresponds to a single allocation - // - MutableBuffer does not support modification whilst active immutable borrows - unsafe { std::slice::from_raw_parts(self.buffer.as_ptr() as _, self.len) } - } - - /// View the contents of this buffer as a mutable slice - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::Float32BufferBuilder; - /// - /// let mut builder = Float32BufferBuilder::new(10); - /// - /// builder.append_slice(&[1., 2., 3.4]); - /// assert_eq!(builder.as_slice(), &[1., 2., 3.4]); - /// - /// builder.as_slice_mut()[1] = 4.2; - /// assert_eq!(builder.as_slice(), &[1., 4.2, 3.4]); - /// ``` - #[inline] - pub fn as_slice_mut(&mut self) -> &mut [T] { - // SAFETY - // - // - MutableBuffer is aligned and initialized for len elements of T - // - MutableBuffer corresponds to a single allocation - // - MutableBuffer does not support modification whilst active immutable borrows - unsafe { std::slice::from_raw_parts_mut(self.buffer.as_mut_ptr() as _, self.len) } - } - - /// Shorten this BufferBuilder to `len` items - /// - /// If `len` is greater than the builder's current length, this has no effect - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt16BufferBuilder; - /// - /// let mut builder = UInt16BufferBuilder::new(10); - /// - /// builder.append_slice(&[42, 44, 46]); - /// assert_eq!(builder.as_slice(), &[42, 44, 46]); - /// - /// builder.truncate(2); - /// assert_eq!(builder.as_slice(), &[42, 44]); - /// - /// builder.append(12); - /// assert_eq!(builder.as_slice(), &[42, 44, 12]); - /// ``` - #[inline] - pub fn truncate(&mut self, len: usize) { - self.buffer.truncate(len * std::mem::size_of::()); - self.len = len; - } - - /// # Safety - /// This requires the iterator be a trusted length. This could instead require - /// the iterator implement `TrustedLen` once that is stabilized. - #[inline] - pub unsafe fn append_trusted_len_iter(&mut self, iter: impl IntoIterator) { - let iter = iter.into_iter(); - let len = iter - .size_hint() - .1 - .expect("append_trusted_len_iter expects upper bound"); - self.reserve(len); - for v in iter { - self.buffer.push(v) - } - self.len += len; - } - - /// Resets this builder and returns an immutable [`Buffer`](arrow_buffer::Buffer). - /// - /// # Example: - /// - /// ``` - /// # use arrow_array::builder::UInt8BufferBuilder; - /// - /// let mut builder = UInt8BufferBuilder::new(10); - /// builder.append_slice(&[42, 44, 46]); - /// - /// let buffer = builder.finish(); - /// - /// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 44, 46]); - /// ``` - #[inline] - pub fn finish(&mut self) -> Buffer { - let buf = std::mem::replace(&mut self.buffer, MutableBuffer::new(0)); - self.len = 0; - buf.into() - } -} - #[cfg(test)] mod tests { use crate::builder::{ diff --git a/arrow-array/src/builder/fixed_size_binary_builder.rs b/arrow-array/src/builder/fixed_size_binary_builder.rs index a213b3bbf87d..180150e988f3 100644 --- a/arrow-array/src/builder/fixed_size_binary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_builder.rs @@ -139,11 +139,6 @@ impl ArrayBuilder for FixedSizeBinaryBuilder { self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.null_buffer_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/fixed_size_list_builder.rs b/arrow-array/src/builder/fixed_size_list_builder.rs index 0dd58044305e..0fe779d5c1a2 100644 --- a/arrow-array/src/builder/fixed_size_list_builder.rs +++ b/arrow-array/src/builder/fixed_size_list_builder.rs @@ -73,7 +73,11 @@ impl FixedSizeListBuilder { /// Creates a new [`FixedSizeListBuilder`] from a given values array builder /// `value_length` is the number of values within each array pub fn new(values_builder: T, value_length: i32) -> Self { - let capacity = values_builder.len(); + let capacity = values_builder + .len() + .checked_div(value_length as _) + .unwrap_or_default(); + Self::with_capacity(values_builder, value_length, capacity) } @@ -113,11 +117,6 @@ where self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.null_buffer_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/generic_byte_run_builder.rs b/arrow-array/src/builder/generic_byte_run_builder.rs index 4e3f36889a1b..41165208de55 100644 --- a/arrow-array/src/builder/generic_byte_run_builder.rs +++ b/arrow-array/src/builder/generic_byte_run_builder.rs @@ -150,11 +150,6 @@ where self.current_run_end_index } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.current_run_end_index == 0 - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/generic_bytes_builder.rs b/arrow-array/src/builder/generic_bytes_builder.rs index f77940055bf1..d84be8c2fca6 100644 --- a/arrow-array/src/builder/generic_bytes_builder.rs +++ b/arrow-array/src/builder/generic_bytes_builder.rs @@ -189,11 +189,6 @@ impl ArrayBuilder for GenericByteBuilder { self.null_buffer_builder.len() } - /// Returns whether the number of binary slots is zero - fn is_empty(&self) -> bool { - self.null_buffer_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs index d5c62865ff8d..282f423fa6d1 100644 --- a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -193,11 +193,6 @@ where self.keys_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.keys_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs index 99e15d10f3a5..5cc7f7b04e0a 100644 --- a/arrow-array/src/builder/generic_list_builder.rs +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -28,6 +28,60 @@ use std::sync::Arc; /// /// Use [`ListBuilder`] to build [`ListArray`]s and [`LargeListBuilder`] to build [`LargeListArray`]s. /// +/// # Example +/// +/// Here is code that constructs a ListArray with the contents: +/// `[[A,B,C], [], NULL, [D], [NULL, F]]` +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{builder::ListBuilder, builder::StringBuilder, ArrayRef, StringArray, Array}; +/// # +/// let values_builder = StringBuilder::new(); +/// let mut builder = ListBuilder::new(values_builder); +/// +/// // [A, B, C] +/// builder.values().append_value("A"); +/// builder.values().append_value("B"); +/// builder.values().append_value("C"); +/// builder.append(true); +/// +/// // [ ] (empty list) +/// builder.append(true); +/// +/// // Null +/// builder.values().append_value("?"); // irrelevant +/// builder.append(false); +/// +/// // [D] +/// builder.values().append_value("D"); +/// builder.append(true); +/// +/// // [NULL, F] +/// builder.values().append_null(); +/// builder.values().append_value("F"); +/// builder.append(true); +/// +/// // Build the array +/// let array = builder.finish(); +/// +/// // Values is a string array +/// // "A", "B" "C", "?", "D", NULL, "F" +/// assert_eq!( +/// array.values().as_ref(), +/// &StringArray::from(vec![ +/// Some("A"), Some("B"), Some("C"), +/// Some("?"), Some("D"), None, +/// Some("F") +/// ]) +/// ); +/// +/// // Offsets are indexes into the values array +/// assert_eq!( +/// array.value_offsets(), +/// &[0, 3, 3, 4, 5, 7] +/// ); +/// ``` /// /// [`ListBuilder`]: crate::builder::ListBuilder /// [`ListArray`]: crate::array::ListArray @@ -91,11 +145,6 @@ where self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.null_buffer_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/map_builder.rs b/arrow-array/src/builder/map_builder.rs index 56b5619ceab1..4e3ec4a7944d 100644 --- a/arrow-array/src/builder/map_builder.rs +++ b/arrow-array/src/builder/map_builder.rs @@ -214,10 +214,6 @@ impl ArrayBuilder for MapBuilder { self.null_buffer_builder.len() } - fn is_empty(&self) -> bool { - self.len() == 0 - } - fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } diff --git a/arrow-array/src/builder/mod.rs b/arrow-array/src/builder/mod.rs index 1e5e6426be09..38a7500dd55f 100644 --- a/arrow-array/src/builder/mod.rs +++ b/arrow-array/src/builder/mod.rs @@ -237,7 +237,9 @@ pub trait ArrayBuilder: Any + Send { fn len(&self) -> usize; /// Returns whether number of array slots is zero - fn is_empty(&self) -> bool; + fn is_empty(&self) -> bool { + self.len() == 0 + } /// Builds the array fn finish(&mut self) -> ArrayRef; diff --git a/arrow-array/src/builder/null_builder.rs b/arrow-array/src/builder/null_builder.rs index 0b4345006993..53a6b103d541 100644 --- a/arrow-array/src/builder/null_builder.rs +++ b/arrow-array/src/builder/null_builder.rs @@ -40,7 +40,7 @@ use std::sync::Arc; /// let arr = b.finish(); /// /// assert_eq!(8, arr.len()); -/// assert_eq!(8, arr.null_count()); +/// assert_eq!(0, arr.null_count()); /// ``` #[derive(Debug)] pub struct NullBuilder { @@ -133,11 +133,6 @@ impl ArrayBuilder for NullBuilder { self.len } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.len() == 0 - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) @@ -165,7 +160,8 @@ mod tests { let arr = builder.finish(); assert_eq!(20, arr.len()); assert_eq!(0, arr.offset()); - assert_eq!(20, arr.null_count()); + assert_eq!(0, arr.null_count()); + assert!(arr.is_nullable()); } #[test] @@ -175,10 +171,10 @@ mod tests { builder.append_empty_value(); builder.append_empty_values(3); let mut array = builder.finish_cloned(); - assert_eq!(21, array.null_count()); + assert_eq!(21, array.len()); builder.append_empty_values(5); array = builder.finish(); - assert_eq!(26, array.null_count()); + assert_eq!(26, array.len()); } } diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 3e31b1d05576..b23d6bba36c4 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -121,11 +121,6 @@ impl ArrayBuilder for PrimitiveBuilder { self.values_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.values_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs index cde1abe22b7b..7323ee57627d 100644 --- a/arrow-array/src/builder/primitive_dictionary_builder.rs +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -194,11 +194,6 @@ where self.keys_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.keys_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/primitive_run_builder.rs b/arrow-array/src/builder/primitive_run_builder.rs index 53674a73b172..01a989199b58 100644 --- a/arrow-array/src/builder/primitive_run_builder.rs +++ b/arrow-array/src/builder/primitive_run_builder.rs @@ -136,11 +136,6 @@ where self.current_run_end_index } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.current_run_end_index == 0 - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs index 88a23db6d10e..7aa91dacaa8c 100644 --- a/arrow-array/src/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -52,11 +52,6 @@ impl ArrayBuilder for StructBuilder { self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.len() == 0 - } - /// Builds the array. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) @@ -238,6 +233,12 @@ impl StructBuilder { /// Builds the `StructArray` and reset this builder. pub fn finish(&mut self) -> StructArray { self.validate_content(); + if self.fields.is_empty() { + return StructArray::new_empty_fields( + self.len(), + self.null_buffer_builder.finish(), + ); + } let arrays = self.field_builders.iter_mut().map(|f| f.finish()).collect(); let nulls = self.null_buffer_builder.finish(); @@ -248,6 +249,13 @@ impl StructBuilder { pub fn finish_cloned(&self) -> StructArray { self.validate_content(); + if self.fields.is_empty() { + return StructArray::new_empty_fields( + self.len(), + self.null_buffer_builder.finish_cloned(), + ); + } + let arrays = self .field_builders .iter() @@ -596,4 +604,19 @@ mod tests { let mut sa = StructBuilder::new(fields, field_builders); sa.finish(); } + + #[test] + fn test_empty() { + let mut builder = StructBuilder::new(Fields::empty(), vec![]); + builder.append(true); + builder.append(false); + + let a1 = builder.finish_cloned(); + let a2 = builder.finish(); + assert_eq!(a1, a2); + assert_eq!(a1.len(), 2); + assert_eq!(a1.null_count(), 1); + assert!(a1.is_valid(0)); + assert!(a1.is_null(1)); + } } diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index bee8823d1f59..b6cda44e8973 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -799,6 +799,15 @@ pub trait AsArray: private::Sealed { self.as_list_opt().expect("list array") } + /// Downcast this to a [`FixedSizeBinaryArray`] returning `None` if not possible + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray>; + + /// Downcast this to a [`FixedSizeBinaryArray`] panicking if not possible + fn as_fixed_size_binary(&self) -> &FixedSizeBinaryArray { + self.as_fixed_size_binary_opt() + .expect("fixed size binary array") + } + /// Downcast this to a [`FixedSizeListArray`] returning `None` if not possible fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray>; @@ -824,6 +833,14 @@ pub trait AsArray: private::Sealed { fn as_dictionary(&self) -> &DictionaryArray { self.as_dictionary_opt().expect("dictionary array") } + + /// Downcasts this to a [`AnyDictionaryArray`] returning `None` if not possible + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray>; + + /// Downcasts this to a [`AnyDictionaryArray`] panicking if not possible + fn as_any_dictionary(&self) -> &dyn AnyDictionaryArray { + self.as_any_dictionary_opt().expect("any dictionary array") + } } impl private::Sealed for dyn Array + '_ {} @@ -848,6 +865,10 @@ impl AsArray for dyn Array + '_ { self.as_any().downcast_ref() } + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_any().downcast_ref() + } + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { self.as_any().downcast_ref() } @@ -861,6 +882,14 @@ impl AsArray for dyn Array + '_ { ) -> Option<&DictionaryArray> { self.as_any().downcast_ref() } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + let array = self; + downcast_dictionary_array! { + array => Some(array), + _ => None + } + } } impl private::Sealed for ArrayRef {} @@ -885,6 +914,10 @@ impl AsArray for ArrayRef { self.as_ref().as_list_opt() } + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_ref().as_fixed_size_binary_opt() + } + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { self.as_ref().as_fixed_size_list_opt() } @@ -898,6 +931,10 @@ impl AsArray for ArrayRef { ) -> Option<&DictionaryArray> { self.as_ref().as_dictionary_opt() } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + self.as_ref().as_any_dictionary_opt() + } } #[cfg(test)] diff --git a/arrow-array/src/delta.rs b/arrow-array/src/delta.rs index 029168242b90..bf9ee5ca685f 100644 --- a/arrow-array/src/delta.rs +++ b/arrow-array/src/delta.rs @@ -23,22 +23,74 @@ // Copied from chronoutil crate //! Contains utility functions for shifting Date objects. -use chrono::{Datelike, Months}; +use chrono::{DateTime, Datelike, Days, Months, TimeZone}; use std::cmp::Ordering; /// Shift a date by the given number of months. -pub(crate) fn shift_months< - D: Datelike - + std::ops::Add - + std::ops::Sub, ->( - date: D, - months: i32, -) -> D { +pub(crate) fn shift_months(date: D, months: i32) -> D +where + D: Datelike + std::ops::Add + std::ops::Sub, +{ match months.cmp(&0) { Ordering::Equal => date, Ordering::Greater => date + Months::new(months as u32), - Ordering::Less => date - Months::new(-months as u32), + Ordering::Less => date - Months::new(months.unsigned_abs()), + } +} + +/// Add the given number of months to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn add_months_datetime( + dt: DateTime, + months: i32, +) -> Option> { + match months.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_add_months(Months::new(months as u32)), + Ordering::Less => dt.checked_sub_months(Months::new(months.unsigned_abs())), + } +} + +/// Add the given number of days to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn add_days_datetime( + dt: DateTime, + days: i32, +) -> Option> { + match days.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_add_days(Days::new(days as u64)), + Ordering::Less => dt.checked_sub_days(Days::new(days.unsigned_abs() as u64)), + } +} + +/// Substract the given number of months to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn sub_months_datetime( + dt: DateTime, + months: i32, +) -> Option> { + match months.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_sub_months(Months::new(months as u32)), + Ordering::Less => dt.checked_add_months(Months::new(months.unsigned_abs())), + } +} + +/// Substract the given number of days to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn sub_days_datetime( + dt: DateTime, + days: i32, +) -> Option> { + match days.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_sub_days(Days::new(days as u64)), + Ordering::Less => dt.checked_add_days(Days::new(days.unsigned_abs() as u64)), } } diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 86f5d991288a..a198332ca5b5 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -22,6 +22,7 @@ use crate::array::{ GenericListArray, GenericStringArray, PrimitiveArray, }; use crate::{FixedSizeListArray, MapArray}; +use arrow_buffer::NullBuffer; /// An iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] /// @@ -46,6 +47,7 @@ use crate::{FixedSizeListArray, MapArray}; #[derive(Debug)] pub struct ArrayIter { array: T, + logical_nulls: Option, current: usize, current_end: usize, } @@ -54,12 +56,22 @@ impl ArrayIter { /// create a new iterator pub fn new(array: T) -> Self { let len = array.len(); + let logical_nulls = array.logical_nulls(); ArrayIter { array, + logical_nulls, current: 0, current_end: len, } } + + #[inline] + fn is_null(&self, idx: usize) -> bool { + self.logical_nulls + .as_ref() + .map(|x| x.is_null(idx)) + .unwrap_or_default() + } } impl Iterator for ArrayIter { @@ -69,7 +81,7 @@ impl Iterator for ArrayIter { fn next(&mut self) -> Option { if self.current == self.current_end { None - } else if self.array.is_null(self.current) { + } else if self.is_null(self.current) { self.current += 1; Some(None) } else { @@ -98,7 +110,7 @@ impl DoubleEndedIterator for ArrayIter { None } else { self.current_end -= 1; - Some(if self.array.is_null(self.current_end) { + Some(if self.is_null(self.current_end) { None } else { // Safety: diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs index 46de381c3244..afb7ec5e6e44 100644 --- a/arrow-array/src/lib.rs +++ b/arrow-array/src/lib.rs @@ -192,6 +192,9 @@ pub use arithmetic::ArrowNativeTypeOp; mod numeric; pub use numeric::*; +mod scalar; +pub use scalar::*; + pub mod builder; pub mod cast; mod delta; diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index d2e36780a901..27804447fba6 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -43,6 +43,12 @@ pub trait RecordBatchReader: Iterator> { } } +impl RecordBatchReader for Box { + fn schema(&self) -> SchemaRef { + self.as_ref().schema() + } +} + /// Trait for types that can write `RecordBatch`'s. pub trait RecordBatchWriter { /// Write a single batch to the writer. @@ -152,7 +158,6 @@ impl RecordBatch { ))); } - // check that all columns have the same row count let row_count = options .row_count .or_else(|| columns.first().map(|col| col.len())) @@ -171,6 +176,7 @@ impl RecordBatch { } } + // check that all columns have the same row count if columns.iter().any(|c| c.len() != row_count) { let err = match options.row_count { Some(_) => { @@ -232,7 +238,7 @@ impl RecordBatch { }) } - /// Returns the [`Schema`](arrow_schema::Schema) of the record batch. + /// Returns the [`Schema`] of the record batch. pub fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -751,7 +757,7 @@ mod tests { )))) .add_child_data(a2_child.into_data()) .len(2) - .add_buffer(Buffer::from(vec![0i32, 3, 4].to_byte_slice())) + .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice())) .build() .unwrap(); let a2: ArrayRef = Arc::new(ListArray::from(a2)); @@ -1115,4 +1121,22 @@ mod tests { // Cannot remove metadata batch.with_schema(nullable_schema).unwrap_err(); } + + #[test] + fn test_boxed_reader() { + // Make sure we can pass a boxed reader to a function generic over + // RecordBatchReader. + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Arc::new(schema); + + let reader = RecordBatchIterator::new(std::iter::empty(), schema); + let reader: Box = Box::new(reader); + + fn get_size(reader: impl RecordBatchReader) -> usize { + reader.size_hint().0 + } + + let size = get_size(reader); + assert_eq!(size, 0); + } } diff --git a/arrow-array/src/run_iterator.rs b/arrow-array/src/run_iterator.rs index 60022113c3dd..489aabf4756a 100644 --- a/arrow-array/src/run_iterator.rs +++ b/arrow-array/src/run_iterator.rs @@ -237,7 +237,7 @@ mod tests { Some(72), ]; let mut builder = PrimitiveRunBuilder::::new(); - builder.extend(input_vec.clone().into_iter()); + builder.extend(input_vec.iter().copied()); let ree_array = builder.finish(); let ree_array = ree_array.downcast::().unwrap(); @@ -261,7 +261,7 @@ mod tests { Some(72), ]; let mut builder = PrimitiveRunBuilder::::new(); - builder.extend(input_vec.into_iter()); + builder.extend(input_vec); let ree_array = builder.finish(); let ree_array = ree_array.downcast::().unwrap(); diff --git a/arrow-array/src/scalar.rs b/arrow-array/src/scalar.rs new file mode 100644 index 000000000000..f2a696a8f329 --- /dev/null +++ b/arrow-array/src/scalar.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Array; + +/// A possibly [`Scalar`] [`Array`] +/// +/// This allows optimised binary kernels where one or more arguments are constant +/// +/// ``` +/// # use arrow_array::*; +/// # use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; +/// # use arrow_schema::ArrowError; +/// # +/// fn eq_impl( +/// a: &PrimitiveArray, +/// a_scalar: bool, +/// b: &PrimitiveArray, +/// b_scalar: bool, +/// ) -> BooleanArray { +/// let (array, scalar) = match (a_scalar, b_scalar) { +/// (true, true) | (false, false) => { +/// let len = a.len().min(b.len()); +/// let nulls = NullBuffer::union(a.nulls(), b.nulls()); +/// let buffer = BooleanBuffer::collect_bool(len, |idx| a.value(idx) == b.value(idx)); +/// return BooleanArray::new(buffer, nulls); +/// } +/// (true, false) => (b, (a.null_count() == 0).then(|| a.value(0))), +/// (false, true) => (a, (b.null_count() == 0).then(|| b.value(0))), +/// }; +/// match scalar { +/// Some(v) => { +/// let len = array.len(); +/// let nulls = array.nulls().cloned(); +/// let buffer = BooleanBuffer::collect_bool(len, |idx| array.value(idx) == v); +/// BooleanArray::new(buffer, nulls) +/// } +/// None => BooleanArray::new_null(array.len()), +/// } +/// } +/// +/// pub fn eq(l: &dyn Datum, r: &dyn Datum) -> Result { +/// let (l_array, l_scalar) = l.get(); +/// let (r_array, r_scalar) = r.get(); +/// downcast_primitive_array!( +/// (l_array, r_array) => Ok(eq_impl(l_array, l_scalar, r_array, r_scalar)), +/// (a, b) => Err(ArrowError::NotYetImplemented(format!("{a} == {b}"))), +/// ) +/// } +/// +/// // Comparison of two arrays +/// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); +/// let b = Int32Array::from(vec![1, 2, 4, 7, 3]); +/// let r = eq(&a, &b).unwrap(); +/// let values: Vec<_> = r.values().iter().collect(); +/// assert_eq!(values, &[true, true, false, false, false]); +/// +/// // Comparison of an array and a scalar +/// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); +/// let b = Int32Array::new_scalar(1); +/// let r = eq(&a, &b).unwrap(); +/// let values: Vec<_> = r.values().iter().collect(); +/// assert_eq!(values, &[true, false, false, false, false]); +pub trait Datum { + /// Returns the value for this [`Datum`] and a boolean indicating if the value is scalar + fn get(&self) -> (&dyn Array, bool); +} + +impl Datum for T { + fn get(&self) -> (&dyn Array, bool) { + (self, false) + } +} + +impl Datum for dyn Array { + fn get(&self) -> (&dyn Array, bool) { + (self, false) + } +} + +impl Datum for &dyn Array { + fn get(&self) -> (&dyn Array, bool) { + (*self, false) + } +} + +/// A wrapper around a single value [`Array`] that implements +/// [`Datum`] and indicates [compute] kernels should treat this array +/// as a scalar value (a single value). +/// +/// Using a [`Scalar`] is often much more efficient than creating an +/// [`Array`] with the same (repeated) value. +/// +/// See [`Datum`] for more information. +/// +/// # Example +/// +/// ```rust +/// # use arrow_array::{Scalar, Int32Array, ArrayRef}; +/// # fn get_array() -> ArrayRef { std::sync::Arc::new(Int32Array::from(vec![42])) } +/// // Create a (typed) scalar for Int32Array for the value 42 +/// let scalar = Scalar::new(Int32Array::from(vec![42])); +/// +/// // Create a scalar using PrimtiveArray::scalar +/// let scalar = Int32Array::new_scalar(42); +/// +/// // create a scalar from an ArrayRef (for dynamic typed Arrays) +/// let array: ArrayRef = get_array(); +/// let scalar = Scalar::new(array); +/// ``` +/// +/// [compute]: https://docs.rs/arrow/latest/arrow/compute/index.html +#[derive(Debug, Copy, Clone)] +pub struct Scalar(T); + +impl Scalar { + /// Create a new [`Scalar`] from an [`Array`] + /// + /// # Panics + /// + /// Panics if `array.len() != 1` + pub fn new(array: T) -> Self { + assert_eq!(array.len(), 1); + Self(array) + } +} + +impl Datum for Scalar { + fn get(&self) -> (&dyn Array, bool) { + (&self.0, true) + } +} diff --git a/arrow-array/src/trusted_len.rs b/arrow-array/src/trusted_len.rs index fdec18b78781..781cad38f7e9 100644 --- a/arrow-array/src/trusted_len.rs +++ b/arrow-array/src/trusted_len.rs @@ -63,7 +63,7 @@ mod tests { #[test] fn trusted_len_unzip_good() { - let vec = vec![Some(1u32), None]; + let vec = [Some(1u32), None]; let (null, buffer) = unsafe { trusted_len_unzip(vec.iter()) }; assert_eq!(null.as_slice(), &[0b00000001]); assert_eq!(buffer.as_slice(), &[1u8, 0, 0, 0, 0, 0, 0, 0]); diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index f99e6a8f6f81..7988fe9f6690 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -17,7 +17,12 @@ //! Zero-sized types used to parameterize generic array implementations -use crate::delta::shift_months; +use crate::delta::{ + add_days_datetime, add_months_datetime, shift_months, sub_days_datetime, + sub_months_datetime, +}; +use crate::temporal_conversions::as_datetime_with_timezone; +use crate::timezone::Tz; use crate::{ArrowNativeTypeOp, OffsetSizeTrait}; use arrow_buffer::{i256, Buffer, OffsetBuffer}; use arrow_data::decimal::{validate_decimal256_precision, validate_decimal_precision}; @@ -350,158 +355,184 @@ impl ArrowTimestampType for TimestampNanosecondType { } } +fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let months = IntervalYearMonthType::to_months(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_months_datetime(res, months)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_days_datetime(res, days)?; + let res = res.checked_add_signed(Duration::milliseconds(ms as i64))?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_months_datetime(res, months)?; + let res = add_days_datetime(res, days)?; + let res = res.checked_add_signed(Duration::nanoseconds(nanos))?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let months = IntervalYearMonthType::to_months(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_months_datetime(res, months)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_days_datetime(res, days)?; + let res = res.checked_sub_signed(Duration::milliseconds(ms as i64))?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_months_datetime(res, months)?; + let res = sub_days_datetime(res, days)?; + let res = res.checked_sub_signed(Duration::nanoseconds(nanos))?; + let res = res.naive_utc(); + T::make_value(res) +} + impl TimestampSecondType { - /// Adds the given IntervalYearMonthType to an arrow TimestampSecondType + /// Adds the given IntervalYearMonthType to an arrow TimestampSecondType. + /// + /// Returns `None` when it will result in overflow. /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let prior = NaiveDateTime::from_timestamp_opt(timestamp, 0).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - - let months = IntervalYearMonthType::to_months(delta); - let posterior = shift_months(prior, months); - TimestampSecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) } - /// Adds the given IntervalDayTimeType to an arrow TimestampSecondType + /// Adds the given IntervalDayTimeType to an arrow TimestampSecondType. + /// + /// Returns `None` when it will result in overflow. /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let (days, ms) = IntervalDayTimeType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_opt(timestamp, 0).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::milliseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampSecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) } /// Adds the given IntervalMonthDayNanoType to an arrow TimestampSecondType /// + /// Returns `None` when it will result in overflow. /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_opt(timestamp, 0).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = shift_months(res, months); - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampSecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) } /// Subtracts the given IntervalYearMonthType to an arrow TimestampSecondType /// + /// Returns `None` when it will result in overflow. + /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let prior = NaiveDateTime::from_timestamp_opt(timestamp, 0).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let months = IntervalYearMonthType::to_months(-delta); - let posterior = shift_months(prior, months); - TimestampSecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) } /// Subtracts the given IntervalDayTimeType to an arrow TimestampSecondType /// + /// Returns `None` when it will result in overflow. + /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let (days, ms) = IntervalDayTimeType::to_parts(-delta); - let res = NaiveDateTime::from_timestamp_opt(timestamp, 0).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::microseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampSecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) } /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampSecondType /// + /// Returns `None` when it will result in overflow. + /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_opt(timestamp, 0).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = shift_months(res, -months); - let res = res - .checked_add_signed(Duration::days(-days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(-nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampSecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) } } @@ -512,18 +543,13 @@ impl TimestampMicrosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let prior = NaiveDateTime::from_timestamp_micros(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let months = IntervalYearMonthType::to_months(delta); - let posterior = shift_months(prior, months); - TimestampMicrosecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) } /// Adds the given IntervalDayTimeType to an arrow TimestampMicrosecondType @@ -532,27 +558,13 @@ impl TimestampMicrosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (days, ms) = IntervalDayTimeType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_micros(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::milliseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMicrosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) } /// Adds the given IntervalMonthDayNanoType to an arrow TimestampMicrosecondType @@ -561,28 +573,13 @@ impl TimestampMicrosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_micros(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = shift_months(res, months); - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMicrosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) } /// Subtracts the given IntervalYearMonthType to an arrow TimestampMicrosecondType @@ -591,18 +588,13 @@ impl TimestampMicrosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let prior = NaiveDateTime::from_timestamp_micros(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let months = IntervalYearMonthType::to_months(-delta); - let posterior = shift_months(prior, months); - TimestampMicrosecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) } /// Subtracts the given IntervalDayTimeType to an arrow TimestampMicrosecondType @@ -611,27 +603,13 @@ impl TimestampMicrosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (days, ms) = IntervalDayTimeType::to_parts(-delta); - let res = NaiveDateTime::from_timestamp_micros(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::milliseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMicrosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) } /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampMicrosecondType @@ -640,28 +618,13 @@ impl TimestampMicrosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_micros(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = shift_months(res, -months); - let res = res - .checked_add_signed(Duration::days(-days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(-nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMicrosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) } } @@ -672,18 +635,13 @@ impl TimestampMillisecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let prior = NaiveDateTime::from_timestamp_millis(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let months = IntervalYearMonthType::to_months(delta); - let posterior = shift_months(prior, months); - TimestampMillisecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) } /// Adds the given IntervalDayTimeType to an arrow TimestampMillisecondType @@ -692,27 +650,13 @@ impl TimestampMillisecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (days, ms) = IntervalDayTimeType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_millis(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::milliseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMillisecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) } /// Adds the given IntervalMonthDayNanoType to an arrow TimestampMillisecondType @@ -721,28 +665,13 @@ impl TimestampMillisecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_millis(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = shift_months(res, months); - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMillisecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) } /// Subtracts the given IntervalYearMonthType to an arrow TimestampMillisecondType @@ -751,18 +680,13 @@ impl TimestampMillisecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let prior = NaiveDateTime::from_timestamp_millis(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let months = IntervalYearMonthType::to_months(-delta); - let posterior = shift_months(prior, months); - TimestampMillisecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) } /// Subtracts the given IntervalDayTimeType to an arrow TimestampMillisecondType @@ -771,27 +695,13 @@ impl TimestampMillisecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (days, ms) = IntervalDayTimeType::to_parts(-delta); - let res = NaiveDateTime::from_timestamp_millis(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::milliseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMillisecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) } /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampMillisecondType @@ -800,28 +710,13 @@ impl TimestampMillisecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> - { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = NaiveDateTime::from_timestamp_millis(timestamp).ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = shift_months(res, -months); - let res = res - .checked_add_signed(Duration::days(-days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(-nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampMillisecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) } } @@ -832,19 +727,13 @@ impl TimestampNanosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let seconds = timestamp / 1_000_000_000; - let nanos = timestamp % 1_000_000_000; - let prior = NaiveDateTime::from_timestamp_opt(seconds, nanos as u32).ok_or_else( - || ArrowError::ComputeError("Timestamp out of range".to_string()), - )?; - let months = IntervalYearMonthType::to_months(delta); - let posterior = shift_months(prior, months); - TimestampNanosecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) } /// Adds the given IntervalDayTimeType to an arrow TimestampNanosecondType @@ -853,28 +742,13 @@ impl TimestampNanosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let (days, ms) = IntervalDayTimeType::to_parts(delta); - let seconds = timestamp / 1_000_000_000; - let nanos = timestamp % 1_000_000_000; - let res = NaiveDateTime::from_timestamp_opt(seconds, nanos as u32).ok_or_else( - || ArrowError::ComputeError("Timestamp out of range".to_string()), - )?; - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::milliseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampNanosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) } /// Adds the given IntervalMonthDayNanoType to an arrow TimestampNanosecondType @@ -883,114 +757,58 @@ impl TimestampNanosecondType { /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn add_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let seconds = timestamp / 1_000_000_000; - let nanos = timestamp % 1_000_000_000; - let res = NaiveDateTime::from_timestamp_opt(seconds, nanos as u32).ok_or_else( - || ArrowError::ComputeError("Timestamp out of range".to_string()), - )?; + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = shift_months(res, months); - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampNanosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) - } - - /// Subtracs the given IntervalYearMonthType to an arrow TimestampNanosecondType + /// Subtracts the given IntervalYearMonthType to an arrow TimestampNanosecondType /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_year_months( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let seconds = timestamp / 1_000_000_000; - let nanos = timestamp % 1_000_000_000; - let prior = NaiveDateTime::from_timestamp_opt(seconds, nanos as u32).ok_or_else( - || ArrowError::ComputeError("Timestamp out of range".to_string()), - )?; - let months = IntervalYearMonthType::to_months(-delta); - let posterior = shift_months(prior, months); - TimestampNanosecondType::make_value(posterior) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) } - /// Subtracs the given IntervalDayTimeType to an arrow TimestampNanosecondType + /// Subtracts the given IntervalDayTimeType to an arrow TimestampNanosecondType /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_day_time( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let seconds = timestamp / 1_000_000_000; - let nanos = timestamp % 1_000_000_000; - let res = NaiveDateTime::from_timestamp_opt(seconds, nanos as u32).ok_or_else( - || ArrowError::ComputeError("Timestamp out of range".to_string()), - )?; - - let (days, ms) = IntervalDayTimeType::to_parts(-delta); - let res = res - .checked_add_signed(Duration::days(days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::milliseconds(ms as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampNanosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) - } - - /// Subtracs the given IntervalMonthDayNanoType to an arrow TimestampNanosecondType + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampNanosecondType /// /// # Arguments /// /// * `timestamp` - The date on which to perform the operation /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` pub fn subtract_month_day_nano( - timestamp: ::Native, + timestamp: ::Native, delta: ::Native, - ) -> Result<::Native, ArrowError> { - let seconds = timestamp / 1_000_000_000; - let nanos = timestamp % 1_000_000_000; - let res = NaiveDateTime::from_timestamp_opt(seconds, nanos as u32).ok_or_else( - || ArrowError::ComputeError("Timestamp out of range".to_string()), - )?; - - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = shift_months(res, -months); - let res = res - .checked_add_signed(Duration::days(-days as i64)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - let res = res - .checked_add_signed(Duration::nanoseconds(-nanos)) - .ok_or_else(|| { - ArrowError::ComputeError("Timestamp out of range".to_string()) - })?; - TimestampNanosecondType::make_value(res) - .ok_or_else(|| ArrowError::ComputeError("Timestamp out of range".to_string())) + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) } } @@ -1001,6 +819,7 @@ impl IntervalYearMonthType { /// /// * `years` - The number of years (+/-) represented in this interval /// * `months` - The number of months (+/-) represented in this interval + #[inline] pub fn make_value( years: i32, months: i32, @@ -1015,6 +834,7 @@ impl IntervalYearMonthType { /// # Arguments /// /// * `i` - The IntervalYearMonthType::Native to convert + #[inline] pub fn to_months(i: ::Native) -> i32 { i } @@ -1027,6 +847,7 @@ impl IntervalDayTimeType { /// /// * `days` - The number of days (+/-) represented in this interval /// * `millis` - The number of milliseconds (+/-) represented in this interval + #[inline] pub fn make_value( days: i32, millis: i32, @@ -1053,6 +874,7 @@ impl IntervalDayTimeType { /// # Arguments /// /// * `i` - The IntervalDayTimeType to convert + #[inline] pub fn to_parts( i: ::Native, ) -> (i32, i32) { @@ -1070,6 +892,7 @@ impl IntervalMonthDayNanoType { /// * `months` - The number of months (+/-) represented in this interval /// * `days` - The number of days (+/-) represented in this interval /// * `nanos` - The number of nanoseconds (+/-) represented in this interval + #[inline] pub fn make_value( months: i32, days: i32, @@ -1098,6 +921,7 @@ impl IntervalMonthDayNanoType { /// # Arguments /// /// * `i` - The IntervalMonthDayNanoType to convert + #[inline] pub fn to_parts( i: ::Native, ) -> (i32, i32, i64) { @@ -1206,10 +1030,10 @@ impl Date32Type { date: ::Native, delta: ::Native, ) -> ::Native { - let (days, ms) = IntervalDayTimeType::to_parts(-delta); + let (days, ms) = IntervalDayTimeType::to_parts(delta); let res = Date32Type::to_naive_date(date); - let res = res.add(Duration::days(days as i64)); - let res = res.add(Duration::milliseconds(ms as i64)); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::milliseconds(ms as i64)); Date32Type::from_naive_date(res) } @@ -1226,8 +1050,8 @@ impl Date32Type { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); let res = Date32Type::to_naive_date(date); let res = shift_months(res, -months); - let res = res.add(Duration::days(-days as i64)); - let res = res.add(Duration::nanoseconds(-nanos)); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::nanoseconds(nanos)); Date32Type::from_naive_date(res) } } @@ -1330,10 +1154,10 @@ impl Date64Type { date: ::Native, delta: ::Native, ) -> ::Native { - let (days, ms) = IntervalDayTimeType::to_parts(-delta); + let (days, ms) = IntervalDayTimeType::to_parts(delta); let res = Date64Type::to_naive_date(date); - let res = res.add(Duration::days(days as i64)); - let res = res.add(Duration::milliseconds(ms as i64)); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::milliseconds(ms as i64)); Date64Type::from_naive_date(res) } @@ -1350,8 +1174,8 @@ impl Date64Type { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); let res = Date64Type::to_naive_date(date); let res = shift_months(res, -months); - let res = res.add(Duration::days(-days as i64)); - let res = res.add(Duration::nanoseconds(-nanos)); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::nanoseconds(nanos)); Date64Type::from_naive_date(res) } } @@ -1544,12 +1368,14 @@ pub(crate) mod bytes { } impl ByteArrayNativeType for [u8] { + #[inline] unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self { b } } impl ByteArrayNativeType for str { + #[inline] unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self { std::str::from_utf8_unchecked(b) } @@ -1670,7 +1496,6 @@ pub type LargeBinaryType = GenericBinaryType; mod tests { use super::*; use arrow_data::{layout, BufferSpec}; - use std::mem::size_of; #[test] fn month_day_nano_should_roundtrip() { @@ -1717,7 +1542,8 @@ mod tests { assert_eq!( spec, &BufferSpec::FixedWidth { - byte_width: size_of::() + byte_width: std::mem::size_of::(), + alignment: std::mem::align_of::(), } ); } diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml new file mode 100644 index 000000000000..9575874c41d2 --- /dev/null +++ b/arrow-avro/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-avro" +version = { workspace = true } +description = "Support for parsing Avro format into the Arrow format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_avro" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +serde_json = { version = "1.0", default-features = false, features = ["std"] } +serde = { version = "1.0.188", features = ["derive"] } + +[dev-dependencies] + diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs new file mode 100644 index 000000000000..a1a44fc22b68 --- /dev/null +++ b/arrow-avro/src/compression.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::{Deserialize, Serialize}; + +/// The metadata key used for storing the JSON encoded [`CompressionCodec`] +pub const CODEC_METADATA_KEY: &str = "avro.codec"; + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CompressionCodec { + Null, + Deflate, + BZip2, + Snappy, + XZ, + ZStandard, +} diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs new file mode 100644 index 000000000000..c76ecb399a45 --- /dev/null +++ b/arrow-avro/src/lib.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Convert data to / from the [Apache Arrow] memory format and [Apache Avro] +//! +//! [Apache Arrow]: https://arrow.apache.org +//! [Apache Avro]: https://avro.apache.org/ + +#![allow(unused)] // Temporary + +pub mod reader; +mod schema; + +mod compression; + +#[cfg(test)] +mod test_util { + pub fn arrow_test_data(path: &str) -> String { + match std::env::var("ARROW_TEST_DATA") { + Ok(dir) => format!("{dir}/{path}"), + Err(_) => format!("../testing/data/{path}"), + } + } +} diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs new file mode 100644 index 000000000000..479f0ef90909 --- /dev/null +++ b/arrow-avro/src/reader/block.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for [`Block`] + +use crate::reader::vlq::VLQDecoder; +use arrow_schema::ArrowError; + +/// A file data block +/// +/// +#[derive(Debug, Default)] +pub struct Block { + /// The number of objects in this block + pub count: usize, + /// The serialized objects within this block + pub data: Vec, + /// The sync marker + pub sync: [u8; 16], +} + +/// A decoder for [`Block`] +#[derive(Debug)] +pub struct BlockDecoder { + state: BlockDecoderState, + in_progress: Block, + vlq_decoder: VLQDecoder, + bytes_remaining: usize, +} + +#[derive(Debug)] +enum BlockDecoderState { + Count, + Size, + Data, + Sync, + Finished, +} + +impl Default for BlockDecoder { + fn default() -> Self { + Self { + state: BlockDecoderState::Count, + in_progress: Default::default(), + vlq_decoder: Default::default(), + bytes_remaining: 0, + } + } +} + +impl BlockDecoder { + /// Parse [`Block`] from `buf`, returning the number of bytes read + /// + /// This method can be called multiple times with consecutive chunks of data, allowing + /// integration with chunked IO systems like [`BufRead::fill_buf`] + /// + /// All errors should be considered fatal, and decoding aborted + /// + /// Once an entire [`Block`] has been decoded this method will not read any further + /// input bytes, until [`Self::flush`] is called. Afterwards [`Self::decode`] + /// can then be used again to read the next block, if any + /// + /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf + pub fn decode(&mut self, mut buf: &[u8]) -> Result { + let max_read = buf.len(); + while !buf.is_empty() { + match self.state { + BlockDecoderState::Count => { + if let Some(c) = self.vlq_decoder.long(&mut buf) { + self.in_progress.count = c.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Block count cannot be negative, got {c}" + )) + })?; + + self.state = BlockDecoderState::Size; + } + } + BlockDecoderState::Size => { + if let Some(c) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = c.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Block size cannot be negative, got {c}" + )) + })?; + + self.in_progress.data.reserve(self.bytes_remaining); + self.state = BlockDecoderState::Data; + } + } + BlockDecoderState::Data => { + let to_read = self.bytes_remaining.min(buf.len()); + self.in_progress.data.extend_from_slice(&buf[..to_read]); + buf = &buf[to_read..]; + self.bytes_remaining -= to_read; + if self.bytes_remaining == 0 { + self.bytes_remaining = 16; + self.state = BlockDecoderState::Sync; + } + } + BlockDecoderState::Sync => { + let to_decode = buf.len().min(self.bytes_remaining); + let write = &mut self.in_progress.sync[16 - to_decode..]; + write[..to_decode].copy_from_slice(&buf[..to_decode]); + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = BlockDecoderState::Finished; + } + } + BlockDecoderState::Finished => return Ok(max_read - buf.len()), + } + } + Ok(max_read) + } + + /// Flush this decoder returning the parsed [`Block`] if any + pub fn flush(&mut self) -> Option { + match self.state { + BlockDecoderState::Finished => { + self.state = BlockDecoderState::Count; + Some(std::mem::take(&mut self.in_progress)) + } + _ => None, + } + } +} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs new file mode 100644 index 000000000000..2d443175a7aa --- /dev/null +++ b/arrow-avro/src/reader/header.rs @@ -0,0 +1,290 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for [`Header`] + +use crate::reader::vlq::VLQDecoder; +use crate::schema::Schema; +use arrow_schema::ArrowError; + +#[derive(Debug)] +enum HeaderDecoderState { + /// Decoding the [`MAGIC`] prefix + Magic, + /// Decoding a block count + BlockCount, + /// Decoding a block byte length + BlockLen, + /// Decoding a key length + KeyLen, + /// Decoding a key string + Key, + /// Decoding a value length + ValueLen, + /// Decoding a value payload + Value, + /// Decoding sync marker + Sync, + /// Finished decoding + Finished, +} + +/// A decoded header for an [Object Container File](https://avro.apache.org/docs/1.11.1/specification/#object-container-files) +#[derive(Debug, Clone)] +pub struct Header { + meta_offsets: Vec, + meta_buf: Vec, + sync: [u8; 16], +} + +impl Header { + /// Returns an iterator over the meta keys in this header + pub fn metadata(&self) -> impl Iterator { + let mut last = 0; + self.meta_offsets.windows(2).map(move |w| { + let start = last; + last = w[1]; + (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]]) + }) + } + + /// Returns the value for a given metadata key if present + pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> { + self.metadata() + .find_map(|(k, v)| (k == key.as_ref()).then_some(v)) + } + + /// Returns the sync token for this file + pub fn sync(&self) -> [u8; 16] { + self.sync + } +} + +/// A decoder for [`Header`] +/// +/// The avro file format does not encode the length of the header, and so it +/// is necessary to provide a push-based decoder that can be used with streams +#[derive(Debug)] +pub struct HeaderDecoder { + state: HeaderDecoderState, + vlq_decoder: VLQDecoder, + + /// The end offsets of strings in `meta_buf` + meta_offsets: Vec, + /// The raw binary data of the metadata map + meta_buf: Vec, + + /// The decoded sync marker + sync_marker: [u8; 16], + + /// The number of remaining tuples in the current block + tuples_remaining: usize, + /// The number of bytes remaining in the current string/bytes payload + bytes_remaining: usize, +} + +impl Default for HeaderDecoder { + fn default() -> Self { + Self { + state: HeaderDecoderState::Magic, + meta_offsets: vec![], + meta_buf: vec![], + sync_marker: [0; 16], + vlq_decoder: Default::default(), + tuples_remaining: 0, + bytes_remaining: MAGIC.len(), + } + } +} + +const MAGIC: &[u8; 4] = b"Obj\x01"; + +impl HeaderDecoder { + /// Parse [`Header`] from `buf`, returning the number of bytes read + /// + /// This method can be called multiple times with consecutive chunks of data, allowing + /// integration with chunked IO systems like [`BufRead::fill_buf`] + /// + /// All errors should be considered fatal, and decoding aborted + /// + /// Once the entire [`Header`] has been decoded this method will not read any further + /// input bytes, and the header can be obtained with [`Self::flush`] + /// + /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf + pub fn decode(&mut self, mut buf: &[u8]) -> Result { + let max_read = buf.len(); + while !buf.is_empty() { + match self.state { + HeaderDecoderState::Magic => { + let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..]; + let to_decode = buf.len().min(remaining.len()); + if !buf.starts_with(&remaining[..to_decode]) { + return Err(ArrowError::ParseError( + "Incorrect avro magic".to_string(), + )); + } + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = HeaderDecoderState::BlockCount; + } + } + HeaderDecoderState::BlockCount => { + if let Some(block_count) = self.vlq_decoder.long(&mut buf) { + match block_count.try_into() { + Ok(0) => { + self.state = HeaderDecoderState::Sync; + self.bytes_remaining = 16; + } + Ok(remaining) => { + self.tuples_remaining = remaining; + self.state = HeaderDecoderState::KeyLen; + } + Err(_) => { + self.tuples_remaining = block_count.unsigned_abs() as _; + self.state = HeaderDecoderState::BlockLen; + } + } + } + } + HeaderDecoderState::BlockLen => { + if self.vlq_decoder.long(&mut buf).is_some() { + self.state = HeaderDecoderState::KeyLen + } + } + HeaderDecoderState::Key => { + let to_read = self.bytes_remaining.min(buf.len()); + self.meta_buf.extend_from_slice(&buf[..to_read]); + self.bytes_remaining -= to_read; + buf = &buf[to_read..]; + if self.bytes_remaining == 0 { + self.meta_offsets.push(self.meta_buf.len()); + self.state = HeaderDecoderState::ValueLen; + } + } + HeaderDecoderState::Value => { + let to_read = self.bytes_remaining.min(buf.len()); + self.meta_buf.extend_from_slice(&buf[..to_read]); + self.bytes_remaining -= to_read; + buf = &buf[to_read..]; + if self.bytes_remaining == 0 { + self.meta_offsets.push(self.meta_buf.len()); + + self.tuples_remaining -= 1; + match self.tuples_remaining { + 0 => self.state = HeaderDecoderState::BlockCount, + _ => self.state = HeaderDecoderState::KeyLen, + } + } + } + HeaderDecoderState::KeyLen => { + if let Some(len) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = len as _; + self.state = HeaderDecoderState::Key; + } + } + HeaderDecoderState::ValueLen => { + if let Some(len) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = len as _; + self.state = HeaderDecoderState::Value; + } + } + HeaderDecoderState::Sync => { + let to_decode = buf.len().min(self.bytes_remaining); + let write = &mut self.sync_marker[16 - to_decode..]; + write[..to_decode].copy_from_slice(&buf[..to_decode]); + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = HeaderDecoderState::Finished; + } + } + HeaderDecoderState::Finished => return Ok(max_read - buf.len()), + } + } + Ok(max_read) + } + + /// Flush this decoder returning the parsed [`Header`] if any + pub fn flush(&mut self) -> Option
{ + match self.state { + HeaderDecoderState::Finished => { + self.state = HeaderDecoderState::Magic; + Some(Header { + meta_offsets: std::mem::take(&mut self.meta_offsets), + meta_buf: std::mem::take(&mut self.meta_buf), + sync: self.sync_marker, + }) + } + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::reader::read_header; + use crate::schema::SCHEMA_METADATA_KEY; + use crate::test_util::arrow_test_data; + use std::fs::File; + use std::io::{BufRead, BufReader}; + + #[test] + fn test_header_decode() { + let mut decoder = HeaderDecoder::default(); + for m in MAGIC { + decoder.decode(std::slice::from_ref(m)).unwrap(); + } + + let mut decoder = HeaderDecoder::default(); + assert_eq!(decoder.decode(MAGIC).unwrap(), 4); + + let mut decoder = HeaderDecoder::default(); + decoder.decode(b"Ob").unwrap(); + let err = decoder.decode(b"s").unwrap_err().to_string(); + assert_eq!(err, "Parser error: Incorrect avro magic"); + } + + fn decode_file(file: &str) -> Header { + let file = File::open(file).unwrap(); + read_header(BufReader::with_capacity(100, file)).unwrap() + } + + #[test] + fn test_header() { + let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro")); + let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); + let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#; + assert_eq!(schema_json, expected); + let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + assert_eq!( + u128::from_le_bytes(header.sync()), + 226966037233754408753420635932530907102 + ); + + let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro")); + let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); + let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#; + assert_eq!(schema_json, expected); + let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + assert_eq!( + u128::from_le_bytes(header.sync()), + 325166208089902833952788552656412487328 + ); + } +} diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs new file mode 100644 index 000000000000..91e2dbf9835b --- /dev/null +++ b/arrow-avro/src/reader/mod.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Read Avro data to Arrow + +use crate::reader::block::{Block, BlockDecoder}; +use crate::reader::header::{Header, HeaderDecoder}; +use arrow_schema::ArrowError; +use std::io::BufRead; + +mod header; + +mod block; + +mod vlq; + +/// Read a [`Header`] from the provided [`BufRead`] +fn read_header(mut reader: R) -> Result { + let mut decoder = HeaderDecoder::default(); + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + + decoder + .flush() + .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) +} + +/// Return an iterator of [`Block`] from the provided [`BufRead`] +fn read_blocks( + mut reader: R, +) -> impl Iterator> { + let mut decoder = BlockDecoder::default(); + + let mut try_next = move || { + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + Ok(decoder.flush()) + }; + std::iter::from_fn(move || try_next().transpose()) +} + +#[cfg(test)] +mod test { + use crate::reader::{read_blocks, read_header}; + use crate::test_util::arrow_test_data; + use std::fs::File; + use std::io::BufReader; + + #[test] + fn test_mux() { + let file = File::open(arrow_test_data("avro/alltypes_plain.avro")).unwrap(); + let mut reader = BufReader::new(file); + let header = read_header(&mut reader).unwrap(); + for result in read_blocks(reader) { + let block = result.unwrap(); + assert_eq!(block.sync, header.sync()); + } + } +} diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs new file mode 100644 index 000000000000..80f1c60eec7d --- /dev/null +++ b/arrow-avro/src/reader/vlq.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Decoder for zig-zag encoded variable length (VLW) integers +/// +/// See also: +/// +/// +#[derive(Debug, Default)] +pub struct VLQDecoder { + /// Scratch space for decoding VLQ integers + in_progress: u64, + shift: u32, +} + +impl VLQDecoder { + /// Decode a signed long from `buf` + pub fn long(&mut self, buf: &mut &[u8]) -> Option { + while let Some(byte) = buf.first().copied() { + *buf = &buf[1..]; + self.in_progress |= ((byte & 0x7F) as u64) << self.shift; + self.shift += 7; + if byte & 0x80 == 0 { + let val = self.in_progress; + self.in_progress = 0; + self.shift = 0; + return Some((val >> 1) as i64 ^ -((val & 1) as i64)); + } + } + None + } +} diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs new file mode 100644 index 000000000000..839ba65bd5fc --- /dev/null +++ b/arrow-avro/src/schema.rs @@ -0,0 +1,484 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// The metadata key used for storing the JSON encoded [`Schema`] +pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; + +/// Either a [`PrimitiveType`] or a reference to a previously defined named type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum TypeName<'a> { + Primitive(PrimitiveType), + Ref(&'a str), +} + +/// A primitive type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum PrimitiveType { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, +} + +/// Additional attributes within a [`Schema`] +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Attributes<'a> { + /// A logical type name + /// + /// + #[serde(default)] + pub logical_type: Option<&'a str>, + + /// Additional JSON attributes + #[serde(flatten)] + pub additional: HashMap<&'a str, serde_json::Value>, +} + +/// A type definition that is not a variant of [`ComplexType`] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Type<'a> { + #[serde(borrow)] + pub r#type: TypeName<'a>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// An Avro schema +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Schema<'a> { + #[serde(borrow)] + TypeName(TypeName<'a>), + #[serde(borrow)] + Union(Vec>), + #[serde(borrow)] + Complex(ComplexType<'a>), + #[serde(borrow)] + Type(Type<'a>), +} + +/// A complex type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ComplexType<'a> { + #[serde(borrow)] + Union(Vec>), + #[serde(borrow)] + Record(Record<'a>), + #[serde(borrow)] + Enum(Enum<'a>), + #[serde(borrow)] + Array(Array<'a>), + #[serde(borrow)] + Map(Map<'a>), + #[serde(borrow)] + Fixed(Fixed<'a>), +} + +/// A record +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Record<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + #[serde(borrow)] + pub fields: Vec>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A field within a [`Record`] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Field<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow)] + pub r#type: Schema<'a>, + #[serde(borrow, default)] + pub default: Option<&'a str>, +} + +/// An enumeration +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Enum<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + #[serde(borrow)] + pub symbols: Vec<&'a str>, + #[serde(borrow, default)] + pub default: Option<&'a str>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// An array +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Array<'a> { + #[serde(borrow)] + pub items: Box>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A map +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Map<'a> { + #[serde(borrow)] + pub values: Box>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A fixed length binary array +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Fixed<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + pub size: usize, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + #[test] + fn test_deserialize() { + let t: Schema = serde_json::from_str("\"string\"").unwrap(); + assert_eq!( + t, + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)) + ); + + let t: Schema = serde_json::from_str("[\"int\", \"null\"]").unwrap(); + assert_eq!( + t, + Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]) + ); + + let t: Type = serde_json::from_str( + r#"{ + "type":"long", + "logicalType":"timestamp-micros" + }"#, + ) + .unwrap(); + + let timestamp = Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type: Some("timestamp-micros"), + additional: Default::default(), + }, + }; + + assert_eq!(t, timestamp); + + let t: ComplexType = serde_json::from_str( + r#"{ + "type":"fixed", + "name":"fixed", + "namespace":"topLevelRecord.value", + "size":11, + "logicalType":"decimal", + "precision":25, + "scale":2 + }"#, + ) + .unwrap(); + + let decimal = ComplexType::Fixed(Fixed { + name: "fixed", + namespace: Some("topLevelRecord.value"), + aliases: vec![], + size: 11, + attributes: Attributes { + logical_type: Some("decimal"), + additional: vec![("precision", json!(25)), ("scale", json!(2))] + .into_iter() + .collect(), + }, + }); + + assert_eq!(t, decimal); + + let schema: Schema = serde_json::from_str( + r#"{ + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"value", + "type":[ + { + "type":"fixed", + "name":"fixed", + "namespace":"topLevelRecord.value", + "size":11, + "logicalType":"decimal", + "precision":25, + "scale":2 + }, + "null" + ] + } + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "topLevelRecord", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "value", + doc: None, + r#type: Schema::Union(vec![ + Schema::Complex(decimal), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + },], + attributes: Default::default(), + })) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "LongList", + "aliases": ["LinkedLongs"], + "fields" : [ + {"name": "value", "type": "long"}, + {"name": "next", "type": ["null", "LongList"]} + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "LongList", + namespace: None, + doc: None, + aliases: vec!["LinkedLongs"], + fields: vec![ + Field { + name: "value", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive( + PrimitiveType::Long + )), + default: None, + }, + Field { + name: "next", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Ref("LongList")), + ]), + default: None, + } + ], + attributes: Attributes::default(), + })) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"id", + "type":[ + "int", + "null" + ] + }, + { + "name":"timestamp_col", + "type":[ + { + "type":"long", + "logicalType":"timestamp-micros" + }, + "null" + ] + } + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "topLevelRecord", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + Field { + name: "id", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + }, + Field { + name: "timestamp_col", + doc: None, + r#type: Schema::Union(vec![ + Schema::Type(timestamp), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + } + ], + attributes: Default::default(), + })) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc", + "fields": [ + {"name": "clientHash", + "type": {"type": "fixed", "name": "MD5", "size": 16}}, + {"name": "clientProtocol", "type": ["null", "string"]}, + {"name": "serverHash", "type": "MD5"}, + {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]} + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "HandshakeRequest", + namespace: Some("org.apache.avro.ipc"), + doc: None, + aliases: vec![], + fields: vec![ + Field { + name: "clientHash", + doc: None, + r#type: Schema::Complex(ComplexType::Fixed(Fixed { + name: "MD5", + namespace: None, + aliases: vec![], + size: 16, + attributes: Default::default(), + })), + default: None, + }, + Field { + name: "clientProtocol", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]), + default: None, + }, + Field { + name: "serverHash", + doc: None, + r#type: Schema::TypeName(TypeName::Ref("MD5")), + default: None, + }, + Field { + name: "meta", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::Complex(ComplexType::Map(Map { + values: Box::new(Schema::TypeName(TypeName::Primitive( + PrimitiveType::Bytes + ))), + attributes: Default::default(), + })), + ]), + default: None, + } + ], + attributes: Default::default(), + })) + ); + } +} diff --git a/arrow-buffer/Cargo.toml b/arrow-buffer/Cargo.toml index 1db388db8398..746045cc8dde 100644 --- a/arrow-buffer/Cargo.toml +++ b/arrow-buffer/Cargo.toml @@ -34,6 +34,7 @@ path = "src/lib.rs" bench = false [dependencies] +bytes = { version = "1.4" } num = { version = "0.4", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false } diff --git a/arrow-buffer/benches/i256.rs b/arrow-buffer/benches/i256.rs index 2c43e0e91070..ebb45e793bd0 100644 --- a/arrow-buffer/benches/i256.rs +++ b/arrow-buffer/benches/i256.rs @@ -21,18 +21,7 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::str::FromStr; -/// Returns fixed seedable RNG -fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - -fn create_i256_vec(size: usize) -> Vec { - let mut rng = seedable_rng(); - - (0..size) - .map(|_| i256::from_i128(rng.gen::())) - .collect() -} +const SIZE: usize = 1024; fn criterion_benchmark(c: &mut Criterion) { let numbers = vec![ @@ -54,24 +43,40 @@ fn criterion_benchmark(c: &mut Criterion) { }); } - c.bench_function("i256_div", |b| { + let mut rng = StdRng::seed_from_u64(42); + + let numerators: Vec<_> = (0..SIZE) + .map(|_| { + let high = rng.gen_range(1000..i128::MAX); + let low = rng.gen(); + i256::from_parts(low, high) + }) + .collect(); + + let divisors: Vec<_> = numerators + .iter() + .map(|n| { + let quotient = rng.gen_range(1..100_i32); + n.wrapping_div(i256::from(quotient)) + }) + .collect(); + + c.bench_function("i256_div_rem small quotient", |b| { b.iter(|| { - for number_a in create_i256_vec(10) { - for number_b in create_i256_vec(5) { - number_a.checked_div(number_b); - number_a.wrapping_div(number_b); - } + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); } }); }); - c.bench_function("i256_rem", |b| { + let divisors: Vec<_> = (0..SIZE) + .map(|_| i256::from(rng.gen_range(1..100_i32))) + .collect(); + + c.bench_function("i256_div_rem small divisor", |b| { b.iter(|| { - for number_a in create_i256_vec(10) { - for number_b in create_i256_vec(5) { - number_a.checked_rem(number_b); - number_a.wrapping_rem(number_b); - } + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); } }); }); diff --git a/arrow-buffer/src/alloc/alignment.rs b/arrow-buffer/src/alloc/alignment.rs index 7978baa2bbd8..b3979e1d6a06 100644 --- a/arrow-buffer/src/alloc/alignment.rs +++ b/arrow-buffer/src/alloc/alignment.rs @@ -117,3 +117,7 @@ pub const ALIGNMENT: usize = 1 << 7; /// Cache and allocation multiple alignment size #[cfg(target_arch = "aarch64")] pub const ALIGNMENT: usize = 1 << 6; + +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "loongarch64")] +pub const ALIGNMENT: usize = 1 << 6; diff --git a/arrow-buffer/src/bigint/div.rs b/arrow-buffer/src/bigint/div.rs new file mode 100644 index 000000000000..ba530ffcc6c8 --- /dev/null +++ b/arrow-buffer/src/bigint/div.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! N-digit division +//! +//! Implementation heavily inspired by [uint] +//! +//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844 + +/// Unsigned, little-endian, n-digit division with remainder +/// +/// # Panics +/// +/// Panics if divisor is zero +pub fn div_rem( + numerator: &[u64; N], + divisor: &[u64; N], +) -> ([u64; N], [u64; N]) { + let numerator_bits = bits(numerator); + let divisor_bits = bits(divisor); + assert_ne!(divisor_bits, 0, "division by zero"); + + if numerator_bits < divisor_bits { + return ([0; N], *numerator); + } + + if divisor_bits <= 64 { + return div_rem_small(numerator, divisor[0]); + } + + let numerator_words = (numerator_bits + 63) / 64; + let divisor_words = (divisor_bits + 63) / 64; + let n = divisor_words; + let m = numerator_words - divisor_words; + + div_rem_knuth(numerator, divisor, n, m) +} + +/// Return the least number of bits needed to represent the number +fn bits(arr: &[u64]) -> usize { + for (idx, v) in arr.iter().enumerate().rev() { + if *v > 0 { + return 64 - v.leading_zeros() as usize + 64 * idx; + } + } + 0 +} + +/// Division of numerator by a u64 divisor +fn div_rem_small( + numerator: &[u64; N], + divisor: u64, +) -> ([u64; N], [u64; N]) { + let mut rem = 0u64; + let mut numerator = *numerator; + numerator.iter_mut().rev().for_each(|d| { + let (q, r) = div_rem_word(rem, *d, divisor); + *d = q; + rem = r; + }); + + let mut rem_padded = [0; N]; + rem_padded[0] = rem; + (numerator, rem_padded) +} + +/// Use Knuth Algorithm D to compute `numerator / divisor` returning the +/// quotient and remainder +/// +/// `n` is the number of non-zero 64-bit words in `divisor` +/// `m` is the number of non-zero 64-bit words present in `numerator` beyond `divisor`, and +/// therefore the number of words in the quotient +/// +/// A good explanation of the algorithm can be found [here](https://ridiculousfish.com/blog/posts/labor-of-division-episode-iv.html) +fn div_rem_knuth( + numerator: &[u64; N], + divisor: &[u64; N], + n: usize, + m: usize, +) -> ([u64; N], [u64; N]) { + assert!(n + m <= N); + + // The algorithm works by incrementally generating guesses `q_hat`, for the next digit + // of the quotient, starting from the most significant digit. + // + // This relies on the property that for any `q_hat` where + // + // (q_hat << (j * 64)) * divisor <= numerator` + // + // We can set + // + // q += q_hat << (j * 64) + // numerator -= (q_hat << (j * 64)) * divisor + // + // And then iterate until `numerator < divisor` + + // We normalize the divisor so that the highest bit in the highest digit of the + // divisor is set, this ensures our initial guess of `q_hat` is at most 2 off from + // the correct value for q[j] + let shift = divisor[n - 1].leading_zeros(); + // As the shift is computed based on leading zeros, don't need to perform full_shl + let divisor = shl_word(divisor, shift); + // numerator may have fewer leading zeros than divisor, so must add another digit + let mut numerator = full_shl(numerator, shift); + + // The two most significant digits of the divisor + let b0 = divisor[n - 1]; + let b1 = divisor[n - 2]; + + let mut q = [0; N]; + + for j in (0..=m).rev() { + let a0 = numerator[j + n]; + let a1 = numerator[j + n - 1]; + + let mut q_hat = if a0 < b0 { + // The first estimate is [a1, a0] / b0, it may be too large by at most 2 + let (mut q_hat, mut r_hat) = div_rem_word(a0, a1, b0); + + // r_hat = [a1, a0] - q_hat * b0 + // + // Now we want to compute a more precise estimate [a2,a1,a0] / [b1,b0] + // which can only be less or equal to the current q_hat + // + // q_hat is too large if: + // [a2,a1,a0] < q_hat * [b1,b0] + // [a2,r_hat] < q_hat * b1 + let a2 = numerator[j + n - 2]; + loop { + let r = u128::from(q_hat) * u128::from(b1); + let (lo, hi) = (r as u64, (r >> 64) as u64); + if (hi, lo) <= (r_hat, a2) { + break; + } + + q_hat -= 1; + let (new_r_hat, overflow) = r_hat.overflowing_add(b0); + r_hat = new_r_hat; + + if overflow { + break; + } + } + q_hat + } else { + u64::MAX + }; + + // q_hat is now either the correct quotient digit, or in rare cases 1 too large + + // Compute numerator -= (q_hat * divisor) << (j * 64) + let q_hat_v = full_mul_u64(&divisor, q_hat); + let c = sub_assign(&mut numerator[j..], &q_hat_v[..n + 1]); + + // If underflow, q_hat was too large by 1 + if c { + // Reduce q_hat by 1 + q_hat -= 1; + + // Add back one multiple of divisor + let c = add_assign(&mut numerator[j..], &divisor[..n]); + numerator[j + n] = numerator[j + n].wrapping_add(u64::from(c)); + } + + // q_hat is the correct value for q[j] + q[j] = q_hat; + } + + // The remainder is what is left in numerator, with the initial normalization shl reversed + let remainder = full_shr(&numerator, shift); + (q, remainder) +} + +/// Perform narrowing division of a u128 by a u64 divisor, returning the quotient and remainder +/// +/// This method may trap or panic if hi >= divisor, i.e. the quotient would not fit +/// into a 64-bit integer +fn div_rem_word(hi: u64, lo: u64, divisor: u64) -> (u64, u64) { + debug_assert!(hi < divisor); + debug_assert_ne!(divisor, 0); + + // LLVM fails to use the div instruction as it is not able to prove + // that hi < divisor, and therefore the result will fit into 64-bits + #[cfg(target_arch = "x86_64")] + unsafe { + let mut quot = lo; + let mut rem = hi; + std::arch::asm!( + "div {divisor}", + divisor = in(reg) divisor, + inout("rax") quot, + inout("rdx") rem, + options(pure, nomem, nostack) + ); + (quot, rem) + } + #[cfg(not(target_arch = "x86_64"))] + { + let x = (u128::from(hi) << 64) + u128::from(lo); + let y = u128::from(divisor); + ((x / y) as u64, (x % y) as u64) + } +} + +/// Perform `a += b` +fn add_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_add) +} + +/// Perform `a -= b` +fn sub_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_sub) +} + +/// Converts an overflowing binary operation on scalars to one on slices +fn binop_slice( + a: &mut [u64], + b: &[u64], + binop: impl Fn(u64, u64) -> (u64, bool) + Copy, +) -> bool { + let mut c = false; + a.iter_mut().zip(b.iter()).for_each(|(x, y)| { + let (res1, overflow1) = y.overflowing_add(u64::from(c)); + let (res2, overflow2) = binop(*x, res1); + *x = res2; + c = overflow1 || overflow2; + }); + c +} + +/// Widening multiplication of an N-digit array with a u64 +fn full_mul_u64(a: &[u64; N], b: u64) -> ArrayPlusOne { + let mut carry = 0; + let mut out = [0; N]; + out.iter_mut().zip(a).for_each(|(o, v)| { + let r = *v as u128 * b as u128 + carry as u128; + *o = r as u64; + carry = (r >> 64) as u64; + }); + ArrayPlusOne(out, carry) +} + +/// Left shift of an N-digit array by at most 63 bits +fn shl_word(v: &[u64; N], shift: u32) -> [u64; N] { + full_shl(v, shift).0 +} + +/// Widening left shift of an N-digit array by at most 63 bits +fn full_shl(v: &[u64; N], shift: u32) -> ArrayPlusOne { + debug_assert!(shift < 64); + if shift == 0 { + return ArrayPlusOne(*v, 0); + } + let mut out = [0u64; N]; + out[0] = v[0] << shift; + for i in 1..N { + out[i] = v[i - 1] >> (64 - shift) | v[i] << shift + } + let carry = v[N - 1] >> (64 - shift); + ArrayPlusOne(out, carry) +} + +/// Narrowing right shift of an (N+1)-digit array by at most 63 bits +fn full_shr(a: &ArrayPlusOne, shift: u32) -> [u64; N] { + debug_assert!(shift < 64); + if shift == 0 { + return a.0; + } + let mut out = [0; N]; + for i in 0..N - 1 { + out[i] = a[i] >> shift | a[i + 1] << (64 - shift) + } + out[N - 1] = a[N - 1] >> shift; + out +} + +/// An array of N + 1 elements +/// +/// This is a hack around lack of support for const arithmetic +#[repr(C)] +struct ArrayPlusOne([T; N], T); + +impl std::ops::Deref for ArrayPlusOne { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + let x = self as *const Self; + unsafe { std::slice::from_raw_parts(x as *const T, N + 1) } + } +} + +impl std::ops::DerefMut for ArrayPlusOne { + fn deref_mut(&mut self) -> &mut Self::Target { + let x = self as *mut Self; + unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) } + } +} diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint/mod.rs similarity index 92% rename from arrow-buffer/src/bigint.rs rename to arrow-buffer/src/bigint/mod.rs index 86150e67fd91..d064663bf63a 100644 --- a/arrow-buffer/src/bigint.rs +++ b/arrow-buffer/src/bigint/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::bigint::div::div_rem; use num::cast::AsPrimitive; use num::{BigInt, FromPrimitive, ToPrimitive}; use std::cmp::Ordering; @@ -22,6 +23,8 @@ use std::num::ParseIntError; use std::ops::{BitAnd, BitOr, BitXor, Neg, Shl, Shr}; use std::str::FromStr; +mod div; + /// An opaque error similar to [`std::num::ParseIntError`] #[derive(Debug)] pub struct ParseI256Error {} @@ -428,25 +431,6 @@ impl i256 { .then_some(Self { low, high }) } - /// Return the least number of bits needed to represent the number - #[inline] - fn bits_required(&self) -> usize { - let le_bytes = self.to_le_bytes(); - let arr: [u128; 2] = [ - u128::from_le_bytes(le_bytes[0..16].try_into().unwrap()), - u128::from_le_bytes(le_bytes[16..32].try_into().unwrap()), - ]; - - let iter = arr.iter().rev().take(2 - 1); - if self.is_negative() { - let ctr = iter.take_while(|&&b| b == ::core::u128::MAX).count(); - (128 * (2 - ctr)) + 1 - (!arr[2 - ctr - 1]).leading_zeros() as usize - } else { - let ctr = iter.take_while(|&&b| b == ::core::u128::MIN).count(); - (128 * (2 - ctr)) + 1 - arr[2 - ctr - 1].leading_zeros() as usize - } - } - /// Division operation, returns (quotient, remainder). /// This basically implements [Long division]: `` #[inline] @@ -458,41 +442,45 @@ impl i256 { return Err(DivRemError::DivideOverflow); } - if self == Self::MIN || other == Self::MIN { - let l = BigInt::from_signed_bytes_le(&self.to_le_bytes()); - let r = BigInt::from_signed_bytes_le(&other.to_le_bytes()); - let d = i256::from_bigint_with_overflow(&l / &r).0; - let r = i256::from_bigint_with_overflow(&l % &r).0; - return Ok((d, r)); - } - - let mut me = self.checked_abs().unwrap(); - let mut you = other.checked_abs().unwrap(); - let mut ret = [0u128; 2]; - if me < you { - return Ok((Self::from_parts(ret[0], ret[1] as i128), self)); - } + let a = self.wrapping_abs(); + let b = other.wrapping_abs(); - let shift = me.bits_required() - you.bits_required(); - you = you.shl(shift as u8); - for i in (0..=shift).rev() { - if me >= you { - ret[i / 128] |= 1 << (i % 128); - me = me.checked_sub(you).unwrap(); - } - you = you.shr(1); - } + let (div, rem) = div_rem(&a.as_digits(), &b.as_digits()); + let div = Self::from_digits(div); + let rem = Self::from_digits(rem); Ok(( if self.is_negative() == other.is_negative() { - Self::from_parts(ret[0], ret[1] as i128) + div } else { - -Self::from_parts(ret[0], ret[1] as i128) + div.wrapping_neg() + }, + if self.is_negative() { + rem.wrapping_neg() + } else { + rem }, - if self.is_negative() { -me } else { me }, )) } + /// Interpret this [`i256`] as 4 `u64` digits, least significant first + fn as_digits(self) -> [u64; 4] { + [ + self.low as u64, + (self.low >> 64) as u64, + self.high as u64, + (self.high as u128 >> 64) as u64, + ] + } + + /// Interpret 4 `u64` digits, least significant first, as a [`i256`] + fn from_digits(digits: [u64; 4]) -> Self { + Self::from_parts( + digits[0] as u128 | (digits[1] as u128) << 64, + digits[2] as i128 | (digits[3] as i128) << 64, + ) + } + /// Performs wrapping division #[inline] pub fn wrapping_div(self, other: Self) -> Self { @@ -671,6 +659,30 @@ macro_rules! derive_op { self.$wrapping(rhs) } } + + impl<'a> std::ops::$t for &'a i256 { + type Output = i256; + + fn $op(self, rhs: i256) -> Self::Output { + (*self).$op(rhs) + } + } + + impl<'a> std::ops::$t<&'a i256> for i256 { + type Output = i256; + + fn $op(self, rhs: &'a i256) -> Self::Output { + self.$op(*rhs) + } + } + + impl<'a, 'b> std::ops::$t<&'b i256> for &'a i256 { + type Output = i256; + + fn $op(self, rhs: &'b i256) -> Self::Output { + (*self).$op(*rhs) + } + } }; } @@ -969,7 +981,7 @@ mod tests { let expected = bl.clone() % br.clone(); let checked = il.checked_rem(ir); - assert_eq!(actual.to_string(), expected.to_string()); + assert_eq!(actual.to_string(), expected.to_string(), "{il} % {ir}"); if ir == i256::MINUS_ONE && il == i256::MIN { assert!(checked.is_none()); @@ -1206,4 +1218,57 @@ mod tests { assert_eq!(i256::from_string(case), expected) } } + + #[allow(clippy::op_ref)] + fn test_reference_op(il: i256, ir: i256) { + let r1 = il + ir; + let r2 = &il + ir; + let r3 = il + &ir; + let r4 = &il + &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il - ir; + let r2 = &il - ir; + let r3 = il - &ir; + let r4 = &il - &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il * ir; + let r2 = &il * ir; + let r3 = il * &ir; + let r4 = &il * &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il / ir; + let r2 = &il / ir; + let r3 = il / &ir; + let r4 = &il / &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + } + + #[test] + fn test_i256_reference_op() { + let candidates = [ + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(2), + i256::from_i128(-2), + i256::from_i128(3), + i256::from_i128(-3), + ]; + + for il in candidates { + for ir in candidates { + test_reference_op(il, ir) + } + } + } } diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 2ecd3b41913a..bda6dfc5cdee 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -74,7 +74,7 @@ impl Buffer { /// Create a [`Buffer`] from the provided [`Vec`] without copying #[inline] pub fn from_vec(vec: Vec) -> Self { - MutableBuffer::from_vec(vec).into() + MutableBuffer::from(vec).into() } /// Initializes a [Buffer] from a slice of items. @@ -323,6 +323,14 @@ impl Buffer { length, }) } + + /// Returns true if this [`Buffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.ptr == other.ptr && self.length == other.length + } } /// Creating a `Buffer` instance by copying the memory from a `AsRef<[u8]>` into a newly diff --git a/arrow-buffer/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs index 3e66e7f23fa2..2c56f9a5b270 100644 --- a/arrow-buffer/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -112,17 +112,9 @@ impl MutableBuffer { /// Create a [`MutableBuffer`] from the provided [`Vec`] without copying #[inline] + #[deprecated(note = "Use From>")] pub fn from_vec(vec: Vec) -> Self { - // Safety - // Vec::as_ptr guaranteed to not be null and ArrowNativeType are trivially transmutable - let data = unsafe { NonNull::new_unchecked(vec.as_ptr() as _) }; - let len = vec.len() * mem::size_of::(); - // Safety - // Vec guaranteed to have a valid layout matching that of `Layout::array` - // This is based on `RawVec::current_memory` - let layout = unsafe { Layout::array::(vec.capacity()).unwrap_unchecked() }; - mem::forget(vec); - Self { data, len, layout } + Self::from(vec) } /// Allocates a new [MutableBuffer] from given `Bytes`. @@ -168,7 +160,14 @@ impl MutableBuffer { /// `len` of the buffer and so can be used to initialize the memory region from /// `len` to `capacity`. pub fn set_null_bits(&mut self, start: usize, count: usize) { - assert!(start + count <= self.layout.size()); + assert!( + start.saturating_add(count) <= self.layout.size(), + "range start index {start} and count {count} out of bounds for \ + buffer of length {}", + self.layout.size(), + ); + + // Safety: `self.data[start..][..count]` is in-bounds and well-aligned for `u8` unsafe { std::ptr::write_bytes(self.data.as_ptr().add(start), 0, count); } @@ -495,6 +494,21 @@ impl Extend for MutableBuffer { } } +impl From> for MutableBuffer { + fn from(value: Vec) -> Self { + // Safety + // Vec::as_ptr guaranteed to not be null and ArrowNativeType are trivially transmutable + let data = unsafe { NonNull::new_unchecked(value.as_ptr() as _) }; + let len = value.len() * mem::size_of::(); + // Safety + // Vec guaranteed to have a valid layout matching that of `Layout::array` + // This is based on `RawVec::current_memory` + let layout = unsafe { Layout::array::(value.capacity()).unwrap_unchecked() }; + mem::forget(value); + Self { data, len, layout } + } +} + impl MutableBuffer { #[inline] pub(super) fn extend_from_iter>( @@ -643,6 +657,12 @@ impl MutableBuffer { } } +impl Default for MutableBuffer { + fn default() -> Self { + Self::with_capacity(0) + } +} + impl std::ops::Deref for MutableBuffer { type Target = [u8]; @@ -758,6 +778,14 @@ impl std::iter::FromIterator for MutableBuffer { } } +impl std::iter::FromIterator for MutableBuffer { + fn from_iter>(iter: I) -> Self { + let mut buffer = Self::default(); + buffer.extend_from_iter(iter.into_iter()); + buffer + } +} + #[cfg(test)] mod tests { use super::*; @@ -770,6 +798,19 @@ mod tests { assert!(buf.is_empty()); } + #[test] + fn test_mutable_default() { + let buf = MutableBuffer::default(); + assert_eq!(0, buf.capacity()); + assert_eq!(0, buf.len()); + assert!(buf.is_empty()); + + let mut buf = MutableBuffer::default(); + buf.extend_from_slice(b"hello"); + assert_eq!(5, buf.len()); + assert_eq!(b"hello", buf.as_slice()); + } + #[test] fn test_mutable_extend_from_slice() { let mut buf = MutableBuffer::new(100); @@ -932,4 +973,38 @@ mod tests { buffer.shrink_to_fit(); assert!(buffer.capacity() >= 64 && buffer.capacity() < 128); } + + #[test] + fn test_mutable_set_null_bits() { + let mut buffer = MutableBuffer::new(8).with_bitset(8, true); + + for i in 0..=buffer.capacity() { + buffer.set_null_bits(i, 0); + assert_eq!(buffer[..8], [255; 8][..]); + } + + buffer.set_null_bits(1, 4); + assert_eq!(buffer[..8], [255, 0, 0, 0, 0, 255, 255, 255][..]); + } + + #[test] + #[should_panic = "out of bounds for buffer of length"] + fn test_mutable_set_null_bits_oob() { + let mut buffer = MutableBuffer::new(64); + buffer.set_null_bits(1, buffer.capacity()); + } + + #[test] + #[should_panic = "out of bounds for buffer of length"] + fn test_mutable_set_null_bits_oob_by_overflow() { + let mut buffer = MutableBuffer::new(0); + buffer.set_null_bits(1, usize::MAX); + } + + #[test] + fn from_iter() { + let buffer = [1u16, 2, 3, 4].into_iter().collect::(); + assert_eq!(buffer.len(), 4 * mem::size_of::()); + assert_eq!(buffer.as_slice(), &[1, 0, 2, 0, 3, 0, 4, 0]); + } } diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs index 0111d12fbab1..a6f2f7f6cfae 100644 --- a/arrow-buffer/src/buffer/offset.rs +++ b/arrow-buffer/src/buffer/offset.rs @@ -19,7 +19,43 @@ use crate::buffer::ScalarBuffer; use crate::{ArrowNativeType, MutableBuffer}; use std::ops::Deref; -/// A non-empty buffer of monotonically increasing, positive integers +/// A non-empty buffer of monotonically increasing, positive integers. +/// +/// [`OffsetBuffer`] are used to represent ranges of offsets. An +/// `OffsetBuffer` of `N+1` items contains `N` such ranges. The start +/// offset for element `i` is `offsets[i]` and the end offset is +/// `offsets[i+1]`. Equal offsets represent an empty range. +/// +/// # Example +/// +/// This example shows how 5 distinct ranges, are represented using a +/// 6 entry `OffsetBuffer`. The first entry `(0, 3)` represents the +/// three offsets `0, 1, 2`. The entry `(3,3)` represent no offsets +/// (e.g. an empty list). +/// +/// ```text +/// ┌───────┐ ┌───┐ +/// │ (0,3) │ │ 0 │ +/// ├───────┤ ├───┤ +/// │ (3,3) │ │ 3 │ +/// ├───────┤ ├───┤ +/// │ (3,4) │ │ 3 │ +/// ├───────┤ ├───┤ +/// │ (4,5) │ │ 4 │ +/// ├───────┤ ├───┤ +/// │ (5,7) │ │ 5 │ +/// └───────┘ ├───┤ +/// │ 7 │ +/// └───┘ +/// +/// Offsets Buffer +/// Logical +/// Offsets +/// +/// (offsets[i], +/// offsets[i+1]) +/// ``` + #[derive(Debug, Clone)] pub struct OffsetBuffer(ScalarBuffer); @@ -112,6 +148,14 @@ impl OffsetBuffer { pub fn slice(&self, offset: usize, len: usize) -> Self { Self(self.0.slice(offset, len.saturating_add(1))) } + + /// Returns true if this [`OffsetBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.0.ptr_eq(&other.0) + } } impl Deref for OffsetBuffer { diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 70c86f11866d..276e635e825c 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -86,6 +86,14 @@ impl ScalarBuffer { pub fn into_inner(self) -> Buffer { self.buffer } + + /// Returns true if this [`ScalarBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.buffer.ptr_eq(&other.buffer) + } } impl Deref for ScalarBuffer { diff --git a/arrow-buffer/src/builder/boolean.rs b/arrow-buffer/src/builder/boolean.rs index f84cfa79c2dc..f0e7f0f13670 100644 --- a/arrow-buffer/src/builder/boolean.rs +++ b/arrow-buffer/src/builder/boolean.rs @@ -203,6 +203,12 @@ impl BooleanBufferBuilder { ); } + /// Append [`BooleanBuffer`] to this [`BooleanBufferBuilder`] + pub fn append_buffer(&mut self, buffer: &BooleanBuffer) { + let range = buffer.offset()..buffer.offset() + buffer.len(); + self.append_packed_range(range, buffer.values()) + } + /// Returns the packed bits pub fn as_slice(&self) -> &[u8] { self.buffer.as_slice() diff --git a/arrow-buffer/src/builder/mod.rs b/arrow-buffer/src/builder/mod.rs index f9d2d0935300..d5d5a7d3f18d 100644 --- a/arrow-buffer/src/builder/mod.rs +++ b/arrow-buffer/src/builder/mod.rs @@ -21,3 +21,398 @@ mod boolean; pub use boolean::*; mod null; pub use null::*; + +use crate::{ArrowNativeType, Buffer, MutableBuffer}; +use std::{iter, marker::PhantomData}; + +/// Builder for creating a [Buffer] object. +/// +/// A [Buffer] is the underlying data structure of Arrow's Arrays. +/// +/// For all supported types, there are type definitions for the +/// generic version of `BufferBuilder`, e.g. `BufferBuilder`. +/// +/// # Example: +/// +/// ``` +/// # use arrow_buffer::builder::BufferBuilder; +/// +/// let mut builder = BufferBuilder::::new(100); +/// builder.append_slice(&[42, 43, 44]); +/// builder.append(45); +/// let buffer = builder.finish(); +/// +/// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 43, 44, 45]); +/// ``` +#[derive(Debug)] +pub struct BufferBuilder { + buffer: MutableBuffer, + len: usize, + _marker: PhantomData, +} + +impl BufferBuilder { + /// Creates a new builder with initial capacity for _at least_ `capacity` + /// elements of type `T`. + /// + /// The capacity can later be manually adjusted with the + /// [`reserve()`](BufferBuilder::reserve) method. + /// Also the + /// [`append()`](BufferBuilder::append), + /// [`append_slice()`](BufferBuilder::append_slice) and + /// [`advance()`](BufferBuilder::advance) + /// methods automatically increase the capacity if needed. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// + /// assert!(builder.capacity() >= 10); + /// ``` + #[inline] + pub fn new(capacity: usize) -> Self { + let buffer = MutableBuffer::new(capacity * std::mem::size_of::()); + + Self { + buffer, + len: 0, + _marker: PhantomData, + } + } + + /// Creates a new builder from a [`MutableBuffer`] + pub fn new_from_buffer(buffer: MutableBuffer) -> Self { + let buffer_len = buffer.len(); + Self { + buffer, + len: buffer_len / std::mem::size_of::(), + _marker: PhantomData, + } + } + + /// Returns the current number of array elements in the internal buffer. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append(42); + /// + /// assert_eq!(builder.len(), 1); + /// ``` + pub fn len(&self) -> usize { + self.len + } + + /// Returns whether the internal buffer is empty. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append(42); + /// + /// assert_eq!(builder.is_empty(), false); + /// ``` + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the actual capacity (number of elements) of the internal buffer. + /// + /// Note: the internal capacity returned by this method might be larger than + /// what you'd expect after setting the capacity in the `new()` or `reserve()` + /// functions. + pub fn capacity(&self) -> usize { + let byte_capacity = self.buffer.capacity(); + byte_capacity / std::mem::size_of::() + } + + /// Increases the number of elements in the internal buffer by `n` + /// and resizes the buffer as needed. + /// + /// The values of the newly added elements are 0. + /// This method is usually used when appending `NULL` values to the buffer + /// as they still require physical memory space. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.advance(2); + /// + /// assert_eq!(builder.len(), 2); + /// ``` + #[inline] + pub fn advance(&mut self, i: usize) { + self.buffer.extend_zeros(i * std::mem::size_of::()); + self.len += i; + } + + /// Reserves memory for _at least_ `n` more elements of type `T`. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.reserve(10); + /// + /// assert!(builder.capacity() >= 20); + /// ``` + #[inline] + pub fn reserve(&mut self, n: usize) { + self.buffer.reserve(n * std::mem::size_of::()); + } + + /// Appends a value of type `T` into the builder, + /// growing the internal buffer as needed. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append(42); + /// + /// assert_eq!(builder.len(), 1); + /// ``` + #[inline] + pub fn append(&mut self, v: T) { + self.reserve(1); + self.buffer.push(v); + self.len += 1; + } + + /// Appends a value of type `T` into the builder N times, + /// growing the internal buffer as needed. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append_n(10, 42); + /// + /// assert_eq!(builder.len(), 10); + /// ``` + #[inline] + pub fn append_n(&mut self, n: usize, v: T) { + self.reserve(n); + self.extend(iter::repeat(v).take(n)) + } + + /// Appends `n`, zero-initialized values + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append_n_zeroed(3); + /// + /// assert_eq!(builder.len(), 3); + /// assert_eq!(builder.as_slice(), &[0, 0, 0]) + #[inline] + pub fn append_n_zeroed(&mut self, n: usize) { + self.buffer.extend_zeros(n * std::mem::size_of::()); + self.len += n; + } + + /// Appends a slice of type `T`, growing the internal buffer as needed. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append_slice(&[42, 44, 46]); + /// + /// assert_eq!(builder.len(), 3); + /// ``` + #[inline] + pub fn append_slice(&mut self, slice: &[T]) { + self.buffer.extend_from_slice(slice); + self.len += slice.len(); + } + + /// View the contents of this buffer as a slice + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append(1.3); + /// builder.append_n(2, 2.3); + /// + /// assert_eq!(builder.as_slice(), &[1.3, 2.3, 2.3]); + /// ``` + #[inline] + pub fn as_slice(&self) -> &[T] { + // SAFETY + // + // - MutableBuffer is aligned and initialized for len elements of T + // - MutableBuffer corresponds to a single allocation + // - MutableBuffer does not support modification whilst active immutable borrows + unsafe { std::slice::from_raw_parts(self.buffer.as_ptr() as _, self.len) } + } + + /// View the contents of this buffer as a mutable slice + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// + /// builder.append_slice(&[1., 2., 3.4]); + /// assert_eq!(builder.as_slice(), &[1., 2., 3.4]); + /// + /// builder.as_slice_mut()[1] = 4.2; + /// assert_eq!(builder.as_slice(), &[1., 4.2, 3.4]); + /// ``` + #[inline] + pub fn as_slice_mut(&mut self) -> &mut [T] { + // SAFETY + // + // - MutableBuffer is aligned and initialized for len elements of T + // - MutableBuffer corresponds to a single allocation + // - MutableBuffer does not support modification whilst active immutable borrows + unsafe { std::slice::from_raw_parts_mut(self.buffer.as_mut_ptr() as _, self.len) } + } + + /// Shorten this BufferBuilder to `len` items + /// + /// If `len` is greater than the builder's current length, this has no effect + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// + /// builder.append_slice(&[42, 44, 46]); + /// assert_eq!(builder.as_slice(), &[42, 44, 46]); + /// + /// builder.truncate(2); + /// assert_eq!(builder.as_slice(), &[42, 44]); + /// + /// builder.append(12); + /// assert_eq!(builder.as_slice(), &[42, 44, 12]); + /// ``` + #[inline] + pub fn truncate(&mut self, len: usize) { + self.buffer.truncate(len * std::mem::size_of::()); + self.len = len; + } + + /// # Safety + /// This requires the iterator be a trusted length. This could instead require + /// the iterator implement `TrustedLen` once that is stabilized. + #[inline] + pub unsafe fn append_trusted_len_iter(&mut self, iter: impl IntoIterator) { + let iter = iter.into_iter(); + let len = iter + .size_hint() + .1 + .expect("append_trusted_len_iter expects upper bound"); + self.reserve(len); + self.extend(iter); + } + + /// Resets this builder and returns an immutable [Buffer]. + /// + /// # Example: + /// + /// ``` + /// # use arrow_buffer::builder::BufferBuilder; + /// + /// let mut builder = BufferBuilder::::new(10); + /// builder.append_slice(&[42, 44, 46]); + /// + /// let buffer = builder.finish(); + /// + /// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 44, 46]); + /// ``` + #[inline] + pub fn finish(&mut self) -> Buffer { + let buf = std::mem::take(&mut self.buffer); + self.len = 0; + buf.into() + } +} + +impl Default for BufferBuilder { + fn default() -> Self { + Self::new(0) + } +} + +impl Extend for BufferBuilder { + fn extend>(&mut self, iter: I) { + self.buffer.extend(iter.into_iter().inspect(|_| { + self.len += 1; + })) + } +} + +impl From> for BufferBuilder { + fn from(value: Vec) -> Self { + Self::new_from_buffer(MutableBuffer::from(value)) + } +} + +impl FromIterator for BufferBuilder { + fn from_iter>(iter: I) -> Self { + let mut builder = Self::default(); + builder.extend(iter); + builder + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[test] + fn default() { + let builder = BufferBuilder::::default(); + assert!(builder.is_empty()); + assert!(builder.buffer.is_empty()); + assert_eq!(builder.buffer.capacity(), 0); + } + + #[test] + fn from_iter() { + let input = [1u16, 2, 3, 4]; + let builder = input.into_iter().collect::>(); + assert_eq!(builder.len(), 4); + assert_eq!(builder.buffer.len(), 4 * mem::size_of::()); + } + + #[test] + fn extend() { + let input = [1, 2]; + let mut builder = input.into_iter().collect::>(); + assert_eq!(builder.len(), 2); + builder.extend([3, 4]); + assert_eq!(builder.len(), 4); + } +} diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index b3105ed5a3b4..8f5019d5a4cc 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -148,3 +148,31 @@ impl Debug for Bytes { write!(f, " }}") } } + +impl From for Bytes { + fn from(value: bytes::Bytes) -> Self { + Self { + len: value.len(), + ptr: NonNull::new(value.as_ptr() as _).unwrap(), + deallocation: Deallocation::Custom(std::sync::Arc::new(value)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_bytes() { + let bytes = bytes::Bytes::from(vec![1, 2, 3, 4]); + let arrow_bytes: Bytes = bytes.clone().into(); + + assert_eq!(bytes.as_ptr(), arrow_bytes.as_ptr()); + + drop(bytes); + drop(arrow_bytes); + + let _ = Bytes::from(bytes::Bytes::new()); + } +} diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs index 8fe6cf2b7894..38074a8dc26c 100644 --- a/arrow-buffer/src/native.rs +++ b/arrow-buffer/src/native.rs @@ -222,7 +222,7 @@ pub trait ToByteSlice { impl ToByteSlice for [T] { #[inline] fn to_byte_slice(&self) -> &[u8] { - let raw_ptr = self.as_ptr() as *const T as *const u8; + let raw_ptr = self.as_ptr() as *const u8; unsafe { std::slice::from_raw_parts(raw_ptr, std::mem::size_of_val(self)) } } } diff --git a/arrow-buffer/src/util/bit_chunk_iterator.rs b/arrow-buffer/src/util/bit_chunk_iterator.rs index 3d9632e73229..6830acae94a1 100644 --- a/arrow-buffer/src/util/bit_chunk_iterator.rs +++ b/arrow-buffer/src/util/bit_chunk_iterator.rs @@ -157,7 +157,7 @@ impl<'a> UnalignedBitChunk<'a> { self.prefix .into_iter() .chain(self.chunks.iter().cloned()) - .chain(self.suffix.into_iter()) + .chain(self.suffix) } /// Counts the number of ones diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 494ad104b11c..2e0a9fdd4ebd 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -45,7 +45,7 @@ arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } -chrono = { version = "0.4.23", default-features = false, features = ["clock"] } +chrono = { workspace = true } half = { version = "2.1", default-features = false } num = { version = "0.4", default-features = false, features = ["std"] } lexical-core = { version = "^0.8", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } @@ -65,6 +65,10 @@ harness = false name = "parse_time" harness = false +[[bench]] +name = "parse_date" +harness = false + [[bench]] name = "parse_decimal" harness = false diff --git a/arrow-cast/benches/parse_date.rs b/arrow-cast/benches/parse_date.rs new file mode 100644 index 000000000000..e05d38d2f853 --- /dev/null +++ b/arrow-cast/benches/parse_date.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Date32Type; +use arrow_cast::parse::Parser; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let timestamps = ["2020-09-08", "2020-9-8", "2020-09-8", "2020-9-08"]; + + for timestamp in timestamps { + let t = black_box(timestamp); + c.bench_function(t, |b| { + b.iter(|| Date32Type::parse(t).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 95c0a63a3a4e..54c500f1ac41 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -37,19 +37,19 @@ //! assert_eq!(7.0, c.value(2)); //! ``` -use chrono::{NaiveTime, Offset, TimeZone, Timelike, Utc}; +use chrono::{NaiveTime, Offset, TimeZone, Utc}; use std::cmp::Ordering; use std::sync::Arc; -use crate::display::{array_value_to_string, ArrayFormatter, FormatOptions}; +use crate::display::{ArrayFormatter, FormatOptions}; use crate::parse::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, - string_to_datetime, + string_to_datetime, Parser, }; use arrow_array::{ builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *, }; -use arrow_buffer::{i256, ArrowNativeType, Buffer, MutableBuffer, ScalarBuffer}; +use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer}; use arrow_data::ArrayData; use arrow_schema::*; use arrow_select::take::take; @@ -364,21 +364,32 @@ where if cast_options.safe { array - .unary_opt::<_, Decimal128Type>(|v| (mul * v.as_()).round().to_i128()) + .unary_opt::<_, Decimal128Type>(|v| { + (mul * v.as_()).round().to_i128().filter(|v| { + Decimal128Type::validate_decimal_precision(*v, precision).is_ok() + }) + }) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } else { array .try_unary::<_, Decimal128Type, _>(|v| { - (mul * v.as_()).round().to_i128().ok_or_else(|| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - Decimal128Type::PREFIX, - precision, - scale, - v - )) - }) + (mul * v.as_()) + .round() + .to_i128() + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal128Type::validate_decimal_precision(v, precision) + .map(|_| v) + }) })? .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) @@ -398,21 +409,30 @@ where if cast_options.safe { array - .unary_opt::<_, Decimal256Type>(|v| i256::from_f64((v.as_() * mul).round())) + .unary_opt::<_, Decimal256Type>(|v| { + i256::from_f64((v.as_() * mul).round()).filter(|v| { + Decimal256Type::validate_decimal_precision(*v, precision).is_ok() + }) + }) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } else { array .try_unary::<_, Decimal256Type, _>(|v| { - i256::from_f64((v.as_() * mul).round()).ok_or_else(|| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - Decimal256Type::PREFIX, - precision, - scale, - v - )) - }) + i256::from_f64((v.as_() * mul).round()) + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal256Type::validate_decimal_precision(v, precision) + .map(|_| v) + }) })? .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) @@ -447,20 +467,11 @@ fn cast_interval_day_time_to_interval_month_day_nano( } /// Cast the array from interval to duration -fn cast_interval_to_duration>( +fn cast_month_day_nano_to_duration>( array: &dyn Array, cast_options: &CastOptions, ) -> Result { - let array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast interval to IntervalArray of expected type" - .to_string(), - ) - })?; - + let array = array.as_primitive::(); let scale = match D::DATA_TYPE { DataType::Duration(TimeUnit::Second) => 1_000_000_000, DataType::Duration(TimeUnit::Millisecond) => 1_000_000, @@ -470,16 +481,9 @@ fn cast_interval_to_duration>( }; if cast_options.safe { - let iter = array.iter().map(|v| { - v.and_then(|v| { - let v = v / scale; - if v > i64::MAX as i128 { - None - } else { - Some(v as i64) - } - }) - }); + let iter = array + .iter() + .map(|v| v.and_then(|v| (v >> 64 == 0).then_some((v as i64) / scale))); Ok(Arc::new(unsafe { PrimitiveArray::::from_trusted_len_iter(iter) })) @@ -487,17 +491,9 @@ fn cast_interval_to_duration>( let vec = array .iter() .map(|v| { - v.map(|v| { - let v = v / scale; - if v > i64::MAX as i128 { - Err(ArrowError::ComputeError(format!( - "Cannot cast to {:?}. Overflowing on {:?}", - D::DATA_TYPE, - v - ))) - } else { - Ok(v as i64) - } + v.map(|v| match v >> 64 { + 0 => Ok((v as i64) / scale), + _ => Err(ArrowError::ComputeError("Cannot convert interval containing non-zero months or days to duration".to_string())) }) .transpose() }) @@ -646,21 +642,6 @@ where Ok(Arc::new(array)) } -// cast the List array to Utf8 array -macro_rules! cast_list_to_string { - ($ARRAY:expr, $SIZE:ident) => {{ - let mut value_builder: GenericStringBuilder<$SIZE> = GenericStringBuilder::new(); - for i in 0..$ARRAY.len() { - if $ARRAY.is_null(i) { - value_builder.append_null(); - } else { - value_builder.append_value(array_value_to_string($ARRAY, i)?); - } - } - Ok(Arc::new(value_builder.finish())) - }}; -} - fn make_timestamp_array( array: &PrimitiveArray, unit: TimeUnit, @@ -824,8 +805,8 @@ pub fn cast_with_options( } } (List(_) | LargeList(_), _) => match to_type { - Utf8 => cast_list_to_string!(array, i32), - LargeUtf8 => cast_list_to_string!(array, i64), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), _ => Err(ArrowError::CastError( "Cannot cast list to non-list data types".to_string(), )), @@ -849,12 +830,8 @@ pub fn cast_with_options( } } - (_, List(ref to)) => { - cast_primitive_to_list::(array, to, to_type, cast_options) - } - (_, LargeList(ref to)) => { - cast_primitive_to_list::(array, to, to_type, cast_options) - } + (_, List(ref to)) => cast_values_to_list::(array, to, cast_options), + (_, LargeList(ref to)) => cast_values_to_list::(array, to, cast_options), (Decimal128(_, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), @@ -952,8 +929,8 @@ pub fn cast_with_options( x as f64 / 10_f64.powi(*scale as i32) }) } - Utf8 => value_to_string::(array, Some(&cast_options.format_options)), - LargeUtf8 => value_to_string::(array, Some(&cast_options.format_options)), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported" @@ -1021,8 +998,8 @@ pub fn cast_with_options( x.to_f64().unwrap() / 10_f64.powi(*scale as i32) }) } - Utf8 => value_to_string::(array, Some(&cast_options.format_options)), - LargeUtf8 => value_to_string::(array, Some(&cast_options.format_options)), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported" @@ -1243,59 +1220,35 @@ pub fn cast_with_options( Float16 => cast_bool_to_numeric::(array, cast_options), Float32 => cast_bool_to_numeric::(array, cast_options), Float64 => cast_bool_to_numeric::(array, cast_options), - Utf8 => { - let array = array.as_any().downcast_ref::().unwrap(); - Ok(Arc::new( - array - .iter() - .map(|value| value.map(|value| if value { "1" } else { "0" })) - .collect::(), - )) - } - LargeUtf8 => { - let array = array.as_any().downcast_ref::().unwrap(); - Ok(Arc::new( - array - .iter() - .map(|value| value.map(|value| if value { "1" } else { "0" })) - .collect::(), - )) - } + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported", ))), }, (Utf8, _) => match to_type { - UInt8 => cast_string_to_numeric::(array, cast_options), - UInt16 => cast_string_to_numeric::(array, cast_options), - UInt32 => cast_string_to_numeric::(array, cast_options), - UInt64 => cast_string_to_numeric::(array, cast_options), - Int8 => cast_string_to_numeric::(array, cast_options), - Int16 => cast_string_to_numeric::(array, cast_options), - Int32 => cast_string_to_numeric::(array, cast_options), - Int64 => cast_string_to_numeric::(array, cast_options), - Float32 => cast_string_to_numeric::(array, cast_options), - Float64 => cast_string_to_numeric::(array, cast_options), - Date32 => cast_string_to_date32::(array, cast_options), - Date64 => cast_string_to_date64::(array, cast_options), + UInt8 => parse_string::(array, cast_options), + UInt16 => parse_string::(array, cast_options), + UInt32 => parse_string::(array, cast_options), + UInt64 => parse_string::(array, cast_options), + Int8 => parse_string::(array, cast_options), + Int16 => parse_string::(array, cast_options), + Int32 => parse_string::(array, cast_options), + Int64 => parse_string::(array, cast_options), + Float32 => parse_string::(array, cast_options), + Float64 => parse_string::(array, cast_options), + Date32 => parse_string::(array, cast_options), + Date64 => parse_string::(array, cast_options), Binary => Ok(Arc::new(BinaryArray::from(array.as_string::().clone()))), LargeBinary => { let binary = BinaryArray::from(array.as_string::().clone()); cast_byte_container::(&binary) } LargeUtf8 => cast_byte_container::(array), - Time32(TimeUnit::Second) => { - cast_string_to_time32second::(array, cast_options) - } - Time32(TimeUnit::Millisecond) => { - cast_string_to_time32millisecond::(array, cast_options) - } - Time64(TimeUnit::Microsecond) => { - cast_string_to_time64microsecond::(array, cast_options) - } - Time64(TimeUnit::Nanosecond) => { - cast_string_to_time64nanosecond::(array, cast_options) - } + Time32(TimeUnit::Second) => parse_string::(array, cast_options), + Time32(TimeUnit::Millisecond) => parse_string::(array, cast_options), + Time64(TimeUnit::Microsecond) => parse_string::(array, cast_options), + Time64(TimeUnit::Nanosecond) => parse_string::(array, cast_options), Timestamp(TimeUnit::Second, to_tz) => { cast_string_to_timestamp::(array, to_tz, cast_options) } @@ -1322,18 +1275,18 @@ pub fn cast_with_options( ))), }, (LargeUtf8, _) => match to_type { - UInt8 => cast_string_to_numeric::(array, cast_options), - UInt16 => cast_string_to_numeric::(array, cast_options), - UInt32 => cast_string_to_numeric::(array, cast_options), - UInt64 => cast_string_to_numeric::(array, cast_options), - Int8 => cast_string_to_numeric::(array, cast_options), - Int16 => cast_string_to_numeric::(array, cast_options), - Int32 => cast_string_to_numeric::(array, cast_options), - Int64 => cast_string_to_numeric::(array, cast_options), - Float32 => cast_string_to_numeric::(array, cast_options), - Float64 => cast_string_to_numeric::(array, cast_options), - Date32 => cast_string_to_date32::(array, cast_options), - Date64 => cast_string_to_date64::(array, cast_options), + UInt8 => parse_string::(array, cast_options), + UInt16 => parse_string::(array, cast_options), + UInt32 => parse_string::(array, cast_options), + UInt64 => parse_string::(array, cast_options), + Int8 => parse_string::(array, cast_options), + Int16 => parse_string::(array, cast_options), + Int32 => parse_string::(array, cast_options), + Int64 => parse_string::(array, cast_options), + Float32 => parse_string::(array, cast_options), + Float64 => parse_string::(array, cast_options), + Date32 => parse_string::(array, cast_options), + Date64 => parse_string::(array, cast_options), Utf8 => cast_byte_container::(array), Binary => { let large_binary = @@ -1343,18 +1296,10 @@ pub fn cast_with_options( LargeBinary => Ok(Arc::new(LargeBinaryArray::from( array.as_string::().clone(), ))), - Time32(TimeUnit::Second) => { - cast_string_to_time32second::(array, cast_options) - } - Time32(TimeUnit::Millisecond) => { - cast_string_to_time32millisecond::(array, cast_options) - } - Time64(TimeUnit::Microsecond) => { - cast_string_to_time64microsecond::(array, cast_options) - } - Time64(TimeUnit::Nanosecond) => { - cast_string_to_time64nanosecond::(array, cast_options) - } + Time32(TimeUnit::Second) => parse_string::(array, cast_options), + Time32(TimeUnit::Millisecond) => parse_string::(array, cast_options), + Time64(TimeUnit::Microsecond) => parse_string::(array, cast_options), + Time64(TimeUnit::Nanosecond) => parse_string::(array, cast_options), Timestamp(TimeUnit::Second, to_tz) => { cast_string_to_timestamp::(array, to_tz, cast_options) } @@ -1418,8 +1363,8 @@ pub fn cast_with_options( "Casting from {from_type:?} to {to_type:?} not supported", ))), }, - (from_type, LargeUtf8) if from_type.is_primitive() => value_to_string::(array, Some(&cast_options.format_options)), - (from_type, Utf8) if from_type.is_primitive() => value_to_string::(array, Some(&cast_options.format_options)), + (from_type, LargeUtf8) if from_type.is_primitive() => value_to_string::(array, cast_options), + (from_type, Utf8) if from_type.is_primitive() => value_to_string::(array, cast_options), // start numeric casts (UInt8, UInt16) => { cast_numeric_arrays::(array, cast_options) @@ -2194,16 +2139,16 @@ pub fn cast_with_options( cast_duration_to_interval::(array, cast_options) } (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Second)) => { - cast_interval_to_duration::(array, cast_options) + cast_month_day_nano_to_duration::(array, cast_options) } (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Millisecond)) => { - cast_interval_to_duration::(array, cast_options) + cast_month_day_nano_to_duration::(array, cast_options) } (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Microsecond)) => { - cast_interval_to_duration::(array, cast_options) + cast_month_day_nano_to_duration::(array, cast_options) } (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Nanosecond)) => { - cast_interval_to_duration::(array, cast_options) + cast_month_day_nano_to_duration::(array, cast_options) } (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::MonthDayNano)) => { cast_interval_year_month_to_interval_month_day_nano(array, cast_options) @@ -2505,14 +2450,10 @@ where fn value_to_string( array: &dyn Array, - options: Option<&FormatOptions>, + options: &CastOptions, ) -> Result { let mut builder = GenericStringBuilder::::new(); - let mut fmt_options = &FormatOptions::default(); - if let Some(fmt_opts) = options { - fmt_options = fmt_opts; - }; - let formatter = ArrayFormatter::try_new(array, fmt_options)?; + let formatter = ArrayFormatter::try_new(array, &options.format_options)?; let nulls = array.nulls(); for i in 0..array.len() { match nulls.map(|x| x.is_null(i)).unwrap_or_default() { @@ -2527,422 +2468,35 @@ fn value_to_string( Ok(Arc::new(builder.finish())) } -/// Cast numeric types to Utf8 -fn cast_string_to_numeric( - from: &dyn Array, - cast_options: &CastOptions, -) -> Result -where - T: ArrowPrimitiveType, - ::Native: lexical_core::FromLexical, -{ - Ok(Arc::new(string_to_numeric_cast::( - from.as_any() - .downcast_ref::>() - .unwrap(), - cast_options, - )?)) -} - -fn string_to_numeric_cast( - from: &GenericStringArray, - cast_options: &CastOptions, -) -> Result, ArrowError> -where - T: ArrowPrimitiveType, - ::Native: lexical_core::FromLexical, -{ - if cast_options.safe { - let iter = from - .iter() - .map(|v| v.and_then(|v| lexical_core::parse(v.as_bytes()).ok())); - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - Ok(unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }) - } else { - let vec = from - .iter() - .map(|v| { - v.map(|v| { - lexical_core::parse(v.as_bytes()).map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - T::DATA_TYPE, - )) - }) - }) - .transpose() - }) - .collect::, _>>()?; - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - Ok(unsafe { PrimitiveArray::::from_trusted_len_iter(vec.iter()) }) - } -} - -/// Casts generic string arrays to Date32Array -fn cast_string_to_date32( - array: &dyn Array, - cast_options: &CastOptions, -) -> Result { - use chrono::Datelike; - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - let array = if cast_options.safe { - let iter = string_array.iter().map(|v| { - v.and_then(|v| { - v.parse::() - .map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) - .ok() - }) - }); - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Date32Array::from_trusted_len_iter(iter) } - } else { - let vec = string_array - .iter() - .map(|v| { - v.map(|v| { - v.parse::() - .map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - DataType::Date32 - )) - }) - }) - .transpose() - }) - .collect::>, _>>()?; - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Date32Array::from_trusted_len_iter(vec.iter()) } - }; - - Ok(Arc::new(array) as ArrayRef) -} - -/// Casts generic string arrays to Date64Array -fn cast_string_to_date64( - array: &dyn Array, - cast_options: &CastOptions, -) -> Result { - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - let array = if cast_options.safe { - let iter = string_array.iter().map(|v| { - v.and_then(|v| { - v.parse::() - .map(|datetime| datetime.timestamp_millis()) - .ok() - }) - }); - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Date64Array::from_trusted_len_iter(iter) } - } else { - let vec = string_array - .iter() - .map(|v| { - v.map(|v| { - v.parse::() - .map(|datetime| datetime.timestamp_millis()) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - DataType::Date64 - )) - }) - }) - .transpose() - }) - .collect::>, _>>()?; - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Date64Array::from_trusted_len_iter(vec.iter()) } - }; - - Ok(Arc::new(array) as ArrayRef) -} - -/// Casts generic string arrays to `Time32SecondArray` -fn cast_string_to_time32second( - array: &dyn Array, - cast_options: &CastOptions, -) -> Result { - /// The number of nanoseconds per millisecond. - const NANOS_PER_SEC: u32 = 1_000_000_000; - - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - let array = if cast_options.safe { - let iter = string_array.iter().map(|v| { - v.and_then(|v| { - v.parse::() - .map(|time| { - (time.num_seconds_from_midnight() - + time.nanosecond() / NANOS_PER_SEC) - as i32 - }) - .ok() - }) - }); - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time32SecondArray::from_trusted_len_iter(iter) } - } else { - let vec = string_array - .iter() - .map(|v| { - v.map(|v| { - v.parse::() - .map(|time| { - (time.num_seconds_from_midnight() - + time.nanosecond() / NANOS_PER_SEC) - as i32 - }) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - DataType::Time32(TimeUnit::Second) - )) - }) - }) - .transpose() - }) - .collect::>, _>>()?; - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time32SecondArray::from_trusted_len_iter(vec.iter()) } - }; - - Ok(Arc::new(array) as ArrayRef) -} - -/// Casts generic string arrays to `Time32MillisecondArray` -fn cast_string_to_time32millisecond( - array: &dyn Array, - cast_options: &CastOptions, -) -> Result { - /// The number of nanoseconds per millisecond. - const NANOS_PER_MILLI: u32 = 1_000_000; - /// The number of milliseconds per second. - const MILLIS_PER_SEC: u32 = 1_000; - - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - let array = if cast_options.safe { - let iter = string_array.iter().map(|v| { - v.and_then(|v| { - v.parse::() - .map(|time| { - (time.num_seconds_from_midnight() * MILLIS_PER_SEC - + time.nanosecond() / NANOS_PER_MILLI) - as i32 - }) - .ok() - }) - }); - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time32MillisecondArray::from_trusted_len_iter(iter) } - } else { - let vec = string_array - .iter() - .map(|v| { - v.map(|v| { - v.parse::() - .map(|time| { - (time.num_seconds_from_midnight() * MILLIS_PER_SEC - + time.nanosecond() / NANOS_PER_MILLI) - as i32 - }) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - DataType::Time32(TimeUnit::Millisecond) - )) - }) - }) - .transpose() - }) - .collect::>, _>>()?; - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time32MillisecondArray::from_trusted_len_iter(vec.iter()) } - }; - - Ok(Arc::new(array) as ArrayRef) -} - -/// Casts generic string arrays to `Time64MicrosecondArray` -fn cast_string_to_time64microsecond( +/// Parse UTF-8 +fn parse_string( array: &dyn Array, cast_options: &CastOptions, ) -> Result { - /// The number of nanoseconds per microsecond. - const NANOS_PER_MICRO: i64 = 1_000; - /// The number of microseconds per second. - const MICROS_PER_SEC: i64 = 1_000_000; - - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - + let string_array = array.as_string::(); let array = if cast_options.safe { - let iter = string_array.iter().map(|v| { - v.and_then(|v| { - v.parse::() - .map(|time| { - time.num_seconds_from_midnight() as i64 * MICROS_PER_SEC - + time.nanosecond() as i64 / NANOS_PER_MICRO - }) - .ok() - }) - }); + let iter = string_array.iter().map(|x| x.and_then(P::parse)); // Benefit: // 20% performance improvement // Soundness: // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time64MicrosecondArray::from_trusted_len_iter(iter) } + unsafe { PrimitiveArray::

::from_trusted_len_iter(iter) } } else { - let vec = string_array + let v = string_array .iter() - .map(|v| { - v.map(|v| { - v.parse::() - .map(|time| { - time.num_seconds_from_midnight() as i64 * MICROS_PER_SEC - + time.nanosecond() as i64 / NANOS_PER_MICRO - }) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - DataType::Time64(TimeUnit::Microsecond) - )) - }) - }) - .transpose() - }) - .collect::>, _>>()?; - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time64MicrosecondArray::from_trusted_len_iter(vec.iter()) } - }; - - Ok(Arc::new(array) as ArrayRef) -} - -/// Casts generic string arrays to `Time64NanosecondArray` -fn cast_string_to_time64nanosecond( - array: &dyn Array, - cast_options: &CastOptions, -) -> Result { - /// The number of nanoseconds per second. - const NANOS_PER_SEC: i64 = 1_000_000_000; - - let string_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - let array = if cast_options.safe { - let iter = string_array.iter().map(|v| { - v.and_then(|v| { - v.parse::() - .map(|time| { - time.num_seconds_from_midnight() as i64 * NANOS_PER_SEC - + time.nanosecond() as i64 - }) - .ok() - }) - }); - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time64NanosecondArray::from_trusted_len_iter(iter) } - } else { - let vec = string_array - .iter() - .map(|v| { - v.map(|v| { - v.parse::() - .map(|time| { - time.num_seconds_from_midnight() as i64 * NANOS_PER_SEC - + time.nanosecond() as i64 - }) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - DataType::Time64(TimeUnit::Nanosecond) - )) - }) - }) - .transpose() + .map(|x| match x { + Some(v) => P::parse(v).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + P::DATA_TYPE + )) + }), + None => Ok(P::Native::default()), }) - .collect::>, _>>()?; - - // Benefit: - // 20% performance improvement - // Soundness: - // The iterator is trustedLen because it comes from an `StringArray`. - unsafe { Time64NanosecondArray::from_trusted_len_iter(vec.iter()) } + .collect::, ArrowError>>()?; + PrimitiveArray::new(v.into(), string_array.nulls().cloned()) }; Ok(Arc::new(array) as ArrayRef) @@ -3267,6 +2821,11 @@ where if cast_options.safe { let iter = from.iter().map(|v| { v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) + .and_then(|v| { + T::validate_decimal_precision(v, precision) + .is_ok() + .then_some(v) + }) }); // Benefit: // 20% performance improvement @@ -3281,13 +2840,17 @@ where .iter() .map(|v| { v.map(|v| { - parse_string_to_decimal_native::(v, scale as usize).map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - T::DATA_TYPE, - )) - }) + parse_string_to_decimal_native::(v, scale as usize) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + T::DATA_TYPE, + )) + }) + .and_then(|v| { + T::validate_decimal_precision(v, precision).map(|_| v) + }) }) .transpose() }) @@ -3493,19 +3056,7 @@ where { let dict_array = array.as_dictionary::(); let cast_dict_values = cast_with_options(dict_array.values(), to_type, cast_options)?; - let keys = dict_array.keys(); - match K::DATA_TYPE { - DataType::Int32 => { - // Dictionary guarantees all non-null keys >= 0 - let buffer = ScalarBuffer::new(keys.values().inner().clone(), 0, keys.len()); - let indices = PrimitiveArray::new(buffer, keys.nulls().cloned()); - take::(cast_dict_values.as_ref(), &indices, None) - } - _ => { - let indices = cast_with_options(keys, &DataType::UInt32, cast_options)?; - take::(cast_dict_values.as_ref(), indices.as_primitive(), None) - } - } + take(cast_dict_values.as_ref(), dict_array.keys(), None) } /// Attempts to encode an array into an `ArrayDictionary` with index @@ -3645,39 +3196,15 @@ where } /// Helper function that takes a primitive array and casts to a (generic) list array. -fn cast_primitive_to_list( +fn cast_values_to_list( array: &dyn Array, - to: &Field, - to_type: &DataType, + to: &FieldRef, cast_options: &CastOptions, ) -> Result { - // cast primitive to list's primitive - let cast_array = cast_with_options(array, to.data_type(), cast_options)?; - // create offsets, where if array.len() = 2, we have [0,1,2] - // Safety: - // Length of range can be trusted. - // Note: could not yet create a generic range in stable Rust. - let offsets = unsafe { - MutableBuffer::from_trusted_len_iter( - (0..=array.len()).map(|i| OffsetSize::from(i).expect("integer")), - ) - }; - - let list_data = unsafe { - ArrayData::new_unchecked( - to_type.clone(), - array.len(), - Some(cast_array.null_count()), - cast_array.nulls().map(|b| b.inner().sliced()), - 0, - vec![offsets.into()], - vec![cast_array.into_data()], - ) - }; - let list_array = - Arc::new(GenericListArray::::from(list_data)) as ArrayRef; - - Ok(list_array) + let values = cast_with_options(array, to.data_type(), cast_options)?; + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(1).take(values.len())); + let list = GenericListArray::::new(to.clone(), offsets, values, None); + Ok(Arc::new(list)) } /// Helper function that takes an Generic list container and casts the inner datatype. @@ -3930,50 +3457,24 @@ mod tests { macro_rules! generate_cast_test_case { ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { + let output = $OUTPUT_TYPE_ARRAY::from($OUTPUT_VALUES) + .with_data_type($OUTPUT_TYPE.clone()); + // assert cast type let input_array_type = $INPUT_ARRAY.data_type(); assert!(can_cast_types(input_array_type, $OUTPUT_TYPE)); - let casted_array = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap(); - let result_array = casted_array - .as_any() - .downcast_ref::<$OUTPUT_TYPE_ARRAY>() - .unwrap(); - assert_eq!($OUTPUT_TYPE, result_array.data_type()); - assert_eq!(result_array.len(), $OUTPUT_VALUES.len()); - for (i, x) in $OUTPUT_VALUES.iter().enumerate() { - match x { - Some(x) => { - assert!(!result_array.is_null(i)); - assert_eq!(result_array.value(i), *x); - } - None => { - assert!(result_array.is_null(i)); - } - } - } + let result = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap(); + assert_eq!($OUTPUT_TYPE, result.data_type()); + assert_eq!(result.as_ref(), &output); let cast_option = CastOptions { safe: false, format_options: FormatOptions::default(), }; - let casted_array_with_option = + let result = cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); - let result_array = casted_array_with_option - .as_any() - .downcast_ref::<$OUTPUT_TYPE_ARRAY>() - .unwrap(); - assert_eq!($OUTPUT_TYPE, result_array.data_type()); - assert_eq!(result_array.len(), $OUTPUT_VALUES.len()); - for (i, x) in $OUTPUT_VALUES.iter().enumerate() { - match x { - Some(x) => { - assert_eq!(result_array.value(i), *x); - } - None => { - assert!(result_array.is_null(i)); - } - } - } + assert_eq!($OUTPUT_TYPE, result.data_type()); + assert_eq!(result.as_ref(), &output); }; } @@ -5098,7 +4599,7 @@ mod tests { ) .unwrap(); assert_eq!(5, b.len()); - assert_eq!(1, b.null_count()); + assert_eq!(0, b.null_count()); let arr = b.as_list::(); assert_eq!(&[0, 1, 2, 3, 4, 5], arr.value_offsets()); assert_eq!(1, arr.value_length(0)); @@ -5127,7 +4628,7 @@ mod tests { ) .unwrap(); assert_eq!(4, b.len()); - assert_eq!(1, b.null_count()); + assert_eq!(0, b.null_count()); let arr = b.as_list::(); assert_eq!(&[0, 1, 2, 3, 4], arr.value_offsets()); assert_eq!(1, arr.value_length(0)); @@ -5218,6 +4719,26 @@ mod tests { assert!(!c.is_valid(2)); } + #[test] + fn test_cast_bool_to_utf8() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_bool_to_large_utf8() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::LargeUtf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + #[test] fn test_cast_bool_to_f64() { let array = BooleanArray::from(vec![Some(true), Some(false), None]); @@ -6479,7 +6000,7 @@ mod tests { #[test] fn test_str_to_str_casts() { - for data in vec![ + for data in [ vec![Some("foo"), Some("bar"), Some("ham")], vec![Some("foo"), None, Some("bar")], ] { @@ -7703,9 +7224,8 @@ mod tests { assert_eq!(array.data_type(), &data_type); let cast_array = cast(&array, &DataType::Null).expect("cast failed"); assert_eq!(cast_array.data_type(), &DataType::Null); - for i in 0..4 { - assert!(cast_array.is_null(i)); - } + assert_eq!(cast_array.len(), 4); + assert_eq!(cast_array.logical_nulls().unwrap().null_count(), 4); } #[test] @@ -7804,14 +7324,10 @@ mod tests { /// Print the `DictionaryArray` `array` as a vector of strings fn array_to_strings(array: &ArrayRef) -> Vec { + let options = FormatOptions::new().with_null("null"); + let formatter = ArrayFormatter::try_new(array.as_ref(), &options).unwrap(); (0..array.len()) - .map(|i| { - if array.is_null(i) { - "null".to_string() - } else { - array_value_to_string(array, i).expect("Convert array to String") - } - }) + .map(|i| formatter.value(i).to_string()) .collect() } @@ -7874,13 +7390,13 @@ mod tests { assert_eq!(946728000000, c.value(0)); assert!(c.is_valid(1)); // "2020-12-15T12:34:56" assert_eq!(1608035696000, c.value(1)); - assert!(c.is_valid(2)); // "2020-2-2T12:34:56" - assert_eq!(1580646896000, c.value(2)); + assert!(!c.is_valid(2)); // "2020-2-2T12:34:56" - // test invalid inputs assert!(!c.is_valid(3)); // "2000-00-00T12:00:00" - assert!(!c.is_valid(4)); // "2000-01-01 12:00:00" - assert!(!c.is_valid(5)); // "2000-01-01" + assert!(c.is_valid(4)); // "2000-01-01 12:00:00" + assert_eq!(946728000000, c.value(4)); + assert!(c.is_valid(5)); // "2000-01-01" + assert_eq!(946684800000, c.value(5)); } #[test] @@ -8252,6 +7768,68 @@ mod tests { assert!(casted_array.is_err()); } + #[test] + fn test_cast_floating_point_to_decimal128_precision_overflow() { + let array = Float64Array::from(vec![1.1]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(2, 2), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(2, 2), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Invalid argument error: 110 is too large to store in a Decimal128 of precision 2. Max is 99"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_floating_point_to_decimal256_precision_overflow() { + let array = Float64Array::from(vec![1.1]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(2, 2), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(2, 2), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Invalid argument error: 110 is too large to store in a Decimal256 of precision 2. Max is 99"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + #[test] fn test_cast_floating_point_to_decimal128_overflow() { let array = Float64Array::from(vec![f64::MAX]); @@ -8665,6 +8243,32 @@ mod tests { ); } + #[test] + fn test_cast_string_to_decimal128_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + #[test] fn test_cast_utf8_to_decimal128_overflow() { let overflow_str_array = StringArray::from(vec![ @@ -8722,6 +8326,32 @@ mod tests { assert!(decimal_arr.is_null(6)); } + #[test] + fn test_cast_string_to_decimal256_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + #[test] fn test_cast_utf8_to_decimal256_overflow() { let overflow_str_array = StringArray::from(vec![ @@ -9055,29 +8685,16 @@ mod tests { } /// helper function to test casting from duration to interval - fn cast_from_duration_to_interval( + fn cast_from_duration_to_interval>( array: Vec, cast_options: &CastOptions, - ) -> Result, ArrowError> - where - arrow_array::PrimitiveArray: From>, - { - let array = PrimitiveArray::::from(array); + ) -> Result, ArrowError> { + let array = PrimitiveArray::::new(array.into(), None); let array = Arc::new(array) as ArrayRef; - let casted_array = cast_with_options( - &array, - &DataType::Interval(IntervalUnit::MonthDayNano), - cast_options, - )?; - casted_array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Failed to downcast to IntervalMonthDayNanoArray".to_string(), - ) - }) - .cloned() + let interval = DataType::Interval(IntervalUnit::MonthDayNano); + let out = cast_with_options(&array, &interval, cast_options)?; + let out = out.as_primitive::().clone(); + Ok(out) } #[test] @@ -9199,11 +8816,9 @@ mod tests { /// helper function to test casting from interval to duration fn cast_from_interval_to_duration( - array: Vec, + array: &IntervalMonthDayNanoArray, cast_options: &CastOptions, ) -> Result, ArrowError> { - let array = IntervalMonthDayNanoArray::from(array); - let array = Arc::new(array) as ArrayRef; let casted_array = cast_with_options(&array, &T::DATA_TYPE, cast_options)?; casted_array .as_any() @@ -9219,125 +8834,89 @@ mod tests { #[test] fn test_cast_from_interval_to_duration() { + let nullable = CastOptions::default(); + let fallible = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + // from interval month day nano to duration second - let array = vec![1234567]; - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions::default(), - ) - .unwrap(); - assert_eq!( - casted_array.data_type(), - &DataType::Duration(TimeUnit::Second) - ); + let array = vec![1234567].into(); + let casted_array: DurationSecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); assert_eq!(casted_array.value(0), 0); - let array = vec![i128::MAX]; - let casted_array = cast_from_interval_to_duration::( - array.clone(), - &CastOptions::default(), - ) - .unwrap(); + let array = vec![i128::MAX].into(); + let casted_array: DurationSecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); assert!(!casted_array.is_valid(0)); - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions { - safe: false, - format_options: FormatOptions::default(), - }, - ); - assert!(casted_array.is_err()); + let res = cast_from_interval_to_duration::(&array, &fallible); + assert!(res.is_err()); // from interval month day nano to duration millisecond - let array = vec![1234567]; - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions::default(), - ) - .unwrap(); + let array = vec![1234567].into(); + let casted_array: DurationMillisecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); assert_eq!(casted_array.value(0), 1); - let array = vec![i128::MAX]; - let casted_array = cast_from_interval_to_duration::( - array.clone(), - &CastOptions::default(), - ) - .unwrap(); + let array = vec![i128::MAX].into(); + let casted_array: DurationMillisecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); assert!(!casted_array.is_valid(0)); - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions { - safe: false, - format_options: FormatOptions::default(), - }, - ); - assert!(casted_array.is_err()); + let res = + cast_from_interval_to_duration::(&array, &fallible); + assert!(res.is_err()); // from interval month day nano to duration microsecond - let array = vec![1234567]; - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions::default(), - ) - .unwrap(); - assert_eq!( - casted_array.data_type(), - &DataType::Duration(TimeUnit::Microsecond) - ); + let array = vec![1234567].into(); + let casted_array: DurationMicrosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); assert_eq!(casted_array.value(0), 1234); - let array = vec![i128::MAX]; - let casted_array = cast_from_interval_to_duration::( - array.clone(), - &CastOptions::default(), - ) - .unwrap(); + let array = vec![i128::MAX].into(); + let casted_array = + cast_from_interval_to_duration::(&array, &nullable) + .unwrap(); assert!(!casted_array.is_valid(0)); - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions { - safe: false, - format_options: FormatOptions::default(), - }, - ); + let casted_array = + cast_from_interval_to_duration::(&array, &fallible); assert!(casted_array.is_err()); // from interval month day nano to duration nanosecond - let array = vec![1234567]; - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions::default(), - ) - .unwrap(); - assert_eq!( - casted_array.data_type(), - &DataType::Duration(TimeUnit::Nanosecond) - ); + let array = vec![1234567].into(); + let casted_array: DurationNanosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); assert_eq!(casted_array.value(0), 1234567); - let array = vec![i128::MAX]; - let casted_array = cast_from_interval_to_duration::( - array.clone(), - &CastOptions::default(), - ) - .unwrap(); - assert_eq!( - casted_array.data_type(), - &DataType::Duration(TimeUnit::Nanosecond) - ); + let array = vec![i128::MAX].into(); + let casted_array: DurationNanosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); assert!(!casted_array.is_valid(0)); - let casted_array = cast_from_interval_to_duration::( - array, - &CastOptions { - safe: false, - format_options: FormatOptions::default(), - }, - ); + let casted_array = + cast_from_interval_to_duration::(&array, &fallible); assert!(casted_array.is_err()); + + let array = vec![ + IntervalMonthDayNanoType::make_value(0, 1, 0), + IntervalMonthDayNanoType::make_value(-1, 0, 0), + IntervalMonthDayNanoType::make_value(1, 1, 0), + IntervalMonthDayNanoType::make_value(1, 0, 1), + IntervalMonthDayNanoType::make_value(0, 0, -1), + ] + .into(); + let casted_array = + cast_from_interval_to_duration::(&array, &nullable) + .unwrap(); + assert!(!casted_array.is_valid(0)); + assert!(!casted_array.is_valid(1)); + assert!(!casted_array.is_valid(2)); + assert!(!casted_array.is_valid(3)); + assert!(casted_array.is_valid(4)); + assert_eq!(casted_array.value(4), -1); } /// helper function to test casting from interval year month to interval month day nano @@ -9448,4 +9027,47 @@ mod tests { assert_eq!("1969-12-31", string_array.value(1)); assert_eq!("1989-12-31", string_array.value(2)); } + + #[test] + fn test_nested_list() { + let mut list = ListBuilder::new(Int32Builder::new()); + list.append_value([Some(1), Some(2), Some(3)]); + list.append_value([Some(4), None, Some(6)]); + let list = list.finish(); + + let to_field = Field::new("nested", list.data_type().clone(), false); + let to = DataType::List(Arc::new(to_field)); + let out = cast(&list, &to).unwrap(); + let opts = FormatOptions::default().with_null("null"); + let formatted = ArrayFormatter::try_new(out.as_ref(), &opts).unwrap(); + + assert_eq!(formatted.value(0).to_string(), "[[1], [2], [3]]"); + assert_eq!(formatted.value(1).to_string(), "[[4], [null], [6]]"); + } + + const CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: true, + format_options: FormatOptions::new(), + }; + + #[test] + #[allow(clippy::assertions_on_constants)] + fn test_const_options() { + assert!(CAST_OPTIONS.safe) + } + + #[test] + fn test_list_format_options() { + let options = CastOptions { + safe: false, + format_options: FormatOptions::default().with_null("null"), + }; + let array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(0), None, Some(2)]), + ]); + let a = cast_with_options(&array, &DataType::Utf8, &options).unwrap(); + let r: Vec<_> = a.as_string::().iter().map(|x| x.unwrap()).collect(); + assert_eq!(r, &["[0, 1, 2]", "[0, null, 2]"]); + } } diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs index 07e78f8984f9..246135e114bc 100644 --- a/arrow-cast/src/display.rs +++ b/arrow-cast/src/display.rs @@ -34,6 +34,16 @@ use lexical_core::FormattedSize; type TimeFormat<'a> = Option<&'a str>; +/// Format for displaying durations +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum DurationFormat { + /// ISO 8601 - `P198DT72932.972880S` + ISO8601, + /// A human readable representation - `198 days 16 hours 34 mins 15.407810000 secs` + Pretty, +} + /// Options for formatting arrays /// /// By default nulls are formatted as `""` and temporal types formatted @@ -56,10 +66,18 @@ pub struct FormatOptions<'a> { timestamp_tz_format: TimeFormat<'a>, /// Time format for time arrays time_format: TimeFormat<'a>, + /// Duration format + duration_format: DurationFormat, } impl<'a> Default for FormatOptions<'a> { fn default() -> Self { + Self::new() + } +} + +impl<'a> FormatOptions<'a> { + pub const fn new() -> Self { Self { safe: true, null: "", @@ -68,14 +86,13 @@ impl<'a> Default for FormatOptions<'a> { timestamp_format: None, timestamp_tz_format: None, time_format: None, + duration_format: DurationFormat::ISO8601, } } -} -impl<'a> FormatOptions<'a> { /// If set to `true` any formatting errors will be written to the output /// instead of being converted into a [`std::fmt::Error`] - pub fn with_display_error(mut self, safe: bool) -> Self { + pub const fn with_display_error(mut self, safe: bool) -> Self { self.safe = safe; self } @@ -83,12 +100,12 @@ impl<'a> FormatOptions<'a> { /// Overrides the string used to represent a null /// /// Defaults to `""` - pub fn with_null(self, null: &'a str) -> Self { + pub const fn with_null(self, null: &'a str) -> Self { Self { null, ..self } } /// Overrides the format used for [`DataType::Date32`] columns - pub fn with_date_format(self, date_format: Option<&'a str>) -> Self { + pub const fn with_date_format(self, date_format: Option<&'a str>) -> Self { Self { date_format, ..self @@ -96,7 +113,7 @@ impl<'a> FormatOptions<'a> { } /// Overrides the format used for [`DataType::Date64`] columns - pub fn with_datetime_format(self, datetime_format: Option<&'a str>) -> Self { + pub const fn with_datetime_format(self, datetime_format: Option<&'a str>) -> Self { Self { datetime_format, ..self @@ -104,7 +121,7 @@ impl<'a> FormatOptions<'a> { } /// Overrides the format used for [`DataType::Timestamp`] columns without a timezone - pub fn with_timestamp_format(self, timestamp_format: Option<&'a str>) -> Self { + pub const fn with_timestamp_format(self, timestamp_format: Option<&'a str>) -> Self { Self { timestamp_format, ..self @@ -112,7 +129,10 @@ impl<'a> FormatOptions<'a> { } /// Overrides the format used for [`DataType::Timestamp`] columns with a timezone - pub fn with_timestamp_tz_format(self, timestamp_tz_format: Option<&'a str>) -> Self { + pub const fn with_timestamp_tz_format( + self, + timestamp_tz_format: Option<&'a str>, + ) -> Self { Self { timestamp_tz_format, ..self @@ -120,12 +140,22 @@ impl<'a> FormatOptions<'a> { } /// Overrides the format used for [`DataType::Time32`] and [`DataType::Time64`] columns - pub fn with_time_format(self, time_format: Option<&'a str>) -> Self { + pub const fn with_time_format(self, time_format: Option<&'a str>) -> Self { Self { time_format, ..self } } + + /// Overrides the format used for duration columns + /// + /// Defaults to [`DurationFormat::ISO8601`] + pub const fn with_duration_format(self, duration_format: DurationFormat) -> Self { + Self { + duration_format, + ..self + } + } } /// Implements [`Display`] for a specific array value @@ -369,8 +399,15 @@ impl<'a> DisplayIndex for &'a BooleanArray { } } -impl<'a> DisplayIndex for &'a NullArray { - fn write(&self, _idx: usize, _f: &mut dyn Write) -> FormatResult { +impl<'a> DisplayIndexState<'a> for &'a NullArray { + type State = &'a str; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok(options.null) + } + + fn write(&self, state: &Self::State, _idx: usize, f: &mut dyn Write) -> FormatResult { + f.write_str(state)?; Ok(()) } } @@ -534,20 +571,84 @@ temporal_display!(time64us_to_time, time_format, Time64MicrosecondType); temporal_display!(time64ns_to_time, time_format, Time64NanosecondType); macro_rules! duration_display { - ($convert:ident, $t:ty) => { - impl<'a> DisplayIndex for &'a PrimitiveArray<$t> { - fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { - write!(f, "{}", $convert(self.value(idx)))?; + ($convert:ident, $t:ty, $scale:tt) => { + impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = DurationFormat; + + fn prepare( + &self, + options: &FormatOptions<'a>, + ) -> Result { + Ok(options.duration_format) + } + + fn write( + &self, + fmt: &Self::State, + idx: usize, + f: &mut dyn Write, + ) -> FormatResult { + let v = self.value(idx); + match fmt { + DurationFormat::ISO8601 => write!(f, "{}", $convert(v))?, + DurationFormat::Pretty => duration_fmt!(f, v, $scale)?, + } Ok(()) } } }; } -duration_display!(duration_s_to_duration, DurationSecondType); -duration_display!(duration_ms_to_duration, DurationMillisecondType); -duration_display!(duration_us_to_duration, DurationMicrosecondType); -duration_display!(duration_ns_to_duration, DurationNanosecondType); +macro_rules! duration_fmt { + ($f:ident, $v:expr, 0) => {{ + let secs = $v; + let mins = secs / 60; + let hours = mins / 60; + let days = hours / 24; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + let hours = hours - (days * 24); + write!($f, "{days} days {hours} hours {mins} mins {secs} secs") + }}; + ($f:ident, $v:expr, $scale:tt) => {{ + let subsec = $v; + let secs = subsec / 10_i64.pow($scale); + let mins = secs / 60; + let hours = mins / 60; + let days = hours / 24; + + let subsec = subsec - (secs * 10_i64.pow($scale)); + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + let hours = hours - (days * 24); + match subsec.is_negative() { + true => { + write!( + $f, + concat!("{} days {} hours {} mins -{}.{:0", $scale, "} secs"), + days, + hours, + mins, + secs.abs(), + subsec.abs() + ) + } + false => { + write!( + $f, + concat!("{} days {} hours {} mins {}.{:0", $scale, "} secs"), + days, hours, mins, secs, subsec + ) + } + } + }}; +} + +duration_display!(duration_s_to_duration, DurationSecondType, 0); +duration_display!(duration_ms_to_duration, DurationMillisecondType, 3); +duration_display!(duration_us_to_duration, DurationMicrosecondType, 6); +duration_display!(duration_ns_to_duration, DurationNanosecondType, 9); impl<'a> DisplayIndex for &'a PrimitiveArray { fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { @@ -866,8 +967,18 @@ pub fn lexical_to_string(n: N) -> String { mod tests { use super::*; + /// Test to verify options can be constant. See #4580 + const TEST_CONST_OPTIONS: FormatOptions<'static> = FormatOptions::new() + .with_date_format(Some("foo")) + .with_timestamp_format(Some("404")); + #[test] - fn test_map_arry_to_string() { + fn test_const_options() { + assert_eq!(TEST_CONST_OPTIONS.date_format, Some("foo")); + } + + #[test] + fn test_map_array_to_string() { let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]); @@ -887,25 +998,119 @@ mod tests { ); } + fn format_array(array: &dyn Array, fmt: &FormatOptions) -> Vec { + let fmt = ArrayFormatter::try_new(array, fmt).unwrap(); + (0..array.len()).map(|x| fmt.value(x).to_string()).collect() + } + #[test] fn test_array_value_to_string_duration() { - let ns_array = DurationNanosecondArray::from(vec![Some(1), None]); - assert_eq!( - array_value_to_string(&ns_array, 0).unwrap(), - "PT0.000000001S" - ); - assert_eq!(array_value_to_string(&ns_array, 1).unwrap(), ""); - - let us_array = DurationMicrosecondArray::from(vec![Some(1), None]); - assert_eq!(array_value_to_string(&us_array, 0).unwrap(), "PT0.000001S"); - assert_eq!(array_value_to_string(&us_array, 1).unwrap(), ""); - - let ms_array = DurationMillisecondArray::from(vec![Some(1), None]); - assert_eq!(array_value_to_string(&ms_array, 0).unwrap(), "PT0.001S"); - assert_eq!(array_value_to_string(&ms_array, 1).unwrap(), ""); + let iso_fmt = FormatOptions::new(); + let pretty_fmt = + FormatOptions::new().with_duration_format(DurationFormat::Pretty); + + let array = DurationNanosecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000_000 + 123456789, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000_000 - 123456789, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.000000001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.000000001 secs"); + assert_eq!(iso[1], "-PT0.000000001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.000000001 secs"); + assert_eq!(iso[2], "PT0.000001S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 0.000001000 secs"); + assert_eq!(iso[3], "-PT0.000001S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -0.000001000 secs"); + assert_eq!(iso[4], "P45DT50554.123456789S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123456789 secs"); + assert_eq!(iso[5], "-P45DT50554.123456789S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123456789 secs"); + + let array = DurationMicrosecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000 + 123456, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000 - 123456, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.000001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.000001 secs"); + assert_eq!(iso[1], "-PT0.000001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.000001 secs"); + assert_eq!(iso[2], "PT0.001S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 0.001000 secs"); + assert_eq!(iso[3], "-PT0.001S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -0.001000 secs"); + assert_eq!(iso[4], "P45DT50554.123456S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123456 secs"); + assert_eq!(iso[5], "-P45DT50554.123456S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123456 secs"); + + let array = DurationMillisecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000 + 123, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000 - 123, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.001 secs"); + assert_eq!(iso[1], "-PT0.001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.001 secs"); + assert_eq!(iso[2], "PT1S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 1.000 secs"); + assert_eq!(iso[3], "-PT1S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -1.000 secs"); + assert_eq!(iso[4], "P45DT50554.123S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123 secs"); + assert_eq!(iso[5], "-P45DT50554.123S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123 secs"); + + let array = DurationSecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + 45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34, + -45 * 60 * 60 * 24 - 14 * 60 * 60 - 2 * 60 - 34, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT1S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 1 secs"); + assert_eq!(iso[1], "-PT1S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -1 secs"); + assert_eq!(iso[2], "PT1000S"); + assert_eq!(pretty[2], "0 days 0 hours 16 mins 40 secs"); + assert_eq!(iso[3], "-PT1000S"); + assert_eq!(pretty[3], "0 days 0 hours -16 mins -40 secs"); + assert_eq!(iso[4], "P45DT50554S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34 secs"); + assert_eq!(iso[5], "-P45DT50554S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34 secs"); + } - let s_array = DurationSecondArray::from(vec![Some(1), None]); - assert_eq!(array_value_to_string(&s_array, 0).unwrap(), "PT1S"); - assert_eq!(array_value_to_string(&s_array, 1).unwrap(), ""); + #[test] + fn test_null() { + let array = NullArray::new(2); + let options = FormatOptions::new().with_null("NULL"); + let formatted = format_array(&array, &options); + assert_eq!(formatted, &["NULL".to_string(), "NULL".to_string()]) } } diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index 67477c57d519..3806f0adc5d6 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -1,2232 +1,2307 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow_array::timezone::Tz; -use arrow_array::types::*; -use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; -use arrow_buffer::ArrowNativeType; -use arrow_schema::ArrowError; -use chrono::prelude::*; -use half::f16; -use std::str::FromStr; - -/// Parse nanoseconds from the first `N` values in digits, subtracting the offset `O` -#[inline] -fn parse_nanos(digits: &[u8]) -> u32 { - digits[..N] - .iter() - .fold(0_u32, |acc, v| acc * 10 + v.wrapping_sub(O) as u32) - * 10_u32.pow((9 - N) as _) -} - -/// Helper for parsing timestamps -struct TimestampParser { - /// The timestamp bytes to parse minus `b'0'` - /// - /// This makes interpretation as an integer inexpensive - digits: [u8; 32], - /// A mask containing a `1` bit where the corresponding byte is a valid ASCII digit - mask: u32, -} - -impl TimestampParser { - fn new(bytes: &[u8]) -> Self { - let mut digits = [0; 32]; - let mut mask = 0; - - // Treating all bytes the same way, helps LLVM vectorise this correctly - for (idx, (o, i)) in digits.iter_mut().zip(bytes).enumerate() { - *o = i.wrapping_sub(b'0'); - mask |= ((*o < 10) as u32) << idx - } - - Self { digits, mask } - } - - /// Returns true if the byte at `idx` in the original string equals `b` - fn test(&self, idx: usize, b: u8) -> bool { - self.digits[idx] == b.wrapping_sub(b'0') - } - - /// Parses a date of the form `1997-01-31` - fn date(&self) -> Option { - if self.mask & 0b1111111111 != 0b1101101111 - || !self.test(4, b'-') - || !self.test(7, b'-') - { - return None; - } - - let year = self.digits[0] as u16 * 1000 - + self.digits[1] as u16 * 100 - + self.digits[2] as u16 * 10 - + self.digits[3] as u16; - - let month = self.digits[5] * 10 + self.digits[6]; - let day = self.digits[8] * 10 + self.digits[9]; - - NaiveDate::from_ymd_opt(year as _, month as _, day as _) - } - - /// Parses a time of any of forms - /// - `09:26:56` - /// - `09:26:56.123` - /// - `09:26:56.123456` - /// - `09:26:56.123456789` - /// - `092656` - /// - /// Returning the end byte offset - fn time(&self) -> Option<(NaiveTime, usize)> { - // Make a NaiveTime handling leap seconds - let time = |hour, min, sec, nano| match sec { - 60 => { - let nano = 1_000_000_000 + nano; - NaiveTime::from_hms_nano_opt(hour as _, min as _, 59, nano) - } - _ => NaiveTime::from_hms_nano_opt(hour as _, min as _, sec as _, nano), - }; - - match (self.mask >> 11) & 0b11111111 { - // 09:26:56 - 0b11011011 if self.test(13, b':') && self.test(16, b':') => { - let hour = self.digits[11] * 10 + self.digits[12]; - let minute = self.digits[14] * 10 + self.digits[15]; - let second = self.digits[17] * 10 + self.digits[18]; - - match self.test(19, b'.') { - true => { - let digits = (self.mask >> 20).trailing_ones(); - let nanos = match digits { - 0 => return None, - 1 => parse_nanos::<1, 0>(&self.digits[20..21]), - 2 => parse_nanos::<2, 0>(&self.digits[20..22]), - 3 => parse_nanos::<3, 0>(&self.digits[20..23]), - 4 => parse_nanos::<4, 0>(&self.digits[20..24]), - 5 => parse_nanos::<5, 0>(&self.digits[20..25]), - 6 => parse_nanos::<6, 0>(&self.digits[20..26]), - 7 => parse_nanos::<7, 0>(&self.digits[20..27]), - 8 => parse_nanos::<8, 0>(&self.digits[20..28]), - _ => parse_nanos::<9, 0>(&self.digits[20..29]), - }; - Some((time(hour, minute, second, nanos)?, 20 + digits as usize)) - } - false => Some((time(hour, minute, second, 0)?, 19)), - } - } - // 092656 - 0b111111 => { - let hour = self.digits[11] * 10 + self.digits[12]; - let minute = self.digits[13] * 10 + self.digits[14]; - let second = self.digits[15] * 10 + self.digits[16]; - let time = time(hour, minute, second, 0)?; - Some((time, 17)) - } - _ => None, - } - } -} - -/// Accepts a string and parses it relative to the provided `timezone` -/// -/// In addition to RFC3339 / ISO8601 standard timestamps, it also -/// accepts strings that use a space ` ` to separate the date and time -/// as well as strings that have no explicit timezone offset. -/// -/// Examples of accepted inputs: -/// * `1997-01-31T09:26:56.123Z` # RCF3339 -/// * `1997-01-31T09:26:56.123-05:00` # RCF3339 -/// * `1997-01-31 09:26:56.123-05:00` # close to RCF3339 but with a space rather than T -/// * `2023-01-01 04:05:06.789 -08` # close to RCF3339, no fractional seconds or time separator -/// * `1997-01-31T09:26:56.123` # close to RCF3339 but no timezone offset specified -/// * `1997-01-31 09:26:56.123` # close to RCF3339 but uses a space and no timezone offset -/// * `1997-01-31 09:26:56` # close to RCF3339, no fractional seconds -/// * `1997-01-31 092656` # close to RCF3339, no fractional seconds -/// * `1997-01-31 092656+04:00` # close to RCF3339, no fractional seconds or time separator -/// * `1997-01-31` # close to RCF3339, only date no time -/// -/// [IANA timezones] are only supported if the `arrow-array/chrono-tz` feature is enabled -/// -/// * `2023-01-01 040506 America/Los_Angeles` -/// -/// If a timestamp is ambiguous, for example as a result of daylight-savings time, an error -/// will be returned -/// -/// Some formats supported by PostgresSql -/// are not supported, like -/// -/// * "2023-01-01 04:05:06.789 +07:30:00", -/// * "2023-01-01 040506 +07:30:00", -/// * "2023-01-01 04:05:06.789 PST", -/// -/// [IANA timezones]: https://www.iana.org/time-zones -pub fn string_to_datetime( - timezone: &T, - s: &str, -) -> Result, ArrowError> { - let err = |ctx: &str| { - ArrowError::ParseError(format!("Error parsing timestamp from '{s}': {ctx}")) - }; - - let bytes = s.as_bytes(); - if bytes.len() < 10 { - return Err(err("timestamp must contain at least 10 characters")); - } - - let parser = TimestampParser::new(bytes); - let date = parser.date().ok_or_else(|| err("error parsing date"))?; - if bytes.len() == 10 { - let offset = timezone.offset_from_local_date(&date); - let offset = offset - .single() - .ok_or_else(|| err("error computing timezone offset"))?; - - let time = NaiveTime::from_hms_opt(0, 0, 0).unwrap(); - return Ok(DateTime::from_local(date.and_time(time), offset)); - } - - if !parser.test(10, b'T') && !parser.test(10, b't') && !parser.test(10, b' ') { - return Err(err("invalid timestamp separator")); - } - - let (time, mut tz_offset) = parser.time().ok_or_else(|| err("error parsing time"))?; - let datetime = date.and_time(time); - - if tz_offset == 32 { - // Decimal overrun - while tz_offset < bytes.len() && bytes[tz_offset].is_ascii_digit() { - tz_offset += 1; - } - } - - if bytes.len() <= tz_offset { - let offset = timezone.offset_from_local_datetime(&datetime); - let offset = offset - .single() - .ok_or_else(|| err("error computing timezone offset"))?; - return Ok(DateTime::from_local(datetime, offset)); - } - - if bytes[tz_offset] == b'z' || bytes[tz_offset] == b'Z' { - let offset = timezone.offset_from_local_datetime(&datetime); - let offset = offset - .single() - .ok_or_else(|| err("error computing timezone offset"))?; - return Ok(DateTime::from_utc(datetime, offset)); - } - - // Parse remainder of string as timezone - let parsed_tz: Tz = s[tz_offset..].trim_start().parse()?; - let offset = parsed_tz.offset_from_local_datetime(&datetime); - let offset = offset - .single() - .ok_or_else(|| err("error computing timezone offset"))?; - Ok(DateTime::::from_local(datetime, offset).with_timezone(timezone)) -} - -/// Accepts a string in RFC3339 / ISO8601 standard format and some -/// variants and converts it to a nanosecond precision timestamp. -/// -/// See [`string_to_datetime`] for the full set of supported formats -/// -/// Implements the `to_timestamp` function to convert a string to a -/// timestamp, following the model of spark SQL’s to_`timestamp`. -/// -/// Internally, this function uses the `chrono` library for the -/// datetime parsing -/// -/// We hope to extend this function in the future with a second -/// parameter to specifying the format string. -/// -/// ## Timestamp Precision -/// -/// Function uses the maximum precision timestamps supported by -/// Arrow (nanoseconds stored as a 64-bit integer) timestamps. This -/// means the range of dates that timestamps can represent is ~1677 AD -/// to 2262 AM -/// -/// ## Timezone / Offset Handling -/// -/// Numerical values of timestamps are stored compared to offset UTC. -/// -/// This function interprets string without an explicit time zone as timestamps -/// relative to UTC, see [`string_to_datetime`] for alternative semantics -/// -/// In particular: -/// -/// ``` -/// # use arrow_cast::parse::string_to_timestamp_nanos; -/// // Note all three of these timestamps are parsed as the same value -/// let a = string_to_timestamp_nanos("1997-01-31 09:26:56.123Z").unwrap(); -/// let b = string_to_timestamp_nanos("1997-01-31T09:26:56.123").unwrap(); -/// let c = string_to_timestamp_nanos("1997-01-31T14:26:56.123+05:00").unwrap(); -/// -/// assert_eq!(a, b); -/// assert_eq!(b, c); -/// ``` -/// -#[inline] -pub fn string_to_timestamp_nanos(s: &str) -> Result { - to_timestamp_nanos(string_to_datetime(&Utc, s)?.naive_utc()) -} - -/// Defensive check to prevent chrono-rs panics when nanosecond conversion happens on non-supported dates -#[inline] -fn to_timestamp_nanos(dt: NaiveDateTime) -> Result { - if dt.timestamp().checked_mul(1_000_000_000).is_none() { - return Err(ArrowError::ParseError( - ERR_NANOSECONDS_NOT_SUPPORTED.to_string(), - )); - } - - Ok(dt.timestamp_nanos()) -} - -/// Accepts a string in ISO8601 standard format and some -/// variants and converts it to nanoseconds since midnight. -/// -/// Examples of accepted inputs: -/// * `09:26:56.123 AM` -/// * `23:59:59` -/// * `6:00 pm` -// -/// Internally, this function uses the `chrono` library for the -/// time parsing -/// -/// ## Timezone / Offset Handling -/// -/// This function does not support parsing strings with a timezone -/// or offset specified, as it considers only time since midnight. -pub fn string_to_time_nanoseconds(s: &str) -> Result { - let nt = string_to_time(s).ok_or_else(|| { - ArrowError::ParseError(format!("Failed to parse \'{s}\' as time")) - })?; - Ok(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) -} - -fn string_to_time(s: &str) -> Option { - let bytes = s.as_bytes(); - if bytes.len() < 4 { - return None; - } - - let (am, bytes) = match bytes.get(bytes.len() - 3..) { - Some(b" AM" | b" am" | b" Am" | b" aM") => { - (Some(true), &bytes[..bytes.len() - 3]) - } - Some(b" PM" | b" pm" | b" pM" | b" Pm") => { - (Some(false), &bytes[..bytes.len() - 3]) - } - _ => (None, bytes), - }; - - if bytes.len() < 4 { - return None; - } - - let mut digits = [b'0'; 6]; - - // Extract hour - let bytes = match (bytes[1], bytes[2]) { - (b':', _) => { - digits[1] = bytes[0]; - &bytes[2..] - } - (_, b':') => { - digits[0] = bytes[0]; - digits[1] = bytes[1]; - &bytes[3..] - } - _ => return None, - }; - - if bytes.len() < 2 { - return None; // Minutes required - } - - // Extract minutes - digits[2] = bytes[0]; - digits[3] = bytes[1]; - - let nanoseconds = match bytes.get(2) { - Some(b':') => { - if bytes.len() < 5 { - return None; - } - - // Extract seconds - digits[4] = bytes[3]; - digits[5] = bytes[4]; - - // Extract sub-seconds if any - match bytes.get(5) { - Some(b'.') => { - let decimal = &bytes[6..]; - if decimal.iter().any(|x| !x.is_ascii_digit()) { - return None; - } - match decimal.len() { - 0 => return None, - 1 => parse_nanos::<1, b'0'>(decimal), - 2 => parse_nanos::<2, b'0'>(decimal), - 3 => parse_nanos::<3, b'0'>(decimal), - 4 => parse_nanos::<4, b'0'>(decimal), - 5 => parse_nanos::<5, b'0'>(decimal), - 6 => parse_nanos::<6, b'0'>(decimal), - 7 => parse_nanos::<7, b'0'>(decimal), - 8 => parse_nanos::<8, b'0'>(decimal), - _ => parse_nanos::<9, b'0'>(decimal), - } - } - Some(_) => return None, - None => 0, - } - } - Some(_) => return None, - None => 0, - }; - - digits.iter_mut().for_each(|x| *x = x.wrapping_sub(b'0')); - if digits.iter().any(|x| *x > 9) { - return None; - } - - let hour = match (digits[0] * 10 + digits[1], am) { - (12, Some(true)) => 0, // 12:00 AM -> 00:00 - (h @ 1..=11, Some(true)) => h, // 1:00 AM -> 01:00 - (12, Some(false)) => 12, // 12:00 PM -> 12:00 - (h @ 1..=11, Some(false)) => h + 12, // 1:00 PM -> 13:00 - (_, Some(_)) => return None, - (h, None) => h, - }; - - // Handle leap second - let (second, nanoseconds) = match digits[4] * 10 + digits[5] { - 60 => (59, nanoseconds + 1_000_000_000), - s => (s, nanoseconds), - }; - - NaiveTime::from_hms_nano_opt( - hour as _, - (digits[2] * 10 + digits[3]) as _, - second as _, - nanoseconds, - ) -} - -/// Specialized parsing implementations -/// used by csv and json reader -pub trait Parser: ArrowPrimitiveType { - fn parse(string: &str) -> Option; - - fn parse_formatted(string: &str, _format: &str) -> Option { - Self::parse(string) - } -} - -impl Parser for Float16Type { - fn parse(string: &str) -> Option { - lexical_core::parse(string.as_bytes()) - .ok() - .map(f16::from_f32) - } -} - -impl Parser for Float32Type { - fn parse(string: &str) -> Option { - lexical_core::parse(string.as_bytes()).ok() - } -} - -impl Parser for Float64Type { - fn parse(string: &str) -> Option { - lexical_core::parse(string.as_bytes()).ok() - } -} - -macro_rules! parser_primitive { - ($t:ty) => { - impl Parser for $t { - fn parse(string: &str) -> Option { - lexical_core::parse::(string.as_bytes()).ok() - } - } - }; -} -parser_primitive!(UInt64Type); -parser_primitive!(UInt32Type); -parser_primitive!(UInt16Type); -parser_primitive!(UInt8Type); -parser_primitive!(Int64Type); -parser_primitive!(Int32Type); -parser_primitive!(Int16Type); -parser_primitive!(Int8Type); - -impl Parser for TimestampNanosecondType { - fn parse(string: &str) -> Option { - string_to_timestamp_nanos(string).ok() - } -} - -impl Parser for TimestampMicrosecondType { - fn parse(string: &str) -> Option { - let nanos = string_to_timestamp_nanos(string).ok(); - nanos.map(|x| x / 1000) - } -} - -impl Parser for TimestampMillisecondType { - fn parse(string: &str) -> Option { - let nanos = string_to_timestamp_nanos(string).ok(); - nanos.map(|x| x / 1_000_000) - } -} - -impl Parser for TimestampSecondType { - fn parse(string: &str) -> Option { - let nanos = string_to_timestamp_nanos(string).ok(); - nanos.map(|x| x / 1_000_000_000) - } -} - -impl Parser for Time64NanosecondType { - // Will truncate any fractions of a nanosecond - fn parse(string: &str) -> Option { - string_to_time_nanoseconds(string) - .ok() - .or_else(|| string.parse::().ok()) - } - - fn parse_formatted(string: &str, format: &str) -> Option { - let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i64 * 1_000_000_000 - + nt.nanosecond() as i64, - ) - } -} - -impl Parser for Time64MicrosecondType { - // Will truncate any fractions of a microsecond - fn parse(string: &str) -> Option { - string_to_time_nanoseconds(string) - .ok() - .map(|nanos| nanos / 1_000) - .or_else(|| string.parse::().ok()) - } - - fn parse_formatted(string: &str, format: &str) -> Option { - let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i64 * 1_000_000 - + nt.nanosecond() as i64 / 1_000, - ) - } -} - -impl Parser for Time32MillisecondType { - // Will truncate any fractions of a millisecond - fn parse(string: &str) -> Option { - string_to_time_nanoseconds(string) - .ok() - .map(|nanos| (nanos / 1_000_000) as i32) - .or_else(|| string.parse::().ok()) - } - - fn parse_formatted(string: &str, format: &str) -> Option { - let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i32 * 1_000 - + nt.nanosecond() as i32 / 1_000_000, - ) - } -} - -impl Parser for Time32SecondType { - // Will truncate any fractions of a second - fn parse(string: &str) -> Option { - string_to_time_nanoseconds(string) - .ok() - .map(|nanos| (nanos / 1_000_000_000) as i32) - .or_else(|| string.parse::().ok()) - } - - fn parse_formatted(string: &str, format: &str) -> Option { - let nt = NaiveTime::parse_from_str(string, format).ok()?; - Some( - nt.num_seconds_from_midnight() as i32 - + nt.nanosecond() as i32 / 1_000_000_000, - ) - } -} - -/// Number of days between 0001-01-01 and 1970-01-01 -const EPOCH_DAYS_FROM_CE: i32 = 719_163; - -/// Error message if nanosecond conversion request beyond supported interval -const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; - -impl Parser for Date32Type { - fn parse(string: &str) -> Option { - let parser = TimestampParser::new(string.as_bytes()); - let date = parser.date()?; - Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) - } - - fn parse_formatted(string: &str, format: &str) -> Option { - let date = NaiveDate::parse_from_str(string, format).ok()?; - Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) - } -} - -impl Parser for Date64Type { - fn parse(string: &str) -> Option { - let date_time = string_to_datetime(&Utc, string).ok()?; - Some(date_time.timestamp_millis()) - } - - fn parse_formatted(string: &str, format: &str) -> Option { - use chrono::format::Fixed; - use chrono::format::StrftimeItems; - let fmt = StrftimeItems::new(format); - let has_zone = fmt.into_iter().any(|item| match item { - chrono::format::Item::Fixed(fixed_item) => matches!( - fixed_item, - Fixed::RFC2822 - | Fixed::RFC3339 - | Fixed::TimezoneName - | Fixed::TimezoneOffsetColon - | Fixed::TimezoneOffsetColonZ - | Fixed::TimezoneOffset - | Fixed::TimezoneOffsetZ - ), - _ => false, - }); - if has_zone { - let date_time = chrono::DateTime::parse_from_str(string, format).ok()?; - Some(date_time.timestamp_millis()) - } else { - let date_time = NaiveDateTime::parse_from_str(string, format).ok()?; - Some(date_time.timestamp_millis()) - } - } -} - -/// Parse the string format decimal value to i128/i256 format and checking the precision and scale. -/// The result value can't be out of bounds. -pub fn parse_decimal( - s: &str, - precision: u8, - scale: i8, -) -> Result { - let mut result = T::Native::usize_as(0); - let mut fractionals = 0; - let mut digits = 0; - let base = T::Native::usize_as(10); - - let bs = s.as_bytes(); - let (bs, negative) = match bs.first() { - Some(b'-') => (&bs[1..], true), - Some(b'+') => (&bs[1..], false), - _ => (bs, false), - }; - - if bs.is_empty() { - return Err(ArrowError::ParseError(format!( - "can't parse the string value {s} to decimal" - ))); - } - - let mut bs = bs.iter(); - // Overflow checks are not required if 10^(precision - 1) <= T::MAX holds. - // Thus, if we validate the precision correctly, we can skip overflow checks. - while let Some(b) = bs.next() { - match b { - b'0'..=b'9' => { - if digits == 0 && *b == b'0' { - // Ignore leading zeros. - continue; - } - digits += 1; - result = result.mul_wrapping(base); - result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); - } - b'.' => { - for b in bs.by_ref() { - if !b.is_ascii_digit() { - return Err(ArrowError::ParseError(format!( - "can't parse the string value {s} to decimal" - ))); - } - if fractionals == scale { - // We have processed all the digits that we need. All that - // is left is to validate that the rest of the string contains - // valid digits. - continue; - } - fractionals += 1; - digits += 1; - result = result.mul_wrapping(base); - result = - result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); - } - - // Fail on "." - if digits == 0 { - return Err(ArrowError::ParseError(format!( - "can't parse the string value {s} to decimal" - ))); - } - } - _ => { - return Err(ArrowError::ParseError(format!( - "can't parse the string value {s} to decimal" - ))); - } - } - } - - if fractionals < scale { - let exp = scale - fractionals; - if exp as u8 + digits > precision { - return Err(ArrowError::ParseError("parse decimal overflow".to_string())); - } - let mul = base.pow_wrapping(exp as _); - result = result.mul_wrapping(mul); - } else if digits > precision { - return Err(ArrowError::ParseError("parse decimal overflow".to_string())); - } - - Ok(if negative { - result.neg_wrapping() - } else { - result - }) -} - -pub fn parse_interval_year_month( - value: &str, -) -> Result<::Native, ArrowError> { - let config = IntervalParseConfig::new(IntervalUnit::Year); - let interval = Interval::parse(value, &config)?; - - let months = interval.to_year_months().map_err(|_| ArrowError::CastError(format!( - "Cannot cast {value} to IntervalYearMonth. Only year and month fields are allowed." - )))?; - - Ok(IntervalYearMonthType::make_value(0, months)) -} - -pub fn parse_interval_day_time( - value: &str, -) -> Result<::Native, ArrowError> { - let config = IntervalParseConfig::new(IntervalUnit::Day); - let interval = Interval::parse(value, &config)?; - - let (days, millis) = interval.to_day_time().map_err(|_| ArrowError::CastError(format!( - "Cannot cast {value} to IntervalDayTime because the nanos part isn't multiple of milliseconds" - )))?; - - Ok(IntervalDayTimeType::make_value(days, millis)) -} - -pub fn parse_interval_month_day_nano( - value: &str, -) -> Result<::Native, ArrowError> { - let config = IntervalParseConfig::new(IntervalUnit::Month); - let interval = Interval::parse(value, &config)?; - - let (months, days, nanos) = interval.to_month_day_nanos(); - - Ok(IntervalMonthDayNanoType::make_value(months, days, nanos)) -} - -const NANOS_PER_MILLIS: i64 = 1_000_000; -const NANOS_PER_SECOND: i64 = 1_000 * NANOS_PER_MILLIS; -const NANOS_PER_MINUTE: i64 = 60 * NANOS_PER_SECOND; -const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MINUTE; -#[cfg(test)] -const NANOS_PER_DAY: i64 = 24 * NANOS_PER_HOUR; - -#[rustfmt::skip] -#[derive(Clone, Copy)] -#[repr(u16)] -enum IntervalUnit { - Century = 0b_0000_0000_0001, - Decade = 0b_0000_0000_0010, - Year = 0b_0000_0000_0100, - Month = 0b_0000_0000_1000, - Week = 0b_0000_0001_0000, - Day = 0b_0000_0010_0000, - Hour = 0b_0000_0100_0000, - Minute = 0b_0000_1000_0000, - Second = 0b_0001_0000_0000, - Millisecond = 0b_0010_0000_0000, - Microsecond = 0b_0100_0000_0000, - Nanosecond = 0b_1000_0000_0000, -} - -impl FromStr for IntervalUnit { - type Err = ArrowError; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "century" | "centuries" => Ok(Self::Century), - "decade" | "decades" => Ok(Self::Decade), - "year" | "years" => Ok(Self::Year), - "month" | "months" => Ok(Self::Month), - "week" | "weeks" => Ok(Self::Week), - "day" | "days" => Ok(Self::Day), - "hour" | "hours" => Ok(Self::Hour), - "minute" | "minutes" => Ok(Self::Minute), - "second" | "seconds" => Ok(Self::Second), - "millisecond" | "milliseconds" => Ok(Self::Millisecond), - "microsecond" | "microseconds" => Ok(Self::Microsecond), - "nanosecond" | "nanoseconds" => Ok(Self::Nanosecond), - _ => Err(ArrowError::NotYetImplemented(format!( - "Unknown interval type: {s}" - ))), - } - } -} - -pub type MonthDayNano = (i32, i32, i64); - -/// Chosen based on the number of decimal digits in 1 week in nanoseconds -const INTERVAL_PRECISION: u32 = 15; - -#[derive(Clone, Copy, Debug, PartialEq)] -struct IntervalAmount { - /// The integer component of the interval amount - integer: i64, - /// The fractional component multiplied by 10^INTERVAL_PRECISION - frac: i64, -} - -#[cfg(test)] -impl IntervalAmount { - fn new(integer: i64, frac: i64) -> Self { - Self { integer, frac } - } -} - -impl FromStr for IntervalAmount { - type Err = ArrowError; - - fn from_str(s: &str) -> Result { - match s.split_once('.') { - Some((integer, frac)) - if frac.len() <= INTERVAL_PRECISION as usize - && !frac.is_empty() - && !frac.starts_with('-') => - { - // integer will be "" for values like ".5" - // and "-" for values like "-.5" - let explicit_neg = integer.starts_with('-'); - let integer = if integer.is_empty() || integer == "-" { - Ok(0) - } else { - integer.parse::().map_err(|_| { - ArrowError::ParseError(format!( - "Failed to parse {s} as interval amount" - )) - }) - }?; - - let frac_unscaled = frac.parse::().map_err(|_| { - ArrowError::ParseError(format!( - "Failed to parse {s} as interval amount" - )) - })?; - - // scale fractional part by interval precision - let frac = - frac_unscaled * 10_i64.pow(INTERVAL_PRECISION - frac.len() as u32); - - // propagate the sign of the integer part to the fractional part - let frac = if integer < 0 || explicit_neg { - -frac - } else { - frac - }; - - let result = Self { integer, frac }; - - Ok(result) - } - Some((_, frac)) if frac.starts_with('-') => Err(ArrowError::ParseError( - format!("Failed to parse {s} as interval amount"), - )), - Some((_, frac)) if frac.len() > INTERVAL_PRECISION as usize => { - Err(ArrowError::ParseError(format!( - "{s} exceeds the precision available for interval amount" - ))) - } - Some(_) | None => { - let integer = s.parse::().map_err(|_| { - ArrowError::ParseError(format!( - "Failed to parse {s} as interval amount" - )) - })?; - - let result = Self { integer, frac: 0 }; - Ok(result) - } - } - } -} - -#[derive(Debug, Default, PartialEq)] -struct Interval { - months: i32, - days: i32, - nanos: i64, -} - -impl Interval { - fn new(months: i32, days: i32, nanos: i64) -> Self { - Self { - months, - days, - nanos, - } - } - - fn to_year_months(&self) -> Result { - match (self.months, self.days, self.nanos) { - (months, days, nanos) if days == 0 && nanos == 0 => Ok(months), - _ => Err(ArrowError::InvalidArgumentError(format!( - "Unable to represent interval with days and nanos as year-months: {:?}", - self - ))), - } - } - - fn to_day_time(&self) -> Result<(i32, i32), ArrowError> { - let days = self.months.mul_checked(30)?.add_checked(self.days)?; - - match self.nanos { - nanos if nanos % NANOS_PER_MILLIS == 0 => { - let millis = (self.nanos / 1_000_000).try_into().map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Unable to represent {} nanos as milliseconds in a signed 32-bit integer", - self.nanos - )) - })?; - - Ok((days, millis)) - } - nanos => Err(ArrowError::InvalidArgumentError(format!( - "Unable to represent {nanos} as milliseconds" - ))), - } - } - - fn to_month_day_nanos(&self) -> (i32, i32, i64) { - (self.months, self.days, self.nanos) - } - - /// Parse string value in traditional Postgres format such as - /// `1 year 2 months 3 days 4 hours 5 minutes 6 seconds` - fn parse(value: &str, config: &IntervalParseConfig) -> Result { - let components = parse_interval_components(value, config)?; - - let result = components.into_iter().fold( - Ok(Self::default()), - |result, (amount, unit)| match result { - Ok(result) => result.add(amount, unit), - Err(e) => Err(e), - }, - )?; - - Ok(result) - } - - /// Interval addition following Postgres behavior. Fractional units will be spilled into smaller units. - /// When the interval unit is larger than months, the result is rounded to total months and not spilled to days/nanos. - /// Fractional parts of weeks and days are represented using days and nanoseconds. - /// e.g. INTERVAL '0.5 MONTH' = 15 days, INTERVAL '1.5 MONTH' = 1 month 15 days - /// e.g. INTERVAL '0.5 DAY' = 12 hours, INTERVAL '1.5 DAY' = 1 day 12 hours - /// [Postgres reference](https://www.postgresql.org/docs/15/datatype-datetime.html#DATATYPE-INTERVAL-INPUT:~:text=Field%20values%20can,fractional%20on%20output.) - fn add( - &self, - amount: IntervalAmount, - unit: IntervalUnit, - ) -> Result { - let result = match unit { - IntervalUnit::Century => { - let months_int = amount.integer.mul_checked(100)?.mul_checked(12)?; - let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 2); - let months = - months_int - .add_checked(month_frac)? - .try_into() - .map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} centuries as months in a signed 32-bit integer", - &amount.integer - )) - })?; - - Self::new(self.months.add_checked(months)?, self.days, self.nanos) - } - IntervalUnit::Decade => { - let months_int = amount.integer.mul_checked(10)?.mul_checked(12)?; - - let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 1); - let months = - months_int - .add_checked(month_frac)? - .try_into() - .map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} decades as months in a signed 32-bit integer", - &amount.integer - )) - })?; - - Self::new(self.months.add_checked(months)?, self.days, self.nanos) - } - IntervalUnit::Year => { - let months_int = amount.integer.mul_checked(12)?; - let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION); - let months = - months_int - .add_checked(month_frac)? - .try_into() - .map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} years as months in a signed 32-bit integer", - &amount.integer - )) - })?; - - Self::new(self.months.add_checked(months)?, self.days, self.nanos) - } - IntervalUnit::Month => { - let months = amount.integer.try_into().map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} months in a signed 32-bit integer", - &amount.integer - )) - })?; - - let days = amount.frac * 3 / 10_i64.pow(INTERVAL_PRECISION - 1); - let days = days.try_into().map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} months as days in a signed 32-bit integer", - amount.frac / 10_i64.pow(INTERVAL_PRECISION) - )) - })?; - - Self::new( - self.months.add_checked(months)?, - self.days.add_checked(days)?, - self.nanos, - ) - } - IntervalUnit::Week => { - let days = amount.integer.mul_checked(7)?.try_into().map_err(|_| { - ArrowError::ParseError(format!( - "Unable to represent {} weeks as days in a signed 32-bit integer", - &amount.integer - )) - })?; - - let nanos = - amount.frac * 7 * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); - - Self::new( - self.months, - self.days.add_checked(days)?, - self.nanos.add_checked(nanos)?, - ) - } - IntervalUnit::Day => { - let days = amount.integer.try_into().map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Unable to represent {} days in a signed 32-bit integer", - amount.integer - )) - })?; - - let nanos = - amount.frac * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); - - Self::new( - self.months, - self.days.add_checked(days)?, - self.nanos.add_checked(nanos)?, - ) - } - IntervalUnit::Hour => { - let nanos_int = amount.integer.mul_checked(NANOS_PER_HOUR)?; - let nanos_frac = - amount.frac * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); - let nanos = nanos_int.add_checked(nanos_frac)?; - - Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) - } - IntervalUnit::Minute => { - let nanos_int = amount.integer.mul_checked(NANOS_PER_MINUTE)?; - let nanos_frac = amount.frac * 6 / 10_i64.pow(INTERVAL_PRECISION - 10); - - let nanos = nanos_int.add_checked(nanos_frac)?; - - Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) - } - IntervalUnit::Second => { - let nanos_int = amount.integer.mul_checked(NANOS_PER_SECOND)?; - let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 9); - let nanos = nanos_int.add_checked(nanos_frac)?; - - Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) - } - IntervalUnit::Millisecond => { - let nanos_int = amount.integer.mul_checked(NANOS_PER_MILLIS)?; - let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 6); - let nanos = nanos_int.add_checked(nanos_frac)?; - - Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) - } - IntervalUnit::Microsecond => { - let nanos_int = amount.integer.mul_checked(1_000)?; - let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 3); - let nanos = nanos_int.add_checked(nanos_frac)?; - - Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) - } - IntervalUnit::Nanosecond => { - let nanos_int = amount.integer; - let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION); - let nanos = nanos_int.add_checked(nanos_frac)?; - - Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) - } - }; - - Ok(result) - } -} - -struct IntervalParseConfig { - /// The default unit to use if none is specified - /// e.g. `INTERVAL 1` represents `INTERVAL 1 SECOND` when default_unit = IntervalType::Second - default_unit: IntervalUnit, -} - -impl IntervalParseConfig { - fn new(default_unit: IntervalUnit) -> Self { - Self { default_unit } - } -} - -/// parse the string into a vector of interval components i.e. (amount, unit) tuples -fn parse_interval_components( - value: &str, - config: &IntervalParseConfig, -) -> Result, ArrowError> { - let parts = value.split_whitespace(); - - let raw_amounts = parts.clone().step_by(2); - let raw_units = parts.skip(1).step_by(2); - - // parse amounts - let (amounts, invalid_amounts) = raw_amounts - .map(IntervalAmount::from_str) - .partition::, _>(Result::is_ok); - - // invalid amounts? - if !invalid_amounts.is_empty() { - return Err(ArrowError::NotYetImplemented(format!( - "Unsupported Interval Expression with value {value:?}" - ))); - } - - // parse units - let (units, invalid_units): (Vec<_>, Vec<_>) = raw_units - .clone() - .map(IntervalUnit::from_str) - .partition(Result::is_ok); - - // invalid units? - if !invalid_units.is_empty() { - return Err(ArrowError::ParseError(format!( - "Invalid input syntax for type interval: {value:?}" - ))); - } - - // collect parsed results - let amounts = amounts.into_iter().map(Result::unwrap).collect::>(); - let units = units.into_iter().map(Result::unwrap).collect::>(); - - // if only an amount is specified, use the default unit - if amounts.len() == 1 && units.is_empty() { - return Ok(vec![(amounts[0], config.default_unit)]); - }; - - // duplicate units? - let mut observed_interval_types = 0; - for (unit, raw_unit) in units.iter().zip(raw_units) { - if observed_interval_types & (*unit as u16) != 0 { - return Err(ArrowError::ParseError(format!( - "Invalid input syntax for type interval: {value:?}. Repeated type '{raw_unit}'", - ))); - } - - observed_interval_types |= *unit as u16; - } - - let result = amounts.iter().copied().zip(units.iter().copied()); - - Ok(result.collect::>()) -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::timezone::Tz; - use arrow_buffer::i256; - - #[test] - fn test_parse_nanos() { - assert_eq!(parse_nanos::<3, 0>(&[1, 2, 3]), 123_000_000); - assert_eq!(parse_nanos::<5, 0>(&[1, 2, 3, 4, 5]), 123_450_000); - assert_eq!(parse_nanos::<6, b'0'>(b"123456"), 123_456_000); - } - - #[test] - fn string_to_timestamp_timezone() { - // Explicit timezone - assert_eq!( - 1599572549190855000, - parse_timestamp("2020-09-08T13:42:29.190855+00:00").unwrap() - ); - assert_eq!( - 1599572549190855000, - parse_timestamp("2020-09-08T13:42:29.190855Z").unwrap() - ); - assert_eq!( - 1599572549000000000, - parse_timestamp("2020-09-08T13:42:29Z").unwrap() - ); // no fractional part - assert_eq!( - 1599590549190855000, - parse_timestamp("2020-09-08T13:42:29.190855-05:00").unwrap() - ); - } - - #[test] - fn string_to_timestamp_timezone_space() { - // Ensure space rather than T between time and date is accepted - assert_eq!( - 1599572549190855000, - parse_timestamp("2020-09-08 13:42:29.190855+00:00").unwrap() - ); - assert_eq!( - 1599572549190855000, - parse_timestamp("2020-09-08 13:42:29.190855Z").unwrap() - ); - assert_eq!( - 1599572549000000000, - parse_timestamp("2020-09-08 13:42:29Z").unwrap() - ); // no fractional part - assert_eq!( - 1599590549190855000, - parse_timestamp("2020-09-08 13:42:29.190855-05:00").unwrap() - ); - } - - #[test] - #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function: mktime - fn string_to_timestamp_no_timezone() { - // This test is designed to succeed in regardless of the local - // timezone the test machine is running. Thus it is still - // somewhat susceptible to bugs in the use of chrono - let naive_datetime = NaiveDateTime::new( - NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), - NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), - ); - - // Ensure both T and ' ' variants work - assert_eq!( - naive_datetime.timestamp_nanos(), - parse_timestamp("2020-09-08T13:42:29.190855").unwrap() - ); - - assert_eq!( - naive_datetime.timestamp_nanos(), - parse_timestamp("2020-09-08 13:42:29.190855").unwrap() - ); - - // Also ensure that parsing timestamps with no fractional - // second part works as well - let naive_datetime_whole_secs = NaiveDateTime::new( - NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), - NaiveTime::from_hms_opt(13, 42, 29).unwrap(), - ); - - // Ensure both T and ' ' variants work - assert_eq!( - naive_datetime_whole_secs.timestamp_nanos(), - parse_timestamp("2020-09-08T13:42:29").unwrap() - ); - - assert_eq!( - naive_datetime_whole_secs.timestamp_nanos(), - parse_timestamp("2020-09-08 13:42:29").unwrap() - ); - - // ensure without time work - // no time, should be the nano second at - // 2020-09-08 0:0:0 - let naive_datetime_no_time = NaiveDateTime::new( - NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), - NaiveTime::from_hms_opt(0, 0, 0).unwrap(), - ); - - assert_eq!( - naive_datetime_no_time.timestamp_nanos(), - parse_timestamp("2020-09-08").unwrap() - ) - } - - #[test] - fn string_to_timestamp_chrono() { - let cases = [ - "2020-09-08T13:42:29Z", - "1969-01-01T00:00:00.1Z", - "2020-09-08T12:00:12.12345678+00:00", - "2020-09-08T12:00:12+00:00", - "2020-09-08T12:00:12.1+00:00", - "2020-09-08T12:00:12.12+00:00", - "2020-09-08T12:00:12.123+00:00", - "2020-09-08T12:00:12.1234+00:00", - "2020-09-08T12:00:12.12345+00:00", - "2020-09-08T12:00:12.123456+00:00", - "2020-09-08T12:00:12.1234567+00:00", - "2020-09-08T12:00:12.12345678+00:00", - "2020-09-08T12:00:12.123456789+00:00", - "2020-09-08T12:00:12.12345678912z", - "2020-09-08T12:00:12.123456789123Z", - "2020-09-08T12:00:12.123456789123+02:00", - "2020-09-08T12:00:12.12345678912345Z", - "2020-09-08T12:00:12.1234567891234567+02:00", - "2020-09-08T12:00:60Z", - "2020-09-08T12:00:60.123Z", - "2020-09-08T12:00:60.123456+02:00", - "2020-09-08T12:00:60.1234567891234567+02:00", - "2020-09-08T12:00:60.999999999+02:00", - "2020-09-08t12:00:12.12345678+00:00", - "2020-09-08t12:00:12+00:00", - "2020-09-08t12:00:12Z", - ]; - - for case in cases { - let chrono = DateTime::parse_from_rfc3339(case).unwrap(); - let chrono_utc = chrono.with_timezone(&Utc); - - let custom = string_to_datetime(&Utc, case).unwrap(); - assert_eq!(chrono_utc, custom) - } - } - - #[test] - fn string_to_timestamp_naive() { - let cases = [ - "2018-11-13T17:11:10.011375885995", - "2030-12-04T17:11:10.123", - "2030-12-04T17:11:10.1234", - "2030-12-04T17:11:10.123456", - ]; - for case in cases { - let chrono = - NaiveDateTime::parse_from_str(case, "%Y-%m-%dT%H:%M:%S%.f").unwrap(); - let custom = string_to_datetime(&Utc, case).unwrap(); - assert_eq!(chrono, custom.naive_utc()) - } - } - - #[test] - fn string_to_timestamp_invalid() { - // Test parsing invalid formats - let cases = [ - ("", "timestamp must contain at least 10 characters"), - ("SS", "timestamp must contain at least 10 characters"), - ("Wed, 18 Feb 2015 23:16:09 GMT", "error parsing date"), - ("1997-01-31H09:26:56.123Z", "invalid timestamp separator"), - ("1997-01-31 09:26:56.123Z", "error parsing time"), - ("1997:01:31T09:26:56.123Z", "error parsing date"), - ("1997:1:31T09:26:56.123Z", "error parsing date"), - ("1997-01-32T09:26:56.123Z", "error parsing date"), - ("1997-13-32T09:26:56.123Z", "error parsing date"), - ("1997-02-29T09:26:56.123Z", "error parsing date"), - ("2015-02-30T17:35:20-08:00", "error parsing date"), - ("1997-01-10T9:26:56.123Z", "error parsing time"), - ("2015-01-20T25:35:20-08:00", "error parsing time"), - ("1997-01-10T09:61:56.123Z", "error parsing time"), - ("1997-01-10T09:61:90.123Z", "error parsing time"), - ("1997-01-10T12:00:6.123Z", "error parsing time"), - ("1997-01-31T092656.123Z", "error parsing time"), - ("1997-01-10T12:00:06.", "error parsing time"), - ("1997-01-10T12:00:06. ", "error parsing time"), - ]; - - for (s, ctx) in cases { - let expected = - format!("Parser error: Error parsing timestamp from '{s}': {ctx}"); - let actual = string_to_datetime(&Utc, s).unwrap_err().to_string(); - assert_eq!(actual, expected) - } - } - - // Parse a timestamp to timestamp int with a useful human readable error message - fn parse_timestamp(s: &str) -> Result { - let result = string_to_timestamp_nanos(s); - if let Err(e) = &result { - eprintln!("Error parsing timestamp '{s}': {e:?}"); - } - result - } - - #[test] - fn string_without_timezone_to_timestamp() { - // string without timezone should always output the same regardless the local or session timezone - - let naive_datetime = NaiveDateTime::new( - NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), - NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), - ); - - // Ensure both T and ' ' variants work - assert_eq!( - naive_datetime.timestamp_nanos(), - parse_timestamp("2020-09-08T13:42:29.190855").unwrap() - ); - - assert_eq!( - naive_datetime.timestamp_nanos(), - parse_timestamp("2020-09-08 13:42:29.190855").unwrap() - ); - - let naive_datetime = NaiveDateTime::new( - NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), - NaiveTime::from_hms_nano_opt(13, 42, 29, 0).unwrap(), - ); - - // Ensure both T and ' ' variants work - assert_eq!( - naive_datetime.timestamp_nanos(), - parse_timestamp("2020-09-08T13:42:29").unwrap() - ); - - assert_eq!( - naive_datetime.timestamp_nanos(), - parse_timestamp("2020-09-08 13:42:29").unwrap() - ); - - let tz: Tz = "+02:00".parse().unwrap(); - let date = string_to_datetime(&tz, "2020-09-08 13:42:29").unwrap(); - let utc = date.naive_utc().to_string(); - assert_eq!(utc, "2020-09-08 11:42:29"); - let local = date.naive_local().to_string(); - assert_eq!(local, "2020-09-08 13:42:29"); - - let date = string_to_datetime(&tz, "2020-09-08 13:42:29Z").unwrap(); - let utc = date.naive_utc().to_string(); - assert_eq!(utc, "2020-09-08 13:42:29"); - let local = date.naive_local().to_string(); - assert_eq!(local, "2020-09-08 15:42:29"); - - let dt = - NaiveDateTime::parse_from_str("2020-09-08T13:42:29Z", "%Y-%m-%dT%H:%M:%SZ") - .unwrap(); - let local: Tz = "+08:00".parse().unwrap(); - - // Parsed as offset from UTC - let date = string_to_datetime(&local, "2020-09-08T13:42:29Z").unwrap(); - assert_eq!(dt, date.naive_utc()); - assert_ne!(dt, date.naive_local()); - - // Parsed as offset from local - let date = string_to_datetime(&local, "2020-09-08 13:42:29").unwrap(); - assert_eq!(dt, date.naive_local()); - assert_ne!(dt, date.naive_utc()); - } - - #[test] - fn parse_time64_nanos() { - assert_eq!( - Time64NanosecondType::parse("02:10:01.1234567899999999"), - Some(7_801_123_456_789) - ); - assert_eq!( - Time64NanosecondType::parse("02:10:01.1234567"), - Some(7_801_123_456_700) - ); - assert_eq!( - Time64NanosecondType::parse("2:10:01.1234567"), - Some(7_801_123_456_700) - ); - assert_eq!( - Time64NanosecondType::parse("12:10:01.123456789 AM"), - Some(601_123_456_789) - ); - assert_eq!( - Time64NanosecondType::parse("12:10:01.123456789 am"), - Some(601_123_456_789) - ); - assert_eq!( - Time64NanosecondType::parse("2:10:01.12345678 PM"), - Some(51_001_123_456_780) - ); - assert_eq!( - Time64NanosecondType::parse("2:10:01.12345678 pm"), - Some(51_001_123_456_780) - ); - assert_eq!( - Time64NanosecondType::parse("02:10:01"), - Some(7_801_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("2:10:01"), - Some(7_801_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("12:10:01 AM"), - Some(601_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("12:10:01 am"), - Some(601_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("2:10:01 PM"), - Some(51_001_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("2:10:01 pm"), - Some(51_001_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("02:10"), - Some(7_800_000_000_000) - ); - assert_eq!(Time64NanosecondType::parse("2:10"), Some(7_800_000_000_000)); - assert_eq!( - Time64NanosecondType::parse("12:10 AM"), - Some(600_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("12:10 am"), - Some(600_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("2:10 PM"), - Some(51_000_000_000_000) - ); - assert_eq!( - Time64NanosecondType::parse("2:10 pm"), - Some(51_000_000_000_000) - ); - - // parse directly as nanoseconds - assert_eq!(Time64NanosecondType::parse("1"), Some(1)); - - // leap second - assert_eq!( - Time64NanosecondType::parse("23:59:60"), - Some(86_400_000_000_000) - ); - - // custom format - assert_eq!( - Time64NanosecondType::parse_formatted( - "02 - 10 - 01 - .1234567", - "%H - %M - %S - %.f" - ), - Some(7_801_123_456_700) - ); - } - - #[test] - fn parse_time64_micros() { - // expected formats - assert_eq!( - Time64MicrosecondType::parse("02:10:01.1234"), - Some(7_801_123_400) - ); - assert_eq!( - Time64MicrosecondType::parse("2:10:01.1234"), - Some(7_801_123_400) - ); - assert_eq!( - Time64MicrosecondType::parse("12:10:01.123456 AM"), - Some(601_123_456) - ); - assert_eq!( - Time64MicrosecondType::parse("12:10:01.123456 am"), - Some(601_123_456) - ); - assert_eq!( - Time64MicrosecondType::parse("2:10:01.12345 PM"), - Some(51_001_123_450) - ); - assert_eq!( - Time64MicrosecondType::parse("2:10:01.12345 pm"), - Some(51_001_123_450) - ); - assert_eq!( - Time64MicrosecondType::parse("02:10:01"), - Some(7_801_000_000) - ); - assert_eq!(Time64MicrosecondType::parse("2:10:01"), Some(7_801_000_000)); - assert_eq!( - Time64MicrosecondType::parse("12:10:01 AM"), - Some(601_000_000) - ); - assert_eq!( - Time64MicrosecondType::parse("12:10:01 am"), - Some(601_000_000) - ); - assert_eq!( - Time64MicrosecondType::parse("2:10:01 PM"), - Some(51_001_000_000) - ); - assert_eq!( - Time64MicrosecondType::parse("2:10:01 pm"), - Some(51_001_000_000) - ); - assert_eq!(Time64MicrosecondType::parse("02:10"), Some(7_800_000_000)); - assert_eq!(Time64MicrosecondType::parse("2:10"), Some(7_800_000_000)); - assert_eq!(Time64MicrosecondType::parse("12:10 AM"), Some(600_000_000)); - assert_eq!(Time64MicrosecondType::parse("12:10 am"), Some(600_000_000)); - assert_eq!( - Time64MicrosecondType::parse("2:10 PM"), - Some(51_000_000_000) - ); - assert_eq!( - Time64MicrosecondType::parse("2:10 pm"), - Some(51_000_000_000) - ); - - // parse directly as microseconds - assert_eq!(Time64MicrosecondType::parse("1"), Some(1)); - - // leap second - assert_eq!( - Time64MicrosecondType::parse("23:59:60"), - Some(86_400_000_000) - ); - - // custom format - assert_eq!( - Time64MicrosecondType::parse_formatted( - "02 - 10 - 01 - .1234", - "%H - %M - %S - %.f" - ), - Some(7_801_123_400) - ); - } - - #[test] - fn parse_time32_millis() { - // expected formats - assert_eq!(Time32MillisecondType::parse("02:10:01.1"), Some(7_801_100)); - assert_eq!(Time32MillisecondType::parse("2:10:01.1"), Some(7_801_100)); - assert_eq!( - Time32MillisecondType::parse("12:10:01.123 AM"), - Some(601_123) - ); - assert_eq!( - Time32MillisecondType::parse("12:10:01.123 am"), - Some(601_123) - ); - assert_eq!( - Time32MillisecondType::parse("2:10:01.12 PM"), - Some(51_001_120) - ); - assert_eq!( - Time32MillisecondType::parse("2:10:01.12 pm"), - Some(51_001_120) - ); - assert_eq!(Time32MillisecondType::parse("02:10:01"), Some(7_801_000)); - assert_eq!(Time32MillisecondType::parse("2:10:01"), Some(7_801_000)); - assert_eq!(Time32MillisecondType::parse("12:10:01 AM"), Some(601_000)); - assert_eq!(Time32MillisecondType::parse("12:10:01 am"), Some(601_000)); - assert_eq!(Time32MillisecondType::parse("2:10:01 PM"), Some(51_001_000)); - assert_eq!(Time32MillisecondType::parse("2:10:01 pm"), Some(51_001_000)); - assert_eq!(Time32MillisecondType::parse("02:10"), Some(7_800_000)); - assert_eq!(Time32MillisecondType::parse("2:10"), Some(7_800_000)); - assert_eq!(Time32MillisecondType::parse("12:10 AM"), Some(600_000)); - assert_eq!(Time32MillisecondType::parse("12:10 am"), Some(600_000)); - assert_eq!(Time32MillisecondType::parse("2:10 PM"), Some(51_000_000)); - assert_eq!(Time32MillisecondType::parse("2:10 pm"), Some(51_000_000)); - - // parse directly as milliseconds - assert_eq!(Time32MillisecondType::parse("1"), Some(1)); - - // leap second - assert_eq!(Time32MillisecondType::parse("23:59:60"), Some(86_400_000)); - - // custom format - assert_eq!( - Time32MillisecondType::parse_formatted( - "02 - 10 - 01 - .1", - "%H - %M - %S - %.f" - ), - Some(7_801_100) - ); - } - - #[test] - fn parse_time32_secs() { - // expected formats - assert_eq!(Time32SecondType::parse("02:10:01.1"), Some(7_801)); - assert_eq!(Time32SecondType::parse("02:10:01"), Some(7_801)); - assert_eq!(Time32SecondType::parse("2:10:01"), Some(7_801)); - assert_eq!(Time32SecondType::parse("12:10:01 AM"), Some(601)); - assert_eq!(Time32SecondType::parse("12:10:01 am"), Some(601)); - assert_eq!(Time32SecondType::parse("2:10:01 PM"), Some(51_001)); - assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001)); - assert_eq!(Time32SecondType::parse("02:10"), Some(7_800)); - assert_eq!(Time32SecondType::parse("2:10"), Some(7_800)); - assert_eq!(Time32SecondType::parse("12:10 AM"), Some(600)); - assert_eq!(Time32SecondType::parse("12:10 am"), Some(600)); - assert_eq!(Time32SecondType::parse("2:10 PM"), Some(51_000)); - assert_eq!(Time32SecondType::parse("2:10 pm"), Some(51_000)); - - // parse directly as seconds - assert_eq!(Time32SecondType::parse("1"), Some(1)); - - // leap second - assert_eq!(Time32SecondType::parse("23:59:60"), Some(86400)); - - // custom format - assert_eq!( - Time32SecondType::parse_formatted("02 - 10 - 01", "%H - %M - %S"), - Some(7_801) - ); - } - - #[test] - fn test_string_to_time_invalid() { - let cases = [ - "25:00", - "9:00:", - "009:00", - "09:0:00", - "25:00:00", - "13:00 AM", - "13:00 PM", - "12:00. AM", - "09:0:00", - "09:01:0", - "09:01:1", - "9:1:0", - "09:01:0", - "1:00.123", - "1:00:00.123f", - " 9:00:00", - ":09:00", - "T9:00:00", - "AM", - ]; - for case in cases { - assert!(string_to_time(case).is_none(), "{case}"); - } - } - - #[test] - fn test_string_to_time_chrono() { - let cases = [ - ("1:00", "%H:%M"), - ("12:00", "%H:%M"), - ("13:00", "%H:%M"), - ("24:00", "%H:%M"), - ("1:00:00", "%H:%M:%S"), - ("12:00:30", "%H:%M:%S"), - ("13:00:59", "%H:%M:%S"), - ("24:00:60", "%H:%M:%S"), - ("09:00:00", "%H:%M:%S%.f"), - ("0:00:30.123456", "%H:%M:%S%.f"), - ("0:00 AM", "%I:%M %P"), - ("1:00 AM", "%I:%M %P"), - ("12:00 AM", "%I:%M %P"), - ("13:00 AM", "%I:%M %P"), - ("0:00 PM", "%I:%M %P"), - ("1:00 PM", "%I:%M %P"), - ("12:00 PM", "%I:%M %P"), - ("13:00 PM", "%I:%M %P"), - ("1:00 pM", "%I:%M %P"), - ("1:00 Pm", "%I:%M %P"), - ("1:00 aM", "%I:%M %P"), - ("1:00 Am", "%I:%M %P"), - ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), - ("1:00:30.123456789 PM", "%I:%M:%S%.f %P"), - ("1:00:30.123456789123 PM", "%I:%M:%S%.f %P"), - ("1:00:30.1234 PM", "%I:%M:%S%.f %P"), - ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), - ("1:00:30.123456789123456789 PM", "%I:%M:%S%.f %P"), - ("1:00:30.12F456 PM", "%I:%M:%S%.f %P"), - ]; - for (s, format) in cases { - let chrono = NaiveTime::parse_from_str(s, format).ok(); - let custom = string_to_time(s); - assert_eq!(chrono, custom, "{s}"); - } - } - - #[test] - fn test_parse_interval() { - let config = IntervalParseConfig::new(IntervalUnit::Month); - - assert_eq!( - Interval::new(1i32, 0i32, 0i64), - Interval::parse("1 month", &config).unwrap(), - ); - - assert_eq!( - Interval::new(2i32, 0i32, 0i64), - Interval::parse("2 month", &config).unwrap(), - ); - - assert_eq!( - Interval::new(-1i32, -18i32, -(NANOS_PER_DAY / 5)), - Interval::parse("-1.5 months -3.2 days", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, 15i32, 0), - Interval::parse("0.5 months", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, 15i32, 0), - Interval::parse(".5 months", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, -15i32, 0), - Interval::parse("-0.5 months", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, -15i32, 0), - Interval::parse("-.5 months", &config).unwrap(), - ); - - assert_eq!( - Interval::new(2i32, 10i32, 9 * NANOS_PER_HOUR), - Interval::parse("2.1 months 7.25 days 3 hours", &config).unwrap(), - ); - - assert_eq!( - Interval::parse("1 centurys 1 month", &config) - .unwrap_err() - .to_string(), - r#"Parser error: Invalid input syntax for type interval: "1 centurys 1 month""# - ); - - assert_eq!( - Interval::new(37i32, 0i32, 0i64), - Interval::parse("3 year 1 month", &config).unwrap(), - ); - - assert_eq!( - Interval::new(35i32, 0i32, 0i64), - Interval::parse("3 year -1 month", &config).unwrap(), - ); - - assert_eq!( - Interval::new(-37i32, 0i32, 0i64), - Interval::parse("-3 year -1 month", &config).unwrap(), - ); - - assert_eq!( - Interval::new(-35i32, 0i32, 0i64), - Interval::parse("-3 year 1 month", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, 5i32, 0i64), - Interval::parse("5 days", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, 7i32, 3 * NANOS_PER_HOUR), - Interval::parse("7 days 3 hours", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, 7i32, 5 * NANOS_PER_MINUTE), - Interval::parse("7 days 5 minutes", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, 7i32, -5 * NANOS_PER_MINUTE), - Interval::parse("7 days -5 minutes", &config).unwrap(), - ); - - assert_eq!( - Interval::new(0i32, -7i32, 5 * NANOS_PER_HOUR), - Interval::parse("-7 days 5 hours", &config).unwrap(), - ); - - assert_eq!( - Interval::new( - 0i32, - -7i32, - -5 * NANOS_PER_HOUR - 5 * NANOS_PER_MINUTE - 5 * NANOS_PER_SECOND - ), - Interval::parse("-7 days -5 hours -5 minutes -5 seconds", &config).unwrap(), - ); - - assert_eq!( - Interval::new(12i32, 0i32, 25 * NANOS_PER_MILLIS), - Interval::parse("1 year 25 millisecond", &config).unwrap(), - ); - - assert_eq!( - Interval::new( - 12i32, - 1i32, - (NANOS_PER_SECOND as f64 * 0.000000001_f64) as i64 - ), - Interval::parse("1 year 1 day 0.000000001 seconds", &config).unwrap(), - ); - - assert_eq!( - Interval::new(12i32, 1i32, NANOS_PER_MILLIS / 10), - Interval::parse("1 year 1 day 0.1 milliseconds", &config).unwrap(), - ); - - assert_eq!( - Interval::new(12i32, 1i32, 1000i64), - Interval::parse("1 year 1 day 1 microsecond", &config).unwrap(), - ); - - assert_eq!( - Interval::new(12i32, 1i32, 1i64), - Interval::parse("1 year 1 day 1 nanoseconds", &config).unwrap(), - ); - - assert_eq!( - Interval::new(1i32, 0i32, -NANOS_PER_SECOND), - Interval::parse("1 month -1 second", &config).unwrap(), - ); - - assert_eq!( - Interval::new(-13i32, -8i32, -NANOS_PER_HOUR - NANOS_PER_MINUTE - NANOS_PER_SECOND - (1.11_f64 * NANOS_PER_MILLIS as f64) as i64), - Interval::parse("-1 year -1 month -1 week -1 day -1 hour -1 minute -1 second -1.11 millisecond", &config).unwrap(), - ); - } - - #[test] - fn test_duplicate_interval_type() { - let config = IntervalParseConfig::new(IntervalUnit::Month); - - let err = Interval::parse("1 month 1 second 1 second", &config) - .expect_err("parsing interval should have failed"); - assert_eq!( - r#"ParseError("Invalid input syntax for type interval: \"1 month 1 second 1 second\". Repeated type 'second'")"#, - format!("{err:?}") - ); - - // test with singular and plural forms - let err = Interval::parse("1 century 2 centuries", &config) - .expect_err("parsing interval should have failed"); - assert_eq!( - r#"ParseError("Invalid input syntax for type interval: \"1 century 2 centuries\". Repeated type 'centuries'")"#, - format!("{err:?}") - ); - } - - #[test] - fn test_interval_amount_parsing() { - // integer - let result = IntervalAmount::from_str("123").unwrap(); - let expected = IntervalAmount::new(123, 0); - - assert_eq!(result, expected); - - // positive w/ fractional - let result = IntervalAmount::from_str("0.3").unwrap(); - let expected = IntervalAmount::new(0, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)); - - assert_eq!(result, expected); - - // negative w/ fractional - let result = IntervalAmount::from_str("-3.5").unwrap(); - let expected = IntervalAmount::new(-3, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)); - - assert_eq!(result, expected); - - // invalid: missing fractional - let result = IntervalAmount::from_str("3."); - assert!(result.is_err()); - - // invalid: sign in fractional - let result = IntervalAmount::from_str("3.-5"); - assert!(result.is_err()); - } - - #[test] - fn test_interval_precision() { - let config = IntervalParseConfig::new(IntervalUnit::Month); - - let result = Interval::parse("100000.1 days", &config).unwrap(); - let expected = Interval::new(0_i32, 100_000_i32, NANOS_PER_DAY / 10); - - assert_eq!(result, expected); - } - - #[test] - fn test_interval_addition() { - // add 4.1 centuries - let start = Interval::new(1, 2, 3); - let expected = Interval::new(4921, 2, 3); - - let result = start - .add( - IntervalAmount::new(4, 10_i64.pow(INTERVAL_PRECISION - 1)), - IntervalUnit::Century, - ) - .unwrap(); - - assert_eq!(result, expected); - - // add 10.25 decades - let start = Interval::new(1, 2, 3); - let expected = Interval::new(1231, 2, 3); - - let result = start - .add( - IntervalAmount::new(10, 25 * 10_i64.pow(INTERVAL_PRECISION - 2)), - IntervalUnit::Decade, - ) - .unwrap(); - - assert_eq!(result, expected); - - // add 30.3 years (reminder: Postgres logic does not spill to days/nanos when interval is larger than a month) - let start = Interval::new(1, 2, 3); - let expected = Interval::new(364, 2, 3); - - let result = start - .add( - IntervalAmount::new(30, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)), - IntervalUnit::Year, - ) - .unwrap(); - - assert_eq!(result, expected); - - // add 1.5 months - let start = Interval::new(1, 2, 3); - let expected = Interval::new(2, 17, 3); - - let result = start - .add( - IntervalAmount::new(1, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), - IntervalUnit::Month, - ) - .unwrap(); - - assert_eq!(result, expected); - - // add -2 weeks - let start = Interval::new(1, 25, 3); - let expected = Interval::new(1, 11, 3); - - let result = start - .add(IntervalAmount::new(-2, 0), IntervalUnit::Week) - .unwrap(); - - assert_eq!(result, expected); - - // add 2.2 days - let start = Interval::new(12, 15, 3); - let expected = Interval::new(12, 17, 3 + 17_280 * NANOS_PER_SECOND); - - let result = start - .add( - IntervalAmount::new(2, 2 * 10_i64.pow(INTERVAL_PRECISION - 1)), - IntervalUnit::Day, - ) - .unwrap(); - - assert_eq!(result, expected); - - // add 12.5 hours - let start = Interval::new(1, 2, 3); - let expected = Interval::new(1, 2, 3 + 45_000 * NANOS_PER_SECOND); - - let result = start - .add( - IntervalAmount::new(12, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), - IntervalUnit::Hour, - ) - .unwrap(); - - assert_eq!(result, expected); - - // add -1.5 minutes - let start = Interval::new(0, 0, -3); - let expected = Interval::new(0, 0, -90_000_000_000 - 3); - - let result = start - .add( - IntervalAmount::new(-1, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)), - IntervalUnit::Minute, - ) - .unwrap(); - - assert_eq!(result, expected); - } - - #[test] - fn string_to_timestamp_old() { - parse_timestamp("1677-06-14T07:29:01.256") - .map_err(|e| assert!(e.to_string().ends_with(ERR_NANOSECONDS_NOT_SUPPORTED))) - .unwrap_err(); - } - - #[test] - fn test_parse_decimal_with_parameter() { - let tests = [ - ("0", 0i128), - ("123.123", 123123i128), - ("123.1234", 123123i128), - ("123.1", 123100i128), - ("123", 123000i128), - ("-123.123", -123123i128), - ("-123.1234", -123123i128), - ("-123.1", -123100i128), - ("-123", -123000i128), - ("0.0000123", 0i128), - ("12.", 12000i128), - ("-12.", -12000i128), - ("00.1", 100i128), - ("-00.1", -100i128), - ("12345678912345678.1234", 12345678912345678123i128), - ("-12345678912345678.1234", -12345678912345678123i128), - ("99999999999999999.999", 99999999999999999999i128), - ("-99999999999999999.999", -99999999999999999999i128), - (".123", 123i128), - ("-.123", -123i128), - ("123.", 123000i128), - ("-123.", -123000i128), - ]; - for (s, i) in tests { - let result_128 = parse_decimal::(s, 20, 3); - assert_eq!(i, result_128.unwrap()); - let result_256 = parse_decimal::(s, 20, 3); - assert_eq!(i256::from_i128(i), result_256.unwrap()); - } - let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+", "-"]; - for s in can_not_parse_tests { - let result_128 = parse_decimal::(s, 20, 3); - assert_eq!( - format!("Parser error: can't parse the string value {s} to decimal"), - result_128.unwrap_err().to_string() - ); - let result_256 = parse_decimal::(s, 20, 3); - assert_eq!( - format!("Parser error: can't parse the string value {s} to decimal"), - result_256.unwrap_err().to_string() - ); - } - let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"]; - for s in overflow_parse_tests { - let result_128 = parse_decimal::(s, 10, 3); - let expected_128 = "Parser error: parse decimal overflow"; - let actual_128 = result_128.unwrap_err().to_string(); - - assert!( - actual_128.contains(expected_128), - "actual: '{actual_128}', expected: '{expected_128}'" - ); - - let result_256 = parse_decimal::(s, 10, 3); - let expected_256 = "Parser error: parse decimal overflow"; - let actual_256 = result_256.unwrap_err().to_string(); - - assert!( - actual_256.contains(expected_256), - "actual: '{actual_256}', expected: '{expected_256}'" - ); - } - - let edge_tests_128 = [ - ( - "99999999999999999999999999999999999999", - 99999999999999999999999999999999999999i128, - 0, - ), - ( - "999999999999999999999999999999999999.99", - 99999999999999999999999999999999999999i128, - 2, - ), - ( - "9999999999999999999999999.9999999999999", - 99999999999999999999999999999999999999i128, - 13, - ), - ( - "9999999999999999999999999", - 99999999999999999999999990000000000000i128, - 13, - ), - ( - "0.99999999999999999999999999999999999999", - 99999999999999999999999999999999999999i128, - 38, - ), - ]; - for (s, i, scale) in edge_tests_128 { - let result_128 = parse_decimal::(s, 38, scale); - assert_eq!(i, result_128.unwrap()); - } - let edge_tests_256 = [ - ( - "9999999999999999999999999999999999999999999999999999999999999999999999999999", -i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), - 0, - ), - ( - "999999999999999999999999999999999999999999999999999999999999999999999999.9999", - i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), - 4, - ), - ( - "99999999999999999999999999999999999999999999999999.99999999999999999999999999", - i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), - 26, - ), - ( - "99999999999999999999999999999999999999999999999999", - i256::from_string("9999999999999999999999999999999999999999999999999900000000000000000000000000").unwrap(), - 26, - ), - ]; - for (s, i, scale) in edge_tests_256 { - let result = parse_decimal::(s, 76, scale); - assert_eq!(i, result.unwrap()); - } - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; +use arrow_buffer::ArrowNativeType; +use arrow_schema::ArrowError; +use chrono::prelude::*; +use half::f16; +use std::str::FromStr; + +/// Parse nanoseconds from the first `N` values in digits, subtracting the offset `O` +#[inline] +fn parse_nanos(digits: &[u8]) -> u32 { + digits[..N] + .iter() + .fold(0_u32, |acc, v| acc * 10 + v.wrapping_sub(O) as u32) + * 10_u32.pow((9 - N) as _) +} + +/// Helper for parsing RFC3339 timestamps +struct TimestampParser { + /// The timestamp bytes to parse minus `b'0'` + /// + /// This makes interpretation as an integer inexpensive + digits: [u8; 32], + /// A mask containing a `1` bit where the corresponding byte is a valid ASCII digit + mask: u32, +} + +impl TimestampParser { + fn new(bytes: &[u8]) -> Self { + let mut digits = [0; 32]; + let mut mask = 0; + + // Treating all bytes the same way, helps LLVM vectorise this correctly + for (idx, (o, i)) in digits.iter_mut().zip(bytes).enumerate() { + *o = i.wrapping_sub(b'0'); + mask |= ((*o < 10) as u32) << idx + } + + Self { digits, mask } + } + + /// Returns true if the byte at `idx` in the original string equals `b` + fn test(&self, idx: usize, b: u8) -> bool { + self.digits[idx] == b.wrapping_sub(b'0') + } + + /// Parses a date of the form `1997-01-31` + fn date(&self) -> Option { + if self.mask & 0b1111111111 != 0b1101101111 + || !self.test(4, b'-') + || !self.test(7, b'-') + { + return None; + } + + let year = self.digits[0] as u16 * 1000 + + self.digits[1] as u16 * 100 + + self.digits[2] as u16 * 10 + + self.digits[3] as u16; + + let month = self.digits[5] * 10 + self.digits[6]; + let day = self.digits[8] * 10 + self.digits[9]; + + NaiveDate::from_ymd_opt(year as _, month as _, day as _) + } + + /// Parses a time of any of forms + /// - `09:26:56` + /// - `09:26:56.123` + /// - `09:26:56.123456` + /// - `09:26:56.123456789` + /// - `092656` + /// + /// Returning the end byte offset + fn time(&self) -> Option<(NaiveTime, usize)> { + // Make a NaiveTime handling leap seconds + let time = |hour, min, sec, nano| match sec { + 60 => { + let nano = 1_000_000_000 + nano; + NaiveTime::from_hms_nano_opt(hour as _, min as _, 59, nano) + } + _ => NaiveTime::from_hms_nano_opt(hour as _, min as _, sec as _, nano), + }; + + match (self.mask >> 11) & 0b11111111 { + // 09:26:56 + 0b11011011 if self.test(13, b':') && self.test(16, b':') => { + let hour = self.digits[11] * 10 + self.digits[12]; + let minute = self.digits[14] * 10 + self.digits[15]; + let second = self.digits[17] * 10 + self.digits[18]; + + match self.test(19, b'.') { + true => { + let digits = (self.mask >> 20).trailing_ones(); + let nanos = match digits { + 0 => return None, + 1 => parse_nanos::<1, 0>(&self.digits[20..21]), + 2 => parse_nanos::<2, 0>(&self.digits[20..22]), + 3 => parse_nanos::<3, 0>(&self.digits[20..23]), + 4 => parse_nanos::<4, 0>(&self.digits[20..24]), + 5 => parse_nanos::<5, 0>(&self.digits[20..25]), + 6 => parse_nanos::<6, 0>(&self.digits[20..26]), + 7 => parse_nanos::<7, 0>(&self.digits[20..27]), + 8 => parse_nanos::<8, 0>(&self.digits[20..28]), + _ => parse_nanos::<9, 0>(&self.digits[20..29]), + }; + Some((time(hour, minute, second, nanos)?, 20 + digits as usize)) + } + false => Some((time(hour, minute, second, 0)?, 19)), + } + } + // 092656 + 0b111111 => { + let hour = self.digits[11] * 10 + self.digits[12]; + let minute = self.digits[13] * 10 + self.digits[14]; + let second = self.digits[15] * 10 + self.digits[16]; + let time = time(hour, minute, second, 0)?; + Some((time, 17)) + } + _ => None, + } + } +} + +/// Accepts a string and parses it relative to the provided `timezone` +/// +/// In addition to RFC3339 / ISO8601 standard timestamps, it also +/// accepts strings that use a space ` ` to separate the date and time +/// as well as strings that have no explicit timezone offset. +/// +/// Examples of accepted inputs: +/// * `1997-01-31T09:26:56.123Z` # RCF3339 +/// * `1997-01-31T09:26:56.123-05:00` # RCF3339 +/// * `1997-01-31 09:26:56.123-05:00` # close to RCF3339 but with a space rather than T +/// * `2023-01-01 04:05:06.789 -08` # close to RCF3339, no fractional seconds or time separator +/// * `1997-01-31T09:26:56.123` # close to RCF3339 but no timezone offset specified +/// * `1997-01-31 09:26:56.123` # close to RCF3339 but uses a space and no timezone offset +/// * `1997-01-31 09:26:56` # close to RCF3339, no fractional seconds +/// * `1997-01-31 092656` # close to RCF3339, no fractional seconds +/// * `1997-01-31 092656+04:00` # close to RCF3339, no fractional seconds or time separator +/// * `1997-01-31` # close to RCF3339, only date no time +/// +/// [IANA timezones] are only supported if the `arrow-array/chrono-tz` feature is enabled +/// +/// * `2023-01-01 040506 America/Los_Angeles` +/// +/// If a timestamp is ambiguous, for example as a result of daylight-savings time, an error +/// will be returned +/// +/// Some formats supported by PostgresSql +/// are not supported, like +/// +/// * "2023-01-01 04:05:06.789 +07:30:00", +/// * "2023-01-01 040506 +07:30:00", +/// * "2023-01-01 04:05:06.789 PST", +/// +/// [IANA timezones]: https://www.iana.org/time-zones +pub fn string_to_datetime( + timezone: &T, + s: &str, +) -> Result, ArrowError> { + let err = |ctx: &str| { + ArrowError::ParseError(format!("Error parsing timestamp from '{s}': {ctx}")) + }; + + let bytes = s.as_bytes(); + if bytes.len() < 10 { + return Err(err("timestamp must contain at least 10 characters")); + } + + let parser = TimestampParser::new(bytes); + let date = parser.date().ok_or_else(|| err("error parsing date"))?; + if bytes.len() == 10 { + let datetime = date.and_time(NaiveTime::from_hms_opt(0, 0, 0).unwrap()); + return timezone + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset")); + } + + if !parser.test(10, b'T') && !parser.test(10, b't') && !parser.test(10, b' ') { + return Err(err("invalid timestamp separator")); + } + + let (time, mut tz_offset) = parser.time().ok_or_else(|| err("error parsing time"))?; + let datetime = date.and_time(time); + + if tz_offset == 32 { + // Decimal overrun + while tz_offset < bytes.len() && bytes[tz_offset].is_ascii_digit() { + tz_offset += 1; + } + } + + if bytes.len() <= tz_offset { + return timezone + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset")); + } + + if bytes[tz_offset] == b'z' || bytes[tz_offset] == b'Z' { + return Ok(timezone.from_utc_datetime(&datetime)); + } + + // Parse remainder of string as timezone + let parsed_tz: Tz = s[tz_offset..].trim_start().parse()?; + let parsed = parsed_tz + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset"))?; + + Ok(parsed.with_timezone(timezone)) +} + +/// Accepts a string in RFC3339 / ISO8601 standard format and some +/// variants and converts it to a nanosecond precision timestamp. +/// +/// See [`string_to_datetime`] for the full set of supported formats +/// +/// Implements the `to_timestamp` function to convert a string to a +/// timestamp, following the model of spark SQL’s to_`timestamp`. +/// +/// Internally, this function uses the `chrono` library for the +/// datetime parsing +/// +/// We hope to extend this function in the future with a second +/// parameter to specifying the format string. +/// +/// ## Timestamp Precision +/// +/// Function uses the maximum precision timestamps supported by +/// Arrow (nanoseconds stored as a 64-bit integer) timestamps. This +/// means the range of dates that timestamps can represent is ~1677 AD +/// to 2262 AM +/// +/// ## Timezone / Offset Handling +/// +/// Numerical values of timestamps are stored compared to offset UTC. +/// +/// This function interprets string without an explicit time zone as timestamps +/// relative to UTC, see [`string_to_datetime`] for alternative semantics +/// +/// In particular: +/// +/// ``` +/// # use arrow_cast::parse::string_to_timestamp_nanos; +/// // Note all three of these timestamps are parsed as the same value +/// let a = string_to_timestamp_nanos("1997-01-31 09:26:56.123Z").unwrap(); +/// let b = string_to_timestamp_nanos("1997-01-31T09:26:56.123").unwrap(); +/// let c = string_to_timestamp_nanos("1997-01-31T14:26:56.123+05:00").unwrap(); +/// +/// assert_eq!(a, b); +/// assert_eq!(b, c); +/// ``` +/// +#[inline] +pub fn string_to_timestamp_nanos(s: &str) -> Result { + to_timestamp_nanos(string_to_datetime(&Utc, s)?.naive_utc()) +} + +/// Fallible conversion of [`NaiveDateTime`] to `i64` nanoseconds +#[inline] +fn to_timestamp_nanos(dt: NaiveDateTime) -> Result { + dt.timestamp_nanos_opt() + .ok_or_else(|| ArrowError::ParseError(ERR_NANOSECONDS_NOT_SUPPORTED.to_string())) +} + +/// Accepts a string in ISO8601 standard format and some +/// variants and converts it to nanoseconds since midnight. +/// +/// Examples of accepted inputs: +/// * `09:26:56.123 AM` +/// * `23:59:59` +/// * `6:00 pm` +// +/// Internally, this function uses the `chrono` library for the +/// time parsing +/// +/// ## Timezone / Offset Handling +/// +/// This function does not support parsing strings with a timezone +/// or offset specified, as it considers only time since midnight. +pub fn string_to_time_nanoseconds(s: &str) -> Result { + let nt = string_to_time(s).ok_or_else(|| { + ArrowError::ParseError(format!("Failed to parse \'{s}\' as time")) + })?; + Ok(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) +} + +fn string_to_time(s: &str) -> Option { + let bytes = s.as_bytes(); + if bytes.len() < 4 { + return None; + } + + let (am, bytes) = match bytes.get(bytes.len() - 3..) { + Some(b" AM" | b" am" | b" Am" | b" aM") => { + (Some(true), &bytes[..bytes.len() - 3]) + } + Some(b" PM" | b" pm" | b" pM" | b" Pm") => { + (Some(false), &bytes[..bytes.len() - 3]) + } + _ => (None, bytes), + }; + + if bytes.len() < 4 { + return None; + } + + let mut digits = [b'0'; 6]; + + // Extract hour + let bytes = match (bytes[1], bytes[2]) { + (b':', _) => { + digits[1] = bytes[0]; + &bytes[2..] + } + (_, b':') => { + digits[0] = bytes[0]; + digits[1] = bytes[1]; + &bytes[3..] + } + _ => return None, + }; + + if bytes.len() < 2 { + return None; // Minutes required + } + + // Extract minutes + digits[2] = bytes[0]; + digits[3] = bytes[1]; + + let nanoseconds = match bytes.get(2) { + Some(b':') => { + if bytes.len() < 5 { + return None; + } + + // Extract seconds + digits[4] = bytes[3]; + digits[5] = bytes[4]; + + // Extract sub-seconds if any + match bytes.get(5) { + Some(b'.') => { + let decimal = &bytes[6..]; + if decimal.iter().any(|x| !x.is_ascii_digit()) { + return None; + } + match decimal.len() { + 0 => return None, + 1 => parse_nanos::<1, b'0'>(decimal), + 2 => parse_nanos::<2, b'0'>(decimal), + 3 => parse_nanos::<3, b'0'>(decimal), + 4 => parse_nanos::<4, b'0'>(decimal), + 5 => parse_nanos::<5, b'0'>(decimal), + 6 => parse_nanos::<6, b'0'>(decimal), + 7 => parse_nanos::<7, b'0'>(decimal), + 8 => parse_nanos::<8, b'0'>(decimal), + _ => parse_nanos::<9, b'0'>(decimal), + } + } + Some(_) => return None, + None => 0, + } + } + Some(_) => return None, + None => 0, + }; + + digits.iter_mut().for_each(|x| *x = x.wrapping_sub(b'0')); + if digits.iter().any(|x| *x > 9) { + return None; + } + + let hour = match (digits[0] * 10 + digits[1], am) { + (12, Some(true)) => 0, // 12:00 AM -> 00:00 + (h @ 1..=11, Some(true)) => h, // 1:00 AM -> 01:00 + (12, Some(false)) => 12, // 12:00 PM -> 12:00 + (h @ 1..=11, Some(false)) => h + 12, // 1:00 PM -> 13:00 + (_, Some(_)) => return None, + (h, None) => h, + }; + + // Handle leap second + let (second, nanoseconds) = match digits[4] * 10 + digits[5] { + 60 => (59, nanoseconds + 1_000_000_000), + s => (s, nanoseconds), + }; + + NaiveTime::from_hms_nano_opt( + hour as _, + (digits[2] * 10 + digits[3]) as _, + second as _, + nanoseconds, + ) +} + +/// Specialized parsing implementations +/// used by csv and json reader +pub trait Parser: ArrowPrimitiveType { + fn parse(string: &str) -> Option; + + fn parse_formatted(string: &str, _format: &str) -> Option { + Self::parse(string) + } +} + +impl Parser for Float16Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()) + .ok() + .map(f16::from_f32) + } +} + +impl Parser for Float32Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +impl Parser for Float64Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +macro_rules! parser_primitive { + ($t:ty) => { + impl Parser for $t { + fn parse(string: &str) -> Option { + lexical_core::parse::(string.as_bytes()).ok() + } + } + }; +} +parser_primitive!(UInt64Type); +parser_primitive!(UInt32Type); +parser_primitive!(UInt16Type); +parser_primitive!(UInt8Type); +parser_primitive!(Int64Type); +parser_primitive!(Int32Type); +parser_primitive!(Int16Type); +parser_primitive!(Int8Type); + +impl Parser for TimestampNanosecondType { + fn parse(string: &str) -> Option { + string_to_timestamp_nanos(string).ok() + } +} + +impl Parser for TimestampMicrosecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1000) + } +} + +impl Parser for TimestampMillisecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000) + } +} + +impl Parser for TimestampSecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000_000) + } +} + +impl Parser for Time64NanosecondType { + // Will truncate any fractions of a nanosecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some( + nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + + nt.nanosecond() as i64, + ) + } +} + +impl Parser for Time64MicrosecondType { + // Will truncate any fractions of a microsecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| nanos / 1_000) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some( + nt.num_seconds_from_midnight() as i64 * 1_000_000 + + nt.nanosecond() as i64 / 1_000, + ) + } +} + +impl Parser for Time32MillisecondType { + // Will truncate any fractions of a millisecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| (nanos / 1_000_000) as i32) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some( + nt.num_seconds_from_midnight() as i32 * 1_000 + + nt.nanosecond() as i32 / 1_000_000, + ) + } +} + +impl Parser for Time32SecondType { + // Will truncate any fractions of a second + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| (nanos / 1_000_000_000) as i32) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some( + nt.num_seconds_from_midnight() as i32 + + nt.nanosecond() as i32 / 1_000_000_000, + ) + } +} + +/// Number of days between 0001-01-01 and 1970-01-01 +const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// Error message if nanosecond conversion request beyond supported interval +const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; + +fn parse_date(string: &str) -> Option { + if string.len() > 10 { + return None; + } + let mut digits = [0; 10]; + let mut mask = 0; + + // Treating all bytes the same way, helps LLVM vectorise this correctly + for (idx, (o, i)) in digits.iter_mut().zip(string.bytes()).enumerate() { + *o = i.wrapping_sub(b'0'); + mask |= ((*o < 10) as u16) << idx + } + + const HYPHEN: u8 = b'-'.wrapping_sub(b'0'); + + if digits[4] != HYPHEN { + return None; + } + + let (month, day) = match mask { + 0b1101101111 => { + if digits[7] != HYPHEN { + return None; + } + (digits[5] * 10 + digits[6], digits[8] * 10 + digits[9]) + } + 0b101101111 => { + if digits[7] != HYPHEN { + return None; + } + (digits[5] * 10 + digits[6], digits[8]) + } + 0b110101111 => { + if digits[6] != HYPHEN { + return None; + } + (digits[5], digits[7] * 10 + digits[8]) + } + 0b10101111 => { + if digits[6] != HYPHEN { + return None; + } + (digits[5], digits[7]) + } + _ => return None, + }; + + let year = digits[0] as u16 * 1000 + + digits[1] as u16 * 100 + + digits[2] as u16 * 10 + + digits[3] as u16; + + NaiveDate::from_ymd_opt(year as _, month as _, day as _) +} + +impl Parser for Date32Type { + fn parse(string: &str) -> Option { + let date = parse_date(string)?; + Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let date = NaiveDate::parse_from_str(string, format).ok()?; + Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } +} + +impl Parser for Date64Type { + fn parse(string: &str) -> Option { + if string.len() <= 10 { + let date = parse_date(string)?; + Some(NaiveDateTime::new(date, NaiveTime::default()).timestamp_millis()) + } else { + let date_time = string_to_datetime(&Utc, string).ok()?; + Some(date_time.timestamp_millis()) + } + } + + fn parse_formatted(string: &str, format: &str) -> Option { + use chrono::format::Fixed; + use chrono::format::StrftimeItems; + let fmt = StrftimeItems::new(format); + let has_zone = fmt.into_iter().any(|item| match item { + chrono::format::Item::Fixed(fixed_item) => matches!( + fixed_item, + Fixed::RFC2822 + | Fixed::RFC3339 + | Fixed::TimezoneName + | Fixed::TimezoneOffsetColon + | Fixed::TimezoneOffsetColonZ + | Fixed::TimezoneOffset + | Fixed::TimezoneOffsetZ + ), + _ => false, + }); + if has_zone { + let date_time = chrono::DateTime::parse_from_str(string, format).ok()?; + Some(date_time.timestamp_millis()) + } else { + let date_time = NaiveDateTime::parse_from_str(string, format).ok()?; + Some(date_time.timestamp_millis()) + } + } +} + +/// Parse the string format decimal value to i128/i256 format and checking the precision and scale. +/// The result value can't be out of bounds. +pub fn parse_decimal( + s: &str, + precision: u8, + scale: i8, +) -> Result { + let mut result = T::Native::usize_as(0); + let mut fractionals = 0; + let mut digits = 0; + let base = T::Native::usize_as(10); + + let bs = s.as_bytes(); + let (bs, negative) = match bs.first() { + Some(b'-') => (&bs[1..], true), + Some(b'+') => (&bs[1..], false), + _ => (bs, false), + }; + + if bs.is_empty() { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + + let mut bs = bs.iter(); + // Overflow checks are not required if 10^(precision - 1) <= T::MAX holds. + // Thus, if we validate the precision correctly, we can skip overflow checks. + while let Some(b) = bs.next() { + match b { + b'0'..=b'9' => { + if digits == 0 && *b == b'0' { + // Ignore leading zeros. + continue; + } + digits += 1; + result = result.mul_wrapping(base); + result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + } + b'.' => { + for b in bs.by_ref() { + if !b.is_ascii_digit() { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + if fractionals == scale { + // We have processed all the digits that we need. All that + // is left is to validate that the rest of the string contains + // valid digits. + continue; + } + fractionals += 1; + digits += 1; + result = result.mul_wrapping(base); + result = + result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + } + + // Fail on "." + if digits == 0 { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + } + } + + if fractionals < scale { + let exp = scale - fractionals; + if exp as u8 + digits > precision { + return Err(ArrowError::ParseError("parse decimal overflow".to_string())); + } + let mul = base.pow_wrapping(exp as _); + result = result.mul_wrapping(mul); + } else if digits > precision { + return Err(ArrowError::ParseError("parse decimal overflow".to_string())); + } + + Ok(if negative { + result.neg_wrapping() + } else { + result + }) +} + +pub fn parse_interval_year_month( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Year); + let interval = Interval::parse(value, &config)?; + + let months = interval.to_year_months().map_err(|_| ArrowError::CastError(format!( + "Cannot cast {value} to IntervalYearMonth. Only year and month fields are allowed." + )))?; + + Ok(IntervalYearMonthType::make_value(0, months)) +} + +pub fn parse_interval_day_time( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Day); + let interval = Interval::parse(value, &config)?; + + let (days, millis) = interval.to_day_time().map_err(|_| ArrowError::CastError(format!( + "Cannot cast {value} to IntervalDayTime because the nanos part isn't multiple of milliseconds" + )))?; + + Ok(IntervalDayTimeType::make_value(days, millis)) +} + +pub fn parse_interval_month_day_nano( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Month); + let interval = Interval::parse(value, &config)?; + + let (months, days, nanos) = interval.to_month_day_nanos(); + + Ok(IntervalMonthDayNanoType::make_value(months, days, nanos)) +} + +const NANOS_PER_MILLIS: i64 = 1_000_000; +const NANOS_PER_SECOND: i64 = 1_000 * NANOS_PER_MILLIS; +const NANOS_PER_MINUTE: i64 = 60 * NANOS_PER_SECOND; +const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MINUTE; +#[cfg(test)] +const NANOS_PER_DAY: i64 = 24 * NANOS_PER_HOUR; + +#[rustfmt::skip] +#[derive(Clone, Copy)] +#[repr(u16)] +enum IntervalUnit { + Century = 0b_0000_0000_0001, + Decade = 0b_0000_0000_0010, + Year = 0b_0000_0000_0100, + Month = 0b_0000_0000_1000, + Week = 0b_0000_0001_0000, + Day = 0b_0000_0010_0000, + Hour = 0b_0000_0100_0000, + Minute = 0b_0000_1000_0000, + Second = 0b_0001_0000_0000, + Millisecond = 0b_0010_0000_0000, + Microsecond = 0b_0100_0000_0000, + Nanosecond = 0b_1000_0000_0000, +} + +impl FromStr for IntervalUnit { + type Err = ArrowError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "century" | "centuries" => Ok(Self::Century), + "decade" | "decades" => Ok(Self::Decade), + "year" | "years" => Ok(Self::Year), + "month" | "months" => Ok(Self::Month), + "week" | "weeks" => Ok(Self::Week), + "day" | "days" => Ok(Self::Day), + "hour" | "hours" => Ok(Self::Hour), + "minute" | "minutes" => Ok(Self::Minute), + "second" | "seconds" => Ok(Self::Second), + "millisecond" | "milliseconds" => Ok(Self::Millisecond), + "microsecond" | "microseconds" => Ok(Self::Microsecond), + "nanosecond" | "nanoseconds" => Ok(Self::Nanosecond), + _ => Err(ArrowError::NotYetImplemented(format!( + "Unknown interval type: {s}" + ))), + } + } +} + +pub type MonthDayNano = (i32, i32, i64); + +/// Chosen based on the number of decimal digits in 1 week in nanoseconds +const INTERVAL_PRECISION: u32 = 15; + +#[derive(Clone, Copy, Debug, PartialEq)] +struct IntervalAmount { + /// The integer component of the interval amount + integer: i64, + /// The fractional component multiplied by 10^INTERVAL_PRECISION + frac: i64, +} + +#[cfg(test)] +impl IntervalAmount { + fn new(integer: i64, frac: i64) -> Self { + Self { integer, frac } + } +} + +impl FromStr for IntervalAmount { + type Err = ArrowError; + + fn from_str(s: &str) -> Result { + match s.split_once('.') { + Some((integer, frac)) + if frac.len() <= INTERVAL_PRECISION as usize + && !frac.is_empty() + && !frac.starts_with('-') => + { + // integer will be "" for values like ".5" + // and "-" for values like "-.5" + let explicit_neg = integer.starts_with('-'); + let integer = if integer.is_empty() || integer == "-" { + Ok(0) + } else { + integer.parse::().map_err(|_| { + ArrowError::ParseError(format!( + "Failed to parse {s} as interval amount" + )) + }) + }?; + + let frac_unscaled = frac.parse::().map_err(|_| { + ArrowError::ParseError(format!( + "Failed to parse {s} as interval amount" + )) + })?; + + // scale fractional part by interval precision + let frac = + frac_unscaled * 10_i64.pow(INTERVAL_PRECISION - frac.len() as u32); + + // propagate the sign of the integer part to the fractional part + let frac = if integer < 0 || explicit_neg { + -frac + } else { + frac + }; + + let result = Self { integer, frac }; + + Ok(result) + } + Some((_, frac)) if frac.starts_with('-') => Err(ArrowError::ParseError( + format!("Failed to parse {s} as interval amount"), + )), + Some((_, frac)) if frac.len() > INTERVAL_PRECISION as usize => { + Err(ArrowError::ParseError(format!( + "{s} exceeds the precision available for interval amount" + ))) + } + Some(_) | None => { + let integer = s.parse::().map_err(|_| { + ArrowError::ParseError(format!( + "Failed to parse {s} as interval amount" + )) + })?; + + let result = Self { integer, frac: 0 }; + Ok(result) + } + } + } +} + +#[derive(Debug, Default, PartialEq)] +struct Interval { + months: i32, + days: i32, + nanos: i64, +} + +impl Interval { + fn new(months: i32, days: i32, nanos: i64) -> Self { + Self { + months, + days, + nanos, + } + } + + fn to_year_months(&self) -> Result { + match (self.months, self.days, self.nanos) { + (months, days, nanos) if days == 0 && nanos == 0 => Ok(months), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unable to represent interval with days and nanos as year-months: {:?}", + self + ))), + } + } + + fn to_day_time(&self) -> Result<(i32, i32), ArrowError> { + let days = self.months.mul_checked(30)?.add_checked(self.days)?; + + match self.nanos { + nanos if nanos % NANOS_PER_MILLIS == 0 => { + let millis = (self.nanos / 1_000_000).try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Unable to represent {} nanos as milliseconds in a signed 32-bit integer", + self.nanos + )) + })?; + + Ok((days, millis)) + } + nanos => Err(ArrowError::InvalidArgumentError(format!( + "Unable to represent {nanos} as milliseconds" + ))), + } + } + + fn to_month_day_nanos(&self) -> (i32, i32, i64) { + (self.months, self.days, self.nanos) + } + + /// Parse string value in traditional Postgres format such as + /// `1 year 2 months 3 days 4 hours 5 minutes 6 seconds` + fn parse(value: &str, config: &IntervalParseConfig) -> Result { + let components = parse_interval_components(value, config)?; + + components + .into_iter() + .try_fold(Self::default(), |result, (amount, unit)| { + result.add(amount, unit) + }) + } + + /// Interval addition following Postgres behavior. Fractional units will be spilled into smaller units. + /// When the interval unit is larger than months, the result is rounded to total months and not spilled to days/nanos. + /// Fractional parts of weeks and days are represented using days and nanoseconds. + /// e.g. INTERVAL '0.5 MONTH' = 15 days, INTERVAL '1.5 MONTH' = 1 month 15 days + /// e.g. INTERVAL '0.5 DAY' = 12 hours, INTERVAL '1.5 DAY' = 1 day 12 hours + /// [Postgres reference](https://www.postgresql.org/docs/15/datatype-datetime.html#DATATYPE-INTERVAL-INPUT:~:text=Field%20values%20can,fractional%20on%20output.) + fn add( + &self, + amount: IntervalAmount, + unit: IntervalUnit, + ) -> Result { + let result = match unit { + IntervalUnit::Century => { + let months_int = amount.integer.mul_checked(100)?.mul_checked(12)?; + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 2); + let months = + months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} centuries as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Decade => { + let months_int = amount.integer.mul_checked(10)?.mul_checked(12)?; + + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 1); + let months = + months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} decades as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Year => { + let months_int = amount.integer.mul_checked(12)?; + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION); + let months = + months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} years as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Month => { + let months = amount.integer.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} months in a signed 32-bit integer", + &amount.integer + )) + })?; + + let days = amount.frac * 3 / 10_i64.pow(INTERVAL_PRECISION - 1); + let days = days.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} months as days in a signed 32-bit integer", + amount.frac / 10_i64.pow(INTERVAL_PRECISION) + )) + })?; + + Self::new( + self.months.add_checked(months)?, + self.days.add_checked(days)?, + self.nanos, + ) + } + IntervalUnit::Week => { + let days = amount.integer.mul_checked(7)?.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} weeks as days in a signed 32-bit integer", + &amount.integer + )) + })?; + + let nanos = + amount.frac * 7 * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + + Self::new( + self.months, + self.days.add_checked(days)?, + self.nanos.add_checked(nanos)?, + ) + } + IntervalUnit::Day => { + let days = amount.integer.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Unable to represent {} days in a signed 32-bit integer", + amount.integer + )) + })?; + + let nanos = + amount.frac * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + + Self::new( + self.months, + self.days.add_checked(days)?, + self.nanos.add_checked(nanos)?, + ) + } + IntervalUnit::Hour => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_HOUR)?; + let nanos_frac = + amount.frac * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Minute => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_MINUTE)?; + let nanos_frac = amount.frac * 6 / 10_i64.pow(INTERVAL_PRECISION - 10); + + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Second => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_SECOND)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 9); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Millisecond => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_MILLIS)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 6); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Microsecond => { + let nanos_int = amount.integer.mul_checked(1_000)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 3); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Nanosecond => { + let nanos_int = amount.integer; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + }; + + Ok(result) + } +} + +struct IntervalParseConfig { + /// The default unit to use if none is specified + /// e.g. `INTERVAL 1` represents `INTERVAL 1 SECOND` when default_unit = IntervalType::Second + default_unit: IntervalUnit, +} + +impl IntervalParseConfig { + fn new(default_unit: IntervalUnit) -> Self { + Self { default_unit } + } +} + +/// parse the string into a vector of interval components i.e. (amount, unit) tuples +fn parse_interval_components( + value: &str, + config: &IntervalParseConfig, +) -> Result, ArrowError> { + let parts = value.split_whitespace(); + + let raw_amounts = parts.clone().step_by(2); + let raw_units = parts.skip(1).step_by(2); + + // parse amounts + let (amounts, invalid_amounts) = raw_amounts + .map(IntervalAmount::from_str) + .partition::, _>(Result::is_ok); + + // invalid amounts? + if !invalid_amounts.is_empty() { + return Err(ArrowError::NotYetImplemented(format!( + "Unsupported Interval Expression with value {value:?}" + ))); + } + + // parse units + let (units, invalid_units): (Vec<_>, Vec<_>) = raw_units + .clone() + .map(IntervalUnit::from_str) + .partition(Result::is_ok); + + // invalid units? + if !invalid_units.is_empty() { + return Err(ArrowError::ParseError(format!( + "Invalid input syntax for type interval: {value:?}" + ))); + } + + // collect parsed results + let amounts = amounts.into_iter().map(Result::unwrap).collect::>(); + let units = units.into_iter().map(Result::unwrap).collect::>(); + + // if only an amount is specified, use the default unit + if amounts.len() == 1 && units.is_empty() { + return Ok(vec![(amounts[0], config.default_unit)]); + }; + + // duplicate units? + let mut observed_interval_types = 0; + for (unit, raw_unit) in units.iter().zip(raw_units) { + if observed_interval_types & (*unit as u16) != 0 { + return Err(ArrowError::ParseError(format!( + "Invalid input syntax for type interval: {value:?}. Repeated type '{raw_unit}'", + ))); + } + + observed_interval_types |= *unit as u16; + } + + let result = amounts.iter().copied().zip(units.iter().copied()); + + Ok(result.collect::>()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::temporal_conversions::date32_to_datetime; + use arrow_array::timezone::Tz; + use arrow_buffer::i256; + + #[test] + fn test_parse_nanos() { + assert_eq!(parse_nanos::<3, 0>(&[1, 2, 3]), 123_000_000); + assert_eq!(parse_nanos::<5, 0>(&[1, 2, 3, 4, 5]), 123_450_000); + assert_eq!(parse_nanos::<6, b'0'>(b"123456"), 123_456_000); + } + + #[test] + fn string_to_timestamp_timezone() { + // Explicit timezone + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08T13:42:29.190855+00:00").unwrap() + ); + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08T13:42:29.190855Z").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp("2020-09-08T13:42:29Z").unwrap() + ); // no fractional part + assert_eq!( + 1599590549190855000, + parse_timestamp("2020-09-08T13:42:29.190855-05:00").unwrap() + ); + } + + #[test] + fn string_to_timestamp_timezone_space() { + // Ensure space rather than T between time and date is accepted + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08 13:42:29.190855+00:00").unwrap() + ); + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08 13:42:29.190855Z").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp("2020-09-08 13:42:29Z").unwrap() + ); // no fractional part + assert_eq!( + 1599590549190855000, + parse_timestamp("2020-09-08 13:42:29.190855-05:00").unwrap() + ); + } + + #[test] + #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function: mktime + fn string_to_timestamp_no_timezone() { + // This test is designed to succeed in regardless of the local + // timezone the test machine is running. Thus it is still + // somewhat susceptible to bugs in the use of chrono + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29.190855").unwrap() + ); + + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29.190855").unwrap() + ); + + // Also ensure that parsing timestamps with no fractional + // second part works as well + let naive_datetime_whole_secs = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_opt(13, 42, 29).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime_whole_secs.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29").unwrap() + ); + + assert_eq!( + naive_datetime_whole_secs.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29").unwrap() + ); + + // ensure without time work + // no time, should be the nano second at + // 2020-09-08 0:0:0 + let naive_datetime_no_time = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_opt(0, 0, 0).unwrap(), + ); + + assert_eq!( + naive_datetime_no_time.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08").unwrap() + ) + } + + #[test] + fn string_to_timestamp_chrono() { + let cases = [ + "2020-09-08T13:42:29Z", + "1969-01-01T00:00:00.1Z", + "2020-09-08T12:00:12.12345678+00:00", + "2020-09-08T12:00:12+00:00", + "2020-09-08T12:00:12.1+00:00", + "2020-09-08T12:00:12.12+00:00", + "2020-09-08T12:00:12.123+00:00", + "2020-09-08T12:00:12.1234+00:00", + "2020-09-08T12:00:12.12345+00:00", + "2020-09-08T12:00:12.123456+00:00", + "2020-09-08T12:00:12.1234567+00:00", + "2020-09-08T12:00:12.12345678+00:00", + "2020-09-08T12:00:12.123456789+00:00", + "2020-09-08T12:00:12.12345678912z", + "2020-09-08T12:00:12.123456789123Z", + "2020-09-08T12:00:12.123456789123+02:00", + "2020-09-08T12:00:12.12345678912345Z", + "2020-09-08T12:00:12.1234567891234567+02:00", + "2020-09-08T12:00:60Z", + "2020-09-08T12:00:60.123Z", + "2020-09-08T12:00:60.123456+02:00", + "2020-09-08T12:00:60.1234567891234567+02:00", + "2020-09-08T12:00:60.999999999+02:00", + "2020-09-08t12:00:12.12345678+00:00", + "2020-09-08t12:00:12+00:00", + "2020-09-08t12:00:12Z", + ]; + + for case in cases { + let chrono = DateTime::parse_from_rfc3339(case).unwrap(); + let chrono_utc = chrono.with_timezone(&Utc); + + let custom = string_to_datetime(&Utc, case).unwrap(); + assert_eq!(chrono_utc, custom) + } + } + + #[test] + fn string_to_timestamp_naive() { + let cases = [ + "2018-11-13T17:11:10.011375885995", + "2030-12-04T17:11:10.123", + "2030-12-04T17:11:10.1234", + "2030-12-04T17:11:10.123456", + ]; + for case in cases { + let chrono = + NaiveDateTime::parse_from_str(case, "%Y-%m-%dT%H:%M:%S%.f").unwrap(); + let custom = string_to_datetime(&Utc, case).unwrap(); + assert_eq!(chrono, custom.naive_utc()) + } + } + + #[test] + fn string_to_timestamp_invalid() { + // Test parsing invalid formats + let cases = [ + ("", "timestamp must contain at least 10 characters"), + ("SS", "timestamp must contain at least 10 characters"), + ("Wed, 18 Feb 2015 23:16:09 GMT", "error parsing date"), + ("1997-01-31H09:26:56.123Z", "invalid timestamp separator"), + ("1997-01-31 09:26:56.123Z", "error parsing time"), + ("1997:01:31T09:26:56.123Z", "error parsing date"), + ("1997:1:31T09:26:56.123Z", "error parsing date"), + ("1997-01-32T09:26:56.123Z", "error parsing date"), + ("1997-13-32T09:26:56.123Z", "error parsing date"), + ("1997-02-29T09:26:56.123Z", "error parsing date"), + ("2015-02-30T17:35:20-08:00", "error parsing date"), + ("1997-01-10T9:26:56.123Z", "error parsing time"), + ("2015-01-20T25:35:20-08:00", "error parsing time"), + ("1997-01-10T09:61:56.123Z", "error parsing time"), + ("1997-01-10T09:61:90.123Z", "error parsing time"), + ("1997-01-10T12:00:6.123Z", "error parsing time"), + ("1997-01-31T092656.123Z", "error parsing time"), + ("1997-01-10T12:00:06.", "error parsing time"), + ("1997-01-10T12:00:06. ", "error parsing time"), + ]; + + for (s, ctx) in cases { + let expected = + format!("Parser error: Error parsing timestamp from '{s}': {ctx}"); + let actual = string_to_datetime(&Utc, s).unwrap_err().to_string(); + assert_eq!(actual, expected) + } + } + + // Parse a timestamp to timestamp int with a useful human readable error message + fn parse_timestamp(s: &str) -> Result { + let result = string_to_timestamp_nanos(s); + if let Err(e) = &result { + eprintln!("Error parsing timestamp '{s}': {e:?}"); + } + result + } + + #[test] + fn string_without_timezone_to_timestamp() { + // string without timezone should always output the same regardless the local or session timezone + + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29.190855").unwrap() + ); + + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29.190855").unwrap() + ); + + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 0).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29").unwrap() + ); + + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29").unwrap() + ); + + let tz: Tz = "+02:00".parse().unwrap(); + let date = string_to_datetime(&tz, "2020-09-08 13:42:29").unwrap(); + let utc = date.naive_utc().to_string(); + assert_eq!(utc, "2020-09-08 11:42:29"); + let local = date.naive_local().to_string(); + assert_eq!(local, "2020-09-08 13:42:29"); + + let date = string_to_datetime(&tz, "2020-09-08 13:42:29Z").unwrap(); + let utc = date.naive_utc().to_string(); + assert_eq!(utc, "2020-09-08 13:42:29"); + let local = date.naive_local().to_string(); + assert_eq!(local, "2020-09-08 15:42:29"); + + let dt = + NaiveDateTime::parse_from_str("2020-09-08T13:42:29Z", "%Y-%m-%dT%H:%M:%SZ") + .unwrap(); + let local: Tz = "+08:00".parse().unwrap(); + + // Parsed as offset from UTC + let date = string_to_datetime(&local, "2020-09-08T13:42:29Z").unwrap(); + assert_eq!(dt, date.naive_utc()); + assert_ne!(dt, date.naive_local()); + + // Parsed as offset from local + let date = string_to_datetime(&local, "2020-09-08 13:42:29").unwrap(); + assert_eq!(dt, date.naive_local()); + assert_ne!(dt, date.naive_utc()); + } + + #[test] + fn parse_date32() { + let cases = [ + "2020-09-08", + "2020-9-8", + "2020-09-8", + "2020-9-08", + "2020-12-1", + "1690-2-5", + ]; + for case in cases { + let v = date32_to_datetime(Date32Type::parse(case).unwrap()).unwrap(); + let expected: NaiveDate = case.parse().unwrap(); + assert_eq!(v.date(), expected); + } + + let err_cases = [ + "", + "80-01-01", + "342", + "Foo", + "2020-09-08-03", + "2020--04-03", + "2020--", + ]; + for case in err_cases { + assert_eq!(Date32Type::parse(case), None); + } + } + + #[test] + fn parse_time64_nanos() { + assert_eq!( + Time64NanosecondType::parse("02:10:01.1234567899999999"), + Some(7_801_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("02:10:01.1234567"), + Some(7_801_123_456_700) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.1234567"), + Some(7_801_123_456_700) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 AM"), + Some(601_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 am"), + Some(601_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.12345678 PM"), + Some(51_001_123_456_780) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.12345678 pm"), + Some(51_001_123_456_780) + ); + assert_eq!( + Time64NanosecondType::parse("02:10:01"), + Some(7_801_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01"), + Some(7_801_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01 AM"), + Some(601_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01 am"), + Some(601_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01 PM"), + Some(51_001_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01 pm"), + Some(51_001_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("02:10"), + Some(7_800_000_000_000) + ); + assert_eq!(Time64NanosecondType::parse("2:10"), Some(7_800_000_000_000)); + assert_eq!( + Time64NanosecondType::parse("12:10 AM"), + Some(600_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10 am"), + Some(600_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10 PM"), + Some(51_000_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10 pm"), + Some(51_000_000_000_000) + ); + + // parse directly as nanoseconds + assert_eq!(Time64NanosecondType::parse("1"), Some(1)); + + // leap second + assert_eq!( + Time64NanosecondType::parse("23:59:60"), + Some(86_400_000_000_000) + ); + + // custom format + assert_eq!( + Time64NanosecondType::parse_formatted( + "02 - 10 - 01 - .1234567", + "%H - %M - %S - %.f" + ), + Some(7_801_123_456_700) + ); + } + + #[test] + fn parse_time64_micros() { + // expected formats + assert_eq!( + Time64MicrosecondType::parse("02:10:01.1234"), + Some(7_801_123_400) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.1234"), + Some(7_801_123_400) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 AM"), + Some(601_123_456) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 am"), + Some(601_123_456) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.12345 PM"), + Some(51_001_123_450) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.12345 pm"), + Some(51_001_123_450) + ); + assert_eq!( + Time64MicrosecondType::parse("02:10:01"), + Some(7_801_000_000) + ); + assert_eq!(Time64MicrosecondType::parse("2:10:01"), Some(7_801_000_000)); + assert_eq!( + Time64MicrosecondType::parse("12:10:01 AM"), + Some(601_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01 am"), + Some(601_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01 PM"), + Some(51_001_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01 pm"), + Some(51_001_000_000) + ); + assert_eq!(Time64MicrosecondType::parse("02:10"), Some(7_800_000_000)); + assert_eq!(Time64MicrosecondType::parse("2:10"), Some(7_800_000_000)); + assert_eq!(Time64MicrosecondType::parse("12:10 AM"), Some(600_000_000)); + assert_eq!(Time64MicrosecondType::parse("12:10 am"), Some(600_000_000)); + assert_eq!( + Time64MicrosecondType::parse("2:10 PM"), + Some(51_000_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10 pm"), + Some(51_000_000_000) + ); + + // parse directly as microseconds + assert_eq!(Time64MicrosecondType::parse("1"), Some(1)); + + // leap second + assert_eq!( + Time64MicrosecondType::parse("23:59:60"), + Some(86_400_000_000) + ); + + // custom format + assert_eq!( + Time64MicrosecondType::parse_formatted( + "02 - 10 - 01 - .1234", + "%H - %M - %S - %.f" + ), + Some(7_801_123_400) + ); + } + + #[test] + fn parse_time32_millis() { + // expected formats + assert_eq!(Time32MillisecondType::parse("02:10:01.1"), Some(7_801_100)); + assert_eq!(Time32MillisecondType::parse("2:10:01.1"), Some(7_801_100)); + assert_eq!( + Time32MillisecondType::parse("12:10:01.123 AM"), + Some(601_123) + ); + assert_eq!( + Time32MillisecondType::parse("12:10:01.123 am"), + Some(601_123) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 PM"), + Some(51_001_120) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 pm"), + Some(51_001_120) + ); + assert_eq!(Time32MillisecondType::parse("02:10:01"), Some(7_801_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01"), Some(7_801_000)); + assert_eq!(Time32MillisecondType::parse("12:10:01 AM"), Some(601_000)); + assert_eq!(Time32MillisecondType::parse("12:10:01 am"), Some(601_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01 PM"), Some(51_001_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01 pm"), Some(51_001_000)); + assert_eq!(Time32MillisecondType::parse("02:10"), Some(7_800_000)); + assert_eq!(Time32MillisecondType::parse("2:10"), Some(7_800_000)); + assert_eq!(Time32MillisecondType::parse("12:10 AM"), Some(600_000)); + assert_eq!(Time32MillisecondType::parse("12:10 am"), Some(600_000)); + assert_eq!(Time32MillisecondType::parse("2:10 PM"), Some(51_000_000)); + assert_eq!(Time32MillisecondType::parse("2:10 pm"), Some(51_000_000)); + + // parse directly as milliseconds + assert_eq!(Time32MillisecondType::parse("1"), Some(1)); + + // leap second + assert_eq!(Time32MillisecondType::parse("23:59:60"), Some(86_400_000)); + + // custom format + assert_eq!( + Time32MillisecondType::parse_formatted( + "02 - 10 - 01 - .1", + "%H - %M - %S - %.f" + ), + Some(7_801_100) + ); + } + + #[test] + fn parse_time32_secs() { + // expected formats + assert_eq!(Time32SecondType::parse("02:10:01.1"), Some(7_801)); + assert_eq!(Time32SecondType::parse("02:10:01"), Some(7_801)); + assert_eq!(Time32SecondType::parse("2:10:01"), Some(7_801)); + assert_eq!(Time32SecondType::parse("12:10:01 AM"), Some(601)); + assert_eq!(Time32SecondType::parse("12:10:01 am"), Some(601)); + assert_eq!(Time32SecondType::parse("2:10:01 PM"), Some(51_001)); + assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001)); + assert_eq!(Time32SecondType::parse("02:10"), Some(7_800)); + assert_eq!(Time32SecondType::parse("2:10"), Some(7_800)); + assert_eq!(Time32SecondType::parse("12:10 AM"), Some(600)); + assert_eq!(Time32SecondType::parse("12:10 am"), Some(600)); + assert_eq!(Time32SecondType::parse("2:10 PM"), Some(51_000)); + assert_eq!(Time32SecondType::parse("2:10 pm"), Some(51_000)); + + // parse directly as seconds + assert_eq!(Time32SecondType::parse("1"), Some(1)); + + // leap second + assert_eq!(Time32SecondType::parse("23:59:60"), Some(86400)); + + // custom format + assert_eq!( + Time32SecondType::parse_formatted("02 - 10 - 01", "%H - %M - %S"), + Some(7_801) + ); + } + + #[test] + fn test_string_to_time_invalid() { + let cases = [ + "25:00", + "9:00:", + "009:00", + "09:0:00", + "25:00:00", + "13:00 AM", + "13:00 PM", + "12:00. AM", + "09:0:00", + "09:01:0", + "09:01:1", + "9:1:0", + "09:01:0", + "1:00.123", + "1:00:00.123f", + " 9:00:00", + ":09:00", + "T9:00:00", + "AM", + ]; + for case in cases { + assert!(string_to_time(case).is_none(), "{case}"); + } + } + + #[test] + fn test_string_to_time_chrono() { + let cases = [ + ("1:00", "%H:%M"), + ("12:00", "%H:%M"), + ("13:00", "%H:%M"), + ("24:00", "%H:%M"), + ("1:00:00", "%H:%M:%S"), + ("12:00:30", "%H:%M:%S"), + ("13:00:59", "%H:%M:%S"), + ("24:00:60", "%H:%M:%S"), + ("09:00:00", "%H:%M:%S%.f"), + ("0:00:30.123456", "%H:%M:%S%.f"), + ("0:00 AM", "%I:%M %P"), + ("1:00 AM", "%I:%M %P"), + ("12:00 AM", "%I:%M %P"), + ("13:00 AM", "%I:%M %P"), + ("0:00 PM", "%I:%M %P"), + ("1:00 PM", "%I:%M %P"), + ("12:00 PM", "%I:%M %P"), + ("13:00 PM", "%I:%M %P"), + ("1:00 pM", "%I:%M %P"), + ("1:00 Pm", "%I:%M %P"), + ("1:00 aM", "%I:%M %P"), + ("1:00 Am", "%I:%M %P"), + ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789123 PM", "%I:%M:%S%.f %P"), + ("1:00:30.1234 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789123456789 PM", "%I:%M:%S%.f %P"), + ("1:00:30.12F456 PM", "%I:%M:%S%.f %P"), + ]; + for (s, format) in cases { + let chrono = NaiveTime::parse_from_str(s, format).ok(); + let custom = string_to_time(s); + assert_eq!(chrono, custom, "{s}"); + } + } + + #[test] + fn test_parse_interval() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + assert_eq!( + Interval::new(1i32, 0i32, 0i64), + Interval::parse("1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(2i32, 0i32, 0i64), + Interval::parse("2 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-1i32, -18i32, -(NANOS_PER_DAY / 5)), + Interval::parse("-1.5 months -3.2 days", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 15i32, 0), + Interval::parse("0.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 15i32, 0), + Interval::parse(".5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -15i32, 0), + Interval::parse("-0.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -15i32, 0), + Interval::parse("-.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(2i32, 10i32, 9 * NANOS_PER_HOUR), + Interval::parse("2.1 months 7.25 days 3 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::parse("1 centurys 1 month", &config) + .unwrap_err() + .to_string(), + r#"Parser error: Invalid input syntax for type interval: "1 centurys 1 month""# + ); + + assert_eq!( + Interval::new(37i32, 0i32, 0i64), + Interval::parse("3 year 1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(35i32, 0i32, 0i64), + Interval::parse("3 year -1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-37i32, 0i32, 0i64), + Interval::parse("-3 year -1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-35i32, 0i32, 0i64), + Interval::parse("-3 year 1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 5i32, 0i64), + Interval::parse("5 days", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, 3 * NANOS_PER_HOUR), + Interval::parse("7 days 3 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, 5 * NANOS_PER_MINUTE), + Interval::parse("7 days 5 minutes", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, -5 * NANOS_PER_MINUTE), + Interval::parse("7 days -5 minutes", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -7i32, 5 * NANOS_PER_HOUR), + Interval::parse("-7 days 5 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + 0i32, + -7i32, + -5 * NANOS_PER_HOUR - 5 * NANOS_PER_MINUTE - 5 * NANOS_PER_SECOND + ), + Interval::parse("-7 days -5 hours -5 minutes -5 seconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 0i32, 25 * NANOS_PER_MILLIS), + Interval::parse("1 year 25 millisecond", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + 12i32, + 1i32, + (NANOS_PER_SECOND as f64 * 0.000000001_f64) as i64 + ), + Interval::parse("1 year 1 day 0.000000001 seconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, NANOS_PER_MILLIS / 10), + Interval::parse("1 year 1 day 0.1 milliseconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, 1000i64), + Interval::parse("1 year 1 day 1 microsecond", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, 1i64), + Interval::parse("1 year 1 day 1 nanoseconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(1i32, 0i32, -NANOS_PER_SECOND), + Interval::parse("1 month -1 second", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-13i32, -8i32, -NANOS_PER_HOUR - NANOS_PER_MINUTE - NANOS_PER_SECOND - (1.11_f64 * NANOS_PER_MILLIS as f64) as i64), + Interval::parse("-1 year -1 month -1 week -1 day -1 hour -1 minute -1 second -1.11 millisecond", &config).unwrap(), + ); + } + + #[test] + fn test_duplicate_interval_type() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + let err = Interval::parse("1 month 1 second 1 second", &config) + .expect_err("parsing interval should have failed"); + assert_eq!( + r#"ParseError("Invalid input syntax for type interval: \"1 month 1 second 1 second\". Repeated type 'second'")"#, + format!("{err:?}") + ); + + // test with singular and plural forms + let err = Interval::parse("1 century 2 centuries", &config) + .expect_err("parsing interval should have failed"); + assert_eq!( + r#"ParseError("Invalid input syntax for type interval: \"1 century 2 centuries\". Repeated type 'centuries'")"#, + format!("{err:?}") + ); + } + + #[test] + fn test_interval_amount_parsing() { + // integer + let result = IntervalAmount::from_str("123").unwrap(); + let expected = IntervalAmount::new(123, 0); + + assert_eq!(result, expected); + + // positive w/ fractional + let result = IntervalAmount::from_str("0.3").unwrap(); + let expected = IntervalAmount::new(0, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)); + + assert_eq!(result, expected); + + // negative w/ fractional + let result = IntervalAmount::from_str("-3.5").unwrap(); + let expected = IntervalAmount::new(-3, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)); + + assert_eq!(result, expected); + + // invalid: missing fractional + let result = IntervalAmount::from_str("3."); + assert!(result.is_err()); + + // invalid: sign in fractional + let result = IntervalAmount::from_str("3.-5"); + assert!(result.is_err()); + } + + #[test] + fn test_interval_precision() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + let result = Interval::parse("100000.1 days", &config).unwrap(); + let expected = Interval::new(0_i32, 100_000_i32, NANOS_PER_DAY / 10); + + assert_eq!(result, expected); + } + + #[test] + fn test_interval_addition() { + // add 4.1 centuries + let start = Interval::new(1, 2, 3); + let expected = Interval::new(4921, 2, 3); + + let result = start + .add( + IntervalAmount::new(4, 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Century, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 10.25 decades + let start = Interval::new(1, 2, 3); + let expected = Interval::new(1231, 2, 3); + + let result = start + .add( + IntervalAmount::new(10, 25 * 10_i64.pow(INTERVAL_PRECISION - 2)), + IntervalUnit::Decade, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 30.3 years (reminder: Postgres logic does not spill to days/nanos when interval is larger than a month) + let start = Interval::new(1, 2, 3); + let expected = Interval::new(364, 2, 3); + + let result = start + .add( + IntervalAmount::new(30, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Year, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 1.5 months + let start = Interval::new(1, 2, 3); + let expected = Interval::new(2, 17, 3); + + let result = start + .add( + IntervalAmount::new(1, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Month, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add -2 weeks + let start = Interval::new(1, 25, 3); + let expected = Interval::new(1, 11, 3); + + let result = start + .add(IntervalAmount::new(-2, 0), IntervalUnit::Week) + .unwrap(); + + assert_eq!(result, expected); + + // add 2.2 days + let start = Interval::new(12, 15, 3); + let expected = Interval::new(12, 17, 3 + 17_280 * NANOS_PER_SECOND); + + let result = start + .add( + IntervalAmount::new(2, 2 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Day, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 12.5 hours + let start = Interval::new(1, 2, 3); + let expected = Interval::new(1, 2, 3 + 45_000 * NANOS_PER_SECOND); + + let result = start + .add( + IntervalAmount::new(12, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Hour, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add -1.5 minutes + let start = Interval::new(0, 0, -3); + let expected = Interval::new(0, 0, -90_000_000_000 - 3); + + let result = start + .add( + IntervalAmount::new(-1, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Minute, + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn string_to_timestamp_old() { + parse_timestamp("1677-06-14T07:29:01.256") + .map_err(|e| assert!(e.to_string().ends_with(ERR_NANOSECONDS_NOT_SUPPORTED))) + .unwrap_err(); + } + + #[test] + fn test_parse_decimal_with_parameter() { + let tests = [ + ("0", 0i128), + ("123.123", 123123i128), + ("123.1234", 123123i128), + ("123.1", 123100i128), + ("123", 123000i128), + ("-123.123", -123123i128), + ("-123.1234", -123123i128), + ("-123.1", -123100i128), + ("-123", -123000i128), + ("0.0000123", 0i128), + ("12.", 12000i128), + ("-12.", -12000i128), + ("00.1", 100i128), + ("-00.1", -100i128), + ("12345678912345678.1234", 12345678912345678123i128), + ("-12345678912345678.1234", -12345678912345678123i128), + ("99999999999999999.999", 99999999999999999999i128), + ("-99999999999999999.999", -99999999999999999999i128), + (".123", 123i128), + ("-.123", -123i128), + ("123.", 123000i128), + ("-123.", -123000i128), + ]; + for (s, i) in tests { + let result_128 = parse_decimal::(s, 20, 3); + assert_eq!(i, result_128.unwrap()); + let result_256 = parse_decimal::(s, 20, 3); + assert_eq!(i256::from_i128(i), result_256.unwrap()); + } + let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+", "-"]; + for s in can_not_parse_tests { + let result_128 = parse_decimal::(s, 20, 3); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result_128.unwrap_err().to_string() + ); + let result_256 = parse_decimal::(s, 20, 3); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result_256.unwrap_err().to_string() + ); + } + let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"]; + for s in overflow_parse_tests { + let result_128 = parse_decimal::(s, 10, 3); + let expected_128 = "Parser error: parse decimal overflow"; + let actual_128 = result_128.unwrap_err().to_string(); + + assert!( + actual_128.contains(expected_128), + "actual: '{actual_128}', expected: '{expected_128}'" + ); + + let result_256 = parse_decimal::(s, 10, 3); + let expected_256 = "Parser error: parse decimal overflow"; + let actual_256 = result_256.unwrap_err().to_string(); + + assert!( + actual_256.contains(expected_256), + "actual: '{actual_256}', expected: '{expected_256}'" + ); + } + + let edge_tests_128 = [ + ( + "99999999999999999999999999999999999999", + 99999999999999999999999999999999999999i128, + 0, + ), + ( + "999999999999999999999999999999999999.99", + 99999999999999999999999999999999999999i128, + 2, + ), + ( + "9999999999999999999999999.9999999999999", + 99999999999999999999999999999999999999i128, + 13, + ), + ( + "9999999999999999999999999", + 99999999999999999999999990000000000000i128, + 13, + ), + ( + "0.99999999999999999999999999999999999999", + 99999999999999999999999999999999999999i128, + 38, + ), + ]; + for (s, i, scale) in edge_tests_128 { + let result_128 = parse_decimal::(s, 38, scale); + assert_eq!(i, result_128.unwrap()); + } + let edge_tests_256 = [ + ( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", +i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), + 0, + ), + ( + "999999999999999999999999999999999999999999999999999999999999999999999999.9999", + i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), + 4, + ), + ( + "99999999999999999999999999999999999999999999999999.99999999999999999999999999", + i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(), + 26, + ), + ( + "99999999999999999999999999999999999999999999999999", + i256::from_string("9999999999999999999999999999999999999999999999999900000000000000000000000000").unwrap(), + 26, + ), + ]; + for (s, i, scale) in edge_tests_256 { + let result = parse_decimal::(s, 76, scale); + assert_eq!(i, result.unwrap()); + } + } +} diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs index 13d1df6a118d..59a9f9d605e2 100644 --- a/arrow-cast/src/pretty.rs +++ b/arrow-cast/src/pretty.rs @@ -848,7 +848,7 @@ mod tests { let mut buf = String::new(); write!(&mut buf, "{}", pretty_format_batches(&[batch]).unwrap()).unwrap(); - let s = vec![ + let s = [ "+---+-----+", "| a | b |", "+---+-----+", diff --git a/arrow-csv/Cargo.toml b/arrow-csv/Cargo.toml index 1f1a762d5065..66a6d7dbcaa5 100644 --- a/arrow-csv/Cargo.toml +++ b/arrow-csv/Cargo.toml @@ -39,7 +39,7 @@ arrow-buffer = { workspace = true } arrow-cast = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } -chrono = { version = "0.4.23", default-features = false, features = ["clock"] } +chrono = { workspace = true } csv = { version = "1.1", default-features = false } csv-core = { version = "0.1" } lazy_static = { version = "1.4", default-features = false } diff --git a/arrow-csv/examples/README.md b/arrow-csv/examples/README.md new file mode 100644 index 000000000000..340413e76d94 --- /dev/null +++ b/arrow-csv/examples/README.md @@ -0,0 +1,21 @@ + + +# Examples +- [`csv_calculation.rs`](csv_calculation.rs): performs a simple calculation using the CSV reader \ No newline at end of file diff --git a/arrow-csv/examples/csv_calculation.rs b/arrow-csv/examples/csv_calculation.rs new file mode 100644 index 000000000000..6ce963e2b012 --- /dev/null +++ b/arrow-csv/examples/csv_calculation.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::Int16Type; +use arrow_csv::ReaderBuilder; + +use arrow_schema::{DataType, Field, Schema}; +use std::fs::File; +use std::sync::Arc; + +fn main() { + // read csv from file + let file = File::open("arrow-csv/test/data/example.csv").unwrap(); + let csv_schema = Schema::new(vec![ + Field::new("c1", DataType::Int16, true), + Field::new("c2", DataType::Float32, true), + Field::new("c3", DataType::Utf8, true), + Field::new("c4", DataType::Boolean, true), + ]); + let mut reader = ReaderBuilder::new(Arc::new(csv_schema)) + .with_header(true) + .build(file) + .unwrap(); + + match reader.next() { + Some(r) => match r { + Ok(r) => { + // get the column(0) max value + let col = r.column(0).as_primitive::(); + let max = col.iter().max().flatten(); + println!("max value column(0): {max:?}") + } + Err(e) => { + println!("{e:?}"); + } + }, + None => { + println!("csv is empty"); + } + } +} diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 328c2cd41f3b..a194b35ffa46 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -133,8 +133,8 @@ use arrow_schema::*; use chrono::{TimeZone, Utc}; use csv::StringRecord; use lazy_static::lazy_static; -use regex::RegexSet; -use std::fmt; +use regex::{Regex, RegexSet}; +use std::fmt::{self, Debug}; use std::fs::File; use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; use std::sync::Arc; @@ -157,6 +157,22 @@ lazy_static! { ]).unwrap(); } +/// A wrapper over `Option` to check if the value is `NULL`. +#[derive(Debug, Clone, Default)] +struct NullRegex(Option); + +impl NullRegex { + /// Returns true if the value should be considered as `NULL` according to + /// the provided regular expression. + #[inline] + fn is_null(&self, s: &str) -> bool { + match &self.0 { + Some(r) => r.is_match(s), + None => s.is_empty(), + } + } +} + #[derive(Default, Copy, Clone)] struct InferredDataType { /// Packed booleans indicating type @@ -177,6 +193,7 @@ impl InferredDataType { /// Returns the inferred data type fn get(&self) -> DataType { match self.packed { + 0 => DataType::Null, 1 => DataType::Boolean, 2 => DataType::Int64, 4 | 6 => DataType::Float64, // Promote Int64 to Float64 @@ -208,16 +225,17 @@ impl InferredDataType { /// The format specification for the CSV file #[derive(Debug, Clone, Default)] pub struct Format { - has_header: bool, + header: bool, delimiter: Option, escape: Option, quote: Option, terminator: Option, + null_regex: NullRegex, } impl Format { pub fn with_header(mut self, has_header: bool) -> Self { - self.has_header = has_header; + self.header = has_header; self } @@ -241,6 +259,12 @@ impl Format { self } + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.null_regex = NullRegex(Some(null_regex)); + self + } + /// Infer schema of CSV records from the provided `reader` /// /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` @@ -256,7 +280,7 @@ impl Format { // get or create header names // when has_header is false, creates default column names with column_ prefix - let headers: Vec = if self.has_header { + let headers: Vec = if self.header { let headers = &csv_reader.headers().map_err(map_csv_error)?.clone(); headers.iter().map(|s| s.to_string()).collect() } else { @@ -287,7 +311,7 @@ impl Format { column_types.iter_mut().enumerate().take(header_length) { if let Some(string) = record.get(i) { - if !string.is_empty() { + if !self.null_regex.is_null(string) { column_type.update(string) } } @@ -307,7 +331,7 @@ impl Format { /// Build a [`csv::Reader`] for this [`Format`] fn build_reader(&self, reader: R) -> csv::Reader { let mut builder = csv::ReaderBuilder::new(); - builder.has_headers(self.has_header); + builder.has_headers(self.header); if let Some(c) = self.delimiter { builder.delimiter(c); @@ -379,7 +403,7 @@ pub fn infer_reader_schema( ) -> Result<(Schema, usize), ArrowError> { let format = Format { delimiter: Some(delimiter), - has_header, + header: has_header, ..Default::default() }; format.infer_schema(reader, max_read_records) @@ -401,7 +425,7 @@ pub fn infer_schema_from_files( let mut records_to_read = max_read_records.unwrap_or(usize::MAX); let format = Format { delimiter: Some(delimiter), - has_header, + header: has_header, ..Default::default() }; @@ -557,6 +581,9 @@ pub struct Decoder { /// A decoder for [`StringRecords`] record_decoder: RecordDecoder, + + /// Check if the string matches this pattern for `NULL`. + null_regex: NullRegex, } impl Decoder { @@ -603,6 +630,7 @@ impl Decoder { Some(self.schema.metadata.clone()), self.projection.as_ref(), self.line_number, + &self.null_regex, )?; self.line_number += rows.len(); Ok(Some(batch)) @@ -621,6 +649,7 @@ fn parse( metadata: Option>, projection: Option<&Vec>, line_number: usize, + null_regex: &NullRegex, ) -> Result { let projection: Vec = match projection { Some(v) => v.clone(), @@ -633,7 +662,9 @@ fn parse( let i = *i; let field = &fields[i]; match field.data_type() { - DataType::Boolean => build_boolean_array(line_number, rows, i), + DataType::Boolean => { + build_boolean_array(line_number, rows, i, null_regex) + } DataType::Decimal128(precision, scale) => { build_decimal_array::( line_number, @@ -641,6 +672,7 @@ fn parse( i, *precision, *scale, + null_regex, ) } DataType::Decimal256(precision, scale) => { @@ -650,53 +682,73 @@ fn parse( i, *precision, *scale, + null_regex, ) } - DataType::Int8 => build_primitive_array::(line_number, rows, i), + DataType::Int8 => { + build_primitive_array::(line_number, rows, i, null_regex) + } DataType::Int16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Int32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Int64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt8 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Float32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Float64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Date32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Date64 => { - build_primitive_array::(line_number, rows, i) - } - DataType::Time32(TimeUnit::Second) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } + DataType::Time32(TimeUnit::Second) => build_primitive_array::< + Time32SecondType, + >( + line_number, rows, i, null_regex + ), DataType::Time32(TimeUnit::Millisecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + null_regex, + ) } DataType::Time64(TimeUnit::Microsecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + null_regex, + ) } DataType::Time64(TimeUnit::Nanosecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + null_regex, + ) } DataType::Timestamp(TimeUnit::Second, tz) => { build_timestamp_array::( @@ -704,6 +756,7 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { @@ -712,6 +765,7 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { @@ -720,6 +774,7 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { @@ -728,11 +783,18 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } + DataType::Null => { + Ok(Arc::new(NullArray::builder(rows.len()).finish()) as ArrayRef) + } DataType::Utf8 => Ok(Arc::new( rows.iter() - .map(|row| Some(row.get(i))) + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) .collect::(), ) as ArrayRef), DataType::Dictionary(key_type, value_type) @@ -827,11 +889,12 @@ fn build_decimal_array( col_idx: usize, precision: u8, scale: i8, + null_regex: &NullRegex, ) -> Result { let mut decimal_builder = PrimitiveBuilder::::with_capacity(rows.len()); for row in rows.iter() { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { // append null decimal_builder.append_null(); } else { @@ -859,12 +922,13 @@ fn build_primitive_array( line_number: usize, rows: &StringRecords<'_>, col_idx: usize, + null_regex: &NullRegex, ) -> Result { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { return Ok(None); } @@ -888,14 +952,27 @@ fn build_timestamp_array( rows: &StringRecords<'_>, col_idx: usize, timezone: Option<&str>, + null_regex: &NullRegex, ) -> Result { Ok(Arc::new(match timezone { Some(timezone) => { let tz: Tz = timezone.parse()?; - build_timestamp_array_impl::(line_number, rows, col_idx, &tz)? - .with_timezone(timezone) + build_timestamp_array_impl::( + line_number, + rows, + col_idx, + &tz, + null_regex, + )? + .with_timezone(timezone) } - None => build_timestamp_array_impl::(line_number, rows, col_idx, &Utc)?, + None => build_timestamp_array_impl::( + line_number, + rows, + col_idx, + &Utc, + null_regex, + )?, })) } @@ -904,29 +981,36 @@ fn build_timestamp_array_impl( rows: &StringRecords<'_>, col_idx: usize, timezone: &Tz, + null_regex: &NullRegex, ) -> Result, ArrowError> { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { return Ok(None); } - let date = string_to_datetime(timezone, s).map_err(|e| { - ArrowError::ParseError(format!( - "Error parsing column {col_idx} at line {}: {}", - line_number + row_index, - e - )) - })?; - - Ok(Some(match T::UNIT { - TimeUnit::Second => date.timestamp(), - TimeUnit::Millisecond => date.timestamp_millis(), - TimeUnit::Microsecond => date.timestamp_micros(), - TimeUnit::Nanosecond => date.timestamp_nanos(), - })) + let date = string_to_datetime(timezone, s) + .and_then(|date| match T::UNIT { + TimeUnit::Second => Ok(date.timestamp()), + TimeUnit::Millisecond => Ok(date.timestamp_millis()), + TimeUnit::Microsecond => Ok(date.timestamp_micros()), + TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| { + ArrowError::ParseError(format!( + "{} would overflow 64-bit signed nanoseconds", + date.to_rfc3339(), + )) + }), + }) + .map_err(|e| { + ArrowError::ParseError(format!( + "Error parsing column {col_idx} at line {}: {}", + line_number + row_index, + e + )) + })?; + Ok(Some(date)) }) .collect() } @@ -936,12 +1020,13 @@ fn build_boolean_array( line_number: usize, rows: &StringRecords<'_>, col_idx: usize, + null_regex: &NullRegex, ) -> Result { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { return Ok(None); } let parsed = parse_bool(s); @@ -1010,8 +1095,16 @@ impl ReaderBuilder { } /// Set whether the CSV file has headers + #[deprecated(note = "Use with_header")] + #[doc(hidden)] pub fn has_header(mut self, has_header: bool) -> Self { - self.format.has_header = has_header; + self.format.header = has_header; + self + } + + /// Set whether the CSV file has a header + pub fn with_header(mut self, has_header: bool) -> Self { + self.format.header = has_header; self } @@ -1042,6 +1135,12 @@ impl ReaderBuilder { self } + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.format.null_regex = NullRegex(Some(null_regex)); + self + } + /// Set the batch size (number of records to load at one time) pub fn with_batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; @@ -1085,7 +1184,7 @@ impl ReaderBuilder { let delimiter = self.format.build_parser(); let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len()); - let header = self.format.has_header as usize; + let header = self.format.header as usize; let (start, end) = match self.bounds { Some((start, end)) => (start + header, end + header), @@ -1100,6 +1199,7 @@ impl ReaderBuilder { end, projection: self.projection, batch_size: self.batch_size, + null_regex: self.format.null_regex, } } } @@ -1225,7 +1325,7 @@ mod tests { .chain(Cursor::new("\n".to_string())) .chain(file_without_headers); let mut csv = ReaderBuilder::new(Arc::new(schema)) - .has_header(true) + .with_header(true) .build(both_files) .unwrap(); let batch = csv.next().unwrap().unwrap(); @@ -1243,7 +1343,7 @@ mod tests { .unwrap(); file.rewind().unwrap(); - let builder = ReaderBuilder::new(Arc::new(schema)).has_header(true); + let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true); let mut csv = builder.build(file).unwrap(); let expected_schema = Schema::new(vec![ @@ -1406,14 +1506,14 @@ mod tests { let schema = Arc::new(Schema::new(vec![ Field::new("c_int", DataType::UInt64, false), Field::new("c_float", DataType::Float32, true), - Field::new("c_string", DataType::Utf8, false), + Field::new("c_string", DataType::Utf8, true), Field::new("c_bool", DataType::Boolean, false), ])); let file = File::open("test/data/null_test.csv").unwrap(); let mut csv = ReaderBuilder::new(schema) - .has_header(true) + .with_header(true) .build(file) .unwrap(); @@ -1426,6 +1526,91 @@ mod tests { assert!(!batch.column(1).is_null(4)); } + #[test] + fn test_init_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + Field::new("c_null", DataType::Null, true), + ])); + let file = File::open("test/data/init_null_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_init_nulls_with_inference() { + let format = Format::default().with_header(true).with_delimiter(b','); + + let mut file = File::open("test/data/init_null_test.csv").unwrap(); + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c_int", DataType::Int64, true), + Field::new("c_float", DataType::Float64, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + Field::new("c_null", DataType::Null, true), + ]); + assert_eq!(schema, expected_schema); + + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_custom_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ])); + + let file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .with_null_regex(null_regex) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + // "nil"s should be NULL + assert!(batch.column(0).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(batch.column(3).is_null(4)); + assert!(batch.column(2).is_null(3)); + assert!(!batch.column(2).is_null(4)); + } + #[test] fn test_nulls_with_inference() { let mut file = File::open("test/data/various_types.csv").unwrap(); @@ -1485,6 +1670,42 @@ mod tests { assert!(!batch.column(1).is_null(4)); } + #[test] + fn test_custom_nulls_with_inference() { + let mut file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let format = Format::default() + .with_header(true) + .with_null_regex(null_regex); + + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c_int", DataType::Int64, true), + Field::new("c_float", DataType::Float64, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ]); + + assert_eq!(schema, expected_schema); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3]); + + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + assert_eq!(5, batch.num_rows()); + assert_eq!(4, batch.num_columns()); + + assert_eq!(batch.schema().as_ref(), &expected_schema); + } + #[test] fn test_parse_invalid_csv() { let file = File::open("test/data/various_types_invalid.csv").unwrap(); @@ -1497,7 +1718,7 @@ mod tests { ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(true) + .with_header(true) .with_delimiter(b'|') .with_batch_size(512) .with_projection(vec![0, 1, 2, 3]); @@ -1824,7 +2045,7 @@ mod tests { Field::new("text2", DataType::Utf8, false), ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(false) + .with_header(false) .with_quote(b'~'); // default is ", change to ~ let mut csv_text = Vec::new(); @@ -1856,7 +2077,7 @@ mod tests { Field::new("text2", DataType::Utf8, false), ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(false) + .with_header(false) .with_escape(b'\\'); // default is None, change to \ let mut csv_text = Vec::new(); @@ -1888,7 +2109,7 @@ mod tests { Field::new("text2", DataType::Utf8, false), ]); let builder = ReaderBuilder::new(Arc::new(schema)) - .has_header(false) + .with_header(false) .with_terminator(b'\n'); // default is CRLF, change to LF let mut csv_text = Vec::new(); @@ -1930,7 +2151,7 @@ mod tests { ])); for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() { - let mut reader = ReaderBuilder::new(schema.clone()).has_header(has_header); + let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header); if let Some((start, end)) = bounds { reader = reader.with_bounds(start, end); } @@ -1995,7 +2216,7 @@ mod tests { for capacity in [1, 3, 7, 100] { let reader = ReaderBuilder::new(schema.clone()) .with_batch_size(batch_size) - .has_header(has_header) + .with_header(has_header) .build(File::open(path).unwrap()) .unwrap(); @@ -2013,7 +2234,7 @@ mod tests { let reader = ReaderBuilder::new(schema.clone()) .with_batch_size(batch_size) - .has_header(has_header) + .with_header(has_header) .build_buffered(buffered) .unwrap(); @@ -2026,8 +2247,8 @@ mod tests { fn err_test(csv: &[u8], expected: &str) { let schema = Arc::new(Schema::new(vec![ - Field::new("text1", DataType::Utf8, false), - Field::new("text2", DataType::Utf8, false), + Field::new("text1", DataType::Utf8, true), + Field::new("text2", DataType::Utf8, true), ])); let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv)); let b = ReaderBuilder::new(schema) @@ -2132,7 +2353,7 @@ mod tests { #[test] fn test_inference() { let cases: &[(&[&str], DataType)] = &[ - (&[], DataType::Utf8), + (&[], DataType::Null), (&["false", "12"], DataType::Utf8), (&["12", "cupcakes"], DataType::Utf8), (&["12", "12.4"], DataType::Float64), diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index 840e8e8a93cc..1ca956e2c73f 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -70,11 +70,6 @@ use csv::ByteRecord; use std::io::Write; use crate::map_csv_error; - -const DEFAULT_DATE_FORMAT: &str = "%F"; -const DEFAULT_TIME_FORMAT: &str = "%T"; -const DEFAULT_TIMESTAMP_FORMAT: &str = "%FT%H:%M:%S.%9f"; -const DEFAULT_TIMESTAMP_TZ_FORMAT: &str = "%FT%H:%M:%S.%9f%:z"; const DEFAULT_NULL_VALUE: &str = ""; /// A CSV writer @@ -82,41 +77,29 @@ const DEFAULT_NULL_VALUE: &str = ""; pub struct Writer { /// The object to write to writer: csv::Writer, - /// Whether file should be written with headers. Defaults to `true` + /// Whether file should be written with headers, defaults to `true` has_headers: bool, - /// The date format for date arrays + /// The date format for date arrays, defaults to RFC3339 date_format: Option, - /// The datetime format for datetime arrays + /// The datetime format for datetime arrays, defaults to RFC3339 datetime_format: Option, - /// The timestamp format for timestamp arrays + /// The timestamp format for timestamp arrays, defaults to RFC3339 timestamp_format: Option, - /// The timestamp format for timestamp (with timezone) arrays + /// The timestamp format for timestamp (with timezone) arrays, defaults to RFC3339 timestamp_tz_format: Option, - /// The time format for time arrays + /// The time format for time arrays, defaults to RFC3339 time_format: Option, /// Is the beginning-of-writer beginning: bool, - /// The value to represent null entries - null_value: String, + /// The value to represent null entries, defaults to [`DEFAULT_NULL_VALUE`] + null_value: Option, } impl Writer { /// Create a new CsvWriter from a writable object, with default options pub fn new(writer: W) -> Self { let delimiter = b','; - let mut builder = csv::WriterBuilder::new(); - let writer = builder.delimiter(delimiter).from_writer(writer); - Writer { - writer, - has_headers: true, - date_format: Some(DEFAULT_DATE_FORMAT.to_string()), - datetime_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - time_format: Some(DEFAULT_TIME_FORMAT.to_string()), - timestamp_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - timestamp_tz_format: Some(DEFAULT_TIMESTAMP_TZ_FORMAT.to_string()), - beginning: true, - null_value: DEFAULT_NULL_VALUE.to_string(), - } + WriterBuilder::new().with_delimiter(delimiter).build(writer) } /// Write a vector of record batches to a writable object @@ -138,7 +121,7 @@ impl Writer { } let options = FormatOptions::default() - .with_null(&self.null_value) + .with_null(self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE)) .with_date_format(self.date_format.as_deref()) .with_datetime_format(self.datetime_format.as_deref()) .with_timestamp_format(self.timestamp_format.as_deref()) @@ -207,9 +190,9 @@ impl RecordBatchWriter for Writer { #[derive(Clone, Debug)] pub struct WriterBuilder { /// Optional column delimiter. Defaults to `b','` - delimiter: Option, + delimiter: u8, /// Whether to write column names as file headers. Defaults to `true` - has_headers: bool, + has_header: bool, /// Optional date format for date arrays date_format: Option, /// Optional datetime format for datetime arrays @@ -227,14 +210,14 @@ pub struct WriterBuilder { impl Default for WriterBuilder { fn default() -> Self { Self { - has_headers: true, - delimiter: None, - date_format: Some(DEFAULT_DATE_FORMAT.to_string()), - datetime_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - time_format: Some(DEFAULT_TIME_FORMAT.to_string()), - timestamp_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - timestamp_tz_format: Some(DEFAULT_TIMESTAMP_TZ_FORMAT.to_string()), - null_value: Some(DEFAULT_NULL_VALUE.to_string()), + has_header: true, + delimiter: b',', + date_format: None, + datetime_format: None, + time_format: None, + timestamp_format: None, + timestamp_tz_format: None, + null_value: None, } } } @@ -254,7 +237,7 @@ impl WriterBuilder { /// let file = File::create("target/out.csv").unwrap(); /// /// // create a builder that doesn't write headers - /// let builder = WriterBuilder::new().has_headers(false); + /// let builder = WriterBuilder::new().with_header(false); /// let writer = builder.build(file); /// /// writer @@ -265,48 +248,92 @@ impl WriterBuilder { } /// Set whether to write headers + #[deprecated(note = "Use Self::with_header")] + #[doc(hidden)] pub fn has_headers(mut self, has_headers: bool) -> Self { - self.has_headers = has_headers; + self.has_header = has_headers; + self + } + + /// Set whether to write the CSV file with a header + pub fn with_header(mut self, header: bool) -> Self { + self.has_header = header; self } + /// Returns `true` if this writer is configured to write a header + pub fn header(&self) -> bool { + self.has_header + } + /// Set the CSV file's column delimiter as a byte character pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.delimiter = Some(delimiter); + self.delimiter = delimiter; self } + /// Get the CSV file's column delimiter as a byte character + pub fn delimiter(&self) -> u8 { + self.delimiter + } + /// Set the CSV file's date format pub fn with_date_format(mut self, format: String) -> Self { self.date_format = Some(format); self } + /// Get the CSV file's date format if set, defaults to RFC3339 + pub fn date_format(&self) -> Option<&str> { + self.date_format.as_deref() + } + /// Set the CSV file's datetime format pub fn with_datetime_format(mut self, format: String) -> Self { self.datetime_format = Some(format); self } + /// Get the CSV file's datetime format if set, defaults to RFC3339 + pub fn datetime_format(&self) -> Option<&str> { + self.datetime_format.as_deref() + } + /// Set the CSV file's time format pub fn with_time_format(mut self, format: String) -> Self { self.time_format = Some(format); self } + /// Get the CSV file's datetime time if set, defaults to RFC3339 + pub fn time_format(&self) -> Option<&str> { + self.time_format.as_deref() + } + /// Set the CSV file's timestamp format pub fn with_timestamp_format(mut self, format: String) -> Self { self.timestamp_format = Some(format); self } + /// Get the CSV file's timestamp format if set, defaults to RFC3339 + pub fn timestamp_format(&self) -> Option<&str> { + self.timestamp_format.as_deref() + } + /// Set the value to represent null in output pub fn with_null(mut self, null_value: String) -> Self { self.null_value = Some(null_value); self } - /// Use RFC3339 format for date/time/timestamps + /// Get the value to represent null in output + pub fn null(&self) -> &str { + self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE) + } + + /// Use RFC3339 format for date/time/timestamps (default) + #[deprecated(note = "Use WriterBuilder::default()")] pub fn with_rfc3339(mut self) -> Self { self.date_format = None; self.datetime_format = None; @@ -318,21 +345,18 @@ impl WriterBuilder { /// Create a new `Writer` pub fn build(self, writer: W) -> Writer { - let delimiter = self.delimiter.unwrap_or(b','); let mut builder = csv::WriterBuilder::new(); - let writer = builder.delimiter(delimiter).from_writer(writer); + let writer = builder.delimiter(self.delimiter).from_writer(writer); Writer { writer, - has_headers: self.has_headers, + beginning: true, + has_headers: self.has_header, date_format: self.date_format, datetime_format: self.datetime_format, time_format: self.time_format, timestamp_format: self.timestamp_format, timestamp_tz_format: self.timestamp_tz_format, - beginning: true, - null_value: self - .null_value - .unwrap_or_else(|| DEFAULT_NULL_VALUE.to_string()), + null_value: self.null_value, } } } @@ -411,11 +435,11 @@ mod tests { let expected = r#"c1,c2,c3,c4,c5,c6,c7 Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo "#; assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap()); } @@ -512,7 +536,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo let mut file = tempfile::tempfile().unwrap(); let builder = WriterBuilder::new() - .has_headers(false) + .with_header(false) .with_delimiter(b'|') .with_null("NULL".to_string()) .with_time_format("%r".to_string()); @@ -560,7 +584,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo ) .unwrap(); - let builder = WriterBuilder::new().has_headers(false); + let builder = WriterBuilder::new().with_header(false); let mut buf: Cursor> = Default::default(); // drop the writer early to release the borrow. @@ -652,7 +676,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo let mut file = tempfile::tempfile().unwrap(); - let builder = WriterBuilder::new().with_rfc3339(); + let builder = WriterBuilder::new(); let mut writer = builder.build(&mut file); let batches = vec![&batch]; for batch in batches { diff --git a/arrow-csv/test/data/custom_null_test.csv b/arrow-csv/test/data/custom_null_test.csv new file mode 100644 index 000000000000..39f9fc4b3eff --- /dev/null +++ b/arrow-csv/test/data/custom_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool +1,1.1,"1.11",True +nil,2.2,"2.22",TRUE +3,nil,"3.33",true +4,4.4,nil,False +5,6.6,"",nil diff --git a/arrow-csv/test/data/example.csv b/arrow-csv/test/data/example.csv new file mode 100644 index 000000000000..0c03cee84528 --- /dev/null +++ b/arrow-csv/test/data/example.csv @@ -0,0 +1,4 @@ +c1,c2,c3,c4 +1,1.1,"hong kong",true +3,323.12,"XiAn",false +10,131323.12,"cheng du",false \ No newline at end of file diff --git a/arrow-csv/test/data/init_null_test.csv b/arrow-csv/test/data/init_null_test.csv new file mode 100644 index 000000000000..f7d8a299645d --- /dev/null +++ b/arrow-csv/test/data/init_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool,c_null +,,,, +2,2.2,"a",TRUE, +3,,"b",true, +4,4.4,,False, +5,6.6,"",FALSE, \ No newline at end of file diff --git a/arrow-data/src/data/mod.rs b/arrow-data/src/data.rs similarity index 90% rename from arrow-data/src/data/mod.rs rename to arrow-data/src/data.rs index 32aae1e92a51..5f87dddd4217 100644 --- a/arrow-data/src/data/mod.rs +++ b/arrow-data/src/data.rs @@ -20,7 +20,7 @@ use crate::bit_iterator::BitSliceIterator; use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; -use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer}; use arrow_schema::{ArrowError, DataType, UnionMode}; use std::convert::TryInto; use std::mem; @@ -29,8 +29,10 @@ use std::sync::Arc; use crate::equal; -mod buffers; -pub use buffers::*; +/// A collection of [`Buffer`] +#[doc(hidden)] +#[deprecated(note = "Use [Buffer]")] +pub type Buffers<'a> = &'a [Buffer]; #[inline] pub(crate) fn contains_nulls( @@ -172,7 +174,7 @@ pub(crate) fn into_buffers( } } -/// An generic representation of Arrow array data which encapsulates common attributes and +/// A generic representation of Arrow array data which encapsulates common attributes and /// operations for Arrow array. Specific operations for different arrays types (e.g., /// primitive, list, struct) are implemented in `Array`. /// @@ -345,10 +347,9 @@ impl ArrayData { &self.data_type } - /// Returns the [`Buffers`] storing data for this [`ArrayData`] - pub fn buffers(&self) -> Buffers<'_> { - // In future ArrayData won't store data contiguously as `Vec` (#1799) - Buffers::from_slice(&self.buffers) + /// Returns the [`Buffer`] storing data for this [`ArrayData`] + pub fn buffers(&self) -> &[Buffer] { + &self.buffers } /// Returns a slice of children [`ArrayData`]. This will be non @@ -450,7 +451,7 @@ impl ArrayData { for spec in layout.buffers.iter() { match spec { - BufferSpec::FixedWidth { byte_width } => { + BufferSpec::FixedWidth { byte_width, .. } => { let buffer_size = self.len.checked_mul(*byte_width).ok_or_else(|| { ArrowError::ComputeError( @@ -634,9 +635,12 @@ impl ArrayData { let children = f .iter() .enumerate() - .map(|(idx, (_, f))| match idx { - 0 => Self::new_null(f.data_type(), len), - _ => Self::new_empty(f.data_type()), + .map(|(idx, (_, f))| { + if idx == 0 || *mode == UnionMode::Sparse { + Self::new_null(f.data_type(), len) + } else { + Self::new_empty(f.data_type()) + } }) .collect(); @@ -695,6 +699,23 @@ impl ArrayData { Self::new_null(data_type, 0) } + /// Verifies that the buffers meet the minimum alignment requirements for the data type + /// + /// Buffers that are not adequately aligned will be copied to a new aligned allocation + /// + /// This can be useful for when interacting with data sent over IPC or FFI, that may + /// not meet the minimum alignment requirements + pub fn align_buffers(&mut self) { + let layout = layout(&self.data_type); + for (buffer, spec) in self.buffers.iter_mut().zip(&layout.buffers) { + if let BufferSpec::FixedWidth { alignment, .. } = spec { + if buffer.as_ptr().align_offset(*alignment) != 0 { + *buffer = Buffer::from_slice_ref(buffer.as_ref()) + } + } + } + } + /// "cheap" validation of an `ArrayData`. Ensures buffers are /// sufficiently sized to store `len` + `offset` total elements of /// `data_type` and performs other inexpensive consistency checks. @@ -732,10 +753,11 @@ impl ArrayData { self.buffers.iter().zip(layout.buffers.iter()).enumerate() { match spec { - BufferSpec::FixedWidth { byte_width } => { - let min_buffer_size = len_plus_offset - .checked_mul(*byte_width) - .expect("integer overflow computing min buffer size"); + BufferSpec::FixedWidth { + byte_width, + alignment, + } => { + let min_buffer_size = len_plus_offset.saturating_mul(*byte_width); if buffer.len() < min_buffer_size { return Err(ArrowError::InvalidArgumentError(format!( @@ -743,6 +765,14 @@ impl ArrayData { min_buffer_size, i, self.data_type, buffer.len() ))); } + + let align_offset = buffer.as_ptr().align_offset(*alignment); + if align_offset != 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Misaligned buffers[{i}] in array of type {:?}, offset from expected alignment of {alignment} by {}", + self.data_type, align_offset.min(alignment - align_offset) + ))); + } } BufferSpec::VariableWidth => { // not cheap to validate (need to look at the @@ -1489,7 +1519,8 @@ impl ArrayData { pub fn layout(data_type: &DataType) -> DataTypeLayout { // based on C/C++ implementation in // https://github.com/apache/arrow/blob/661c7d749150905a63dd3b52e0a04dac39030d95/cpp/src/arrow/type.h (and .cc) - use std::mem::size_of; + use arrow_schema::IntervalUnit::*; + match data_type { DataType::Null => DataTypeLayout { buffers: vec![], @@ -1499,44 +1530,52 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout { buffers: vec![BufferSpec::BitMap], can_contain_null_mask: true, }, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Timestamp(_, _) - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) - | DataType::Interval(_) => { - DataTypeLayout::new_fixed_width(data_type.primitive_width().unwrap()) - } - DataType::Duration(_) => DataTypeLayout::new_fixed_width(size_of::()), - DataType::Binary => DataTypeLayout::new_binary(size_of::()), - DataType::FixedSizeBinary(bytes_per_value) => { - let bytes_per_value: usize = (*bytes_per_value) - .try_into() - .expect("negative size for fixed size binary"); - DataTypeLayout::new_fixed_width(bytes_per_value) + DataType::Int8 => DataTypeLayout::new_fixed_width::(), + DataType::Int16 => DataTypeLayout::new_fixed_width::(), + DataType::Int32 => DataTypeLayout::new_fixed_width::(), + DataType::Int64 => DataTypeLayout::new_fixed_width::(), + DataType::UInt8 => DataTypeLayout::new_fixed_width::(), + DataType::UInt16 => DataTypeLayout::new_fixed_width::(), + DataType::UInt32 => DataTypeLayout::new_fixed_width::(), + DataType::UInt64 => DataTypeLayout::new_fixed_width::(), + DataType::Float16 => DataTypeLayout::new_fixed_width::(), + DataType::Float32 => DataTypeLayout::new_fixed_width::(), + DataType::Float64 => DataTypeLayout::new_fixed_width::(), + DataType::Timestamp(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Date32 => DataTypeLayout::new_fixed_width::(), + DataType::Date64 => DataTypeLayout::new_fixed_width::(), + DataType::Time32(_) => DataTypeLayout::new_fixed_width::(), + DataType::Time64(_) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(YearMonth) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(DayTime) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(MonthDayNano) => DataTypeLayout::new_fixed_width::(), + DataType::Duration(_) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal128(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal256(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::FixedSizeBinary(size) => { + let spec = BufferSpec::FixedWidth { + byte_width: (*size).try_into().unwrap(), + alignment: mem::align_of::(), + }; + DataTypeLayout { + buffers: vec![spec], + can_contain_null_mask: true, + } } - DataType::LargeBinary => DataTypeLayout::new_binary(size_of::()), - DataType::Utf8 => DataTypeLayout::new_binary(size_of::()), - DataType::LargeUtf8 => DataTypeLayout::new_binary(size_of::()), - DataType::List(_) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Binary => DataTypeLayout::new_binary::(), + DataType::LargeBinary => DataTypeLayout::new_binary::(), + DataType::Utf8 => DataTypeLayout::new_binary::(), + DataType::LargeUtf8 => DataTypeLayout::new_binary::(), DataType::FixedSizeList(_, _) => DataTypeLayout::new_empty(), // all in child data - DataType::LargeList(_) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::List(_) => DataTypeLayout::new_fixed_width::(), + DataType::LargeList(_) => DataTypeLayout::new_fixed_width::(), + DataType::Map(_, _) => DataTypeLayout::new_fixed_width::(), DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child data, DataType::RunEndEncoded(_, _) => DataTypeLayout::new_empty(), // all in child data, DataType::Union(_, mode) => { let type_ids = BufferSpec::FixedWidth { - byte_width: size_of::(), + byte_width: mem::size_of::(), + alignment: mem::align_of::(), }; DataTypeLayout { @@ -1548,7 +1587,8 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout { vec![ type_ids, BufferSpec::FixedWidth { - byte_width: size_of::(), + byte_width: mem::size_of::(), + alignment: mem::align_of::(), }, ] } @@ -1557,19 +1597,6 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout { } } DataType::Dictionary(key_type, _value_type) => layout(key_type), - DataType::Decimal128(_, _) => { - // Decimals are always some fixed width; The rust implementation - // always uses 16 bytes / size of i128 - DataTypeLayout::new_fixed_width(size_of::()) - } - DataType::Decimal256(_, _) => { - // Decimals are always some fixed width. - DataTypeLayout::new_fixed_width(32) - } - DataType::Map(_, _) => { - // same as ListType - DataTypeLayout::new_fixed_width(size_of::()) - } } } @@ -1585,10 +1612,13 @@ pub struct DataTypeLayout { } impl DataTypeLayout { - /// Describes a basic numeric array where each element has a fixed width - pub fn new_fixed_width(byte_width: usize) -> Self { + /// Describes a basic numeric array where each element has type `T` + pub fn new_fixed_width() -> Self { Self { - buffers: vec![BufferSpec::FixedWidth { byte_width }], + buffers: vec![BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }], can_contain_null_mask: true, } } @@ -1604,14 +1634,15 @@ impl DataTypeLayout { } /// Describes a basic numeric array where each element has a fixed - /// with offset buffer of `offset_byte_width` bytes, followed by a + /// with offset buffer of type `T`, followed by a /// variable width data buffer - pub fn new_binary(offset_byte_width: usize) -> Self { + pub fn new_binary() -> Self { Self { buffers: vec![ // offsets BufferSpec::FixedWidth { - byte_width: offset_byte_width, + byte_width: mem::size_of::(), + alignment: mem::align_of::(), }, // values BufferSpec::VariableWidth, @@ -1624,8 +1655,17 @@ impl DataTypeLayout { /// Layout specification for a single data type buffer #[derive(Debug, PartialEq, Eq)] pub enum BufferSpec { - /// each element has a fixed width - FixedWidth { byte_width: usize }, + /// Each element is a fixed width primitive, with the given `byte_width` and `alignment` + /// + /// `alignment` is the alignment required by Rust for an array of the corresponding primitive, + /// see [`Layout::array`](std::alloc::Layout::array) and [`std::mem::align_of`]. + /// + /// Arrow-rs requires that all buffers have at least this alignment, to allow for + /// [slice](std::slice) based APIs. Alignment in excess of this is not required to allow + /// for array slicing and interoperability with `Vec`, which cannot be over-aligned. + /// + /// Note that these alignment requirements will vary between architectures + FixedWidth { byte_width: usize, alignment: usize }, /// Variable width, such as string data for utf8 data VariableWidth, /// Buffer holds a bitmap. @@ -1737,6 +1777,15 @@ impl ArrayDataBuilder { /// apply. #[allow(clippy::let_and_return)] pub unsafe fn build_unchecked(self) -> ArrayData { + let data = self.build_impl(); + // Provide a force_validate mode + #[cfg(feature = "force_validate")] + data.validate_data().unwrap(); + data + } + + /// Same as [`Self::build_unchecked`] but ignoring `force_validate` feature flag + unsafe fn build_impl(self) -> ArrayData { let nulls = self.nulls.or_else(|| { let buffer = self.null_bit_buffer?; let buffer = BooleanBuffer::new(buffer, self.offset, self.len); @@ -1746,26 +1795,41 @@ impl ArrayDataBuilder { }) }); - let data = ArrayData { + ArrayData { data_type: self.data_type, len: self.len, offset: self.offset, buffers: self.buffers, child_data: self.child_data, nulls: nulls.filter(|b| b.null_count() != 0), - }; - - // Provide a force_validate mode - #[cfg(feature = "force_validate")] - data.validate_data().unwrap(); - data + } } /// Creates an array data, validating all inputs - #[allow(clippy::let_and_return)] pub fn build(self) -> Result { - let data = unsafe { self.build_unchecked() }; - #[cfg(not(feature = "force_validate"))] + let data = unsafe { self.build_impl() }; + data.validate_data()?; + Ok(data) + } + + /// Creates an array data, validating all inputs, and aligning any buffers + /// + /// Rust requires that arrays are aligned to their corresponding primitive, + /// see [`Layout::array`](std::alloc::Layout::array) and [`std::mem::align_of`]. + /// + /// [`ArrayData`] therefore requires that all buffers have at least this alignment, + /// to allow for [slice](std::slice) based APIs. See [`BufferSpec::FixedWidth`]. + /// + /// As this alignment is architecture specific, and not guaranteed by all arrow implementations, + /// this method is provided to automatically copy buffers to a new correctly aligned allocation + /// when necessary, making it useful when interacting with buffers produced by other systems, + /// e.g. IPC or FFI. + /// + /// This is unlike `[Self::build`] which will instead return an error on encountering + /// insufficiently aligned buffers. + pub fn build_aligned(self) -> Result { + let mut data = unsafe { self.build_impl() }; + data.align_buffers(); data.validate_data()?; Ok(data) } @@ -2053,4 +2117,31 @@ mod tests { assert_eq!(buffers.len(), layout.buffers.len()); } } + + #[test] + fn test_alignment() { + let buffer = Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let sliced = buffer.slice(1); + + let mut data = ArrayData { + data_type: DataType::Int32, + len: 0, + offset: 0, + buffers: vec![buffer], + child_data: vec![], + nulls: None, + }; + data.validate_full().unwrap(); + + data.buffers[0] = sliced; + let err = data.validate().unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Misaligned buffers[0] in array of type Int32, offset from expected alignment of 4 by 1" + ); + + data.align_buffers(); + data.validate_full().unwrap(); + } } diff --git a/arrow-data/src/data/buffers.rs b/arrow-data/src/data/buffers.rs deleted file mode 100644 index 883e92e36d82..000000000000 --- a/arrow-data/src/data/buffers.rs +++ /dev/null @@ -1,96 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow_buffer::Buffer; -use std::iter::Chain; -use std::ops::Index; - -/// A collection of [`Buffer`] -#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] -pub struct Buffers<'a>([Option<&'a Buffer>; 2]); - -impl<'a> Buffers<'a> { - /// Temporary will be removed once ArrayData does not store `Vec` directly (#3769) - pub(crate) fn from_slice(a: &'a [Buffer]) -> Self { - match a.len() { - 0 => Self([None, None]), - 1 => Self([Some(&a[0]), None]), - _ => Self([Some(&a[0]), Some(&a[1])]), - } - } - - /// Returns the number of [`Buffer`] in this collection - #[inline] - pub fn len(&self) -> usize { - self.0[0].is_some() as usize + self.0[1].is_some() as usize - } - - /// Returns `true` if this collection is empty - #[inline] - pub fn is_empty(&self) -> bool { - self.0[0].is_none() && self.0[1].is_none() - } - - #[inline] - pub fn iter(&self) -> IntoIter<'a> { - self.into_iter() - } - - /// Converts this [`Buffers`] to a `Vec` - #[inline] - pub fn to_vec(&self) -> Vec { - self.iter().cloned().collect() - } -} - -impl<'a> Index for Buffers<'a> { - type Output = &'a Buffer; - - #[inline] - fn index(&self, index: usize) -> &Self::Output { - self.0[index].as_ref().unwrap() - } -} - -impl<'a> IntoIterator for Buffers<'a> { - type Item = &'a Buffer; - type IntoIter = IntoIter<'a>; - - #[inline] - fn into_iter(self) -> Self::IntoIter { - IntoIter(self.0[0].into_iter().chain(self.0[1].into_iter())) - } -} - -type OptionIter<'a> = std::option::IntoIter<&'a Buffer>; - -/// [`Iterator`] for [`Buffers`] -pub struct IntoIter<'a>(Chain, OptionIter<'a>>); - -impl<'a> Iterator for IntoIter<'a> { - type Item = &'a Buffer; - - #[inline] - fn next(&mut self) -> Option { - self.0.next() - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } -} diff --git a/arrow-data/src/equal/fixed_binary.rs b/arrow-data/src/equal/fixed_binary.rs index 9e0e77ff7eca..40dacdddd3a0 100644 --- a/arrow-data/src/equal/fixed_binary.rs +++ b/arrow-data/src/equal/fixed_binary.rs @@ -80,7 +80,7 @@ pub(super) fn fixed_binary_equal( lhs_start + lhs_nulls.offset(), len, ); - let rhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); let rhs_slices_iter = BitSliceIterator::new( rhs_nulls.validity(), rhs_start + rhs_nulls.offset(), diff --git a/arrow-data/src/equal/variable_size.rs b/arrow-data/src/equal/variable_size.rs index ae880437450b..92f00818b4a0 100644 --- a/arrow-data/src/equal/variable_size.rs +++ b/arrow-data/src/equal/variable_size.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::data::{count_nulls, ArrayData}; +use crate::data::{contains_nulls, ArrayData}; use arrow_buffer::ArrowNativeType; use num::Integer; @@ -59,14 +59,9 @@ pub(super) fn variable_sized_equal( let lhs_values = lhs.buffers()[1].as_slice(); let rhs_values = rhs.buffers()[1].as_slice(); - let lhs_null_count = count_nulls(lhs.nulls(), lhs_start, len); - let rhs_null_count = count_nulls(rhs.nulls(), rhs_start, len); - - if lhs_null_count == 0 - && rhs_null_count == 0 - && !lhs_values.is_empty() - && !rhs_values.is_empty() - { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { offset_value_equal( lhs_values, rhs_values, diff --git a/arrow-data/src/transform/union.rs b/arrow-data/src/transform/union.rs index 8d1ea34c314d..d7083588d782 100644 --- a/arrow-data/src/transform/union.rs +++ b/arrow-data/src/transform/union.rs @@ -39,6 +39,9 @@ pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { let type_ids = array.buffer::(0); let offsets = array.buffer::(1); + let arrow_schema::DataType::Union(src_fields, _) = array.data_type() else { + unreachable!(); + }; Box::new( move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { @@ -48,14 +51,18 @@ pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { .extend_from_slice(&type_ids[start..start + len]); (start..start + len).for_each(|i| { - let type_id = type_ids[i] as usize; + let type_id = type_ids[i]; + let child_index = src_fields + .iter() + .position(|(r, _)| r == type_id) + .expect("invalid union type ID"); let src_offset = offsets[i] as usize; - let child_data = &mut mutable.child_data[type_id]; + let child_data = &mut mutable.child_data[child_index]; let dst_offset = child_data.len(); // Extend offsets mutable.buffer2.push(dst_offset as i32); - mutable.child_data[type_id].extend(index, src_offset, src_offset + 1) + mutable.child_data[child_index].extend(index, src_offset, src_offset + 1) }) }, ) diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index ae9759b6685f..70227eedea0e 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -20,7 +20,7 @@ name = "arrow-flight" description = "Apache Arrow Flight" version = { workspace = true } edition = { workspace = true } -rust-version = { workspace = true } +rust-version = "1.70.0" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } @@ -44,14 +44,15 @@ bytes = { version = "1", default-features = false } futures = { version = "0.3", default-features = false, features = ["alloc"] } once_cell = { version = "1", optional = true } paste = { version = "1.0" } -prost = { version = "0.11", default-features = false, features = ["prost-derive"] } +prost = { version = "0.12.1", default-features = false, features = ["prost-derive"] } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } -tonic = { version = "0.9", default-features = false, features = ["transport", "codegen", "prost"] } +tonic = { version = "0.10.0", default-features = false, features = ["transport", "codegen", "prost"] } # CLI-related dependencies -clap = { version = "4.1", default-features = false, features = ["std", "derive", "env", "help", "error-context", "usage"], optional = true } +anyhow = { version = "1.0", optional = true } +clap = { version = "4.4.6", default-features = false, features = ["std", "derive", "env", "help", "error-context", "usage", "wrap_help", "color", "suggestions"], optional = true } tracing-log = { version = "0.1", optional = true } -tracing-subscriber = { version = "0.3.1", default-features = false, features = ["ansi", "fmt"], optional = true } +tracing-subscriber = { version = "0.3.1", default-features = false, features = ["ansi", "env-filter", "fmt"], optional = true } [package.metadata.docs.rs] all-features = true @@ -62,11 +63,14 @@ flight-sql-experimental = ["arrow-arith", "arrow-data", "arrow-ord", "arrow-row" tls = ["tonic/tls"] # Enable CLI tools -cli = ["arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", "tonic/tls-webpki-roots"] +cli = ["anyhow", "arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", "tonic/tls-webpki-roots"] [dev-dependencies] arrow-cast = { workspace = true, features = ["prettyprint"] } assert_cmd = "2.0.8" +http = "0.2.9" +http-body = "0.4.5" +pin-project-lite = "0.2" tempfile = "3.3" tokio-stream = { version = "0.1", features = ["net"] } tower = "0.4.13" diff --git a/arrow-flight/README.md b/arrow-flight/README.md index 9194b209fe72..b80772ac927e 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -44,5 +44,33 @@ that demonstrate how to build a Flight server implemented with [tonic](https://d ## Feature Flags - `flight-sql-experimental`: Enables experimental support for - [Apache Arrow FlightSQL](https://arrow.apache.org/docs/format/FlightSql.html), - a protocol for interacting with SQL databases. + [Apache Arrow FlightSQL], a protocol for interacting with SQL databases. + +## CLI + +This crates offers a basic [Apache Arrow FlightSQL] command line interface. + +The client can be installed from the repository: + +```console +$ cargo install --features=cli,flight-sql-experimental,tls --bin=flight_sql_client --path=. --locked +``` + +The client comes with extensive help text: + +```console +$ flight_sql_client help +``` + +A query can be executed using: + +```console +$ flight_sql_client --host example.com statement-query "SELECT 1;" ++----------+ +| Int64(1) | ++----------+ +| 1 | ++----------+ +``` + +[apache arrow flightsql]: https://arrow.apache.org/docs/format/FlightSql.html diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index f717d9b621b2..013f7e7788f8 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow_flight::sql::server::PeekableFlightDataStream; use base64::prelude::BASE64_STANDARD; use base64::Engine; use futures::{stream, Stream, TryStreamExt}; @@ -196,9 +197,9 @@ impl FlightSqlService for FlightSqlServiceImpl { self.check_token(&request)?; let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; - let schema = (*batch.schema()).clone(); + let schema = batch.schema(); let batches = vec![batch]; - let flight_data = batches_to_flight_data(schema, batches) + let flight_data = batches_to_flight_data(schema.as_ref(), batches) .map_err(|e| status!("Could not convert batches", e))? .into_iter() .map(Ok); @@ -602,7 +603,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { Ok(FAKE_UPDATE_RESULT) } @@ -610,7 +611,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, - _request: Request>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan not implemented", @@ -620,7 +621,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, + _request: Request, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented( "do_put_prepared_statement_query not implemented", @@ -630,7 +631,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented", @@ -788,7 +789,6 @@ mod tests { use arrow_cast::pretty::pretty_format_batches; use arrow_flight::sql::client::FlightSqlServiceClient; - use arrow_flight::utils::flight_data_to_batches; use tonic::transport::server::TcpIncoming; use tonic::transport::{Certificate, Endpoint}; use tower::service_fn; @@ -802,7 +802,7 @@ mod tests { fn endpoint(uri: String) -> Result { let endpoint = Endpoint::new(uri) - .map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))? + .map_err(|_| ArrowError::IpcError("Cannot create endpoint".to_string()))? .connect_timeout(Duration::from_secs(20)) .timeout(Duration::from_secs(20)) .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait @@ -954,8 +954,7 @@ mod tests { let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); let flight_data = client.do_get(ticket).await.unwrap(); - let flight_data: Vec = flight_data.try_collect().await.unwrap(); - let batches = flight_data_to_batches(&flight_data).unwrap(); + let batches: Vec<_> = flight_data.try_collect().await.unwrap(); let res = pretty_format_batches(batches.as_slice()).unwrap(); let expected = r#" diff --git a/arrow-flight/examples/server.rs b/arrow-flight/examples/server.rs index 1d473103af8e..1ed21acef9b8 100644 --- a/arrow-flight/examples/server.rs +++ b/arrow-flight/examples/server.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::pin::Pin; - -use futures::Stream; +use futures::stream::BoxStream; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; @@ -32,27 +30,13 @@ pub struct FlightServiceImpl {} #[tonic::async_trait] impl FlightService for FlightServiceImpl { - type HandshakeStream = Pin< - Box> + Send + Sync + 'static>, - >; - type ListFlightsStream = - Pin> + Send + Sync + 'static>>; - type DoGetStream = - Pin> + Send + Sync + 'static>>; - type DoPutStream = - Pin> + Send + Sync + 'static>>; - type DoActionStream = Pin< - Box< - dyn Stream> - + Send - + Sync - + 'static, - >, - >; - type ListActionsStream = - Pin> + Send + Sync + 'static>>; - type DoExchangeStream = - Pin> + Send + Sync + 'static>>; + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; async fn handshake( &self, diff --git a/arrow-flight/gen/Cargo.toml b/arrow-flight/gen/Cargo.toml index 743df85dc800..036281528c19 100644 --- a/arrow-flight/gen/Cargo.toml +++ b/arrow-flight/gen/Cargo.toml @@ -32,6 +32,6 @@ publish = false [dependencies] # Pin specific version of the tonic-build dependencies to avoid auto-generated # (and checked in) arrow.flight.protocol.rs from changing -proc-macro2 = { version = "=1.0.63", default-features = false } -prost-build = { version = "=0.11.9", default-features = false } -tonic-build = { version = "=0.9.2", default-features = false, features = ["transport", "prost"] } +proc-macro2 = { version = "=1.0.69", default-features = false } +prost-build = { version = "=0.12.1", default-features = false } +tonic-build = { version = "=0.10.2", default-features = false, features = ["transport", "prost"] } diff --git a/arrow-flight/src/arrow.flight.protocol.rs b/arrow-flight/src/arrow.flight.protocol.rs index 10dc7ace0356..e76013bd7c5f 100644 --- a/arrow-flight/src/arrow.flight.protocol.rs +++ b/arrow-flight/src/arrow.flight.protocol.rs @@ -685,7 +685,7 @@ pub mod flight_service_server { #[async_trait] pub trait FlightService: Send + Sync + 'static { /// Server streaming response type for the Handshake method. - type HandshakeStream: futures_core::Stream< + type HandshakeStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, > + Send @@ -700,7 +700,7 @@ pub mod flight_service_server { request: tonic::Request>, ) -> std::result::Result, tonic::Status>; /// Server streaming response type for the ListFlights method. - type ListFlightsStream: futures_core::Stream< + type ListFlightsStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, > + Send @@ -744,7 +744,7 @@ pub mod flight_service_server { request: tonic::Request, ) -> std::result::Result, tonic::Status>; /// Server streaming response type for the DoGet method. - type DoGetStream: futures_core::Stream< + type DoGetStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, > + Send @@ -759,7 +759,7 @@ pub mod flight_service_server { request: tonic::Request, ) -> std::result::Result, tonic::Status>; /// Server streaming response type for the DoPut method. - type DoPutStream: futures_core::Stream< + type DoPutStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, > + Send @@ -776,7 +776,7 @@ pub mod flight_service_server { request: tonic::Request>, ) -> std::result::Result, tonic::Status>; /// Server streaming response type for the DoExchange method. - type DoExchangeStream: futures_core::Stream< + type DoExchangeStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, > + Send @@ -792,7 +792,7 @@ pub mod flight_service_server { request: tonic::Request>, ) -> std::result::Result, tonic::Status>; /// Server streaming response type for the DoAction method. - type DoActionStream: futures_core::Stream< + type DoActionStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, > + Send @@ -809,7 +809,7 @@ pub mod flight_service_server { request: tonic::Request, ) -> std::result::Result, tonic::Status>; /// Server streaming response type for the ListActions method. - type ListActionsStream: futures_core::Stream< + type ListActionsStream: tonic::codegen::tokio_stream::Stream< Item = std::result::Result, > + Send @@ -930,7 +930,9 @@ pub mod flight_service_server { >, ) -> Self::Future { let inner = Arc::clone(&self.0); - let fut = async move { (*inner).handshake(request).await }; + let fut = async move { + ::handshake(&inner, request).await + }; Box::pin(fut) } } @@ -976,7 +978,7 @@ pub mod flight_service_server { ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { - (*inner).list_flights(request).await + ::list_flights(&inner, request).await }; Box::pin(fut) } @@ -1022,7 +1024,7 @@ pub mod flight_service_server { ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { - (*inner).get_flight_info(request).await + ::get_flight_info(&inner, request).await }; Box::pin(fut) } @@ -1067,7 +1069,9 @@ pub mod flight_service_server { request: tonic::Request, ) -> Self::Future { let inner = Arc::clone(&self.0); - let fut = async move { (*inner).get_schema(request).await }; + let fut = async move { + ::get_schema(&inner, request).await + }; Box::pin(fut) } } @@ -1112,7 +1116,9 @@ pub mod flight_service_server { request: tonic::Request, ) -> Self::Future { let inner = Arc::clone(&self.0); - let fut = async move { (*inner).do_get(request).await }; + let fut = async move { + ::do_get(&inner, request).await + }; Box::pin(fut) } } @@ -1157,7 +1163,9 @@ pub mod flight_service_server { request: tonic::Request>, ) -> Self::Future { let inner = Arc::clone(&self.0); - let fut = async move { (*inner).do_put(request).await }; + let fut = async move { + ::do_put(&inner, request).await + }; Box::pin(fut) } } @@ -1202,7 +1210,9 @@ pub mod flight_service_server { request: tonic::Request>, ) -> Self::Future { let inner = Arc::clone(&self.0); - let fut = async move { (*inner).do_exchange(request).await }; + let fut = async move { + ::do_exchange(&inner, request).await + }; Box::pin(fut) } } @@ -1247,7 +1257,9 @@ pub mod flight_service_server { request: tonic::Request, ) -> Self::Future { let inner = Arc::clone(&self.0); - let fut = async move { (*inner).do_action(request).await }; + let fut = async move { + ::do_action(&inner, request).await + }; Box::pin(fut) } } @@ -1293,7 +1305,7 @@ pub mod flight_service_server { ) -> Self::Future { let inner = Arc::clone(&self.0); let fut = async move { - (*inner).list_actions(request).await + ::list_actions(&inner, request).await }; Box::pin(fut) } diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs index e5aacc2e779a..296efc1c308e 100644 --- a/arrow-flight/src/bin/flight_sql_client.rs +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -17,62 +17,58 @@ use std::{sync::Arc, time::Duration}; -use arrow_array::RecordBatch; -use arrow_cast::pretty::pretty_format_batches; -use arrow_flight::{ - sql::client::FlightSqlServiceClient, utils::flight_data_to_batches, FlightData, -}; -use arrow_schema::{ArrowError, Schema}; -use clap::Parser; +use anyhow::{bail, Context, Result}; +use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; +use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; +use arrow_flight::{sql::client::FlightSqlServiceClient, FlightInfo}; +use arrow_schema::Schema; +use clap::{Parser, Subcommand}; use futures::TryStreamExt; -use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; +use tonic::{ + metadata::MetadataMap, + transport::{Channel, ClientTlsConfig, Endpoint}, +}; use tracing_log::log::info; -/// A ':' separated key value pair -#[derive(Debug, Clone)] -struct KeyValue { - pub key: K, - pub value: V, -} - -impl std::str::FromStr for KeyValue -where - K: std::str::FromStr, - V: std::str::FromStr, - K::Err: std::fmt::Display, - V::Err: std::fmt::Display, -{ - type Err = String; - - fn from_str(s: &str) -> std::result::Result { - let parts = s.splitn(2, ':').collect::>(); - match parts.as_slice() { - [key, value] => { - let key = K::from_str(key).map_err(|e| e.to_string())?; - let value = V::from_str(value.trim()).map_err(|e| e.to_string())?; - Ok(Self { key, value }) - } - _ => Err(format!( - "Invalid key value pair - expected 'KEY:VALUE' got '{s}'" - )), - } - } +/// Logging CLI config. +#[derive(Debug, Parser)] +pub struct LoggingArgs { + /// Log verbosity. + /// + /// Defaults to "warn". + /// + /// Use `-v` for "info", `-vv` for "debug", `-vvv` for "trace". + /// + /// Note you can also set logging level using `RUST_LOG` environment variable: + /// `RUST_LOG=debug`. + #[clap( + short = 'v', + long = "verbose", + action = clap::ArgAction::Count, + )] + log_verbose_count: u8, } #[derive(Debug, Parser)] struct ClientArgs { /// Additional headers. /// - /// Values should be key value pairs separated by ':' - #[clap(long, value_delimiter = ',')] - headers: Vec>, + /// Can be given multiple times. Headers and values are separated by '='. + /// + /// Example: `-H foo=bar -H baz=42` + #[clap(long = "header", short = 'H', value_parser = parse_key_val)] + headers: Vec<(String, String)>, - /// Username - #[clap(long)] + /// Username. + /// + /// Optional. If given, `password` must also be set. + #[clap(long, requires = "password")] username: Option, - /// Password - #[clap(long)] + /// Password. + /// + /// Optional. If given, `username` must also be set. + #[clap(long, requires = "username")] password: Option, /// Auth token. @@ -80,78 +76,199 @@ struct ClientArgs { token: Option, /// Use TLS. + /// + /// If not provided, use cleartext connection. #[clap(long)] tls: bool, /// Server host. + /// + /// Required. #[clap(long)] host: String, /// Server port. + /// + /// Defaults to `443` if `tls` is set, otherwise defaults to `80`. #[clap(long)] port: Option, } #[derive(Debug, Parser)] struct Args { + /// Logging args. + #[clap(flatten)] + logging_args: LoggingArgs, + /// Client args. #[clap(flatten)] client_args: ClientArgs, - /// SQL query. - query: String, + #[clap(subcommand)] + cmd: Command, +} + +/// Different available commands. +#[derive(Debug, Subcommand)] +enum Command { + /// Execute given statement. + StatementQuery { + /// SQL query. + /// + /// Required. + query: String, + }, + + /// Prepare given statement and then execute it. + PreparedStatementQuery { + /// SQL query. + /// + /// Required. + /// + /// Can contains placeholders like `$1`. + /// + /// Example: `SELECT * FROM t WHERE x = $1` + query: String, + + /// Additional parameters. + /// + /// Can be given multiple times. Names and values are separated by '='. Values will be + /// converted to the type that the server reported for the prepared statement. + /// + /// Example: `-p $1=42` + #[clap(short, value_parser = parse_key_val)] + params: Vec<(String, String)>, + }, } #[tokio::main] -async fn main() { +async fn main() -> Result<()> { let args = Args::parse(); - setup_logging(); - let mut client = setup_client(args.client_args).await.expect("setup client"); + setup_logging(args.logging_args)?; + let mut client = setup_client(args.client_args) + .await + .context("setup client")?; + + let flight_info = match args.cmd { + Command::StatementQuery { query } => client + .execute(query, None) + .await + .context("execute statement")?, + Command::PreparedStatementQuery { query, params } => { + let mut prepared_stmt = client + .prepare(query, None) + .await + .context("prepare statement")?; + + if !params.is_empty() { + prepared_stmt + .set_parameters( + construct_record_batch_from_params( + ¶ms, + prepared_stmt + .parameter_schema() + .context("get parameter schema")?, + ) + .context("construct parameters")?, + ) + .context("bind parameters")?; + } - let info = client - .execute(args.query, None) + prepared_stmt + .execute() + .await + .context("execute prepared statement")? + } + }; + + let batches = execute_flight(&mut client, flight_info) .await - .expect("prepare statement"); - info!("got flight info"); + .context("read flight data")?; + + let res = pretty_format_batches(batches.as_slice()).context("format results")?; + println!("{res}"); + + Ok(()) +} - let schema = Arc::new(Schema::try_from(info.clone()).expect("valid schema")); +async fn execute_flight( + client: &mut FlightSqlServiceClient, + info: FlightInfo, +) -> Result> { + let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?); let mut batches = Vec::with_capacity(info.endpoint.len() + 1); batches.push(RecordBatch::new_empty(schema)); info!("decoded schema"); for endpoint in info.endpoint { let Some(ticket) = &endpoint.ticket else { - panic!("did not get ticket"); + bail!("did not get ticket"); }; - let flight_data = client.do_get(ticket.clone()).await.expect("do get"); - let flight_data: Vec = flight_data + + let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?; + log_metadata(flight_data.headers(), "header"); + + let mut endpoint_batches: Vec<_> = (&mut flight_data) .try_collect() .await - .expect("collect data stream"); - let mut endpoint_batches = flight_data_to_batches(&flight_data) - .expect("convert flight data to record batches"); + .context("collect data stream")?; batches.append(&mut endpoint_batches); + + if let Some(trailers) = flight_data.trailers() { + log_metadata(&trailers, "trailer"); + } } info!("received data"); - let res = pretty_format_batches(batches.as_slice()).expect("format results"); - println!("{res}"); + Ok(batches) +} + +fn construct_record_batch_from_params( + params: &[(String, String)], + parameter_schema: &Schema, +) -> Result { + let mut items = Vec::<(&String, ArrayRef)>::new(); + + for (name, value) in params { + let field = parameter_schema.field_with_name(name)?; + let value_as_array = StringArray::new_scalar(value); + let casted = cast_with_options( + value_as_array.get().0, + field.data_type(), + &CastOptions::default(), + )?; + items.push((name, casted)) + } + + Ok(RecordBatch::try_from_iter(items)?) } -fn setup_logging() { - tracing_log::LogTracer::init().expect("tracing log init"); - tracing_subscriber::fmt::init(); +fn setup_logging(args: LoggingArgs) -> Result<()> { + use tracing_subscriber::{util::SubscriberInitExt, EnvFilter, FmtSubscriber}; + + tracing_log::LogTracer::init().context("tracing log init")?; + + let filter = match args.log_verbose_count { + 0 => "warn", + 1 => "info", + 2 => "debug", + _ => "trace", + }; + let filter = EnvFilter::try_new(filter).context("set up log env filter")?; + + let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish(); + subscriber.try_init().context("init logging subscriber")?; + + Ok(()) } -async fn setup_client( - args: ClientArgs, -) -> Result, ArrowError> { +async fn setup_client(args: ClientArgs) -> Result> { let port = args.port.unwrap_or(if args.tls { 443 } else { 80 }); let protocol = if args.tls { "https" } else { "http" }; let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port)) - .map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))? + .context("create endpoint")? .connect_timeout(Duration::from_secs(20)) .timeout(Duration::from_secs(20)) .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait @@ -164,19 +281,16 @@ async fn setup_client( let tls_config = ClientTlsConfig::new(); endpoint = endpoint .tls_config(tls_config) - .map_err(|_| ArrowError::IoError("Cannot create TLS endpoint".to_string()))?; + .context("create TLS endpoint")?; } - let channel = endpoint - .connect() - .await - .map_err(|e| ArrowError::IoError(format!("Cannot connect to endpoint: {e}")))?; + let channel = endpoint.connect().await.context("connect to endpoint")?; let mut client = FlightSqlServiceClient::new(channel); info!("connected"); - for kv in args.headers { - client.set_header(kv.key, kv.value); + for (k, v) in args.headers { + client.set_header(k, v); } if let Some(token) = args.token { @@ -190,16 +304,48 @@ async fn setup_client( client .handshake(&username, &password) .await - .expect("handshake"); + .context("handshake")?; info!("performed handshake"); } (Some(_), None) => { - panic!("when username is set, you also need to set a password") + bail!("when username is set, you also need to set a password") } (None, Some(_)) => { - panic!("when password is set, you also need to set a username") + bail!("when password is set, you also need to set a username") } } Ok(client) } + +/// Parse a single key-value pair +fn parse_key_val(s: &str) -> Result<(String, String), String> { + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].to_owned(), s[pos + 1..].to_owned())) +} + +/// Log headers/trailers. +fn log_metadata(map: &MetadataMap, what: &'static str) { + for k_v in map.iter() { + match k_v { + tonic::metadata::KeyAndValueRef::Ascii(k, v) => { + info!( + "{}: {}={}", + what, + k.as_str(), + v.to_str().unwrap_or(""), + ); + } + tonic::metadata::KeyAndValueRef::Binary(k, v) => { + info!( + "{}: {}={}", + what, + k.as_str(), + String::from_utf8_lossy(v.as_ref()), + ); + } + } + } +} diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index f843bbf7cd0c..8793f7834bfb 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -18,9 +18,9 @@ use std::task::Poll; use crate::{ - decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action, - ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, PutResult, Ticket, + decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, + trailers::extract_lazy_trailers, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, PutResult, Ticket, }; use arrow_schema::Schema; use bytes::Bytes; @@ -74,7 +74,7 @@ pub struct FlightClient { } impl FlightClient { - /// Creates a client client with the provided [`Channel`](tonic::transport::Channel) + /// Creates a client client with the provided [`Channel`] pub fn new(channel: Channel) -> Self { Self::new_from_inner(FlightServiceClient::new(channel)) } @@ -204,16 +204,14 @@ impl FlightClient { pub async fn do_get(&mut self, ticket: Ticket) -> Result { let request = self.make_request(ticket); - let response_stream = self - .inner - .do_get(request) - .await? - .into_inner() - .map_err(FlightError::Tonic); + let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts(); + let (response_stream, trailers) = extract_lazy_trailers(response_stream); Ok(FlightRecordBatchStream::new_from_flight_data( - response_stream, - )) + response_stream.map_err(FlightError::Tonic), + ) + .with_headers(md) + .with_trailers(trailers)) } /// Make a `GetFlightInfo` call to the server with the provided @@ -262,7 +260,7 @@ impl FlightClient { } /// Make a `DoPut` call to the server with the provided - /// [`Stream`](futures::Stream) of [`FlightData`] and returning a + /// [`Stream`] of [`FlightData`] and returning a /// stream of [`PutResult`]. /// /// # Note @@ -340,7 +338,7 @@ impl FlightClient { } /// Make a `DoExchange` call to the server with the provided - /// [`Stream`](futures::Stream) of [`FlightData`] and returning a + /// [`Stream`] of [`FlightData`] and returning a /// stream of [`FlightData`]. /// /// # Example: @@ -391,7 +389,7 @@ impl FlightClient { } /// Make a `ListFlights` call to the server with the provided - /// criteria and returning a [`Stream`](futures::Stream) of [`FlightInfo`]. + /// criteria and returning a [`Stream`] of [`FlightInfo`]. /// /// # Example: /// ```no_run @@ -469,7 +467,7 @@ impl FlightClient { } /// Make a `ListActions` call to the server and returning a - /// [`Stream`](futures::Stream) of [`ActionType`]. + /// [`Stream`] of [`ActionType`]. /// /// # Example: /// ```no_run @@ -506,7 +504,7 @@ impl FlightClient { } /// Make a `DoAction` call to the server and returning a - /// [`Stream`](futures::Stream) of opaque [`Bytes`]. + /// [`Stream`] of opaque [`Bytes`]. /// /// # Example: /// ```no_run diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index fe132e3e8448..dfcdd260602c 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::{utils::flight_data_to_arrow_batch, FlightData}; +use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData}; use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::Buffer; use arrow_schema::{Schema, SchemaRef}; use bytes::Bytes; use futures::{ready, stream::BoxStream, Stream, StreamExt}; use std::{ collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, task::Poll, }; +use tonic::metadata::MetadataMap; use crate::error::{FlightError, Result}; @@ -81,13 +83,23 @@ use crate::error::{FlightError, Result}; /// ``` #[derive(Debug)] pub struct FlightRecordBatchStream { + /// Optional grpc header metadata. + headers: MetadataMap, + + /// Optional grpc trailer metadata. + trailers: Option, + inner: FlightDataDecoder, } impl FlightRecordBatchStream { /// Create a new [`FlightRecordBatchStream`] from a decoded stream pub fn new(inner: FlightDataDecoder) -> Self { - Self { inner } + Self { + inner, + headers: MetadataMap::default(), + trailers: None, + } } /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`] @@ -97,9 +109,37 @@ impl FlightRecordBatchStream { { Self { inner: FlightDataDecoder::new(inner), + headers: MetadataMap::default(), + trailers: None, + } + } + + /// Record response headers. + pub fn with_headers(self, headers: MetadataMap) -> Self { + Self { headers, ..self } + } + + /// Record response trailers. + pub fn with_trailers(self, trailers: LazyTrailers) -> Self { + Self { + trailers: Some(trailers), + ..self } } + /// Headers attached to this stream. + pub fn headers(&self) -> &MetadataMap { + &self.headers + } + + /// Trailers attached to this stream. + /// + /// Note that this will return `None` until the entire stream is consumed. + /// Only after calling `next()` returns `None`, might any available trailers be returned. + pub fn trailers(&self) -> Option { + self.trailers.as_ref().and_then(|trailers| trailers.get()) + } + /// Has a message defining the schema been received yet? #[deprecated = "use schema().is_some() instead"] pub fn got_schema(&self) -> bool { @@ -116,6 +156,7 @@ impl FlightRecordBatchStream { self.inner } } + impl futures::Stream for FlightRecordBatchStream { type Item = Result; @@ -258,7 +299,7 @@ impl FlightDataDecoder { )); }; - let buffer: arrow_buffer::Buffer = data.data_body.into(); + let buffer = Buffer::from_bytes(data.data_body.into()); let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| { FlightError::protocol( diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 9650031d8b5f..9ae7f1637982 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -24,16 +24,23 @@ use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use bytes::Bytes; use futures::{ready, stream::BoxStream, Stream, StreamExt}; -/// Creates a [`Stream`](futures::Stream) of [`FlightData`]s from a +/// Creates a [`Stream`] of [`FlightData`]s from a /// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>. /// /// This can be used to implement [`FlightService::do_get`] in an /// Arrow Flight implementation; /// +/// This structure encodes a stream of `Result`s rather than `RecordBatch`es to +/// propagate errors from streaming execution, where the generation of the +/// `RecordBatch`es is incremental, and an error may occur even after +/// several have already been successfully produced. +/// /// # Caveats -/// 1. [`DictionaryArray`](arrow_array::array::DictionaryArray)s -/// are converted to their underlying types prior to transport, due to -/// . +/// 1. When [`DictionaryHandling`] is [`DictionaryHandling::Hydrate`], [`DictionaryArray`](arrow_array::array::DictionaryArray)s +/// are converted to their underlying types prior to transport. +/// When [`DictionaryHandling`] is [`DictionaryHandling::Resend`], Dictionary [`FlightData`] is sent with every +/// [`RecordBatch`] that contains a [`DictionaryArray`](arrow_array::array::DictionaryArray). +/// See . /// /// # Example /// ```no_run @@ -41,14 +48,14 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; /// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; /// # async fn f() { /// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); -/// # let record_batch = RecordBatch::try_from_iter(vec![ +/// # let batch = RecordBatch::try_from_iter(vec![ /// # ("a", Arc::new(c1) as ArrayRef) /// # ]) /// # .expect("cannot create record batch"); /// use arrow_flight::encode::FlightDataEncoderBuilder; /// /// // Get an input stream of Result -/// let input_stream = futures::stream::iter(vec![Ok(record_batch)]); +/// let input_stream = futures::stream::iter(vec![Ok(batch)]); /// /// // Build a stream of `Result` (e.g. to return for do_get) /// let flight_data_stream = FlightDataEncoderBuilder::new() @@ -59,6 +66,39 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; /// # } /// ``` /// +/// # Example: Sending `Vec` +/// +/// You can create a [`Stream`] to pass to [`Self::build`] from an existing +/// `Vec` of `RecordBatch`es like this: +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # fn make_batches() -> Vec { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// # vec![batch.clone(), batch.clone()] +/// # } +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get batches that you want to send via Flight +/// let batches: Vec = make_batches(); +/// +/// // Create an input stream of Result +/// let input_stream = futures::stream::iter( +/// batches.into_iter().map(Ok) +/// ); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// .build(input_stream); +/// # } +/// ``` +/// /// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get /// [`FlightError`]: crate::error::FlightError #[derive(Debug)] @@ -74,6 +114,9 @@ pub struct FlightDataEncoderBuilder { schema: Option, /// Optional flight descriptor, if known before data. descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, } /// Default target size for encoded [`FlightData`]. @@ -90,6 +133,7 @@ impl Default for FlightDataEncoderBuilder { app_metadata: Bytes::new(), schema: None, descriptor: None, + dictionary_handling: DictionaryHandling::Hydrate, } } } @@ -114,6 +158,15 @@ impl FlightDataEncoderBuilder { self } + /// Set [`DictionaryHandling`] for encoder + pub fn with_dictionary_handling( + mut self, + dictionary_handling: DictionaryHandling, + ) -> Self { + self.dictionary_handling = dictionary_handling; + self + } + /// Specify application specific metadata included in the /// [`FlightData::app_metadata`] field of the the first Schema /// message @@ -146,8 +199,10 @@ impl FlightDataEncoderBuilder { self } - /// Return a [`Stream`](futures::Stream) of [`FlightData`], - /// consuming self. More details on [`FlightDataEncoder`] + /// Takes a [`Stream`] of [`Result`] and returns a [`Stream`] + /// of [`FlightData`], consuming self. + /// + /// See example on [`Self`] and [`FlightDataEncoder`] for more details pub fn build(self, input: S) -> FlightDataEncoder where S: Stream> + Send + 'static, @@ -158,6 +213,7 @@ impl FlightDataEncoderBuilder { app_metadata, schema, descriptor, + dictionary_handling, } = self; FlightDataEncoder::new( @@ -167,6 +223,7 @@ impl FlightDataEncoderBuilder { options, app_metadata, descriptor, + dictionary_handling, ) } } @@ -192,6 +249,9 @@ pub struct FlightDataEncoder { done: bool, /// cleared after the first FlightData message is sent descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, } impl FlightDataEncoder { @@ -202,16 +262,21 @@ impl FlightDataEncoder { options: IpcWriteOptions, app_metadata: Bytes, descriptor: Option, + dictionary_handling: DictionaryHandling, ) -> Self { let mut encoder = Self { inner, schema: None, max_flight_data_size, - encoder: FlightIpcEncoder::new(options), + encoder: FlightIpcEncoder::new( + options, + dictionary_handling != DictionaryHandling::Resend, + ), app_metadata: Some(app_metadata), queue: VecDeque::new(), done: false, descriptor, + dictionary_handling, }; // If schema is known up front, enqueue it immediately @@ -242,7 +307,8 @@ impl FlightDataEncoder { fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef { // The first message is the schema message, and all // batches have the same schema - let schema = Arc::new(prepare_schema_for_flight(schema)); + let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; + let schema = Arc::new(prepare_schema_for_flight(schema, send_dictionaries)); let mut schema_flight_data = self.encoder.encode_schema(&schema); // attach any metadata requested @@ -264,7 +330,8 @@ impl FlightDataEncoder { }; // encode the batch - let batch = prepare_batch_for_flight(&batch, schema)?; + let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; + let batch = prepare_batch_for_flight(&batch, schema, send_dictionaries)?; for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { let (flight_dictionaries, flight_batch) = @@ -325,17 +392,46 @@ impl Stream for FlightDataEncoder { } } +/// Defines how a [`FlightDataEncoder`] encodes [`DictionaryArray`]s +/// +/// [`DictionaryArray`]: arrow_array::DictionaryArray +#[derive(Debug, PartialEq)] +pub enum DictionaryHandling { + /// Expands to the underlying type (default). This likely sends more data + /// over the network but requires less memory (dictionaries are not tracked) + /// and is more compatible with other arrow flight client implementations + /// that may not support `DictionaryEncoding` + /// + /// An IPC response, streaming or otherwise, defines its schema up front + /// which defines the mapping from dictionary IDs. It then sends these + /// dictionaries over the wire. + /// + /// This requires identifying the different dictionaries in use, assigning + /// them IDs, and sending new dictionaries, delta or otherwise, when needed + /// + /// See also: + /// * + Hydrate, + /// Send dictionary FlightData with every RecordBatch that contains a + /// [`DictionaryArray`]. See [`Self::Hydrate`] for more tradeoffs. No + /// attempt is made to skip sending the same (logical) dictionary values + /// twice. + /// + /// [`DictionaryArray`]: arrow_array::DictionaryArray + Resend, +} + /// Prepare an arrow Schema for transport over the Arrow Flight protocol /// /// Convert dictionary types to underlying types /// /// See hydrate_dictionary for more information -fn prepare_schema_for_flight(schema: &Schema) -> Schema { +fn prepare_schema_for_flight(schema: &Schema, send_dictionaries: bool) -> Schema { let fields: Fields = schema .fields() .iter() .map(|field| match field.data_type() { - DataType::Dictionary(_, value_type) => Field::new( + DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( field.name(), value_type.as_ref().clone(), field.is_nullable(), @@ -394,8 +490,7 @@ struct FlightIpcEncoder { } impl FlightIpcEncoder { - fn new(options: IpcWriteOptions) -> Self { - let error_on_replacement = true; + fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { Self { options, data_gen: IpcDataGenerator::default(), @@ -438,12 +533,14 @@ impl FlightIpcEncoder { fn prepare_batch_for_flight( batch: &RecordBatch, schema: SchemaRef, + send_dictionaries: bool, ) -> Result { let columns = batch .columns() .iter() - .map(hydrate_dictionary) + .map(|c| hydrate_dictionary(c, send_dictionaries)) .collect::>>()?; + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); Ok(RecordBatch::try_new_with_options( @@ -451,35 +548,28 @@ fn prepare_batch_for_flight( )?) } -/// Hydrates a dictionary to its underlying type -/// -/// An IPC response, streaming or otherwise, defines its schema up front -/// which defines the mapping from dictionary IDs. It then sends these -/// dictionaries over the wire. -/// -/// This requires identifying the different dictionaries in use, assigning -/// them IDs, and sending new dictionaries, delta or otherwise, when needed -/// -/// See also: -/// * -/// -/// For now we just hydrate the dictionaries to their underlying type -fn hydrate_dictionary(array: &ArrayRef) -> Result { - let arr = if let DataType::Dictionary(_, value) = array.data_type() { - arrow_cast::cast(array, value)? - } else { - Arc::clone(array) +/// Hydrates a dictionary to its underlying type if send_dictionaries is false. If send_dictionaries +/// is true, dictionaries are sent with every batch which is not as optimal as described in [DictionaryHandling::Hydrate] above, +/// but does enable sending DictionaryArray's via Flight. +fn hydrate_dictionary(array: &ArrayRef, send_dictionaries: bool) -> Result { + let arr = match array.data_type() { + DataType::Dictionary(_, value) if !send_dictionaries => { + arrow_cast::cast(array, value)? + } + _ => Arc::clone(array), }; Ok(arr) } #[cfg(test)] mod tests { - use arrow_array::types::*; use arrow_array::*; + use arrow_array::{cast::downcast_array, types::*}; use arrow_cast::pretty::pretty_format_batches; use std::collections::HashMap; + use crate::decode::{DecodedPayload, FlightDataDecoder}; + use super::*; #[test] @@ -497,7 +587,7 @@ mod tests { let big_batch = batch.slice(0, batch.num_rows() - 1); let optimized_big_batch = - prepare_batch_for_flight(&big_batch, Arc::clone(&schema)) + prepare_batch_for_flight(&big_batch, Arc::clone(&schema), false) .expect("failed to optimize"); let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options); @@ -509,7 +599,7 @@ mod tests { let small_batch = batch.slice(0, 1); let optimized_small_batch = - prepare_batch_for_flight(&small_batch, Arc::clone(&schema)) + prepare_batch_for_flight(&small_batch, Arc::clone(&schema), false) .expect("failed to optimize"); let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options); @@ -520,6 +610,84 @@ mod tests { ); } + #[tokio::test] + async fn test_dictionary_hydration() { + let arr: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap(); + let encoder = FlightDataEncoderBuilder::default() + .build(futures::stream::once(async { Ok(batch) })); + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = + Schema::new(vec![Field::new("dict", DataType::Utf8, false)]); + let expected_schema = Arc::new(expected_schema); + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = StringArray::from(vec!["a", "a", "b"]); + let actual_array = b.column_by_name("dict").unwrap(); + let actual_array = downcast_array::(actual_array); + + assert_eq!(actual_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_send_dictionaries() { + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + + let arr_one: Arc> = + Arc::new(vec!["a", "a", "b"].into_iter().collect()); + let arr_two: Arc> = + Arc::new(vec!["b", "a", "c"].into_iter().collect()); + let batch_one = + RecordBatch::try_new(schema.clone(), vec![arr_one.clone()]).unwrap(); + let batch_two = + RecordBatch::try_new(schema.clone(), vec![arr_two.clone()]).unwrap(); + + let encoder = FlightDataEncoderBuilder::default() + .with_dictionary_handling(DictionaryHandling::Resend) + .build(futures::stream::iter(vec![Ok(batch_one), Ok(batch_two)])); + + let mut decoder = FlightDataDecoder::new(encoder); + let mut expected_array = arr_one; + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), schema); + + let actual_array = + Arc::new(downcast_array::>( + b.column_by_name("dict").unwrap(), + )); + + assert_eq!(actual_array, expected_array); + + expected_array = arr_two.clone(); + } + } + } + } + #[test] fn test_schema_metadata_encoded() { let schema = @@ -527,7 +695,7 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); - let got = prepare_schema_for_flight(&schema); + let got = prepare_schema_for_flight(&schema, false); assert!(got.metadata().contains_key("some_key")); } @@ -540,7 +708,8 @@ mod tests { ) .expect("cannot create record batch"); - prepare_batch_for_flight(&batch, batch.schema()).expect("failed to optimize"); + prepare_batch_for_flight(&batch, batch.schema(), false) + .expect("failed to optimize"); } pub fn make_flight_data( diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 4163f2ceaa27..3035f109c685 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -111,6 +111,9 @@ pub use gen::Result; pub use gen::SchemaResult; pub use gen::Ticket; +/// Helper to extract HTTP/gRPC trailers from a tonic stream. +mod trailers; + pub mod utils; #[cfg(feature = "flight-sql-experimental")] @@ -313,16 +316,6 @@ impl TryFrom> for SchemaResult { } } -// TryFrom... - -impl TryFrom for DescriptorType { - type Error = ArrowError; - - fn try_from(value: i32) -> ArrowResult { - value.try_into() - } -} - impl TryFrom> for IpcMessage { type Error = ArrowError; diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs index b2137d8543d3..c7c23311e61e 100644 --- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -1077,10 +1077,10 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported commands. /// /// For instance: - /// - return 0 (\b0) => [] (GROUP BY is unsupported); + /// - return 0 (\b0) => \[\] (GROUP BY is unsupported); /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; - /// - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. + /// - return 3 (\b11) => \[SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT\]. /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. SqlSupportedGroupBy = 522, /// @@ -1104,14 +1104,14 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. /// /// For instance: - /// - return 0 (\b0) => [] (SQL grammar is unsupported); + /// - return 0 (\b0) => \[\] (SQL grammar is unsupported); /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; - /// - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; + /// - return 3 (\b11) => \[SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR\]; /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; - /// - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. + /// - return 5 (\b101) => \[SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR\]; + /// - return 6 (\b110) => \[SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR\]; + /// - return 7 (\b111) => \[SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR\]. /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. SqlSupportedGrammar = 525, /// @@ -1121,14 +1121,14 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported commands. /// /// For instance: - /// - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); + /// - return 0 (\b0) => \[\] (ANSI92 SQL grammar is unsupported); /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; - /// - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; + /// - return 3 (\b11) => \[ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL\]; /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; - /// - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; - /// - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; - /// - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. + /// - return 5 (\b101) => \[ANSI92_ENTRY_SQL, ANSI92_FULL_SQL\]; + /// - return 6 (\b110) => \[ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL\]; + /// - return 7 (\b111) => \[ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL\]. /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. SqlAnsi92SupportedLevel = 526, /// @@ -1165,14 +1165,14 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. /// /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL schema); + /// - return 0 (\b0) => \[\] (no supported actions for SQL schema); /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + /// - return 3 (\b11) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + /// - return 5 (\b101) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 6 (\b110) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 7 (\b111) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]. /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlSchemasSupportedActions = 533, /// @@ -1182,14 +1182,14 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. /// /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL catalog); + /// - return 0 (\b0) => \[\] (no supported actions for SQL catalog); /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + /// - return 3 (\b11) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + /// - return 5 (\b101) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 6 (\b110) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 7 (\b111) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]. /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlCatalogsSupportedActions = 534, /// @@ -1199,10 +1199,10 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. /// /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); + /// - return 0 (\b0) => \[\] (no supported SQL positioned commands); /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; - /// - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. + /// - return 3 (\b11) => \[SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE\]. /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. SqlSupportedPositionedCommands = 535, /// @@ -1227,22 +1227,22 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. /// /// For instance: - /// - return 0 (\b0) => [] (no supported SQL subqueries); + /// - return 0 (\b0) => \[\] (no supported SQL subqueries); /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; - /// - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; + /// - return 3 (\b11) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS\]; /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; - /// - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; - /// - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; - /// - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; + /// - return 5 (\b101) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS\]; + /// - return 6 (\b110) => \[SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 7 (\b111) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS\]; /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; - /// - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 9 (\b1001) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 10 (\b1010) => \[SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 11 (\b1011) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 12 (\b1100) => \[SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 13 (\b1101) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 14 (\b1110) => \[SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 15 (\b1111) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; /// - ... /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. SqlSupportedSubqueries = 538, @@ -1260,10 +1260,10 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. /// /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); + /// - return 0 (\b0) => \[\] (no supported SQL positioned commands); /// - return 1 (\b1) => \[SQL_UNION\]; /// - return 2 (\b10) => \[SQL_UNION_ALL\]; - /// - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. + /// - return 3 (\b11) => \[SQL_UNION, SQL_UNION_ALL\]. /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. SqlSupportedUnions = 540, /// Retrieves a int64 value representing the maximum number of hex characters allowed in an inline binary literal. @@ -1341,22 +1341,22 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. /// /// For instance: - /// - return 0 (\b0) => [] (no supported SQL transactions isolation levels); + /// - return 0 (\b0) => \[\] (no supported SQL transactions isolation levels); /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; - /// - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; + /// - return 3 (\b11) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED\]; /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 5 (\b101) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 6 (\b110) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 7 (\b111) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 9 (\b1001) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 10 (\b1010) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 11 (\b1011) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 12 (\b1100) => \[SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 13 (\b1101) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 14 (\b1110) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 15 (\b1111) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; /// - ... /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. @@ -1381,14 +1381,14 @@ pub enum SqlInfo { /// The returned bitmask should be parsed in order to retrieve the supported result set types. /// /// For instance: - /// - return 0 (\b0) => [] (no supported result set types); + /// - return 0 (\b0) => \[\] (no supported result set types); /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; - /// - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + /// - return 3 (\b11) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; - /// - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + /// - return 5 (\b101) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 6 (\b110) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 7 (\b111) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; /// - ... /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. @@ -1398,14 +1398,14 @@ pub enum SqlInfo { /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. /// /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetUnspecified = 568, /// @@ -1413,14 +1413,14 @@ pub enum SqlInfo { /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. /// /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetForwardOnly = 569, /// @@ -1428,14 +1428,14 @@ pub enum SqlInfo { /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. /// /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollSensitive = 570, /// @@ -1443,14 +1443,14 @@ pub enum SqlInfo { /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. /// /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollInsensitive = 571, /// diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index c9adc2b98b12..7685813ff844 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -24,6 +24,9 @@ use std::collections::HashMap; use std::str::FromStr; use tonic::metadata::AsciiMetadataKey; +use crate::decode::FlightRecordBatchStream; +use crate::encode::FlightDataEncoderBuilder; +use crate::error::FlightError; use crate::flight_service_client::FlightServiceClient; use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT}; use crate::sql::{ @@ -32,9 +35,10 @@ use crate::sql::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate, - DoPutUpdateResult, ProstMessageExt, SqlInfo, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, + CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo, }; +use crate::trailers::extract_lazy_trailers; use crate::{ Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, PutResult, Ticket, @@ -150,7 +154,7 @@ impl FlightSqlServiceClient { .flight_client .handshake(req) .await - .map_err(|e| ArrowError::IoError(format!("Can't handshake {e}")))?; + .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?; if let Some(auth) = resp.metadata().get("authorization") { let auth = auth.to_str().map_err(|_| { ArrowError::ParseError("Can't read auth header".to_string()) @@ -229,14 +233,22 @@ impl FlightSqlServiceClient { pub async fn do_get( &mut self, ticket: impl IntoRequest, - ) -> Result, ArrowError> { + ) -> Result { let req = self.set_request_headers(ticket.into_request())?; - Ok(self + + let (md, response_stream, _ext) = self .flight_client .do_get(req) .await .map_err(status_to_arrow_error)? - .into_inner()) + .into_parts(); + let (response_stream, trailers) = extract_lazy_trailers(response_stream); + + Ok(FlightRecordBatchStream::new_from_flight_data( + response_stream.map_err(FlightError::Tonic), + ) + .with_headers(md) + .with_trailers(trailers)) } /// Push a stream to the flight service associated with a particular flight stream. @@ -390,16 +402,20 @@ impl FlightSqlServiceClient { ) -> Result, ArrowError> { for (k, v) in &self.headers { let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| { - ArrowError::IoError(format!("Cannot convert header key \"{k}\": {e}")) + ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) })?; let v = v.parse().map_err(|e| { - ArrowError::IoError(format!("Cannot convert header value \"{v}\": {e}")) + ArrowError::ParseError(format!( + "Cannot convert header value \"{v}\": {e}" + )) })?; req.metadata_mut().insert(k, v); } if let Some(token) = &self.token { let val = format!("Bearer {token}").parse().map_err(|e| { - ArrowError::IoError(format!("Cannot convert token to header value: {e}")) + ArrowError::ParseError(format!( + "Cannot convert token to header value: {e}" + )) })?; req.metadata_mut().insert("authorization", val); } @@ -435,9 +451,12 @@ impl PreparedStatement { /// Executes the prepared statement query on the server. pub async fn execute(&mut self) -> Result { + self.write_bind_params().await?; + let cmd = CommandPreparedStatementQuery { prepared_statement_handle: self.handle.clone(), }; + let result = self .flight_sql_client .get_flight_info_for_command(cmd) @@ -447,7 +466,9 @@ impl PreparedStatement { /// Executes the prepared statement update query on the server. pub async fn execute_update(&mut self) -> Result { - let cmd = CommandPreparedStatementQuery { + self.write_bind_params().await?; + + let cmd = CommandPreparedStatementUpdate { prepared_statement_handle: self.handle.clone(), }; let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); @@ -488,6 +509,36 @@ impl PreparedStatement { Ok(()) } + /// Submit parameters to the server, if any have been set on this prepared statement instance + async fn write_bind_params(&mut self) -> Result<(), ArrowError> { + if let Some(ref params_batch) = self.parameter_binding { + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle: self.handle.clone(), + }; + + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let flight_stream_builder = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .with_schema(params_batch.schema()); + let flight_data = flight_stream_builder + .build(futures::stream::iter( + self.parameter_binding.clone().map(Ok), + )) + .try_collect::>() + .await + .map_err(flight_error_to_arrow_error)?; + + self.flight_sql_client + .do_put(stream::iter(flight_data)) + .await? + .try_collect::>() + .await + .map_err(status_to_arrow_error)?; + } + + Ok(()) + } + /// Close the prepared statement, so that this PreparedStatement can not used /// anymore and server can free up any resources. pub async fn close(mut self) -> Result<(), ArrowError> { @@ -504,11 +555,18 @@ impl PreparedStatement { } fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError { - ArrowError::IoError(err.to_string()) + ArrowError::IpcError(err.to_string()) } fn status_to_arrow_error(status: tonic::Status) -> ArrowError { - ArrowError::IoError(format!("{status:?}")) + ArrowError::IpcError(format!("{status:?}")) +} + +fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { + match err { + FlightError::Arrow(e) => e, + e => ArrowError::ExternalError(Box::new(e)), + } } // A polymorphic structure to natively represent different types of data contained in `FlightData` @@ -538,7 +596,7 @@ pub fn arrow_data_from_flight_data( let dictionaries_by_field = HashMap::new(); let record_batch = read_record_batch( - &Buffer::from(&flight_data.data_body), + &Buffer::from_bytes(flight_data.data_body.into()), ipc_record_batch, arrow_schema_ref.clone(), &dictionaries_by_field, diff --git a/arrow-flight/src/sql/metadata/db_schemas.rs b/arrow-flight/src/sql/metadata/db_schemas.rs index 7b10e1c14299..642802b058d5 100644 --- a/arrow-flight/src/sql/metadata/db_schemas.rs +++ b/arrow-flight/src/sql/metadata/db_schemas.rs @@ -22,11 +22,11 @@ use std::sync::Arc; use arrow_arith::boolean::and; -use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch}; -use arrow_ord::comparison::eq_utf8_scalar; +use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch, StringArray}; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::{filter::filter_record_batch, take::take}; -use arrow_string::like::like_utf8_scalar; +use arrow_string::like::like; use once_cell::sync::Lazy; use super::lexsort_to_indices; @@ -122,14 +122,13 @@ impl GetDbSchemasBuilder { if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { // use like kernel to get wildcard matching - filters.push(like_utf8_scalar( - &db_schema_name, - &db_schema_filter_pattern, - )?) + let scalar = StringArray::new_scalar(db_schema_filter_pattern); + filters.push(like(&db_schema_name, &scalar)?) } if let Some(catalog_filter_name) = catalog_filter { - filters.push(eq_utf8_scalar(&catalog_name, &catalog_filter_name)?); + let scalar = StringArray::new_scalar(catalog_filter_name); + filters.push(eq(&catalog_name, &scalar)?); } // `AND` any filters together diff --git a/arrow-flight/src/sql/metadata/mod.rs b/arrow-flight/src/sql/metadata/mod.rs index 72c882f385d3..1e9881ffa70e 100644 --- a/arrow-flight/src/sql/metadata/mod.rs +++ b/arrow-flight/src/sql/metadata/mod.rs @@ -21,10 +21,14 @@ //! - [`GetCatalogsBuilder`] for building responses to [`CommandGetCatalogs`] queries. //! - [`GetDbSchemasBuilder`] for building responses to [`CommandGetDbSchemas`] queries. //! - [`GetTablesBuilder`]for building responses to [`CommandGetTables`] queries. +//! - [`SqlInfoDataBuilder`]for building responses to [`CommandGetSqlInfo`] queries. +//! - [`XdbcTypeInfoDataBuilder`]for building responses to [`CommandGetXdbcTypeInfo`] queries. //! //! [`CommandGetCatalogs`]: crate::sql::CommandGetCatalogs //! [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas //! [`CommandGetTables`]: crate::sql::CommandGetTables +//! [`CommandGetSqlInfo`]: crate::sql::CommandGetSqlInfo +//! [`CommandGetXdbcTypeInfo`]: crate::sql::CommandGetXdbcTypeInfo mod catalogs; mod db_schemas; @@ -49,7 +53,7 @@ fn lexsort_to_indices(arrays: &[ArrayRef]) -> UInt32Array { .iter() .map(|a| SortField::new(a.data_type().clone())) .collect(); - let mut converter = RowConverter::new(fields).unwrap(); + let converter = RowConverter::new(fields).unwrap(); let rows = converter.convert_columns(arrays).unwrap(); let mut sort: Vec<_> = rows.iter().enumerate().collect(); sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); diff --git a/arrow-flight/src/sql/metadata/sql_info.rs b/arrow-flight/src/sql/metadata/sql_info.rs index d0c9cedbcf7c..88c97227814d 100644 --- a/arrow-flight/src/sql/metadata/sql_info.rs +++ b/arrow-flight/src/sql/metadata/sql_info.rs @@ -33,10 +33,9 @@ use arrow_array::builder::{ ArrayBuilder, BooleanBuilder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, UInt32Builder, }; -use arrow_array::cast::downcast_array; -use arrow_array::RecordBatch; +use arrow_array::{RecordBatch, Scalar}; use arrow_data::ArrayData; -use arrow_ord::comparison::eq_scalar; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, UnionFields, UnionMode}; use arrow_select::filter::filter_record_batch; use once_cell::sync::Lazy; @@ -334,8 +333,8 @@ impl SqlInfoUnionBuilder { /// [`CommandGetSqlInfo`] are metadata requests used by a Flight SQL /// server to communicate supported capabilities to Flight SQL clients. /// -/// Servers constuct - usually static - [`SqlInfoData`] via the [SqlInfoDataBuilder`], -/// and build responses by passing the [`GetSqlInfoBuilder`]. +/// Servers constuct - usually static - [`SqlInfoData`] via the [`SqlInfoDataBuilder`], +/// and build responses using [`CommandGetSqlInfo::into_builder`] #[derive(Debug, Clone, PartialEq)] pub struct SqlInfoDataBuilder { /// Use BTreeMap to ensure the values are sorted by value as @@ -425,13 +424,16 @@ impl SqlInfoData { &self, info: impl IntoIterator, ) -> Result { - let arr: UInt32Array = downcast_array(self.batch.column(0).as_ref()); + let arr = self.batch.column(0); let type_filter = info .into_iter() - .map(|tt| eq_scalar(&arr, tt)) + .map(|tt| { + let s = UInt32Array::from(vec![tt]); + eq(arr, &Scalar::new(&s)) + }) .collect::, _>>()? .into_iter() - // We know the arrays are of same length as they are produced fromn the same root array + // We know the arrays are of same length as they are produced from the same root array .reduce(|filter, arr| or(&filter, &arr).unwrap()); if let Some(filter) = type_filter { Ok(filter_record_batch(&self.batch, &filter)?) diff --git a/arrow-flight/src/sql/metadata/tables.rs b/arrow-flight/src/sql/metadata/tables.rs index 67193969d46d..00502a76db53 100644 --- a/arrow-flight/src/sql/metadata/tables.rs +++ b/arrow-flight/src/sql/metadata/tables.rs @@ -23,11 +23,11 @@ use std::sync::Arc; use arrow_arith::boolean::{and, or}; use arrow_array::builder::{BinaryBuilder, StringBuilder}; -use arrow_array::{ArrayRef, RecordBatch}; -use arrow_ord::comparison::eq_utf8_scalar; +use arrow_array::{ArrayRef, RecordBatch, StringArray}; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::{filter::filter_record_batch, take::take}; -use arrow_string::like::like_utf8_scalar; +use arrow_string::like::like; use once_cell::sync::Lazy; use super::lexsort_to_indices; @@ -184,12 +184,13 @@ impl GetTablesBuilder { let mut filters = vec![]; if let Some(catalog_filter_name) = catalog_filter { - filters.push(eq_utf8_scalar(&catalog_name, &catalog_filter_name)?); + let scalar = StringArray::new_scalar(catalog_filter_name); + filters.push(eq(&catalog_name, &scalar)?); } let tt_filter = table_types_filter .into_iter() - .map(|tt| eq_utf8_scalar(&table_type, &tt)) + .map(|tt| eq(&table_type, &StringArray::new_scalar(tt))) .collect::, _>>()? .into_iter() // We know the arrays are of same length as they are produced fromn the same root array @@ -200,15 +201,14 @@ impl GetTablesBuilder { if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { // use like kernel to get wildcard matching - filters.push(like_utf8_scalar( - &db_schema_name, - &db_schema_filter_pattern, - )?) + let scalar = StringArray::new_scalar(db_schema_filter_pattern); + filters.push(like(&db_schema_name, &scalar)?) } if let Some(table_name_filter_pattern) = table_name_filter_pattern { // use like kernel to get wildcard matching - filters.push(like_utf8_scalar(&table_name, &table_name_filter_pattern)?) + let scalar = StringArray::new_scalar(table_name_filter_pattern); + filters.push(like(&table_name, &scalar)?) } let batch = if let Some(table_schema) = table_schema { diff --git a/arrow-flight/src/sql/metadata/xdbc_info.rs b/arrow-flight/src/sql/metadata/xdbc_info.rs index cecef1b49e8b..8212c847a4fa 100644 --- a/arrow-flight/src/sql/metadata/xdbc_info.rs +++ b/arrow-flight/src/sql/metadata/xdbc_info.rs @@ -27,9 +27,8 @@ use std::sync::Arc; use arrow_array::builder::{BooleanBuilder, Int32Builder, ListBuilder, StringBuilder}; -use arrow_array::cast::downcast_array; -use arrow_array::{ArrayRef, Int32Array, ListArray, RecordBatch}; -use arrow_ord::comparison::eq_scalar; +use arrow_array::{ArrayRef, Int32Array, ListArray, RecordBatch, Scalar}; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::filter::filter_record_batch; use arrow_select::take::take; @@ -70,8 +69,8 @@ pub struct XdbcTypeInfo { /// [`CommandGetXdbcTypeInfo`] are metadata requests used by a Flight SQL /// server to communicate supported capabilities to Flight SQL clients. /// -/// Servers constuct - usually static - [`XdbcTypeInfoData`] via the [XdbcTypeInfoDataBuilder`], -/// and build responses by passing the [`GetXdbcTypeInfoBuilder`]. +/// Servers constuct - usually static - [`XdbcTypeInfoData`] via the [`XdbcTypeInfoDataBuilder`], +/// and build responses using [`CommandGetXdbcTypeInfo::into_builder`]. pub struct XdbcTypeInfoData { batch: RecordBatch, } @@ -81,8 +80,8 @@ impl XdbcTypeInfoData { /// from [`CommandGetXdbcTypeInfo`] pub fn record_batch(&self, data_type: impl Into>) -> Result { if let Some(dt) = data_type.into() { - let arr: Int32Array = downcast_array(self.batch.column(1).as_ref()); - let filter = eq_scalar(&arr, dt)?; + let scalar = Int32Array::from(vec![dt]); + let filter = eq(self.batch.column(1), &Scalar::new(&scalar))?; Ok(filter_record_batch(&self.batch, &filter)?) } else { Ok(self.batch.clone()) diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 4bb8ce8b36e5..4042ce8efc46 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -93,6 +93,7 @@ pub use gen::SqlSupportedTransactions; pub use gen::SqlSupportedUnions; pub use gen::SqlSupportsConvert; pub use gen::SqlTransactionIsolationLevel; +pub use gen::SubstraitPlan; pub use gen::SupportedSqlGrammar; pub use gen::TicketStatementQuery; pub use gen::UpdateDeleteRules; diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index f599fbca46a5..a158ed77f54d 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -19,7 +19,7 @@ use std::pin::Pin; -use futures::Stream; +use futures::{stream::Peekable, Stream, StreamExt}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; @@ -87,186 +87,286 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, - query: CommandStatementQuery, - request: Request, - ) -> Result, Status>; + _query: CommandStatementQuery, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_statement has no default implementation", + )) + } /// Get a FlightInfo for executing a substrait plan. async fn get_flight_info_substrait_plan( &self, - query: CommandStatementSubstraitPlan, - request: Request, - ) -> Result, Status>; + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_substrait_plan has no default implementation", + )) + } /// Get a FlightInfo for executing an already created prepared statement. async fn get_flight_info_prepared_statement( &self, - query: CommandPreparedStatementQuery, - request: Request, - ) -> Result, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_prepared_statement has no default implementation", + )) + } /// Get a FlightInfo for listing catalogs. async fn get_flight_info_catalogs( &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result, Status>; + _query: CommandGetCatalogs, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_catalogs has no default implementation", + )) + } /// Get a FlightInfo for listing schemas. async fn get_flight_info_schemas( &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result, Status>; + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_schemas has no default implementation", + )) + } /// Get a FlightInfo for listing tables. async fn get_flight_info_tables( &self, - query: CommandGetTables, - request: Request, - ) -> Result, Status>; + _query: CommandGetTables, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_tables has no default implementation", + )) + } /// Get a FlightInfo to extract information about the table types. async fn get_flight_info_table_types( &self, - query: CommandGetTableTypes, - request: Request, - ) -> Result, Status>; + _query: CommandGetTableTypes, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_table_types has no default implementation", + )) + } /// Get a FlightInfo for retrieving other information (See SqlInfo). async fn get_flight_info_sql_info( &self, - query: CommandGetSqlInfo, - request: Request, - ) -> Result, Status>; + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_sql_info has no default implementation", + )) + } /// Get a FlightInfo to extract information about primary and foreign keys. async fn get_flight_info_primary_keys( &self, - query: CommandGetPrimaryKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_primary_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about exported keys. async fn get_flight_info_exported_keys( &self, - query: CommandGetExportedKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_exported_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about imported keys. async fn get_flight_info_imported_keys( &self, - query: CommandGetImportedKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_imported_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about cross reference. async fn get_flight_info_cross_reference( &self, - query: CommandGetCrossReference, - request: Request, - ) -> Result, Status>; + _query: CommandGetCrossReference, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_cross_reference has no default implementation", + )) + } /// Get a FlightInfo to extract information about the supported XDBC types. async fn get_flight_info_xdbc_type_info( &self, - query: CommandGetXdbcTypeInfo, - request: Request, - ) -> Result, Status>; + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_xdbc_type_info has no default implementation", + )) + } // do_get /// Get a FlightDataStream containing the query results. async fn do_get_statement( &self, - ticket: TicketStatementQuery, - request: Request, - ) -> Result::DoGetStream>, Status>; + _ticket: TicketStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_statement has no default implementation", + )) + } /// Get a FlightDataStream containing the prepared statement query results. async fn do_get_prepared_statement( &self, - query: CommandPreparedStatementQuery, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_prepared_statement has no default implementation", + )) + } /// Get a FlightDataStream containing the list of catalogs. async fn do_get_catalogs( &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetCatalogs, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_catalogs has no default implementation", + )) + } /// Get a FlightDataStream containing the list of schemas. async fn do_get_schemas( &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_schemas has no default implementation", + )) + } /// Get a FlightDataStream containing the list of tables. async fn do_get_tables( &self, - query: CommandGetTables, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetTables, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_tables has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the table types. async fn do_get_table_types( &self, - query: CommandGetTableTypes, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetTableTypes, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_table_types has no default implementation", + )) + } /// Get a FlightDataStream containing the list of SqlInfo results. async fn do_get_sql_info( &self, - query: CommandGetSqlInfo, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_sql_info has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the primary and foreign keys. async fn do_get_primary_keys( &self, - query: CommandGetPrimaryKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_primary_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the exported keys. async fn do_get_exported_keys( &self, - query: CommandGetExportedKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_exported_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the imported keys. async fn do_get_imported_keys( &self, - query: CommandGetImportedKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_imported_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the cross reference. async fn do_get_cross_reference( &self, - query: CommandGetCrossReference, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetCrossReference, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_cross_reference has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the supported XDBC types. async fn do_get_xdbc_type_info( &self, - query: CommandGetXdbcTypeInfo, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_xdbc_type_info has no default implementation", + )) + } // do_put /// Implementors may override to handle additional calls to do_put() async fn do_put_fallback( &self, - _request: Request>, + _request: Request, message: Any, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented(format!( @@ -278,30 +378,46 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { /// Execute an update SQL statement. async fn do_put_statement_update( &self, - ticket: CommandStatementUpdate, - request: Request>, - ) -> Result; + _ticket: CommandStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_statement_update has no default implementation", + )) + } /// Bind parameters to given prepared statement. async fn do_put_prepared_statement_query( &self, - query: CommandPreparedStatementQuery, - request: Request>, - ) -> Result::DoPutStream>, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result::DoPutStream>, Status> { + Err(Status::unimplemented( + "do_put_prepared_statement_query has no default implementation", + )) + } /// Execute an update SQL prepared statement. async fn do_put_prepared_statement_update( &self, - query: CommandPreparedStatementUpdate, - request: Request>, - ) -> Result; + _query: CommandPreparedStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_prepared_statement_update has no default implementation", + )) + } /// Execute a substrait plan async fn do_put_substrait_plan( &self, - query: CommandStatementSubstraitPlan, - request: Request>, - ) -> Result; + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_substrait_plan has no default implementation", + )) + } // do_action @@ -324,58 +440,90 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { /// Create a prepared statement from given SQL statement. async fn do_action_create_prepared_statement( &self, - query: ActionCreatePreparedStatementRequest, - request: Request, - ) -> Result; + _query: ActionCreatePreparedStatementRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_create_prepared_statement has no default implementation", + )) + } /// Close a prepared statement. async fn do_action_close_prepared_statement( &self, - query: ActionClosePreparedStatementRequest, - request: Request, - ) -> Result<(), Status>; + _query: ActionClosePreparedStatementRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_close_prepared_statement has no default implementation", + )) + } /// Create a prepared substrait plan. async fn do_action_create_prepared_substrait_plan( &self, - query: ActionCreatePreparedSubstraitPlanRequest, - request: Request, - ) -> Result; + _query: ActionCreatePreparedSubstraitPlanRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_create_prepared_substrait_plan has no default implementation", + )) + } /// Begin a transaction async fn do_action_begin_transaction( &self, - query: ActionBeginTransactionRequest, - request: Request, - ) -> Result; + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_begin_transaction has no default implementation", + )) + } /// End a transaction async fn do_action_end_transaction( &self, - query: ActionEndTransactionRequest, - request: Request, - ) -> Result<(), Status>; + _query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_end_transaction has no default implementation", + )) + } /// Begin a savepoint async fn do_action_begin_savepoint( &self, - query: ActionBeginSavepointRequest, - request: Request, - ) -> Result; + _query: ActionBeginSavepointRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_begin_savepoint has no default implementation", + )) + } /// End a savepoint async fn do_action_end_savepoint( &self, - query: ActionEndSavepointRequest, - request: Request, - ) -> Result<(), Status>; + _query: ActionEndSavepointRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_end_savepoint has no default implementation", + )) + } /// Cancel a query async fn do_action_cancel_query( &self, - query: ActionCancelQueryRequest, - request: Request, - ) -> Result; + _query: ActionCancelQueryRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_cancel_query has no default implementation", + )) + } /// do_exchange @@ -540,9 +688,17 @@ where async fn do_put( &self, - mut request: Request>, + request: Request>, ) -> Result, Status> { - let cmd = request.get_mut().message().await?.unwrap(); + // See issue #4658: https://github.com/apache/arrow-rs/issues/4658 + // To dispatch to the correct `do_put` method, we cannot discard the first message, + // as it may contain the Arrow schema, which the `do_put` handler may need. + // To allow the first message to be reused by the `do_put` handler, + // we wrap this stream in a `Peekable` one, which allows us to peek at + // the first message without discarding it. + let mut request = request.map(PeekableFlightDataStream::new); + let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?; + let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; match Command::try_from(message).map_err(arrow_error_to_status)? { @@ -809,3 +965,89 @@ fn decode_error_to_status(err: prost::DecodeError) -> Status { fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status { Status::internal(format!("{err:?}")) } + +/// A wrapper around [`Streaming`] that allows "peeking" at the +/// message at the front of the stream without consuming it. +/// This is needed because sometimes the first message in the stream will contain +/// a [`FlightDescriptor`] in addition to potentially any data, and the dispatch logic +/// must inspect this information. +/// +/// # Example +/// +/// [`PeekableFlightDataStream::peek`] can be used to peek at the first message without +/// discarding it; otherwise, `PeekableFlightDataStream` can be used as a regular stream. +/// See the following example: +/// +/// ```no_run +/// use arrow_array::RecordBatch; +/// use arrow_flight::decode::FlightRecordBatchStream; +/// use arrow_flight::FlightDescriptor; +/// use arrow_flight::error::FlightError; +/// use arrow_flight::sql::server::PeekableFlightDataStream; +/// use tonic::{Request, Status}; +/// use futures::TryStreamExt; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Status> { +/// let request: Request = todo!(); +/// let stream: PeekableFlightDataStream = request.into_inner(); +/// +/// // The first message contains the flight descriptor and the schema. +/// // Read the flight descriptor without discarding the schema: +/// let flight_descriptor: FlightDescriptor = stream +/// .peek() +/// .await +/// .cloned() +/// .transpose()? +/// .and_then(|data| data.flight_descriptor) +/// .expect("first message should contain flight descriptor"); +/// +/// // Pass the stream through a decoder +/// let batches: Vec = FlightRecordBatchStream::new_from_flight_data( +/// request.into_inner().map_err(|e| e.into()), +/// ) +/// .try_collect() +/// .await?; +/// } +/// ``` +pub struct PeekableFlightDataStream { + inner: Peekable>, +} + +impl PeekableFlightDataStream { + fn new(stream: Streaming) -> Self { + Self { + inner: stream.peekable(), + } + } + + /// Convert this stream into a `Streaming`. + /// Any messages observed through [`Self::peek`] will be lost + /// after the conversion. + pub fn into_inner(self) -> Streaming { + self.inner.into_inner() + } + + /// Convert this stream into a `Peekable>`. + /// Preserves the state of the stream, so that calls to [`Self::peek`] + /// and [`Self::poll_next`] are the same. + pub fn into_peekable(self) -> Peekable> { + self.inner + } + + /// Peek at the head of this stream without advancing it. + pub async fn peek(&mut self) -> Option<&Result> { + Pin::new(&mut self.inner).peek().await + } +} + +impl Stream for PeekableFlightDataStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_next_unpin(cx) + } +} diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs new file mode 100644 index 000000000000..d652542da779 --- /dev/null +++ b/arrow-flight/src/trailers.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use futures::{ready, FutureExt, Stream, StreamExt}; +use tonic::{metadata::MetadataMap, Status, Streaming}; + +/// Extract [`LazyTrailers`] from [`Streaming`] [tonic] response. +/// +/// Note that [`LazyTrailers`] has inner mutability and will only hold actual data after [`ExtractTrailersStream`] is +/// fully consumed (dropping it is not required though). +pub fn extract_lazy_trailers( + s: Streaming, +) -> (ExtractTrailersStream, LazyTrailers) { + let trailers: SharedTrailers = Default::default(); + let stream = ExtractTrailersStream { + inner: s, + trailers: Arc::clone(&trailers), + }; + let lazy_trailers = LazyTrailers { trailers }; + (stream, lazy_trailers) +} + +type SharedTrailers = Arc>>; + +/// [Stream] that stores the gRPC trailers into [`LazyTrailers`]. +/// +/// See [`extract_lazy_trailers`] for construction. +#[derive(Debug)] +pub struct ExtractTrailersStream { + inner: Streaming, + trailers: SharedTrailers, +} + +impl Stream for ExtractTrailersStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = ready!(self.inner.poll_next_unpin(cx)); + + if res.is_none() { + // stream exhausted => trailers should available + if let Some(trailers) = self + .inner + .trailers() + .now_or_never() + .and_then(|res| res.ok()) + .flatten() + { + *self.trailers.lock().expect("poisoned") = Some(trailers); + } + } + + Poll::Ready(res) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +/// gRPC trailers that are extracted by [`ExtractTrailersStream`]. +/// +/// See [`extract_lazy_trailers`] for construction. +#[derive(Debug)] +pub struct LazyTrailers { + trailers: SharedTrailers, +} + +impl LazyTrailers { + /// gRPC trailers that are known at the end of a stream. + pub fn get(&self) -> Option { + self.trailers.lock().expect("poisoned").clone() + } +} diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index ccf1e73866e1..145626b6608f 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -147,11 +147,11 @@ pub fn ipc_message_from_arrow_schema( /// Convert `RecordBatch`es to wire protocol `FlightData`s pub fn batches_to_flight_data( - schema: Schema, + schema: &Schema, batches: Vec, ) -> Result, ArrowError> { let options = IpcWriteOptions::default(); - let schema_flight_data: FlightData = SchemaAsIpc::new(&schema, &options).into(); + let schema_flight_data: FlightData = SchemaAsIpc::new(schema, &options).into(); let mut dictionaries = vec![]; let mut flight_data = vec![]; @@ -166,8 +166,8 @@ pub fn batches_to_flight_data( flight_data.push(encoded_batch.into()); } let mut stream = vec![schema_flight_data]; - stream.extend(dictionaries.into_iter()); - stream.extend(flight_data.into_iter()); + stream.extend(dictionaries); + stream.extend(flight_data); let flight_data: Vec<_> = stream.into_iter().collect(); Ok(flight_data) } diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 8ea542879a27..1b9891e121fa 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -19,6 +19,7 @@ mod common { pub mod server; + pub mod trailers_layer; } use arrow_array::{RecordBatch, UInt64Array}; use arrow_flight::{ @@ -28,7 +29,7 @@ use arrow_flight::{ }; use arrow_schema::{DataType, Field, Schema}; use bytes::Bytes; -use common::server::TestFlightServer; +use common::{server::TestFlightServer, trailers_layer::TrailersLayer}; use futures::{Future, StreamExt, TryStreamExt}; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{ @@ -158,18 +159,42 @@ async fn test_do_get() { let response = vec![Ok(batch.clone())]; test_server.set_do_get_response(response); - let response_stream = client + let mut response_stream = client .do_get(ticket.clone()) .await .expect("error making request"); + assert_eq!( + response_stream + .headers() + .get("test-resp-header") + .expect("header exists") + .to_str() + .unwrap(), + "some_val", + ); + + // trailers are not available before stream exhaustion + assert!(response_stream.trailers().is_none()); + let expected_response = vec![batch]; - let response: Vec<_> = response_stream + let response: Vec<_> = (&mut response_stream) .try_collect() .await .expect("Error streaming data"); - assert_eq!(response, expected_response); + + assert_eq!( + response_stream + .trailers() + .expect("stream exhausted") + .get("test-trailer") + .expect("trailer exists") + .to_str() + .unwrap(), + "trailer_val", + ); + assert_eq!(test_server.take_do_get_request(), Some(ticket)); ensure_metadata(&client, &test_server); }) @@ -932,6 +957,7 @@ impl TestFixture { let serve_future = tonic::transport::Server::builder() .timeout(server_timeout) + .layer(TrailersLayer) .add_service(test_server.service()) .serve_with_incoming_shutdown( tokio_stream::wrappers::TcpListenerStream::new(listener), diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs index b87019d632c4..c575d12bbf52 100644 --- a/arrow-flight/tests/common/server.rs +++ b/arrow-flight/tests/common/server.rs @@ -359,7 +359,11 @@ impl FlightService for TestFlightServer { .build(batch_stream) .map_err(Into::into); - Ok(Response::new(stream.boxed())) + let mut resp = Response::new(stream.boxed()); + resp.metadata_mut() + .insert("test-resp-header", "some_val".parse().unwrap()); + + Ok(resp) } async fn do_put( diff --git a/arrow-flight/tests/common/trailers_layer.rs b/arrow-flight/tests/common/trailers_layer.rs new file mode 100644 index 000000000000..9e6be0dcf0da --- /dev/null +++ b/arrow-flight/tests/common/trailers_layer.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::ready; +use http::{HeaderValue, Request, Response}; +use http_body::SizeHint; +use pin_project_lite::pin_project; +use tower::{Layer, Service}; + +#[derive(Debug, Copy, Clone, Default)] +pub struct TrailersLayer; + +impl Layer for TrailersLayer { + type Service = TrailersService; + + fn layer(&self, service: S) -> Self::Service { + TrailersService { service } + } +} + +#[derive(Debug, Clone)] +pub struct TrailersService { + service: S, +} + +impl Service> for TrailersService +where + S: Service, Response = Response>, + ResBody: http_body::Body, +{ + type Response = Response>; + type Error = S::Error; + type Future = WrappedFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + WrappedFuture { + inner: self.service.call(request), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedFuture { + #[pin] + inner: F, + } +} + +impl Future for WrappedFuture +where + F: Future, Error>>, + ResBody: http_body::Body, +{ + type Output = Result>, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result: Result, Error> = + ready!(self.as_mut().project().inner.poll(cx)); + + match result { + Ok(response) => { + Poll::Ready(Ok(response.map(|body| WrappedBody { inner: body }))) + } + Err(e) => Poll::Ready(Err(e)), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedBody { + #[pin] + inner: B, + } +} + +impl http_body::Body for WrappedBody { + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.as_mut().project().inner.poll_data(cx) + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let result: Result, Self::Error> = + ready!(self.as_mut().project().inner.poll_trailers(cx)); + + let mut trailers = http::header::HeaderMap::new(); + trailers.insert("test-trailer", HeaderValue::from_static("trailer_val")); + + match result { + Ok(Some(mut existing)) => { + existing.extend(trailers.iter().map(|(k, v)| (k.clone(), v.clone()))); + Poll::Ready(Ok(Some(existing))) + } + Ok(None) => Poll::Ready(Ok(Some(trailers))), + Err(e) => Poll::Ready(Err(e)), + } + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +} diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 4f1a8e667ffc..71bcf4e0521a 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -386,7 +386,7 @@ async fn test_mismatched_schema_message() { do_test( make_primitive_batch(5), make_dictionary_batch(3), - "Error decoding ipc RecordBatch: Io error: Invalid data for schema", + "Error decoding ipc RecordBatch: Schema error: Invalid data for schema", ) .await; diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index c4ae9280c898..221e776218c3 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -19,11 +19,13 @@ use std::{net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; use arrow_flight::{ + decode::FlightRecordBatchStream, flight_service_server::{FlightService, FlightServiceServer}, sql::{ - server::FlightSqlService, ActionBeginSavepointRequest, - ActionBeginSavepointResult, ActionBeginTransactionRequest, - ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, + server::{FlightSqlService, PeekableFlightDataStream}, + ActionBeginSavepointRequest, ActionBeginSavepointResult, + ActionBeginTransactionRequest, ActionBeginTransactionResult, + ActionCancelQueryRequest, ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, @@ -36,18 +38,20 @@ use arrow_flight::{ }, utils::batches_to_flight_data, Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, Ticket, + HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket, }; +use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; use assert_cmd::Command; -use futures::Stream; +use bytes::Bytes; +use futures::{Stream, StreamExt, TryStreamExt}; use prost::Message; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{Request, Response, Status, Streaming}; const QUERY: &str = "SELECT * FROM table;"; -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[tokio::test] async fn test_simple() { let test_server = FlightSqlServiceImpl {}; let fixture = TestFixture::new(&test_server).await; @@ -63,6 +67,7 @@ async fn test_simple() { .arg(addr.ip().to_string()) .arg("--port") .arg(addr.port().to_string()) + .arg("statement-query") .arg(QUERY) .assert() .success() @@ -87,10 +92,56 @@ async fn test_simple() { ); } +const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; +const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; + +#[tokio::test] +async fn test_do_put_prepared_statement() { + let test_server = FlightSqlServiceImpl {}; + let fixture = TestFixture::new(&test_server).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("prepared-statement-query") + .arg(PREPARED_QUERY) + .args(["-p", "$1=string"]) + .args(["-p", "$2=64"]) + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+-----------+\ + \n| field_string | field_int |\ + \n+--------------+-----------+\ + \n| Hello | 42 |\ + \n| lovely | |\ + \n| FlightSQL! | 1337 |\ + \n+--------------+-----------+", + ); +} + /// All tests must complete within this many seconds or else the test server is shutdown const DEFAULT_TIMEOUT_SECONDS: u64 = 30; -#[derive(Clone)] +#[derive(Clone, Default)] pub struct FlightSqlServiceImpl {} impl FlightSqlServiceImpl { @@ -116,6 +167,59 @@ impl FlightSqlServiceImpl { ]; RecordBatch::try_new(Arc::new(schema), cols) } + + fn create_fake_prepared_stmt( + ) -> Result { + let handle = PREPARED_STATEMENT_HANDLE.to_string(); + let schema = Schema::new(vec![ + Field::new("field_string", DataType::Utf8, false), + Field::new("field_int", DataType::Int64, true), + ]); + + let parameter_schema = Schema::new(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]); + + Ok(ActionCreatePreparedStatementResult { + prepared_statement_handle: handle.into(), + dataset_schema: serialize_schema(&schema)?, + parameter_schema: serialize_schema(¶meter_schema)?, + }) + } + + fn fake_flight_info(&self) -> Result { + let batch = Self::fake_result()?; + + Ok(FlightInfo::new() + .try_with_schema(&batch.schema()) + .expect("encoding schema") + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_1"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_2"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_total_records(batch.num_rows() as i64) + .with_total_bytes(batch.get_array_memory_size() as i64) + .with_ordered(false)) + } +} + +fn serialize_schema(schema: &Schema) -> Result { + Ok(IpcMessage::try_from(SchemaAsIpc::new(schema, &IpcWriteOptions::default()))?.0) } #[tonic::async_trait] @@ -144,9 +248,9 @@ impl FlightSqlService for FlightSqlServiceImpl { "part_2" => batch.slice(2, 1), ticket => panic!("Invalid ticket: {ticket:?}"), }; - let schema = (*batch.schema()).clone(); + let schema = batch.schema(); let batches = vec![batch]; - let flight_data = batches_to_flight_data(schema, batches) + let flight_data = batches_to_flight_data(schema.as_ref(), batches) .unwrap() .into_iter() .map(Ok); @@ -164,45 +268,21 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result, Status> { assert_eq!(query.query, QUERY); - let batch = Self::fake_result().unwrap(); - - let info = FlightInfo::new() - .try_with_schema(&batch.schema()) - .expect("encoding schema") - .with_endpoint( - FlightEndpoint::new().with_ticket(Ticket::new( - FetchResults { - handle: String::from("part_1"), - } - .as_any() - .encode_to_vec(), - )), - ) - .with_endpoint( - FlightEndpoint::new().with_ticket(Ticket::new( - FetchResults { - handle: String::from("part_2"), - } - .as_any() - .encode_to_vec(), - )), - ) - .with_total_records(batch.num_rows() as i64) - .with_total_bytes(batch.get_array_memory_size() as i64) - .with_ordered(false); - - let resp = Response::new(info); + let resp = Response::new(self.fake_flight_info().unwrap()); Ok(resp) } async fn get_flight_info_prepared_statement( &self, - _cmd: CommandPreparedStatementQuery, + cmd: CommandPreparedStatementQuery, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_prepared_statement not implemented", - )) + assert_eq!( + cmd.prepared_statement_handle, + PREPARED_STATEMENT_HANDLE.as_bytes() + ); + let resp = Response::new(self.fake_flight_info().unwrap()); + Ok(resp) } async fn get_flight_info_substrait_plan( @@ -426,7 +506,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_statement_update not implemented", @@ -436,7 +516,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_substrait_plan( &self, _ticket: CommandStatementSubstraitPlan, - _request: Request>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_substrait_plan not implemented", @@ -446,17 +526,36 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, + request: Request, ) -> Result::DoPutStream>, Status> { - Err(Status::unimplemented( - "do_put_prepared_statement_query not implemented", + // just make sure decoding the parameters works + let parameters = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect::>() + .await?; + + for (left, right) in parameters[0].schema().all_fields().iter().zip(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]) { + if left.name() != right.name() || left.data_type() != right.data_type() { + return Err(Status::invalid_argument(format!( + "Parameters did not match parameter schema\ngot {}", + parameters[0].schema(), + ))); + } + } + + Ok(Response::new( + futures::stream::once(async { Ok(PutResult::default()) }).boxed(), )) } async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented", @@ -468,9 +567,8 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: ActionCreatePreparedStatementRequest, _request: Request, ) -> Result { - Err(Status::unimplemented( - "do_action_create_prepared_statement not implemented", - )) + Self::create_fake_prepared_stmt() + .map_err(|e| Status::internal(format!("Unable to serialize schema: {e}"))) } async fn do_action_close_prepared_statement( diff --git a/arrow-integration-test/src/lib.rs b/arrow-integration-test/src/lib.rs index 04bbcf3f6f23..07b69bffd07d 100644 --- a/arrow-integration-test/src/lib.rs +++ b/arrow-integration-test/src/lib.rs @@ -183,7 +183,8 @@ impl ArrowJson { return Ok(false); } } - _ => return Ok(false), + Some(Err(e)) => return Err(e), + None => return Ok(false), } } diff --git a/arrow-integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml index 7f78cf50a9d7..86c2cb27d297 100644 --- a/arrow-integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -39,11 +39,11 @@ async-trait = { version = "0.1.41", default-features = false } clap = { version = "4", default-features = false, features = ["std", "derive", "help", "error-context", "usage"] } futures = { version = "0.3", default-features = false } hex = { version = "0.4", default-features = false, features = ["std"] } -prost = { version = "0.11", default-features = false } +prost = { version = "0.12", default-features = false } serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.0", default-features = false } -tonic = { version = "0.9", default-features = false } +tonic = { version = "0.10", default-features = false } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true } num = { version = "0.4", default-features = false, features = ["std"] } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } diff --git a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs index 2c36e8d9b8ae..db5df8b58a6f 100644 --- a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs +++ b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs @@ -124,7 +124,7 @@ fn canonicalize_schema(schema: &Schema) -> Schema { let key_field = Arc::new(Field::new( "key", first_field.data_type().clone(), - first_field.is_nullable(), + false, )); let second_field = fields.get(1).unwrap(); let value_field = Arc::new(Field::new( @@ -135,8 +135,7 @@ fn canonicalize_schema(schema: &Schema) -> Schema { let fields = Fields::from([key_field, value_field]); let struct_type = DataType::Struct(fields); - let child_field = - Field::new("entries", struct_type, child_field.is_nullable()); + let child_field = Field::new("entries", struct_type, false); Arc::new(Field::new( field.name().as_str(), diff --git a/arrow-integration-testing/tests/ipc_reader.rs b/arrow-integration-testing/tests/ipc_reader.rs index 9205f4318393..696ab6e6053a 100644 --- a/arrow-integration-testing/tests/ipc_reader.rs +++ b/arrow-integration-testing/tests/ipc_reader.rs @@ -27,7 +27,7 @@ use std::fs::File; fn read_0_1_4() { let testdata = arrow_test_data(); let version = "0.14.1"; - let paths = vec![ + let paths = [ "generated_interval", "generated_datetime", "generated_dictionary", @@ -48,7 +48,7 @@ fn read_0_1_4() { fn read_0_1_7() { let testdata = arrow_test_data(); let version = "0.17.1"; - let paths = vec!["generated_union"]; + let paths = ["generated_union"]; paths.iter().for_each(|path| { verify_arrow_file(&testdata, version, path); verify_arrow_stream(&testdata, version, path); @@ -76,7 +76,7 @@ fn read_1_0_0_bigendian_dictionary_should_panic() { #[test] fn read_1_0_0_bigendian() { let testdata = arrow_test_data(); - let paths = vec![ + let paths = [ "generated_interval", "generated_datetime", "generated_map", @@ -145,7 +145,7 @@ fn read_2_0_0_compression() { let version = "2.0.0-compression"; // the test is repetitive, thus we can read all supported files at once - let paths = vec!["generated_lz4", "generated_zstd"]; + let paths = ["generated_lz4", "generated_zstd"]; paths.iter().for_each(|path| { verify_arrow_file(&testdata, version, path); verify_arrow_stream(&testdata, version, path); diff --git a/arrow-integration-testing/tests/ipc_writer.rs b/arrow-integration-testing/tests/ipc_writer.rs index 40f356b1d442..11707d935540 100644 --- a/arrow-integration-testing/tests/ipc_writer.rs +++ b/arrow-integration-testing/tests/ipc_writer.rs @@ -27,7 +27,7 @@ use std::io::Seek; fn write_0_1_4() { let testdata = arrow_test_data(); let version = "0.14.1"; - let paths = vec![ + let paths = [ "generated_interval", "generated_datetime", "generated_dictionary", @@ -48,7 +48,7 @@ fn write_0_1_4() { fn write_0_1_7() { let testdata = arrow_test_data(); let version = "0.17.1"; - let paths = vec!["generated_union"]; + let paths = ["generated_union"]; paths.iter().for_each(|path| { roundtrip_arrow_file(&testdata, version, path); roundtrip_arrow_stream(&testdata, version, path); @@ -59,7 +59,7 @@ fn write_0_1_7() { fn write_1_0_0_littleendian() { let testdata = arrow_test_data(); let version = "1.0.0-littleendian"; - let paths = vec![ + let paths = [ "generated_datetime", "generated_custom_metadata", "generated_decimal", @@ -94,10 +94,10 @@ fn write_1_0_0_littleendian() { fn write_2_0_0_compression() { let testdata = arrow_test_data(); let version = "2.0.0-compression"; - let paths = vec!["generated_lz4", "generated_zstd"]; + let paths = ["generated_lz4", "generated_zstd"]; // writer options for each compression type - let all_options = vec![ + let all_options = [ IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) .unwrap() .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) @@ -187,11 +187,12 @@ fn roundtrip_arrow_file_with_options( let rewrite_reader = FileReader::try_new(&tempfile, None).unwrap(); // Compare to original reader - reader.into_iter().zip(rewrite_reader.into_iter()).for_each( - |(batch1, batch2)| { + reader + .into_iter() + .zip(rewrite_reader) + .for_each(|(batch1, batch2)| { assert_eq!(batch1.unwrap(), batch2.unwrap()); - }, - ); + }); } } @@ -264,10 +265,11 @@ fn roundtrip_arrow_stream_with_options( let rewrite_reader = StreamReader::try_new(&tempfile, None).unwrap(); // Compare to original reader - reader.into_iter().zip(rewrite_reader.into_iter()).for_each( - |(batch1, batch2)| { + reader + .into_iter() + .zip(rewrite_reader) + .for_each(|(batch1, batch2)| { assert_eq!(batch1.unwrap(), batch2.unwrap()); - }, - ); + }); } } diff --git a/arrow-ipc/Cargo.toml b/arrow-ipc/Cargo.toml index a03f53d6641c..83ad044d25e7 100644 --- a/arrow-ipc/Cargo.toml +++ b/arrow-ipc/Cargo.toml @@ -40,8 +40,12 @@ arrow-cast = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } flatbuffers = { version = "23.1.21", default-features = false } -lz4 = { version = "1.23", default-features = false, optional = true } -zstd = { version = "0.12.0", default-features = false, optional = true } +lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } +zstd = { version = "0.13.0", default-features = false, optional = true } + +[features] +default = [] +lz4 = ["lz4_flex"] [dev-dependencies] tempfile = "3.3" diff --git a/arrow-ipc/src/compression.rs b/arrow-ipc/src/compression.rs index db05e9a6a6c6..fafc2c5c9b6d 100644 --- a/arrow-ipc/src/compression.rs +++ b/arrow-ipc/src/compression.rs @@ -103,13 +103,15 @@ impl CompressionCodec { } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA { // no compression input.slice(LENGTH_OF_PREFIX_DATA as usize) - } else { + } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) { // decompress data using the codec - let mut uncompressed_buffer = - Vec::with_capacity(decompressed_length as usize); let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..]; - self.decompress(input_data, &mut uncompressed_buffer)?; - Buffer::from(uncompressed_buffer) + self.decompress(input_data, decompressed_length as _)? + .into() + } else { + return Err(ArrowError::IpcError(format!( + "Invalid uncompressed length: {decompressed_length}" + ))); }; Ok(buffer) } @@ -128,21 +130,30 @@ impl CompressionCodec { fn decompress( &self, input: &[u8], - output: &mut Vec, - ) -> Result { - match self { - CompressionCodec::Lz4Frame => decompress_lz4(input, output), - CompressionCodec::Zstd => decompress_zstd(input, output), + decompressed_size: usize, + ) -> Result, ArrowError> { + let ret = match self { + CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?, + CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?, + }; + if ret.len() != decompressed_size { + return Err(ArrowError::IpcError(format!( + "Expected compressed length of {decompressed_size} got {}", + ret.len() + ))); } + Ok(ret) } } #[cfg(feature = "lz4")] fn compress_lz4(input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { use std::io::Write; - let mut encoder = lz4::EncoderBuilder::new().build(output)?; + let mut encoder = lz4_flex::frame::FrameEncoder::new(output); encoder.write_all(input)?; - encoder.finish().1?; + encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; Ok(()) } @@ -155,14 +166,19 @@ fn compress_lz4(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> } #[cfg(feature = "lz4")] -fn decompress_lz4(input: &[u8], output: &mut Vec) -> Result { +fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result, ArrowError> { use std::io::Read; - Ok(lz4::Decoder::new(input)?.read_to_end(output)?) + let mut output = Vec::with_capacity(decompressed_size); + lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?; + Ok(output) } #[cfg(not(feature = "lz4"))] #[allow(clippy::ptr_arg)] -fn decompress_lz4(_input: &[u8], _output: &mut Vec) -> Result { +fn decompress_lz4( + _input: &[u8], + _decompressed_size: usize, +) -> Result, ArrowError> { Err(ArrowError::InvalidArgumentError( "lz4 IPC decompression requires the lz4 feature".to_string(), )) @@ -186,14 +202,22 @@ fn compress_zstd(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> } #[cfg(feature = "zstd")] -fn decompress_zstd(input: &[u8], output: &mut Vec) -> Result { +fn decompress_zstd( + input: &[u8], + decompressed_size: usize, +) -> Result, ArrowError> { use std::io::Read; - Ok(zstd::Decoder::new(input)?.read_to_end(output)?) + let mut output = Vec::with_capacity(decompressed_size); + zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?; + Ok(output) } #[cfg(not(feature = "zstd"))] #[allow(clippy::ptr_arg)] -fn decompress_zstd(_input: &[u8], _output: &mut Vec) -> Result { +fn decompress_zstd( + _input: &[u8], + _decompressed_size: usize, +) -> Result, ArrowError> { Err(ArrowError::InvalidArgumentError( "zstd IPC decompression requires the zstd feature".to_string(), )) @@ -216,28 +240,26 @@ mod tests { #[test] #[cfg(feature = "lz4")] fn test_lz4_compression() { - let input_bytes = "hello lz4".as_bytes(); + let input_bytes = b"hello lz4"; let codec = super::CompressionCodec::Lz4Frame; let mut output_bytes: Vec = Vec::new(); codec.compress(input_bytes, &mut output_bytes).unwrap(); - let mut result_output_bytes: Vec = Vec::new(); - codec - .decompress(output_bytes.as_slice(), &mut result_output_bytes) + let result = codec + .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); - assert_eq!(input_bytes, result_output_bytes.as_slice()); + assert_eq!(input_bytes, result.as_slice()); } #[test] #[cfg(feature = "zstd")] fn test_zstd_compression() { - let input_bytes = "hello zstd".as_bytes(); + let input_bytes = b"hello zstd"; let codec = super::CompressionCodec::Zstd; let mut output_bytes: Vec = Vec::new(); codec.compress(input_bytes, &mut output_bytes).unwrap(); - let mut result_output_bytes: Vec = Vec::new(); - codec - .decompress(output_bytes.as_slice(), &mut result_output_bytes) + let result = codec + .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); - assert_eq!(input_bytes, result_output_bytes.as_slice()); + assert_eq!(input_bytes, result.as_slice()); } } diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 07f716dea843..a78ccde6e169 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -150,12 +150,12 @@ pub fn try_schema_from_flatbuffer_bytes(bytes: &[u8]) -> Result( RunEndEncoded(run_ends, values) => { let run_ends_field = build_field(fbb, run_ends); let values_field = build_field(fbb, values); - let children = vec![run_ends_field, values_field]; + let children = [run_ends_field, values_field]; FBFieldType { type_type: crate::Type::RunEndEncoded, type_: crate::RunEndEncodedBuilder::new(fbb) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 0908d580d59a..75c91be21dde 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -20,7 +20,6 @@ //! The `FileReader` and `StreamReader` have similar interfaces, //! however the `FileReader` expects a reader that supports `Seek`ing -use arrow_buffer::i256; use flatbuffers::VectorIter; use std::collections::HashMap; use std::fmt; @@ -129,7 +128,7 @@ fn create_array(reader: &mut ArrayReader, field: &Field) -> Result Result Result Result { + _ if data_type.is_primitive() + || matches!(data_type, Boolean | FixedSizeBinary(_)) => + { // read 2 buffers: null buffer (optional) and data buffer ArrayData::builder(data_type.clone()) .len(length) .add_buffer(buffers[1].clone()) .null_bit_buffer(null_buffer) - .build()? - } - Interval(IntervalUnit::MonthDayNano) | Decimal128(_, _) => { - let buffer = get_aligned_buffer::(&buffers[1], length); - - // read 2 buffers: null buffer (optional) and data buffer - ArrayData::builder(data_type.clone()) - .len(length) - .add_buffer(buffer) - .null_bit_buffer(null_buffer) - .build()? - } - Decimal256(_, _) => { - let buffer = get_aligned_buffer::(&buffers[1], length); - - // read 2 buffers: null buffer (optional) and data buffer - ArrayData::builder(data_type.clone()) - .len(length) - .add_buffer(buffer) - .null_bit_buffer(null_buffer) - .build()? + .build_aligned()? } t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; @@ -286,28 +248,10 @@ fn create_primitive_array( Ok(make_array(array_data)) } -/// Checks if given `Buffer` is properly aligned with `T`. -/// If not, copying the data and padded it for alignment. -fn get_aligned_buffer(buffer: &Buffer, length: usize) -> Buffer { - let ptr = buffer.as_ptr(); - let align_req = std::mem::align_of::(); - let align_offset = ptr.align_offset(align_req); - // The buffer is not aligned properly. The writer might use a smaller alignment - // e.g. 8 bytes, but on some platform (e.g. ARM) i128 requires 16 bytes alignment. - // We need to copy the buffer as fallback. - if align_offset != 0 { - let len_in_bytes = (length * std::mem::size_of::()).min(buffer.len()); - let slice = &buffer.as_slice()[0..len_in_bytes]; - Buffer::from_slice_ref(slice) - } else { - buffer.clone() - } -} - /// Reads the correct number of buffers based on list type and null_count, and creates a /// list array ref fn create_list_array( - field_node: &crate::FieldNode, + field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], child_array: ArrayRef, @@ -329,13 +273,13 @@ fn create_list_array( _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; - Ok(make_array(builder.build()?)) + Ok(make_array(builder.build_aligned()?)) } /// Reads the correct number of buffers based on list type and null_count, and creates a /// list array ref fn create_dictionary_array( - field_node: &crate::FieldNode, + field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], value_array: ArrayRef, @@ -348,7 +292,7 @@ fn create_dictionary_array( .add_child_data(value_array.into_data()) .null_bit_buffer(null_buffer); - Ok(make_array(builder.build()?)) + Ok(make_array(builder.build_aligned()?)) } else { unreachable!("Cannot create dictionary array from {:?}", data_type) } @@ -381,7 +325,7 @@ impl<'a> ArrayReader<'a> { fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, ArrowError> { self.nodes.next().ok_or_else(|| { - ArrowError::IoError(format!( + ArrowError::SchemaError(format!( "Invalid data for schema. {} refers to node not found in schema", field )) @@ -458,10 +402,10 @@ pub fn read_record_batch( metadata: &MetadataVersion, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { - ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string()) + ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string()) })?; let field_nodes = batch.nodes().ok_or_else(|| { - ArrowError::IoError("Unable to get field nodes from IPC RecordBatch".to_string()) + ArrowError::IpcError("Unable to get field nodes from IPC RecordBatch".to_string()) })?; let batch_compression = batch.compression(); let compression = batch_compression @@ -518,7 +462,7 @@ pub fn read_dictionary( metadata: &crate::MetadataVersion, ) -> Result<(), ArrowError> { if batch.isDelta() { - return Err(ArrowError::IoError( + return Err(ArrowError::InvalidArgumentError( "delta dictionary batches not supported".to_string(), )); } @@ -625,14 +569,14 @@ impl FileReader { let mut magic_buffer: [u8; 6] = [0; 6]; reader.read_exact(&mut magic_buffer)?; if magic_buffer != super::ARROW_MAGIC { - return Err(ArrowError::IoError( + return Err(ArrowError::ParseError( "Arrow file does not contain correct header".to_string(), )); } reader.seek(SeekFrom::End(-6))?; reader.read_exact(&mut magic_buffer)?; if magic_buffer != super::ARROW_MAGIC { - return Err(ArrowError::IoError( + return Err(ArrowError::ParseError( "Arrow file does not contain correct footer".to_string(), )); } @@ -648,11 +592,11 @@ impl FileReader { reader.read_exact(&mut footer_data)?; let footer = crate::root_as_footer(&footer_data[..]).map_err(|err| { - ArrowError::IoError(format!("Unable to get root as footer: {err:?}")) + ArrowError::ParseError(format!("Unable to get root as footer: {err:?}")) })?; let blocks = footer.recordBatches().ok_or_else(|| { - ArrowError::IoError( + ArrowError::ParseError( "Unable to get record batches from IPC Footer".to_string(), ) })?; @@ -689,7 +633,9 @@ impl FileReader { reader.read_exact(&mut block_data)?; let message = crate::root_as_message(&block_data[..]).map_err(|err| { - ArrowError::IoError(format!("Unable to get root as message: {err:?}")) + ArrowError::ParseError(format!( + "Unable to get root as message: {err:?}" + )) })?; match message.header_type() { @@ -713,7 +659,7 @@ impl FileReader { )?; } t => { - return Err(ArrowError::IoError(format!( + return Err(ArrowError::ParseError(format!( "Expecting DictionaryBatch in dictionary blocks, found {t:?}." ))); } @@ -761,7 +707,7 @@ impl FileReader { /// Sets the current block to the index, allowing random reads pub fn set_index(&mut self, index: usize) -> Result<(), ArrowError> { if index >= self.total_blocks { - Err(ArrowError::IoError(format!( + Err(ArrowError::InvalidArgumentError(format!( "Cannot set batch to index {} from {} total batches", index, self.total_blocks ))) @@ -788,25 +734,25 @@ impl FileReader { let mut block_data = vec![0; meta_len as usize]; self.reader.read_exact(&mut block_data)?; let message = crate::root_as_message(&block_data[..]).map_err(|err| { - ArrowError::IoError(format!("Unable to get root as footer: {err:?}")) + ArrowError::ParseError(format!("Unable to get root as footer: {err:?}")) })?; // some old test data's footer metadata is not set, so we account for that if self.metadata_version != crate::MetadataVersion::V1 && message.version() != self.metadata_version { - return Err(ArrowError::IoError( + return Err(ArrowError::IpcError( "Could not read IPC message as metadata versions mismatch".to_string(), )); } match message.header_type() { - crate::MessageHeader::Schema => Err(ArrowError::IoError( + crate::MessageHeader::Schema => Err(ArrowError::IpcError( "Not expecting a schema when messages are read".to_string(), )), crate::MessageHeader::RecordBatch => { let batch = message.header_as_record_batch().ok_or_else(|| { - ArrowError::IoError( + ArrowError::IpcError( "Unable to read IPC message as record batch".to_string(), ) })?; @@ -830,7 +776,7 @@ impl FileReader { crate::MessageHeader::NONE => { Ok(None) } - t => Err(ArrowError::IoError(format!( + t => Err(ArrowError::InvalidArgumentError(format!( "Reading types other than record batches not yet supported, unable to read {t:?}" ))), } @@ -942,11 +888,11 @@ impl StreamReader { reader.read_exact(&mut meta_buffer)?; let message = crate::root_as_message(meta_buffer.as_slice()).map_err(|err| { - ArrowError::IoError(format!("Unable to get root as message: {err:?}")) + ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) })?; // message header is a Schema, so read it let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| { - ArrowError::IoError("Unable to read IPC message as schema".to_string()) + ArrowError::ParseError("Unable to read IPC message as schema".to_string()) })?; let schema = crate::convert::fb_to_schema(ipc_schema); @@ -1021,16 +967,16 @@ impl StreamReader { let vecs = &meta_buffer.to_vec(); let message = crate::root_as_message(vecs).map_err(|err| { - ArrowError::IoError(format!("Unable to get root as message: {err:?}")) + ArrowError::ParseError(format!("Unable to get root as message: {err:?}")) })?; match message.header_type() { - crate::MessageHeader::Schema => Err(ArrowError::IoError( + crate::MessageHeader::Schema => Err(ArrowError::IpcError( "Not expecting a schema when messages are read".to_string(), )), crate::MessageHeader::RecordBatch => { let batch = message.header_as_record_batch().ok_or_else(|| { - ArrowError::IoError( + ArrowError::IpcError( "Unable to read IPC message as record batch".to_string(), ) })?; @@ -1042,7 +988,7 @@ impl StreamReader { } crate::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().ok_or_else(|| { - ArrowError::IoError( + ArrowError::IpcError( "Unable to read IPC message as dictionary batch".to_string(), ) })?; @@ -1060,7 +1006,7 @@ impl StreamReader { crate::MessageHeader::NONE => { Ok(None) } - t => Err(ArrowError::IoError( + t => Err(ArrowError::InvalidArgumentError( format!("Reading types other than record batches not yet supported, unable to read {t:?} ") )), } @@ -1097,10 +1043,11 @@ impl RecordBatchReader for StreamReader { #[cfg(test)] mod tests { - use crate::writer::unslice_run_array; + use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator}; use super::*; + use crate::root_as_message; use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder}; use arrow_array::types::*; use arrow_buffer::ArrowNativeType; @@ -1209,7 +1156,7 @@ mod tests { let array10_input = vec![Some(1_i32), None, None]; let mut array10_builder = PrimitiveRunBuilder::::new(); - array10_builder.extend(array10_input.into_iter()); + array10_builder.extend(array10_input); let array10 = array10_builder.finish(); let array11 = BooleanArray::from(vec![false, false, true]); @@ -1357,8 +1304,7 @@ mod tests { writer.finish().unwrap(); drop(writer); - let mut reader = - crate::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); + let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); reader.next().unwrap().unwrap() } @@ -1465,7 +1411,7 @@ mod tests { let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)]; let mut run_array_2_builder = PrimitiveRunBuilder::::new(); - run_array_2_builder.extend(run_array_2_inupt.into_iter()); + run_array_2_builder.extend(run_array_2_inupt); let run_array_2 = run_array_2_builder.finish(); let schema = Arc::new(Schema::new(vec![ @@ -1541,7 +1487,7 @@ mod tests { let keys_field = Arc::new(Field::new_dict( "keys", DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), - true, + true, // It is technically not legal for this field to be null. 1, false, )); @@ -1560,7 +1506,7 @@ mod tests { Arc::new(Field::new( "entries", entry_struct.data_type().clone(), - true, + false, )), false, ); @@ -1704,4 +1650,40 @@ mod tests { let output_batch = roundtrip_ipc_stream(&input_batch); assert_eq!(input_batch, output_batch); } + + #[test] + fn test_unaligned() { + let batch = RecordBatch::try_from_iter(vec![( + "i32", + Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + let message = root_as_message(&encoded.ipc_message).unwrap(); + + // Construct an unaligned buffer + let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1); + buffer.push(0_u8); + buffer.extend_from_slice(&encoded.arrow_data); + let b = Buffer::from(buffer).slice(1); + assert_ne!(b.as_ptr().align_offset(8), 0); + + let ipc_batch = message.header_as_record_batch().unwrap(); + let roundtrip = read_record_batch( + &b, + ipc_batch, + batch.schema(), + &Default::default(), + None, + &message.version(), + ) + .unwrap(); + assert_eq!(batch, roundtrip); + } } diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 59657bc4be09..567fa2e94171 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -23,6 +23,7 @@ use std::cmp::min; use std::collections::HashMap; use std::io::{BufWriter, Write}; +use std::sync::Arc; use flatbuffers::FlatBufferBuilder; @@ -696,7 +697,7 @@ pub struct FileWriter { /// IPC write options write_options: IpcWriteOptions, /// A reference to the schema, used in validating record batches - schema: Schema, + schema: SchemaRef, /// The number of bytes between each block of bytes, as an offset for random access block_offsets: usize, /// Dictionary blocks that will be written as part of the IPC footer @@ -739,7 +740,7 @@ impl FileWriter { Ok(Self { writer, write_options, - schema: schema.clone(), + schema: Arc::new(schema.clone()), block_offsets: meta + data + header_size, dictionary_blocks: vec![], record_blocks: vec![], @@ -757,7 +758,7 @@ impl FileWriter { /// Write a record batch to the file pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { if self.finished { - return Err(ArrowError::IoError( + return Err(ArrowError::IpcError( "Cannot write record batch to file writer as it is closed".to_string(), )); } @@ -794,7 +795,7 @@ impl FileWriter { /// Write footer and closing tag, then mark the writer as done pub fn finish(&mut self) -> Result<(), ArrowError> { if self.finished { - return Err(ArrowError::IoError( + return Err(ArrowError::IpcError( "Cannot write footer to file writer as it is closed".to_string(), )); } @@ -832,6 +833,11 @@ impl FileWriter { Ok(()) } + /// Returns the arrow [`SchemaRef`] for this arrow file. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + /// Gets a reference to the underlying writer. pub fn get_ref(&self) -> &W { self.writer.get_ref() @@ -909,7 +915,7 @@ impl StreamWriter { /// Write a record batch to the stream pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { if self.finished { - return Err(ArrowError::IoError( + return Err(ArrowError::IpcError( "Cannot write record batch to stream writer as it is closed".to_string(), )); } @@ -930,7 +936,7 @@ impl StreamWriter { /// Write continuation bytes, and mark the stream as done pub fn finish(&mut self) -> Result<(), ArrowError> { if self.finished { - return Err(ArrowError::IoError( + return Err(ArrowError::IpcError( "Cannot write footer to stream writer as it is closed".to_string(), )); } @@ -1146,7 +1152,7 @@ fn buffer_need_truncate( #[inline] fn get_buffer_element_width(spec: &BufferSpec) -> usize { match spec { - BufferSpec::FixedWidth { byte_width } => *byte_width, + BufferSpec::FixedWidth { byte_width, .. } => *byte_width, _ => 0, } } @@ -2138,7 +2144,7 @@ mod tests { let u32 = UInt32Builder::new(); let mut ls = ListBuilder::new(u32); - for list in vec![vec![1u32, 2, 3], vec![4, 5, 6], vec![7, 8, 9, 10]] { + for list in [vec![1u32, 2, 3], vec![4, 5, 6], vec![7, 8, 9, 10]] { for value in list { ls.values().append_value(value); } diff --git a/arrow-json/Cargo.toml b/arrow-json/Cargo.toml index 137d53557790..df38a52811c2 100644 --- a/arrow-json/Cargo.toml +++ b/arrow-json/Cargo.toml @@ -44,7 +44,7 @@ indexmap = { version = "2.0", default-features = false, features = ["std"] } num = { version = "0.4", default-features = false, features = ["std"] } serde = { version = "1.0", default-features = false } serde_json = { version = "1.0", default-features = false, features = ["std"] } -chrono = { version = "0.4.23", default-features = false, features = ["clock"] } +chrono = { workspace = true } lexical-core = { version = "0.8", default-features = false } [dev-dependencies] @@ -54,3 +54,10 @@ serde = { version = "1.0", default-features = false, features = ["derive"] } futures = "0.3" tokio = { version = "1.27", default-features = false, features = ["io-util"] } bytes = "1.4" +criterion = { version = "0.5", default-features = false } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } + +[[bench]] +name = "serde" +harness = false + diff --git a/arrow-json/benches/serde.rs b/arrow-json/benches/serde.rs new file mode 100644 index 000000000000..7636b9c9dff9 --- /dev/null +++ b/arrow-json/benches/serde.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_json::ReaderBuilder; +use arrow_schema::{DataType, Field, Schema}; +use criterion::*; +use rand::{thread_rng, Rng}; +use serde::Serialize; +use std::sync::Arc; + +#[allow(deprecated)] +fn do_bench(c: &mut Criterion, name: &str, rows: &[R], schema: &Schema) { + let schema = Arc::new(schema.clone()); + c.bench_function(name, |b| { + b.iter(|| { + let builder = ReaderBuilder::new(schema.clone()).with_batch_size(64); + let mut decoder = builder.build_decoder().unwrap(); + decoder.serialize(rows) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = thread_rng(); + let schema = Schema::new(vec![Field::new("i32", DataType::Int32, false)]); + let v: Vec = (0..2048).map(|_| rng.gen_range(0..10000)).collect(); + + do_bench(c, "small_i32", &v, &schema); + let v: Vec = (0..2048).map(|_| rng.gen()).collect(); + do_bench(c, "large_i32", &v, &schema); + + let schema = Schema::new(vec![Field::new("i64", DataType::Int64, false)]); + let v: Vec = (0..2048).map(|_| rng.gen_range(0..10000)).collect(); + do_bench(c, "small_i64", &v, &schema); + let v: Vec = (0..2048).map(|_| rng.gen_range(0..i32::MAX as _)).collect(); + do_bench(c, "medium_i64", &v, &schema); + let v: Vec = (0..2048).map(|_| rng.gen()).collect(); + do_bench(c, "large_i64", &v, &schema); + + let schema = Schema::new(vec![Field::new("f32", DataType::Float32, false)]); + let v: Vec = (0..2048).map(|_| rng.gen_range(0.0..10000.)).collect(); + do_bench(c, "small_f32", &v, &schema); + let v: Vec = (0..2048).map(|_| rng.gen_range(0.0..f32::MAX)).collect(); + do_bench(c, "large_f32", &v, &schema); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 4e98e2fd873a..c1cef0ec81b4 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -17,9 +17,13 @@ //! JSON reader //! -//! This JSON reader allows JSON line-delimited files to be read into the Arrow memory -//! model. Records are loaded in batches and are then converted from row-based data to -//! columnar data. +//! This JSON reader allows JSON records to be read into the Arrow memory +//! model. Records are loaded in batches and are then converted from the record-oriented +//! representation to the columnar arrow data model. +//! +//! The reader ignores whitespace between JSON values, including `\n` and `\r`, allowing +//! parsing of sequences of one or more arbitrarily formatted JSON values, including +//! but not limited to newline-delimited JSON. //! //! # Basic Usage //! @@ -130,6 +134,7 @@ //! use std::io::BufRead; +use std::sync::Arc; use chrono::Utc; use serde::Serialize; @@ -137,9 +142,11 @@ use serde::Serialize; use arrow_array::timezone::Tz; use arrow_array::types::Float32Type; use arrow_array::types::*; -use arrow_array::{downcast_integer, RecordBatch, RecordBatchReader, StructArray}; +use arrow_array::{ + downcast_integer, make_array, RecordBatch, RecordBatchReader, StructArray, +}; use arrow_data::ArrayData; -use arrow_schema::{ArrowError, DataType, SchemaRef, TimeUnit}; +use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; use crate::reader::boolean_array::BooleanArrayDecoder; @@ -150,7 +157,7 @@ use crate::reader::null_array::NullArrayDecoder; use crate::reader::primitive_array::PrimitiveArrayDecoder; use crate::reader::string_array::StringArrayDecoder; use crate::reader::struct_array::StructArrayDecoder; -use crate::reader::tape::{Tape, TapeDecoder, TapeElement}; +use crate::reader::tape::{Tape, TapeDecoder}; use crate::reader::timestamp_array::TimestampArrayDecoder; mod boolean_array; @@ -171,6 +178,7 @@ pub struct ReaderBuilder { batch_size: usize, coerce_primitive: bool, strict_mode: bool, + is_field: bool, schema: SchemaRef, } @@ -189,10 +197,51 @@ impl ReaderBuilder { batch_size: 1024, coerce_primitive: false, strict_mode: false, + is_field: false, schema, } } + /// Create a new [`ReaderBuilder`] that will parse JSON values of `field.data_type()` + /// + /// Unlike [`ReaderBuilder::new`] this does not require the root of the JSON data + /// to be an object, i.e. `{..}`, allowing for parsing of any valid JSON value(s) + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::cast::AsArray; + /// # use arrow_array::types::Int32Type; + /// # use arrow_json::ReaderBuilder; + /// # use arrow_schema::{DataType, Field}; + /// // Root of JSON schema is a numeric type + /// let data = "1\n2\n3\n"; + /// let field = Arc::new(Field::new("int", DataType::Int32, true)); + /// let mut reader = ReaderBuilder::new_with_field(field.clone()).build(data.as_bytes()).unwrap(); + /// let b = reader.next().unwrap().unwrap(); + /// let values = b.column(0).as_primitive::().values(); + /// assert_eq!(values, &[1, 2, 3]); + /// + /// // Root of JSON schema is a list type + /// let data = "[1, 2, 3, 4, 5, 6, 7]\n[1, 2, 3]"; + /// let field = Field::new_list("int", field.clone(), true); + /// let mut reader = ReaderBuilder::new_with_field(field).build(data.as_bytes()).unwrap(); + /// let b = reader.next().unwrap().unwrap(); + /// let list = b.column(0).as_list::(); + /// + /// assert_eq!(list.offsets().as_ref(), &[0, 7, 10]); + /// let list_values = list.values().as_primitive::(); + /// assert_eq!(list_values.values(), &[1, 2, 3, 4, 5, 6, 7, 1, 2, 3]); + /// ``` + pub fn new_with_field(field: impl Into) -> Self { + Self { + batch_size: 1024, + coerce_primitive: false, + strict_mode: false, + is_field: true, + schema: Arc::new(Schema::new([field.into()])), + } + } + /// Sets the batch size in rows to read pub fn with_batch_size(self, batch_size: usize) -> Self { Self { batch_size, ..self } @@ -233,16 +282,22 @@ impl ReaderBuilder { /// Create a [`Decoder`] pub fn build_decoder(self) -> Result { - let decoder = make_decoder( - DataType::Struct(self.schema.fields.clone()), - self.coerce_primitive, - self.strict_mode, - false, - )?; + let (data_type, nullable) = match self.is_field { + false => (DataType::Struct(self.schema.fields.clone()), false), + true => { + let field = &self.schema.fields[0]; + (field.data_type().clone(), field.is_nullable()) + } + }; + + let decoder = + make_decoder(data_type, self.coerce_primitive, self.strict_mode, nullable)?; + let num_fields = self.schema.all_fields().len(); Ok(Decoder { decoder, + is_field: self.is_field, tape_decoder: TapeDecoder::new(self.batch_size, num_fields), batch_size: self.batch_size, schema: self.schema, @@ -344,6 +399,7 @@ pub struct Decoder { tape_decoder: TapeDecoder, decoder: Box, batch_size: usize, + is_field: bool, schema: SchemaRef, } @@ -563,24 +619,20 @@ impl Decoder { let mut next_object = 1; let pos: Vec<_> = (0..tape.num_rows()) .map(|_| { - let end = match tape.get(next_object) { - TapeElement::StartObject(end) => end, - _ => unreachable!("corrupt tape"), - }; - std::mem::replace(&mut next_object, end + 1) + let next = tape.next(next_object, "row").unwrap(); + std::mem::replace(&mut next_object, next) }) .collect(); let decoded = self.decoder.decode(&tape, &pos)?; self.tape_decoder.clear(); - // Sanity check - assert!(matches!(decoded.data_type(), DataType::Struct(_))); - assert_eq!(decoded.null_count(), 0); - assert_eq!(decoded.len(), pos.len()); + let batch = match self.is_field { + true => RecordBatch::try_new(self.schema.clone(), vec![make_array(decoded)])?, + false => RecordBatch::from(StructArray::from(decoded)) + .with_schema(self.schema.clone())?, + }; - let batch = RecordBatch::from(StructArray::from(decoded)) - .with_schema(self.schema.clone())?; Ok(Some(batch)) } } @@ -2175,4 +2227,16 @@ mod tests { let values = batch.column(0).as_primitive::(); assert_eq!(values.values(), &[1681319393, -7200]); } + + #[test] + fn test_serde_field() { + let field = Field::new("int", DataType::Int32, true); + let mut decoder = ReaderBuilder::new_with_field(field) + .build_decoder() + .unwrap(); + decoder.serialize(&[1_i32, 2, 3, 4]).unwrap(); + let b = decoder.flush().unwrap().unwrap(); + let values = b.column(0).as_primitive::().values(); + assert_eq!(values, &[1, 2, 3, 4]); + } } diff --git a/arrow-json/src/reader/primitive_array.rs b/arrow-json/src/reader/primitive_array.rs index c78e4d914060..6cf0bac86737 100644 --- a/arrow-json/src/reader/primitive_array.rs +++ b/arrow-json/src/reader/primitive_array.rs @@ -91,11 +91,12 @@ impl PrimitiveArrayDecoder

{ impl

ArrayDecoder for PrimitiveArrayDecoder

where P: ArrowPrimitiveType + Parser, - P::Native: ParseJsonNumber, + P::Native: ParseJsonNumber + NumCast, { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut builder = PrimitiveBuilder::

::with_capacity(pos.len()) .with_data_type(self.data_type.clone()); + let d = &self.data_type; for p in pos { match tape.get(*p) { @@ -103,10 +104,7 @@ where TapeElement::String(idx) => { let s = tape.get_string(idx); let value = P::parse(s).ok_or_else(|| { - ArrowError::JsonError(format!( - "failed to parse \"{s}\" as {}", - self.data_type - )) + ArrowError::JsonError(format!("failed to parse \"{s}\" as {d}",)) })?; builder.append_value(value) @@ -115,14 +113,44 @@ where let s = tape.get_string(idx); let value = ParseJsonNumber::parse(s.as_bytes()).ok_or_else(|| { - ArrowError::JsonError(format!( - "failed to parse {s} as {}", - self.data_type - )) + ArrowError::JsonError(format!("failed to parse {s} as {d}",)) })?; builder.append_value(value) } + TapeElement::F32(v) => { + let v = f32::from_bits(v); + let value = NumCast::from(v).ok_or_else(|| { + ArrowError::JsonError(format!("failed to parse {v} as {d}",)) + })?; + builder.append_value(value) + } + TapeElement::I32(v) => { + let value = NumCast::from(v).ok_or_else(|| { + ArrowError::JsonError(format!("failed to parse {v} as {d}",)) + })?; + builder.append_value(value) + } + TapeElement::F64(high) => match tape.get(p + 1) { + TapeElement::F32(low) => { + let v = f64::from_bits((high as u64) << 32 | low as u64); + let value = NumCast::from(v).ok_or_else(|| { + ArrowError::JsonError(format!("failed to parse {v} as {d}",)) + })?; + builder.append_value(value) + } + _ => unreachable!(), + }, + TapeElement::I64(high) => match tape.get(p + 1) { + TapeElement::I32(low) => { + let v = (high as i64) << 32 | low as i64; + let value = NumCast::from(v).ok_or_else(|| { + ArrowError::JsonError(format!("failed to parse {v} as {d}",)) + })?; + builder.append_value(value) + } + _ => unreachable!(), + }, _ => return Err(tape.error(*p, "primitive")), } } diff --git a/arrow-json/src/reader/schema.rs b/arrow-json/src/reader/schema.rs index c8250ac37716..126a85df3931 100644 --- a/arrow-json/src/reader/schema.rs +++ b/arrow-json/src/reader/schema.rs @@ -72,6 +72,15 @@ impl InferredType { Ok(()) } + + fn is_none_or_any(ty: Option<&Self>) -> bool { + matches!(ty, Some(Self::Any) | None) + } +} + +/// Shorthand for building list data type of `ty` +fn list_type_of(ty: DataType) -> DataType { + DataType::List(Arc::new(Field::new("item", ty, true))) } /// Coerce data type during inference @@ -84,23 +93,18 @@ fn coerce_data_type(dt: Vec<&DataType>) -> DataType { let dt_init = dt_iter.next().unwrap_or(DataType::Utf8); dt_iter.fold(dt_init, |l, r| match (l, r) { + (DataType::Null, o) | (o, DataType::Null) => o, (DataType::Boolean, DataType::Boolean) => DataType::Boolean, (DataType::Int64, DataType::Int64) => DataType::Int64, (DataType::Float64, DataType::Float64) | (DataType::Float64, DataType::Int64) | (DataType::Int64, DataType::Float64) => DataType::Float64, - (DataType::List(l), DataType::List(r)) => DataType::List(Arc::new(Field::new( - "item", - coerce_data_type(vec![l.data_type(), r.data_type()]), - true, - ))), + (DataType::List(l), DataType::List(r)) => { + list_type_of(coerce_data_type(vec![l.data_type(), r.data_type()])) + } // coerce scalar and scalar array into scalar array (DataType::List(e), not_list) | (not_list, DataType::List(e)) => { - DataType::List(Arc::new(Field::new( - "item", - coerce_data_type(vec![e.data_type(), ¬_list]), - true, - ))) + list_type_of(coerce_data_type(vec![e.data_type(), ¬_list])) } _ => DataType::Utf8, }) @@ -110,11 +114,7 @@ fn generate_datatype(t: &InferredType) -> Result { Ok(match t { InferredType::Scalar(hs) => coerce_data_type(hs.iter().collect()), InferredType::Object(spec) => DataType::Struct(generate_fields(spec)?), - InferredType::Array(ele_type) => DataType::List(Arc::new(Field::new( - "item", - generate_datatype(ele_type)?, - true, - ))), + InferredType::Array(ele_type) => list_type_of(generate_datatype(ele_type)?), InferredType::Any => DataType::Null, }) } @@ -277,7 +277,7 @@ fn set_object_scalar_field_type( key: &str, ftype: DataType, ) -> Result<(), ArrowError> { - if !field_types.contains_key(key) { + if InferredType::is_none_or_any(field_types.get(key)) { field_types.insert(key.to_string(), InferredType::Scalar(HashSet::new())); } @@ -388,7 +388,7 @@ fn collect_field_types_from_object( Value::Array(array) => { let ele_type = infer_array_element_type(array)?; - if !field_types.contains_key(k) { + if InferredType::is_none_or_any(field_types.get(k)) { match ele_type { InferredType::Scalar(_) => { field_types.insert( @@ -438,8 +438,11 @@ fn collect_field_types_from_object( set_object_scalar_field_type(field_types, k, DataType::Boolean)?; } Value::Null => { - // do nothing, we treat json as nullable by default when - // inferring + // we treat json as nullable by default when inferring, so just + // mark existence of a field if it wasn't known before + if !field_types.contains_key(k) { + field_types.insert(k.to_string(), InferredType::Any); + } } Value::Number(n) => { if n.is_i64() { @@ -520,21 +523,9 @@ mod tests { fn test_json_infer_schema() { let schema = Schema::new(vec![ Field::new("a", DataType::Int64, true), - Field::new( - "b", - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), - true, - ), - Field::new( - "c", - DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), - true, - ), - Field::new( - "d", - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), - true, - ), + Field::new("b", list_type_of(DataType::Float64), true), + Field::new("c", list_type_of(DataType::Boolean), true), + Field::new("d", list_type_of(DataType::Utf8), true), ]); let mut reader = @@ -589,22 +580,18 @@ mod tests { let schema = Schema::new(vec![ Field::new( "c1", - DataType::List(Arc::new(Field::new( - "item", - DataType::Struct(Fields::from(vec![ - Field::new("a", DataType::Utf8, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Boolean, true), - ])), - true, - ))), + list_type_of(DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Boolean, true), + ]))), true, ), Field::new("c2", DataType::Float64, true), Field::new( "c3", // empty json array's inner types are inferred as null - DataType::List(Arc::new(Field::new("item", DataType::Null, true))), + list_type_of(DataType::Null), true, ), ]); @@ -629,15 +616,7 @@ mod tests { #[test] fn test_json_infer_schema_nested_list() { let schema = Schema::new(vec![ - Field::new( - "c1", - DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), - true, - ))), - true, - ), + Field::new("c1", list_type_of(list_type_of(DataType::Utf8)), true), Field::new("c2", DataType::Float64, true), ]); @@ -682,36 +661,22 @@ mod tests { #[test] fn test_coercion_scalar_and_list() { - use arrow_schema::DataType::*; - assert_eq!( - List(Arc::new(Field::new("item", Float64, true))), - coerce_data_type(vec![ - &Float64, - &List(Arc::new(Field::new("item", Float64, true))) - ]) + list_type_of(DataType::Float64), + coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Float64)]) ); assert_eq!( - List(Arc::new(Field::new("item", Float64, true))), - coerce_data_type(vec![ - &Float64, - &List(Arc::new(Field::new("item", Int64, true))) - ]) + list_type_of(DataType::Float64), + coerce_data_type(vec![&DataType::Float64, &list_type_of(DataType::Int64)]) ); assert_eq!( - List(Arc::new(Field::new("item", Int64, true))), - coerce_data_type(vec![ - &Int64, - &List(Arc::new(Field::new("item", Int64, true))) - ]) + list_type_of(DataType::Int64), + coerce_data_type(vec![&DataType::Int64, &list_type_of(DataType::Int64)]) ); // boolean and number are incompatible, return utf8 assert_eq!( - List(Arc::new(Field::new("item", Utf8, true))), - coerce_data_type(vec![ - &Boolean, - &List(Arc::new(Field::new("item", Float64, true))) - ]) + list_type_of(DataType::Utf8), + coerce_data_type(vec![&DataType::Boolean, &list_type_of(DataType::Float64)]) ); } @@ -723,4 +688,26 @@ mod tests { "Json error: Not valid JSON: expected value at line 1 column 1", ); } + + #[test] + fn test_null_field_inferred_as_null() { + let data = r#" + {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":null} + {"in":null, "ni":2, "ns":"3", "sn":null, "n":null, "an":null, "na": [], "nas":["8"]} + {"in":1, "ni":null, "ns":null, "sn":"4", "n":null, "an":[], "na": null, "nas":[]} + "#; + let inferred_schema = + infer_json_schema_from_seekable(Cursor::new(data), None).expect("infer"); + let schema = Schema::new(vec![ + Field::new("an", list_type_of(DataType::Null), true), + Field::new("in", DataType::Int64, true), + Field::new("n", DataType::Null, true), + Field::new("na", list_type_of(DataType::Null), true), + Field::new("nas", list_type_of(DataType::Utf8), true), + Field::new("ni", DataType::Int64, true), + Field::new("ns", DataType::Utf8, true), + Field::new("sn", DataType::Utf8, true), + ]); + assert_eq!(inferred_schema, schema); + } } diff --git a/arrow-json/src/reader/serializer.rs b/arrow-json/src/reader/serializer.rs index 2aa72de943f7..2fd250bdfcc3 100644 --- a/arrow-json/src/reader/serializer.rs +++ b/arrow-json/src/reader/serializer.rs @@ -77,22 +77,6 @@ impl<'a> TapeSerializer<'a> { } } -/// The tape stores all values as strings, and so must serialize numeric types -/// -/// Formatting to a string only to parse it back again is rather wasteful, -/// it may be possible to tweak the tape representation to avoid this -/// -/// Need to use macro as const generic expressions are unstable -/// -macro_rules! serialize_numeric { - ($s:ident, $t:ty, $v:ident) => {{ - let mut buffer = [0_u8; <$t>::FORMATTED_SIZE]; - let s = lexical_core::write($v, &mut buffer); - $s.serialize_number(s); - Ok(()) - }}; -} - impl<'a, 'b> Serializer for &'a mut TapeSerializer<'b> { type Ok = (); @@ -115,43 +99,63 @@ impl<'a, 'b> Serializer for &'a mut TapeSerializer<'b> { } fn serialize_i8(self, v: i8) -> Result<(), SerializerError> { - serialize_numeric!(self, i8, v) + self.serialize_i32(v as _) } fn serialize_i16(self, v: i16) -> Result<(), SerializerError> { - serialize_numeric!(self, i16, v) + self.serialize_i32(v as _) } fn serialize_i32(self, v: i32) -> Result<(), SerializerError> { - serialize_numeric!(self, i32, v) + self.elements.push(TapeElement::I32(v)); + Ok(()) } fn serialize_i64(self, v: i64) -> Result<(), SerializerError> { - serialize_numeric!(self, i64, v) + let low = v as i32; + let high = (v >> 32) as i32; + self.elements.push(TapeElement::I64(high)); + self.elements.push(TapeElement::I32(low)); + Ok(()) } fn serialize_u8(self, v: u8) -> Result<(), SerializerError> { - serialize_numeric!(self, u8, v) + self.serialize_i32(v as _) } fn serialize_u16(self, v: u16) -> Result<(), SerializerError> { - serialize_numeric!(self, u16, v) + self.serialize_i32(v as _) } fn serialize_u32(self, v: u32) -> Result<(), SerializerError> { - serialize_numeric!(self, u32, v) + match i32::try_from(v) { + Ok(v) => self.serialize_i32(v), + Err(_) => self.serialize_i64(v as _), + } } fn serialize_u64(self, v: u64) -> Result<(), SerializerError> { - serialize_numeric!(self, u64, v) + match i64::try_from(v) { + Ok(v) => self.serialize_i64(v), + Err(_) => { + let mut buffer = [0_u8; u64::FORMATTED_SIZE]; + let s = lexical_core::write(v, &mut buffer); + self.serialize_number(s); + Ok(()) + } + } } fn serialize_f32(self, v: f32) -> Result<(), SerializerError> { - serialize_numeric!(self, f32, v) + self.elements.push(TapeElement::F32(v.to_bits())); + Ok(()) } fn serialize_f64(self, v: f64) -> Result<(), SerializerError> { - serialize_numeric!(self, f64, v) + let bits = v.to_bits(); + self.elements.push(TapeElement::F64((bits >> 32) as u32)); + self.elements.push(TapeElement::F32(bits as u32)); + Ok(()) } fn serialize_char(self, v: char) -> Result<(), SerializerError> { diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 5eca7b43dcc7..b39caede7047 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -18,6 +18,7 @@ use crate::reader::serializer::TapeSerializer; use arrow_schema::ArrowError; use serde::Serialize; +use std::fmt::Write; /// We decode JSON to a flattened tape representation, /// allowing for efficient traversal of the JSON data @@ -54,6 +55,25 @@ pub enum TapeElement { /// /// Contains the offset into the [`Tape`] string data Number(u32), + + /// The high bits of a i64 + /// + /// Followed by [`Self::I32`] containing the low bits + I64(i32), + + /// A 32-bit signed integer + /// + /// May be preceded by [`Self::I64`] containing high bits + I32(i32), + + /// The high bits of a 64-bit float + /// + /// Followed by [`Self::F32`] containing the low bits + F64(u32), + + /// A 32-bit float or the low-bits of a 64-bit float if preceded by [`Self::F64`] + F32(u32), + /// A true literal True, /// A false literal @@ -104,10 +124,15 @@ impl<'a> Tape<'a> { | TapeElement::Number(_) | TapeElement::True | TapeElement::False - | TapeElement::Null => Ok(cur_idx + 1), + | TapeElement::Null + | TapeElement::I32(_) + | TapeElement::F32(_) => Ok(cur_idx + 1), + TapeElement::I64(_) | TapeElement::F64(_) => Ok(cur_idx + 2), TapeElement::StartList(end_idx) => Ok(end_idx + 1), TapeElement::StartObject(end_idx) => Ok(end_idx + 1), - _ => Err(self.error(cur_idx, expected)), + TapeElement::EndObject(_) | TapeElement::EndList(_) => { + Err(self.error(cur_idx, expected)) + } } } @@ -153,6 +178,28 @@ impl<'a> Tape<'a> { TapeElement::True => out.push_str("true"), TapeElement::False => out.push_str("false"), TapeElement::Null => out.push_str("null"), + TapeElement::I64(high) => match self.get(idx + 1) { + TapeElement::I32(low) => { + let val = (high as i64) << 32 | low as i64; + let _ = write!(out, "{val}"); + return idx + 2; + } + _ => unreachable!(), + }, + TapeElement::I32(val) => { + let _ = write!(out, "{val}"); + } + TapeElement::F64(high) => match self.get(idx + 1) { + TapeElement::F32(low) => { + let val = f64::from_bits((high as u64) << 32 | low as u64); + let _ = write!(out, "{val}"); + return idx + 2; + } + _ => unreachable!(), + }, + TapeElement::F32(val) => { + let _ = write!(out, "{}", f32::from_bits(val)); + } } idx + 1 } @@ -250,7 +297,8 @@ macro_rules! next { pub struct TapeDecoder { elements: Vec, - num_rows: usize, + /// The number of rows decoded, including any in progress if `!stack.is_empty()` + cur_row: usize, /// Number of rows to read per batch batch_size: usize, @@ -283,36 +331,34 @@ impl TapeDecoder { offsets, elements, batch_size, - num_rows: 0, + cur_row: 0, bytes: Vec::with_capacity(num_fields * 2 * 8), stack: Vec::with_capacity(10), } } pub fn decode(&mut self, buf: &[u8]) -> Result { - if self.num_rows >= self.batch_size { - return Ok(0); - } - let mut iter = BufIter::new(buf); while !iter.is_empty() { - match self.stack.last_mut() { - // Start of row + let state = match self.stack.last_mut() { + Some(l) => l, None => { - // Skip over leading whitespace iter.skip_whitespace(); - match next!(iter) { - b'{' => { - let idx = self.elements.len() as u32; - self.stack.push(DecoderState::Object(idx)); - self.elements.push(TapeElement::StartObject(u32::MAX)); - } - b => return Err(err(b, "trimming leading whitespace")), + if iter.is_empty() || self.cur_row >= self.batch_size { + break; } + + // Start of row + self.cur_row += 1; + self.stack.push(DecoderState::Value); + self.stack.last_mut().unwrap() } + }; + + match state { // Decoding an object - Some(DecoderState::Object(start_idx)) => { + DecoderState::Object(start_idx) => { iter.advance_until(|b| !json_whitespace(b) && b != b','); match next!(iter) { b'"' => { @@ -327,16 +373,12 @@ impl TapeDecoder { TapeElement::StartObject(end_idx); self.elements.push(TapeElement::EndObject(start_idx)); self.stack.pop(); - self.num_rows += self.stack.is_empty() as usize; - if self.num_rows >= self.batch_size { - break; - } } b => return Err(err(b, "parsing object")), } } // Decoding a list - Some(DecoderState::List(start_idx)) => { + DecoderState::List(start_idx) => { iter.advance_until(|b| !json_whitespace(b) && b != b','); match iter.peek() { Some(b']') => { @@ -353,7 +395,7 @@ impl TapeDecoder { } } // Decoding a string - Some(DecoderState::String) => { + DecoderState::String => { let s = iter.advance_until(|b| matches!(b, b'\\' | b'"')); self.bytes.extend_from_slice(s); @@ -368,7 +410,7 @@ impl TapeDecoder { b => unreachable!("{}", b), } } - Some(state @ DecoderState::Value) => { + state @ DecoderState::Value => { iter.skip_whitespace(); *state = match next!(iter) { b'"' => DecoderState::String, @@ -392,7 +434,7 @@ impl TapeDecoder { b => return Err(err(b, "parsing value")), }; } - Some(DecoderState::Number) => { + DecoderState::Number => { let s = iter.advance_until(|b| { !matches!(b, b'0'..=b'9' | b'-' | b'+' | b'.' | b'e' | b'E') }); @@ -405,14 +447,14 @@ impl TapeDecoder { self.offsets.push(self.bytes.len()); } } - Some(DecoderState::Colon) => { + DecoderState::Colon => { iter.skip_whitespace(); match next!(iter) { b':' => self.stack.pop(), b => return Err(err(b, "parsing colon")), }; } - Some(DecoderState::Literal(literal, idx)) => { + DecoderState::Literal(literal, idx) => { let bytes = literal.bytes(); let expected = bytes.iter().skip(*idx as usize).copied(); for (expected, b) in expected.zip(&mut iter) { @@ -427,7 +469,7 @@ impl TapeDecoder { self.elements.push(element); } } - Some(DecoderState::Escape) => { + DecoderState::Escape => { let v = match next!(iter) { b'u' => { self.stack.pop(); @@ -449,7 +491,7 @@ impl TapeDecoder { self.bytes.push(v); } // Parse a unicode escape sequence - Some(DecoderState::Unicode(high, low, idx)) => loop { + DecoderState::Unicode(high, low, idx) => loop { match *idx { 0..=3 => *high = *high << 4 | parse_hex(next!(iter))? as u16, 4 => { @@ -500,7 +542,7 @@ impl TapeDecoder { .try_for_each(|row| row.serialize(&mut serializer)) .map_err(|e| ArrowError::JsonError(e.to_string()))?; - self.num_rows += rows.len(); + self.cur_row += rows.len(); Ok(()) } @@ -544,7 +586,7 @@ impl TapeDecoder { strings, elements: &self.elements, string_offsets: &self.offsets, - num_rows: self.num_rows, + num_rows: self.cur_row, }) } @@ -552,7 +594,7 @@ impl TapeDecoder { pub fn clear(&mut self) { assert!(self.stack.is_empty()); - self.num_rows = 0; + self.cur_row = 0; self.bytes.clear(); self.elements.clear(); self.elements.push(TapeElement::Null); @@ -790,7 +832,7 @@ mod tests { let err = decoder.decode(b"hello").unwrap_err().to_string(); assert_eq!( err, - "Json error: Encountered unexpected 'h' whilst trimming leading whitespace" + "Json error: Encountered unexpected 'h' whilst parsing value" ); let mut decoder = TapeDecoder::new(16, 2); diff --git a/arrow-json/src/reader/timestamp_array.rs b/arrow-json/src/reader/timestamp_array.rs index ef69deabce2d..09672614107c 100644 --- a/arrow-json/src/reader/timestamp_array.rs +++ b/arrow-json/src/reader/timestamp_array.rs @@ -71,7 +71,14 @@ where TimeUnit::Second => date.timestamp(), TimeUnit::Millisecond => date.timestamp_millis(), TimeUnit::Microsecond => date.timestamp_micros(), - TimeUnit::Nanosecond => date.timestamp_nanos(), + TimeUnit::Nanosecond => { + date.timestamp_nanos_opt().ok_or_else(|| { + ArrowError::ParseError(format!( + "{} would overflow 64-bit signed nanoseconds", + date.to_rfc3339(), + )) + })? + } }; builder.append_value(value) } @@ -89,6 +96,13 @@ where builder.append_value(value) } + TapeElement::I32(v) => builder.append_value(v as i64), + TapeElement::I64(high) => match tape.get(p + 1) { + TapeElement::I32(low) => { + builder.append_value((high as i64) << 32 | low as i64) + } + _ => unreachable!(), + }, _ => return Err(tape.error(*p, "primitive")), } } diff --git a/arrow-json/src/writer.rs b/arrow-json/src/writer.rs index 571e95a1a4ec..8c4145bc95b4 100644 --- a/arrow-json/src/writer.rs +++ b/arrow-json/src/writer.rs @@ -320,11 +320,9 @@ fn set_column_for_json_rows( } DataType::Struct(_) => { let inner_objs = struct_array_to_jsonmap_array(array.as_struct())?; - rows.iter_mut() - .zip(inner_objs.into_iter()) - .for_each(|(row, obj)| { - row.insert(col_name.to_string(), Value::Object(obj)); - }); + rows.iter_mut().zip(inner_objs).for_each(|(row, obj)| { + row.insert(col_name.to_string(), Value::Object(obj)); + }); } DataType::List(_) => { let listarr = as_list_array(array); @@ -374,7 +372,7 @@ fn set_column_for_json_rows( let keys = keys.as_string::(); let values = array_to_json_array(values)?; - let mut kv = keys.iter().zip(values.into_iter()); + let mut kv = keys.iter().zip(values); for (i, row) in rows.iter_mut().enumerate() { if maparr.is_null(i) { @@ -759,7 +757,8 @@ mod tests { let ts_nanos = ts_string .parse::() .unwrap() - .timestamp_nanos(); + .timestamp_nanos_opt() + .unwrap(); let ts_micros = ts_nanos / 1000; let ts_millis = ts_micros / 1000; let ts_secs = ts_millis / 1000; @@ -811,7 +810,8 @@ mod tests { let ts_nanos = ts_string .parse::() .unwrap() - .timestamp_nanos(); + .timestamp_nanos_opt() + .unwrap(); let ts_micros = ts_nanos / 1000; let ts_millis = ts_micros / 1000; let ts_secs = ts_millis / 1000; @@ -1338,11 +1338,7 @@ mod tests { let batch = reader.next().unwrap().unwrap(); - let list_row = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let list_row = batch.column(0).as_list::(); let values = list_row.values(); assert_eq!(values.len(), 4); assert_eq!(values.null_count(), 1); @@ -1387,7 +1383,7 @@ mod tests { Arc::new(Field::new( "entries", entry_struct.data_type().clone(), - true, + false, )), false, ); diff --git a/arrow-ord/Cargo.toml b/arrow-ord/Cargo.toml index fb061b9b5499..c9c30074fe6e 100644 --- a/arrow-ord/Cargo.toml +++ b/arrow-ord/Cargo.toml @@ -44,10 +44,3 @@ half = { version = "2.1", default-features = false, features = ["num-traits"] } [dev-dependencies] rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } - -[package.metadata.docs.rs] -features = ["dyn_cmp_dict"] - -[features] -dyn_cmp_dict = [] -simd = ["arrow-array/simd"] diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs new file mode 100644 index 000000000000..feb168335568 --- /dev/null +++ b/arrow-ord/src/cmp.rs @@ -0,0 +1,723 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Comparison kernels for `Array`s. +//! +//! These kernels can leverage SIMD if available on your system. Currently no runtime +//! detection is provided, you should enable the specific SIMD intrinsics using +//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation +//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. +//! + +use arrow_array::cast::AsArray; +use arrow_array::types::ByteArrayType; +use arrow_array::{ + downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, + Datum, FixedSizeBinaryArray, GenericByteArray, +}; +use arrow_buffer::bit_util::ceil; +use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; +use arrow_schema::ArrowError; +use arrow_select::take::take; +use std::ops::Not; + +#[derive(Debug, Copy, Clone)] +enum Op { + Equal, + NotEqual, + Less, + LessEqual, + Greater, + GreaterEqual, + Distinct, + NotDistinct, +} + +impl std::fmt::Display for Op { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Op::Equal => write!(f, "=="), + Op::NotEqual => write!(f, "!="), + Op::Less => write!(f, "<"), + Op::LessEqual => write!(f, "<="), + Op::Greater => write!(f, ">"), + Op::GreaterEqual => write!(f, ">="), + Op::Distinct => write!(f, "IS DISTINCT FROM"), + Op::NotDistinct => write!(f, "IS NOT DISTINCT FROM"), + } + } +} + +/// Perform `left == right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::Equal, lhs, rhs) +} + +/// Perform `left != right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::NotEqual, lhs, rhs) +} + +/// Perform `left < right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::Less, lhs, rhs) +} + +/// Perform `left <= right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::LessEqual, lhs, rhs) +} + +/// Perform `left > right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::Greater, lhs, rhs) +} + +/// Perform `left >= right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::GreaterEqual, lhs, rhs) +} + +/// Perform `left IS DISTINCT FROM right` operation on two [`Datum`] +/// +/// [`distinct`] is similar to [`neq`], only differing in null handling. In particular, two +/// operands are considered DISTINCT if they have a different value or if one of them is NULL +/// and the other isn't. The result of [`distinct`] is never NULL. +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::Distinct, lhs, rhs) +} + +/// Perform `left IS NOT DISTINCT FROM right` operation on two [`Datum`] +/// +/// [`not_distinct`] is similar to [`eq`], only differing in null handling. In particular, two +/// operands are considered `NOT DISTINCT` if they have the same value or if both of them +/// is NULL. The result of [`not_distinct`] is never NULL. +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn not_distinct( + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + compare_op(Op::NotDistinct, lhs, rhs) +} + +/// Perform `op` on the provided `Datum` +#[inline(never)] +fn compare_op( + op: Op, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + use arrow_schema::DataType::*; + let (l, l_s) = lhs.get(); + let (r, r_s) = rhs.get(); + + let l_len = l.len(); + let r_len = r.len(); + + if l_len != r_len && !l_s && !r_s { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot compare arrays of different lengths, got {l_len} vs {r_len}" + ))); + } + + let len = match l_s { + true => r_len, + false => l_len, + }; + + let l_nulls = l.logical_nulls(); + let r_nulls = r.logical_nulls(); + + let l_v = l.as_any_dictionary_opt(); + let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); + let l_t = l.data_type(); + + let r_v = r.as_any_dictionary_opt(); + let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); + let r_t = r.data_type(); + + if l_t != r_t || l_t.is_nested() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid comparison operation: {l_t} {op} {r_t}" + ))); + } + + // Defer computation as may not be necessary + let values = || -> BooleanBuffer { + let d = downcast_primitive_array! { + (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), + (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), + (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (LargeUtf8, LargeUtf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (Binary, Binary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), + (Null, Null) => None, + _ => unreachable!(), + }; + d.unwrap_or_else(|| BooleanBuffer::new_unset(len)) + }; + + let l_nulls = l_nulls.filter(|n| n.null_count() > 0); + let r_nulls = r_nulls.filter(|n| n.null_count() > 0); + Ok(match (l_nulls, l_s, r_nulls, r_s) { + (Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => { + // Either both sides are scalar or neither side is scalar + match op { + Op::Distinct => { + let values = values(); + let l = l.inner().bit_chunks().iter_padded(); + let r = r.inner().bit_chunks().iter_padded(); + let ne = values.bit_chunks().iter_padded(); + + let c = |((l, r), n)| ((l ^ r) | (l & r & n)); + let buffer = l.zip(r).zip(ne).map(c).collect(); + BooleanBuffer::new(buffer, 0, len).into() + } + Op::NotDistinct => { + let values = values(); + let l = l.inner().bit_chunks().iter_padded(); + let r = r.inner().bit_chunks().iter_padded(); + let e = values.bit_chunks().iter_padded(); + + let c = |((l, r), e)| u64::not(l | r) | (l & r & e); + let buffer = l.zip(r).zip(e).map(c).collect(); + BooleanBuffer::new(buffer, 0, len).into() + } + _ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))), + } + } + (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => { + // Scalar is null, other side is non-scalar and nullable + match op { + Op::Distinct => a.into_inner().into(), + Op::NotDistinct => a.into_inner().not().into(), + _ => BooleanArray::new_null(len), + } + } + (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => { + // Only one side is nullable + match is_scalar { + true => match op { + // Scalar is null, other side is not nullable + Op::Distinct => BooleanBuffer::new_set(len).into(), + Op::NotDistinct => BooleanBuffer::new_unset(len).into(), + _ => BooleanArray::new_null(len), + }, + false => match op { + Op::Distinct => { + let values = values(); + let l = nulls.inner().bit_chunks().iter_padded(); + let ne = values.bit_chunks().iter_padded(); + let c = |(l, n)| u64::not(l) | n; + let buffer = l.zip(ne).map(c).collect(); + BooleanBuffer::new(buffer, 0, len).into() + } + Op::NotDistinct => (nulls.inner() & &values()).into(), + _ => BooleanArray::new(values(), Some(nulls)), + }, + } + } + // Neither side is nullable + (None, _, None, _) => BooleanArray::new(values(), None), + }) +} + +/// Perform a potentially vectored `op` on the provided `ArrayOrd` +fn apply( + op: Op, + l: T, + l_s: bool, + l_v: Option<&dyn AnyDictionaryArray>, + r: T, + r_s: bool, + r_v: Option<&dyn AnyDictionaryArray>, +) -> Option { + if l.len() == 0 || r.len() == 0 { + return None; // Handle empty dictionaries + } + + if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) { + // Not scalar and at least one side has a dictionary, need to perform vectored comparison + let l_v = l_v + .map(|x| x.normalized_keys()) + .unwrap_or_else(|| (0..l.len()).collect()); + + let r_v = r_v + .map(|x| x.normalized_keys()) + .unwrap_or_else(|| (0..r.len()).collect()); + + assert_eq!(l_v.len(), r_v.len()); // Sanity check + + Some(match op { + Op::Equal | Op::NotDistinct => { + apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq) + } + Op::NotEqual | Op::Distinct => { + apply_op_vectored(l, &l_v, r, &r_v, true, T::is_eq) + } + Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt), + Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true, T::is_lt), + Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false, T::is_lt), + Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_lt), + }) + } else { + let l_s = l_s.then(|| l_v.map(|x| x.normalized_keys()[0]).unwrap_or_default()); + let r_s = r_s.then(|| r_v.map(|x| x.normalized_keys()[0]).unwrap_or_default()); + + let buffer = match op { + Op::Equal | Op::NotDistinct => apply_op(l, l_s, r, r_s, false, T::is_eq), + Op::NotEqual | Op::Distinct => apply_op(l, l_s, r, r_s, true, T::is_eq), + Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt), + Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt), + Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt), + Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt), + }; + + // If a side had a dictionary, and was not scalar, we need to materialize this + Some(match (l_v, r_v) { + (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer), + (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer), + _ => buffer, + }) + } +} + +/// Perform a take operation on `buffer` with the given dictionary +fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer { + let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap(); + array.as_boolean().values().clone() +} + +/// Invokes `f` with values `0..len` collecting the boolean results into a new `BooleanBuffer` +/// +/// This is similar to [`MutableBuffer::collect_bool`] but with +/// the option to efficiently negate the result +fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) -> BooleanBuffer { + let mut buffer = MutableBuffer::new(ceil(len, 64) * 8); + + let chunks = len / 64; + let remainder = len % 64; + for chunk in 0..chunks { + let mut packed = 0; + for bit_idx in 0..64 { + let i = bit_idx + chunk * 64; + packed |= (f(i) as u64) << bit_idx; + } + if neg { + packed = !packed + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + + if remainder != 0 { + let mut packed = 0; + for bit_idx in 0..remainder { + let i = bit_idx + chunks * 64; + packed |= (f(i) as u64) << bit_idx; + } + if neg { + packed = !packed + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + BooleanBuffer::new(buffer.into(), 0, len) +} + +/// Applies `op` to possibly scalar `ArrayOrd` +/// +/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the scalar value in `l` +/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the scalar value in `r` +/// +/// If `neg` is true the result of `op` will be negated +fn apply_op( + l: T, + l_s: Option, + r: T, + r_s: Option, + neg: bool, + op: impl Fn(T::Item, T::Item) -> bool, +) -> BooleanBuffer { + match (l_s, r_s) { + (None, None) => { + assert_eq!(l.len(), r.len()); + collect_bool(l.len(), neg, |idx| unsafe { + op(l.value_unchecked(idx), r.value_unchecked(idx)) + }) + } + (Some(l_s), Some(r_s)) => { + let a = l.value(l_s); + let b = r.value(r_s); + std::iter::once(op(a, b) ^ neg).collect() + } + (Some(l_s), None) => { + let v = l.value(l_s); + collect_bool(r.len(), neg, |idx| op(v, unsafe { r.value_unchecked(idx) })) + } + (None, Some(r_s)) => { + let v = r.value(r_s); + collect_bool(l.len(), neg, |idx| op(unsafe { l.value_unchecked(idx) }, v)) + } + } +} + +/// Applies `op` to possibly scalar `ArrayOrd` with the given indices +fn apply_op_vectored( + l: T, + l_v: &[usize], + r: T, + r_v: &[usize], + neg: bool, + op: impl Fn(T::Item, T::Item) -> bool, +) -> BooleanBuffer { + assert_eq!(l_v.len(), r_v.len()); + collect_bool(l_v.len(), neg, |idx| unsafe { + let l_idx = *l_v.get_unchecked(idx); + let r_idx = *r_v.get_unchecked(idx); + op(l.value_unchecked(l_idx), r.value_unchecked(r_idx)) + }) +} + +trait ArrayOrd { + type Item: Copy + Default; + + fn len(&self) -> usize; + + fn value(&self, idx: usize) -> Self::Item { + assert!(idx < self.len()); + unsafe { self.value_unchecked(idx) } + } + + /// # Safety + /// + /// Safe if `idx < self.len()` + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item; + + fn is_eq(l: Self::Item, r: Self::Item) -> bool; + + fn is_lt(l: Self::Item, r: Self::Item) -> bool; +} + +impl<'a> ArrayOrd for &'a BooleanArray { + type Item = bool; + + fn len(&self) -> usize { + Array::len(self) + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + BooleanArray::value_unchecked(self, idx) + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l == r + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + !l & r + } +} + +impl ArrayOrd for &[T] { + type Item = T; + + fn len(&self) -> usize { + (*self).len() + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + *self.get_unchecked(idx) + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l.is_eq(r) + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + l.is_lt(r) + } +} + +impl<'a, T: ByteArrayType> ArrayOrd for &'a GenericByteArray { + type Item = &'a [u8]; + + fn len(&self) -> usize { + Array::len(self) + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + GenericByteArray::value_unchecked(self, idx).as_ref() + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l == r + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + l < r + } +} + +impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { + type Item = &'a [u8]; + + fn len(&self) -> usize { + Array::len(self) + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + FixedSizeBinaryArray::value_unchecked(self, idx) + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l == r + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + l < r + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray}; + + use super::*; + + #[test] + fn test_null_dict() { + let a = DictionaryArray::new( + Int32Array::new_null(10), + Arc::new(Int32Array::new_null(0)), + ); + let r = eq(&a, &a).unwrap(); + assert_eq!(r.null_count(), 10); + + let a = DictionaryArray::new( + Int32Array::from(vec![1, 2, 3, 4, 5, 6]), + Arc::new(Int32Array::new_null(10)), + ); + let r = eq(&a, &a).unwrap(); + assert_eq!(r.null_count(), 6); + + let scalar = DictionaryArray::new( + Int32Array::new_null(1), + Arc::new(Int32Array::new_null(0)), + ); + let r = eq(&a, &Scalar::new(&scalar)).unwrap(); + assert_eq!(r.null_count(), 6); + + let scalar = DictionaryArray::new( + Int32Array::new_null(1), + Arc::new(Int32Array::new_null(0)), + ); + let r = eq(&Scalar::new(&scalar), &Scalar::new(&scalar)).unwrap(); + assert_eq!(r.null_count(), 1); + + let a = DictionaryArray::new( + Int32Array::from(vec![0, 1, 2]), + Arc::new(Int32Array::from(vec![3, 2, 1])), + ); + let r = eq(&a, &Scalar::new(&scalar)).unwrap(); + assert_eq!(r.null_count(), 3); + } + + #[test] + fn is_distinct_from_non_nulls() { + let left_int_array = Int32Array::from(vec![0, 1, 2, 3, 4]); + let right_int_array = Int32Array::from(vec![4, 3, 2, 1, 0]); + + assert_eq!( + BooleanArray::from(vec![true, true, false, true, true,]), + distinct(&left_int_array, &right_int_array).unwrap() + ); + assert_eq!( + BooleanArray::from(vec![false, false, true, false, false,]), + not_distinct(&left_int_array, &right_int_array).unwrap() + ); + } + + #[test] + fn is_distinct_from_nulls() { + // [0, 0, NULL, 0, 0, 0] + let left_int_array = Int32Array::new( + vec![0, 0, 1, 3, 0, 0].into(), + Some(NullBuffer::from(vec![true, true, false, true, true, true])), + ); + // [0, NULL, NULL, NULL, 0, NULL] + let right_int_array = Int32Array::new( + vec![0; 6].into(), + Some(NullBuffer::from(vec![ + true, false, false, false, true, false, + ])), + ); + + assert_eq!( + BooleanArray::from(vec![false, true, false, true, false, true,]), + distinct(&left_int_array, &right_int_array).unwrap() + ); + + assert_eq!( + BooleanArray::from(vec![true, false, true, false, true, false,]), + not_distinct(&left_int_array, &right_int_array).unwrap() + ); + } + + #[test] + fn test_distinct_scalar() { + let a = Int32Array::new_scalar(12); + let b = Int32Array::new_scalar(12); + assert!(!distinct(&a, &b).unwrap().value(0)); + assert!(not_distinct(&a, &b).unwrap().value(0)); + + let a = Int32Array::new_scalar(12); + let b = Int32Array::new_null(1); + assert!(distinct(&a, &b).unwrap().value(0)); + assert!(!not_distinct(&a, &b).unwrap().value(0)); + assert!(distinct(&b, &a).unwrap().value(0)); + assert!(!not_distinct(&b, &a).unwrap().value(0)); + + let b = Scalar::new(b); + assert!(distinct(&a, &b).unwrap().value(0)); + assert!(!not_distinct(&a, &b).unwrap().value(0)); + + assert!(!distinct(&b, &b).unwrap().value(0)); + assert!(not_distinct(&b, &b).unwrap().value(0)); + + let a = Int32Array::new( + vec![0, 1, 2, 3].into(), + Some(vec![false, false, true, true].into()), + ); + let expected = BooleanArray::from(vec![false, false, true, true]); + assert_eq!(distinct(&a, &b).unwrap(), expected); + assert_eq!(distinct(&b, &a).unwrap(), expected); + + let expected = BooleanArray::from(vec![true, true, false, false]); + assert_eq!(not_distinct(&a, &b).unwrap(), expected); + assert_eq!(not_distinct(&b, &a).unwrap(), expected); + + let b = Int32Array::new_scalar(1); + let expected = BooleanArray::from(vec![true; 4]); + assert_eq!(distinct(&a, &b).unwrap(), expected); + assert_eq!(distinct(&b, &a).unwrap(), expected); + let expected = BooleanArray::from(vec![false; 4]); + assert_eq!(not_distinct(&a, &b).unwrap(), expected); + assert_eq!(not_distinct(&b, &a).unwrap(), expected); + + let b = Int32Array::new_scalar(3); + let expected = BooleanArray::from(vec![true, true, true, false]); + assert_eq!(distinct(&a, &b).unwrap(), expected); + assert_eq!(distinct(&b, &a).unwrap(), expected); + let expected = BooleanArray::from(vec![false, false, false, true]); + assert_eq!(not_distinct(&a, &b).unwrap(), expected); + assert_eq!(not_distinct(&b, &a).unwrap(), expected); + } + + #[test] + fn test_scalar_negation() { + let a = Int32Array::new_scalar(54); + let b = Int32Array::new_scalar(54); + let r = eq(&a, &b).unwrap(); + assert!(r.value(0)); + + let r = neq(&a, &b).unwrap(); + assert!(!r.value(0)) + } + + #[test] + fn test_scalar_empty() { + let a = Int32Array::new_null(0); + let b = Int32Array::new_scalar(23); + let r = eq(&a, &b).unwrap(); + assert_eq!(r.len(), 0); + let r = eq(&b, &a).unwrap(); + assert_eq!(r.len(), 0); + } + + #[test] + fn test_dictionary_nulls() { + let values = StringArray::from(vec![Some("us-west"), Some("us-east")]); + let nulls = NullBuffer::from(vec![false, true, true]); + + let key_values = vec![100i32, 1i32, 0i32].into(); + let keys = Int32Array::new(key_values, Some(nulls)); + let col = DictionaryArray::try_new(keys, Arc::new(values)).unwrap(); + + neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap(); + } +} diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index 4f8b9a322620..ffd35a6070b8 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -23,15 +23,235 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. //! +use half::f16; +use std::sync::Arc; + use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::i256; -use arrow_buffer::{bit_util, BooleanBuffer, Buffer, MutableBuffer, NullBuffer}; -use arrow_data::ArrayData; +use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer}; use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; -use arrow_select::take::take; -use half::f16; + +/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. +/// Type of expression is `Result<.., ArrowError>` +macro_rules! try_to_type { + ($RIGHT: expr, $TY: ident) => { + try_to_type_result($RIGHT.$TY(), &format!("{:?}", $RIGHT), stringify!($TY)) + }; +} + +// Avoids creating a closure for each combination of `$RIGHT` and `$TY` +fn try_to_type_result( + value: Option, + right: &str, + ty: &str, +) -> Result { + value.ok_or_else(|| { + ArrowError::ComputeError(format!("Could not convert {right} with {ty}",)) + }) +} + +fn make_primitive_scalar( + d: &DataType, + scalar: T, +) -> Result { + match d { + DataType::Int8 => { + let right = try_to_type!(scalar, to_i8)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Int16 => { + let right = try_to_type!(scalar, to_i16)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Int32 => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Int64 => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt8 => { + let right = try_to_type!(scalar, to_u8)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt16 => { + let right = try_to_type!(scalar, to_u16)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt32 => { + let right = try_to_type!(scalar, to_u32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt64 => { + let right = try_to_type!(scalar, to_u64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Float16 => { + let right = try_to_type!(scalar, to_f32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + f16::from_f32(right), + ]))) + } + DataType::Float32 => { + let right = try_to_type!(scalar, to_f32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Float64 => { + let right = try_to_type!(scalar, to_f64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Decimal128(_, _) => { + let right = try_to_type!(scalar, to_i128)?; + Ok(Arc::new( + PrimitiveArray::::from(vec![right]) + .with_data_type(d.clone()), + )) + } + DataType::Decimal256(_, _) => { + let right = try_to_type!(scalar, to_i128)?; + Ok(Arc::new( + PrimitiveArray::::from(vec![i256::from_i128(right)]) + .with_data_type(d.clone()), + )) + } + DataType::Date32 => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Date64 => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new( + PrimitiveArray::::from(vec![right]) + .with_data_type(d.clone()), + )) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new( + PrimitiveArray::::from(vec![right]) + .with_data_type(d.clone()), + )) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new( + PrimitiveArray::::from(vec![right]) + .with_data_type(d.clone()), + )) + } + DataType::Timestamp(TimeUnit::Second, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new( + PrimitiveArray::::from(vec![right]) + .with_data_type(d.clone()), + )) + } + DataType::Time32(TimeUnit::Second) => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Time32(TimeUnit::Millisecond) => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Time64(TimeUnit::Microsecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Interval(IntervalUnit::DayTime) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let right = try_to_type!(scalar, to_i128)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Duration(TimeUnit::Second) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Duration(TimeUnit::Millisecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Duration(TimeUnit::Microsecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Duration(TimeUnit::Nanosecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Dictionary(_, v) => make_primitive_scalar(v.as_ref(), scalar), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported primitive scalar data type {d:?}", + ))), + } +} + +fn make_binary_scalar(d: &DataType, scalar: &[u8]) -> Result { + match d { + DataType::Binary => Ok(Arc::new(BinaryArray::from_iter_values([scalar]))), + DataType::FixedSizeBinary(_) => Ok(Arc::new( + FixedSizeBinaryArray::try_from_iter([scalar].into_iter())?, + )), + DataType::LargeBinary => { + Ok(Arc::new(LargeBinaryArray::from_iter_values([scalar]))) + } + DataType::Dictionary(_, v) => make_binary_scalar(v.as_ref(), scalar), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported binary scalar data type {d:?}", + ))), + } +} + +fn make_utf8_scalar(d: &DataType, scalar: &str) -> Result { + match d { + DataType::Utf8 => Ok(Arc::new(StringArray::from_iter_values([scalar]))), + DataType::LargeUtf8 => Ok(Arc::new(LargeStringArray::from_iter_values([scalar]))), + DataType::Dictionary(_, v) => make_utf8_scalar(v.as_ref(), scalar), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported utf8 scalar data type {d:?}", + ))), + } +} /// Helper function to perform boolean lambda function on values from two array accessors, this /// version does not attempt to use SIMD. @@ -67,6 +287,7 @@ where /// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified /// comparison function. +#[deprecated(note = "Use BooleanArray::from_binary")] pub fn no_simd_compare_op( left: &PrimitiveArray, right: &PrimitiveArray, @@ -81,6 +302,7 @@ where /// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using /// a specified comparison function. +#[deprecated(note = "Use BooleanArray::from_unary")] pub fn no_simd_compare_op_scalar( left: &PrimitiveArray, right: T::Native, @@ -94,617 +316,345 @@ where } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a == b) -} - -fn utf8_empty( - left: &GenericStringArray, -) -> Result { - let null_bit_buffer = left.nulls().map(|b| b.inner().sliced()); - - let buffer = unsafe { - MutableBuffer::from_trusted_len_iter_bool(left.value_offsets().windows(2).map( - |offset| { - if EQ { - offset[1].as_usize() == offset[0].as_usize() - } else { - offset[1].as_usize() > offset[0].as_usize() - } - }, - )) - }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(buffer)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) + crate::cmp::eq(left, right) } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - if right.is_empty() { - return utf8_empty::<_, true>(left); - } - compare_op_scalar(left, |a| a == right) + let right = GenericStringArray::::from(vec![right]); + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| !(a ^ b)) + crate::cmp::eq(&left, &right) } /// Perform `left != right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| (a ^ b)) + crate::cmp::neq(&left, &right) } /// Perform `left < right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| ((!a) & b)) + crate::cmp::lt(&left, &right) } /// Perform `left <= right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| !(a & (!b))) + crate::cmp::lt_eq(&left, &right) } /// Perform `left > right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| (a & (!b))) + crate::cmp::gt(&left, &right) } /// Perform `left >= right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| !((!a) & b)) + crate::cmp::gt_eq(&left, &right) } /// Perform `left == right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - let values = match right { - true => left.values().clone(), - false => !left.values(), - }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - values.len(), - None, - left.nulls().map(|b| b.inner().sliced()), - values.offset(), - vec![values.into_inner()], - vec![], - ) - }; - - Ok(BooleanArray::from(data)) + let right = BooleanArray::from(vec![right]); + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a: bool| !a & right) + let right = BooleanArray::from(vec![right]); + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a| a <= right) + let right = BooleanArray::from(vec![right]); + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a: bool| a & !right) + let right = BooleanArray::from(vec![right]); + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a| a >= right) + let right = BooleanArray::from(vec![right]); + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Perform `left != right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - eq_bool_scalar(left, !right) + let right = BooleanArray::from(vec![right]); + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a == b) + crate::cmp::eq(left, right) } /// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a == right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::eq(left, &Scalar::new(&right)) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a != b) + crate::cmp::neq(left, right) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a != right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::neq(left, &Scalar::new(&right)) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a < b) + crate::cmp::lt(left, right) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a < right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::lt(left, &Scalar::new(&right)) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a <= b) + crate::cmp::lt_eq(left, right) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a <= right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::lt_eq(left, &Scalar::new(&right)) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a > b) + crate::cmp::gt(left, right) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a > right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::gt(left, &Scalar::new(&right)) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a >= b) + crate::cmp::gt_eq(left, right) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a >= right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::gt_eq(left, &Scalar::new(&right)) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a != b) + crate::cmp::neq(left, right) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - if right.is_empty() { - return utf8_empty::<_, false>(left); - } - compare_op_scalar(left, |a| a != right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::neq(left, &Scalar::new(&right)) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a < b) + crate::cmp::lt(left, right) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a < right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::lt(left, &Scalar::new(&right)) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a <= b) + crate::cmp::lt_eq(left, right) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a <= right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::lt_eq(left, &Scalar::new(&right)) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a > b) + crate::cmp::gt(left, right) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a > right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::gt(left, &Scalar::new(&right)) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a >= b) + crate::cmp::gt_eq(left, right) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a >= right) -} - -// Avoids creating a closure for each combination of `$RIGHT` and `$TY` -fn try_to_type_result( - value: Option, - right: &str, - ty: &str, -) -> Result { - value.ok_or_else(|| { - ArrowError::ComputeError(format!("Could not convert {right} with {ty}",)) - }) -} - -/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. -/// Type of expression is `Result<.., ArrowError>` -macro_rules! try_to_type { - ($RIGHT: expr, $TY: ident) => { - try_to_type_result($RIGHT.$TY(), &format!("{:?}", $RIGHT), stringify!($TY)) - }; -} - -macro_rules! dyn_compare_scalar { - // Applies `LEFT OP RIGHT` when `LEFT` is a `PrimitiveArray` - ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{ - match $LEFT.data_type() { - DataType::Int8 => { - let right = try_to_type!($RIGHT, to_i8)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Int16 => { - let right = try_to_type!($RIGHT, to_i16)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Int32 => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Int64 => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt8 => { - let right = try_to_type!($RIGHT, to_u8)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt16 => { - let right = try_to_type!($RIGHT, to_u16)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt32 => { - let right = try_to_type!($RIGHT, to_u32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt64 => { - let right = try_to_type!($RIGHT, to_u64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Float16 => { - let right = try_to_type!($RIGHT, to_f32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, f16::from_f32(right)) - } - DataType::Float32 => { - let right = try_to_type!($RIGHT, to_f32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Float64 => { - let right = try_to_type!($RIGHT, to_f64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Decimal128(_, _) => { - let right = try_to_type!($RIGHT, to_i128)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Decimal256(_, _) => { - let right = try_to_type!($RIGHT, to_i128)?; - let left = as_primitive_array::($LEFT); - $OP::(left, i256::from_i128(right)) - } - DataType::Date32 => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Date64 => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Second, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time32(TimeUnit::Second) => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time32(TimeUnit::Millisecond) => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time64(TimeUnit::Microsecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time64(TimeUnit::Nanosecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Interval(IntervalUnit::YearMonth) => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Interval(IntervalUnit::DayTime) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let right = try_to_type!($RIGHT, to_i128)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Second) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Millisecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Microsecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Nanosecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - _ => Err(ArrowError::ComputeError(format!( - "Unsupported data type {:?} for comparison {} with {:?}", - $LEFT.data_type(), - stringify!($OP), - $RIGHT - ))), - } - }}; - // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` with keys of type `KT` - ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ - match $KT.as_ref() { - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - _ => Err(ArrowError::ComputeError(format!( - "Unsupported dictionary key type {:?}", - $KT.as_ref() - ))), - } - }}; -} - -macro_rules! dyn_compare_utf8_scalar { - ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ - match $KT.as_ref() { - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), - } - }}; + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::gt_eq(left, &Scalar::new(&right)) } /// Perform `left == right` operation on an array and a numeric scalar @@ -716,16 +666,13 @@ macro_rules! dyn_compare_utf8_scalar { /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, eq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, eq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on an array and a numeric scalar @@ -737,16 +684,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, lt_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, lt_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on an array and a numeric scalar @@ -758,16 +702,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, lt_eq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, lt_eq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on an array and a numeric scalar @@ -779,16 +720,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, gt_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, gt_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on an array and a numeric scalar @@ -800,16 +738,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, gt_eq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, gt_eq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Perform `left != right` operation on an array and a numeric scalar @@ -821,1317 +756,211 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, neq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, neq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => eq_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => eq_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "eq_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left != right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => neq_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => neq_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "neq_dyn_binary_scalar only supports Binary or LargeBinary arrays" - .to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => lt_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => lt_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "lt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => lt_eq_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => lt_eq_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" - .to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => gt_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => gt_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "gt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => gt_eq_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => gt_eq_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" - .to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn_utf8_scalar( left: &dyn Array, right: &str, ) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, eq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - eq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - eq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_dyn_utf8_scalar( left: &dyn Array, right: &str, ) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, lt_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - lt_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - lt_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_dyn_utf8_scalar( left: &dyn Array, right: &str, ) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, gt_eq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - gt_eq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - gt_eq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_dyn_utf8_scalar( left: &dyn Array, right: &str, ) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, lt_eq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - lt_eq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - lt_eq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_dyn_utf8_scalar( left: &dyn Array, right: &str, ) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, gt_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "gt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - gt_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - gt_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "gt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left != right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_dyn_utf8_scalar( left: &dyn Array, right: &str, ) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, neq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "neq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - neq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - neq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "neq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on an array and a numeric scalar /// value. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn_bool_scalar( left: &dyn Array, right: bool, ) -> Result { - let result = match left.data_type() { - DataType::Boolean => eq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "eq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result + let right = BooleanArray::from(vec![right]); + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on an array and a numeric scalar /// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_dyn_bool_scalar( left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => lt_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "lt_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left > right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn gt_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => gt_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "gt_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left <= right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn lt_eq_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => lt_eq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left >= right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn gt_eq_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => gt_eq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left != right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn neq_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => neq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "neq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// unpacks the results of comparing left.values (as a boolean) -/// -/// TODO add example -/// -fn unpack_dict_comparison( - dict: &DictionaryArray, - dict_comparison: BooleanArray, -) -> Result -where - K: ArrowDictionaryKeyType, - K::Native: num::ToPrimitive, -{ - let array = take(&dict_comparison, dict.keys(), None)? - .as_boolean() - .clone(); - Ok(array) -} - -/// Helper function to perform boolean lambda function on values from two arrays using -/// SIMD. -#[cfg(feature = "simd")] -fn simd_compare_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - simd_op: SI, - scalar_op: SC, -) -> Result -where - T: ArrowNumericType, - SI: Fn(T::Simd, T::Simd) -> T::SimdMask, - SC: Fn(T::Native, T::Native) -> bool, -{ - use std::borrow::BorrowMut; - - let len = left.len(); - if len != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let nulls = NullBuffer::union(left.nulls(), right.nulls()); - - // we process the data in chunks so that each iteration results in one u64 of comparison result bits - const CHUNK_SIZE: usize = 64; - let lanes = T::lanes(); - - // this is currently the case for all our datatypes and allows us to always append full bytes - assert!( - lanes <= CHUNK_SIZE, - "Number of vector lanes must be at most 64" - ); - - let buffer_size = bit_util::ceil(len, 8); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); - let mut right_chunks = right.values().chunks_exact(CHUNK_SIZE); - - let result_chunks = result.typed_data_mut(); - let result_remainder = left_chunks - .borrow_mut() - .zip(right_chunks.borrow_mut()) - .fold(result_chunks, |result_slice, (left_slice, right_slice)| { - let mut i = 0; - let mut bitmask = 0_u64; - while i < CHUNK_SIZE { - let simd_left = T::load(&left_slice[i..]); - let simd_right = T::load(&right_slice[i..]); - let simd_result = simd_op(simd_left, simd_right); - - let m = T::mask_to_u64(&simd_result); - bitmask |= m << i; - - i += lanes; - } - let bytes = bitmask.to_le_bytes(); - result_slice[0..8].copy_from_slice(&bytes); - - &mut result_slice[8..] - }); - - let left_remainder = left_chunks.remainder(); - let right_remainder = right_chunks.remainder(); - - assert_eq!(left_remainder.len(), right_remainder.len()); - - if !left_remainder.is_empty() { - let remainder_bitmask = left_remainder - .iter() - .zip(right_remainder.iter()) - .enumerate() - .fold(0_u64, |mut mask, (i, (scalar_left, scalar_right))| { - let bit = scalar_op(*scalar_left, *scalar_right) as u64; - mask |= bit << i; - mask - }); - let remainder_mask_as_bytes = - &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; - result_remainder.copy_from_slice(remainder_mask_as_bytes); - } - - let values = BooleanBuffer::new(result.into(), 0, len); - Ok(BooleanArray::new(values, nulls)) -} - -/// Helper function to perform boolean lambda function on values from an array and a scalar value using -/// SIMD. -#[cfg(feature = "simd")] -fn simd_compare_op_scalar( - left: &PrimitiveArray, - right: T::Native, - simd_op: SI, - scalar_op: SC, -) -> Result -where - T: ArrowNumericType, - SI: Fn(T::Simd, T::Simd) -> T::SimdMask, - SC: Fn(T::Native, T::Native) -> bool, -{ - use std::borrow::BorrowMut; - - let len = left.len(); - - // we process the data in chunks so that each iteration results in one u64 of comparison result bits - const CHUNK_SIZE: usize = 64; - let lanes = T::lanes(); - - // this is currently the case for all our datatypes and allows us to always append full bytes - assert!( - lanes <= CHUNK_SIZE, - "Number of vector lanes must be at most 64" - ); - - let buffer_size = bit_util::ceil(len, 8); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); - let simd_right = T::init(right); - - let result_chunks = result.typed_data_mut(); - let result_remainder = - left_chunks - .borrow_mut() - .fold(result_chunks, |result_slice, left_slice| { - let mut i = 0; - let mut bitmask = 0_u64; - while i < CHUNK_SIZE { - let simd_left = T::load(&left_slice[i..]); - let simd_result = simd_op(simd_left, simd_right); - - let m = T::mask_to_u64(&simd_result); - bitmask |= m << i; - - i += lanes; - } - let bytes = bitmask.to_le_bytes(); - result_slice[0..8].copy_from_slice(&bytes); - - &mut result_slice[8..] - }); - - let left_remainder = left_chunks.remainder(); - - if !left_remainder.is_empty() { - let remainder_bitmask = left_remainder.iter().enumerate().fold( - 0_u64, - |mut mask, (i, scalar_left)| { - let bit = scalar_op(*scalar_left, right) as u64; - mask |= bit << i; - mask - }, - ); - let remainder_mask_as_bytes = - &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; - result_remainder.copy_from_slice(remainder_mask_as_bytes); - } - - let null_bit_buffer = left.nulls().map(|b| b.inner().sliced()); - - // null count is the same as in the input since the right side of the scalar comparison cannot be null - let null_count = left.null_count(); - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - Some(null_count), - null_bit_buffer, - 0, - vec![result.into()], - vec![], - ) - }; - Ok(BooleanArray::from(data)) -} - -fn cmp_primitive_array( - left: &dyn Array, - right: &dyn Array, - op: F, -) -> Result -where - F: Fn(T::Native, T::Native) -> bool, -{ - let left_array = left.as_primitive::(); - let right_array = right.as_primitive::(); - compare_op(left_array, right_array, op) -} - -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_non_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{ - match $LEFT_KEY_TYPE { - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - t => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - t - ))), - } - }}; -} - -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_string_array_cmp { - ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP: expr) => {{ - match $LEFT_KEY_TYPE { - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - t => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - t - ))), - } - }}; -} - -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_cmp_dict_non_dict { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ - match ($LEFT.data_type(), $RIGHT.data_type()) { - (DataType::Dictionary(left_key_type, left_value_type), right_type) => { - match (left_value_type.as_ref(), right_type) { - (DataType::Boolean, DataType::Boolean) => { - let left = $LEFT; - downcast_dictionary_array!( - left => { - cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - left_key_type.as_ref() - ))), - ) - } - (DataType::Int8, DataType::Int8) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int8Type, $OP_BOOL, $OP) - } - (DataType::Int16, DataType::Int16) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int16Type, $OP_BOOL, $OP) - } - (DataType::Int32, DataType::Int32) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int32Type, $OP_BOOL, $OP) - } - (DataType::Int64, DataType::Int64) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int64Type, $OP_BOOL, $OP) - } - (DataType::UInt8, DataType::UInt8) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt8Type, $OP_BOOL, $OP) - } - (DataType::UInt16, DataType::UInt16) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt16Type, $OP_BOOL, $OP) - } - (DataType::UInt32, DataType::UInt32) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt32Type, $OP_BOOL, $OP) - } - (DataType::UInt64, DataType::UInt64) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt64Type, $OP_BOOL, $OP) - } - (DataType::Float16, DataType::Float16) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float16Type, $OP_BOOL, $OP_FLOAT) - } - (DataType::Float32, DataType::Float32) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float32Type, $OP_BOOL, $OP_FLOAT) - } - (DataType::Float64, DataType::Float64) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float64Type, $OP_BOOL, $OP_FLOAT) - } - (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Decimal128Type, $OP_BOOL, $OP) - } - (DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Decimal256Type, $OP_BOOL, $OP) - } - (DataType::Utf8, DataType::Utf8) => { - typed_dict_string_array_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), i32, $OP) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - typed_dict_string_array_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), i64, $OP) - } - (DataType::Binary, DataType::Binary) => { - let left = $LEFT; - downcast_dictionary_array!( - left => { - cmp_dict_binary_array::<_, i32, _>(left, $RIGHT, $OP) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - left_key_type.as_ref() - ))), - ) - } - (DataType::LargeBinary, DataType::LargeBinary) => { - let left = $LEFT; - downcast_dictionary_array!( - left => { - cmp_dict_binary_array::<_, i64, _>(left, $RIGHT, $OP) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - left_key_type.as_ref() - ))), - ) - } - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing dictionary array of type {} with array of type {} is not yet implemented", - t1, t2 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare dictionary array with array of different value types ({} and {})", - t1, t2 - ))), - } - } - _ => unreachable!("Should not reach this branch"), - } - }}; -} - -#[cfg(not(feature = "dyn_cmp_dict"))] -macro_rules! typed_cmp_dict_non_dict { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ - Err(ArrowError::CastError(format!( - "Comparing dictionary array of type {} with array of type {} requires \"dyn_cmp_dict\" feature", - $LEFT.data_type(), $RIGHT.data_type() - ))) - }} -} - -macro_rules! typed_compares { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ - match ($LEFT.data_type(), $RIGHT.data_type()) { - (DataType::Boolean, DataType::Boolean) => { - compare_op(as_boolean_array($LEFT), as_boolean_array($RIGHT), $OP_BOOL) - } - (DataType::Int8, DataType::Int8) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Int16, DataType::Int16) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Int32, DataType::Int32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Int64, DataType::Int64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt8, DataType::UInt8) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt16, DataType::UInt16) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt32, DataType::UInt32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt64, DataType::UInt64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Float16, DataType::Float16) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float32, DataType::Float32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float64, DataType::Float64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Utf8, DataType::Utf8) => { - compare_op(as_string_array($LEFT), as_string_array($RIGHT), $OP) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => compare_op( - as_largestring_array($LEFT), - as_largestring_array($RIGHT), - $OP, - ), - (DataType::FixedSizeBinary(_), DataType::FixedSizeBinary(_)) => { - let lhs = $LEFT - .as_any() - .downcast_ref::() - .unwrap(); - let rhs = $RIGHT - .as_any() - .downcast_ref::() - .unwrap(); - - compare_op(lhs, rhs, $OP) - } - (DataType::Binary, DataType::Binary) => compare_op( - as_generic_binary_array::($LEFT), - as_generic_binary_array::($RIGHT), - $OP, - ), - (DataType::LargeBinary, DataType::LargeBinary) => compare_op( - as_generic_binary_array::($LEFT), - as_generic_binary_array::($RIGHT), - $OP, - ), - ( - DataType::Timestamp(TimeUnit::Nanosecond, _), - DataType::Timestamp(TimeUnit::Nanosecond, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Timestamp(TimeUnit::Microsecond, _), - DataType::Timestamp(TimeUnit::Microsecond, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Timestamp(TimeUnit::Millisecond, _), - DataType::Timestamp(TimeUnit::Millisecond, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Timestamp(TimeUnit::Second, _), - DataType::Timestamp(TimeUnit::Second, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - (DataType::Date32, DataType::Date32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Date64, DataType::Date64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Time32(TimeUnit::Second), DataType::Time32(TimeUnit::Second)) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - ( - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Millisecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Microsecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Time64(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Nanosecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::YearMonth), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Interval(IntervalUnit::DayTime), - DataType::Interval(IntervalUnit::DayTime), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::MonthDayNano), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Second), - DataType::Duration(TimeUnit::Second), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Millisecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Microsecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Nanosecond), - DataType::Duration(TimeUnit::Nanosecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing arrays of type {} is not yet implemented", - t1 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare two arrays of different types ({} and {})", - t1, t2 - ))), - } - }}; -} - -/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr, $KT: tt) => {{ - match ($LEFT.value_type(), $RIGHT.value_type()) { - (DataType::Boolean, DataType::Boolean) => { - cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL) - } - (DataType::Int8, DataType::Int8) => { - cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Int16, DataType::Int16) => { - cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Int32, DataType::Int32) => { - cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Int64, DataType::Int64) => { - cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt8, DataType::UInt8) => { - cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt16, DataType::UInt16) => { - cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt32, DataType::UInt32) => { - cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt64, DataType::UInt64) => { - cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Float16, DataType::Float16) => { - cmp_dict::<$KT, Float16Type, _>($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float32, DataType::Float32) => { - cmp_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float64, DataType::Float64) => { - cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => { - cmp_dict::<$KT, Decimal128Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => { - cmp_dict::<$KT, Decimal256Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Utf8, DataType::Utf8) => { - cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - cmp_dict_utf8::<$KT, i64, _>($LEFT, $RIGHT, $OP) - } - (DataType::Binary, DataType::Binary) => { - cmp_dict_binary::<$KT, i32, _>($LEFT, $RIGHT, $OP) - } - (DataType::LargeBinary, DataType::LargeBinary) => { - cmp_dict_binary::<$KT, i64, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Nanosecond, _), - DataType::Timestamp(TimeUnit::Nanosecond, _), - ) => { - cmp_dict::<$KT, TimestampNanosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Microsecond, _), - DataType::Timestamp(TimeUnit::Microsecond, _), - ) => { - cmp_dict::<$KT, TimestampMicrosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Millisecond, _), - DataType::Timestamp(TimeUnit::Millisecond, _), - ) => { - cmp_dict::<$KT, TimestampMillisecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Second, _), - DataType::Timestamp(TimeUnit::Second, _), - ) => { - cmp_dict::<$KT, TimestampSecondType, _>($LEFT, $RIGHT, $OP) - } - (DataType::Date32, DataType::Date32) => { - cmp_dict::<$KT, Date32Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Date64, DataType::Date64) => { - cmp_dict::<$KT, Date64Type, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Second), - ) => { - cmp_dict::<$KT, Time32SecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Millisecond), - ) => { - cmp_dict::<$KT, Time32MillisecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Microsecond), - ) => { - cmp_dict::<$KT, Time64MicrosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time64(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Nanosecond), - ) => { - cmp_dict::<$KT, Time64NanosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::YearMonth), - ) => { - cmp_dict::<$KT, IntervalYearMonthType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Interval(IntervalUnit::DayTime), - DataType::Interval(IntervalUnit::DayTime), - ) => { - cmp_dict::<$KT, IntervalDayTimeType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::MonthDayNano), - ) => { - cmp_dict::<$KT, IntervalMonthDayNanoType, _>($LEFT, $RIGHT, $OP) - } - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing dictionary arrays of value type {} is not yet implemented", - t1 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare two dictionary arrays of different value types ({} and {})", - t1, t2 - ))), - } - }}; -} - -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_compares { - // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{ - match ($LEFT.data_type(), $RIGHT.data_type()) { - (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { - match (left_key_type.as_ref(), right_key_type.as_ref()) { - (DataType::Int8, DataType::Int8) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int8Type) - } - (DataType::Int16, DataType::Int16) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int16Type) - } - (DataType::Int32, DataType::Int32) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int32Type) - } - (DataType::Int64, DataType::Int64) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int64Type) - } - (DataType::UInt8, DataType::UInt8) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt8Type) - } - (DataType::UInt16, DataType::UInt16) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt16Type) - } - (DataType::UInt32, DataType::UInt32) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt32Type) - } - (DataType::UInt64, DataType::UInt64) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt64Type) - } - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing dictionary arrays of type {} is not yet implemented", - t1 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare two dictionary arrays of different key types ({} and {})", - t1, t2 - ))), - } - } - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare dictionary array with non-dictionary array ({} and {})", - t1, t2 - ))), - } - }}; -} - -#[cfg(not(feature = "dyn_cmp_dict"))] -macro_rules! typed_dict_compares { - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{ - Err(ArrowError::CastError(format!( - "Comparing array of type {} with array of type {} requires \"dyn_cmp_dict\" feature", - $LEFT.data_type(), $RIGHT.data_type() - ))) - }} -} - -/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value -/// type of `DictionaryArray` is same as `PrimitiveArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_primitive( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - T: ArrowPrimitiveType + Sync + Send, - F: Fn(T::Native, T::Native) -> bool, -{ - compare_op( - left.downcast_dict::>().unwrap(), - right.as_primitive::(), - op, - ) -} - -/// Perform given operation on `DictionaryArray` and `GenericStringArray`. The value -/// type of `DictionaryArray` is same as `GenericStringArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_string_array( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&str, &str) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .as_any() - .downcast_ref::>() - .unwrap(), - op, - ) -} - -/// Perform given operation on `DictionaryArray` and `BooleanArray`. The value -/// type of `DictionaryArray` is same as `BooleanArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_boolean_array( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(bool, bool) -> bool, -{ - compare_op( - left.downcast_dict::().unwrap(), - right.as_any().downcast_ref::().unwrap(), - op, - ) -} - -/// Perform given operation on `DictionaryArray` and `GenericBinaryArray`. The value -/// type of `DictionaryArray` is same as `GenericBinaryArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_binary_array( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&[u8], &[u8]) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .as_any() - .downcast_ref::>() - .unwrap(), - op, - ) -} - -/// Perform given operation on two `DictionaryArray`s which value type is -/// primitive type. Returns an error if the two arrays have different value -/// type -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - T: ArrowPrimitiveType + Sync + Send, - F: Fn(T::Native, T::Native) -> bool, -{ - compare_op( - left.downcast_dict::>().unwrap(), - right.downcast_dict::>().unwrap(), - op, - ) + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::lt(&left, &Scalar::new(&right)) } -/// Perform the given operation on two `DictionaryArray`s which value type is -/// `DataType::Boolean`. -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict_bool( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(bool, bool) -> bool, -{ - compare_op( - left.downcast_dict::().unwrap(), - right.downcast_dict::().unwrap(), - op, - ) +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::gt")] +pub fn gt_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::gt(&left, &Scalar::new(&right)) } -/// Perform the given operation on two `DictionaryArray`s which value type is -/// `DataType::Utf8` or `DataType::LargeUtf8`. -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict_utf8( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&str, &str) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .downcast_dict::>() - .unwrap(), - op, - ) +/// Perform `left <= right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] +pub fn lt_eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } -/// Perform the given operation on two `DictionaryArray`s which value type is -/// `DataType::Binary` or `DataType::LargeBinary`. -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict_binary( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&[u8], &[u8]) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .downcast_dict::>() - .unwrap(), - op, - ) +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] +pub fn gt_eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::gt_eq(&left, &Scalar::new(&right)) +} + +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::neq")] +pub fn neq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on two (dynamic) [`Array`]s. @@ -2154,29 +983,9 @@ where /// let result = eq_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), None, Some(false)]), result); /// ``` +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a == b, |a, b| a.is_eq(b), |a, b| a - == b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b, |a, b| a - .is_eq(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b, |a, b| b - .is_eq(a)) - } - _ => { - typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b, |a, b| a - .is_eq(b)) - } - } + crate::cmp::eq(&left, &right) } /// Perform `left != right` operation on two (dynamic) [`Array`]s. @@ -2201,29 +1010,9 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a != b, |a, b| a.is_ne(b), |a, b| a - != b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b, |a, b| a - .is_ne(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b, |a, b| b - .is_ne(a)) - } - _ => { - typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b, |a, b| a - .is_ne(b)) - } - } + crate::cmp::neq(&left, &right) } /// Perform `left < right` operation on two (dynamic) [`Array`]s. @@ -2247,30 +1036,9 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a < b, |a, b| a.is_lt(b), |a, b| a - < b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a < b, |a, b| a < b, |a, b| a - .is_lt(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a > b, |a, b| a > b, |a, b| b - .is_lt(a)) - } - _ => { - typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b, |a, b| a - .is_lt(b)) - } - } + crate::cmp::lt(&left, &right) } /// Perform `left <= right` operation on two (dynamic) [`Array`]s. @@ -2294,32 +1062,12 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a.is_le(b), |a, b| a - <= b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a <= b, |a, b| a <= b, |a, b| a - .is_le(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a >= b, |a, b| a >= b, |a, b| b - .is_le(a)) - } - _ => { - typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b, |a, b| a - .is_le(b)) - } - } + crate::cmp::lt_eq(&left, &right) } /// Perform `left > right` operation on two (dynamic) [`Array`]s. @@ -2342,30 +1090,9 @@ pub fn lt_eq_dyn( /// let result = gt_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); /// ``` -#[allow(clippy::bool_comparison)] +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a > b, |a, b| a.is_gt(b), |a, b| a - > b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a > b, |a, b| a > b, |a, b| a - .is_gt(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a < b, |a, b| a < b, |a, b| b - .is_gt(a)) - } - _ => { - typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b, |a, b| a - .is_gt(b)) - } - } + crate::cmp::gt(&left, &right) } /// Perform `left >= right` operation on two (dynamic) [`Array`]s. @@ -2388,32 +1115,12 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a.is_ge(b), |a, b| a - >= b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a >= b, |a, b| a >= b, |a, b| a - .is_ge(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a <= b, |a, b| a <= b, |a, b| b - .is_ge(a)) - } - _ => { - typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b, |a, b| a - .is_ge(b)) - } - } + crate::cmp::gt_eq(&left, &right) } /// Perform `left == right` operation on two [`PrimitiveArray`]s. @@ -2424,6 +1131,7 @@ pub fn gt_eq_dyn( /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2432,20 +1140,17 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::eq, |a, b| a == b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_eq(b)); + crate::cmp::eq(&left, &right) } /// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value. /// -/// If `simd` feature flag is not enabled: /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2454,10 +1159,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_eq(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Applies an unary and infallible comparison function to a primitive array. @@ -2480,6 +1183,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2488,10 +1192,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::ne, |a, b| a != b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_ne(b)); + crate::cmp::neq(&left, &right) } /// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2502,6 +1203,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2510,10 +1212,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_ne(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2525,6 +1225,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2533,10 +1234,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::lt, |a, b| a < b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_lt(b)); + crate::cmp::lt(&left, &right) } /// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2548,6 +1246,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_scalar( left: &PrimitiveArray, right: T::Native, @@ -2556,10 +1255,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_lt(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2571,6 +1268,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2579,10 +1277,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::le, |a, b| a <= b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_le(b)); + crate::cmp::lt_eq(&left, &right) } /// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2594,6 +1289,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2602,10 +1298,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_le(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2617,6 +1311,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2625,10 +1320,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::gt, |a, b| a > b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_gt(b)); + crate::cmp::gt(&left, &right) } /// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2640,6 +1332,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_scalar( left: &PrimitiveArray, right: T::Native, @@ -2648,10 +1341,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_gt(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2663,6 +1354,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2671,10 +1363,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::ge, |a, b| a >= b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_ge(b)); + crate::cmp::gt_eq(&left, &right) } /// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2686,6 +1375,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2694,10 +1384,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_ge(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`] @@ -2785,14 +1473,18 @@ where // disable wrapping inside literal vectors used for test data and assertions #[rustfmt::skip::macros(vec)] #[cfg(test)] +#[allow(deprecated)] mod tests { - use super::*; + use std::sync::Arc; + use arrow_array::builder::{ ListBuilder, PrimitiveDictionaryBuilder, StringBuilder, StringDictionaryBuilder, }; - use arrow_buffer::i256; + use arrow_buffer::{i256, Buffer}; + use arrow_data::ArrayData; use arrow_schema::Field; - use std::sync::Arc; + + use super::*; /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>` where `T` is the native @@ -3501,25 +2193,17 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { + let expected = BooleanArray::from($expected); + let left = BinaryArray::from_vec($left); let right = BinaryArray::from_vec($right); let res = $op(&left, &right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!(v, expected[i]); - } + assert_eq!(res, expected); let left = LargeBinaryArray::from_vec($left); let right = LargeBinaryArray::from_vec($right); let res = $op(&left, &right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!(v, expected[i]); - } + assert_eq!(res, expected); } }; } @@ -3542,37 +2226,15 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { + let expected = BooleanArray::from($expected); + let left = BinaryArray::from_vec($left); let res = $op(&left, $right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!( - v, - expected[i], - "unexpected result when comparing {:?} at position {} to {:?} ", - left.value(i), - i, - $right - ); - } + assert_eq!(res, expected); let left = LargeBinaryArray::from_vec($left); let res = $op(&left, $right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!( - v, - expected[i], - "unexpected result when comparing {:?} at position {} to {:?} ", - left.value(i), - i, - $right - ); - } + assert_eq!(res, expected); } }; } @@ -3806,14 +2468,14 @@ mod tests { vec!["arrow", "arrow", "arrow", "arrow"], vec!["arrow", "parquet", "datafusion", "flight"], eq_utf8, - vec![true, false, false, false] + [true, false, false, false] ); test_utf8_scalar!( test_utf8_array_eq_scalar, vec!["arrow", "parquet", "datafusion", "flight"], "arrow", eq_utf8_scalar, - vec![true, false, false, false] + [true, false, false, false] ); test_utf8!( @@ -3821,14 +2483,14 @@ mod tests { vec!["arrow", "arrow", "arrow", "arrow"], vec!["arrow", "parquet", "datafusion", "flight"], neq_utf8, - vec![false, true, true, true] + [false, true, true, true] ); test_utf8_scalar!( test_utf8_array_neq_scalar, vec!["arrow", "parquet", "datafusion", "flight"], "arrow", neq_utf8_scalar, - vec![false, true, true, true] + [false, true, true, true] ); test_utf8!( @@ -3836,14 +2498,14 @@ mod tests { vec!["arrow", "datafusion", "flight", "parquet"], vec!["flight", "flight", "flight", "flight"], lt_utf8, - vec![true, true, false, false] + [true, true, false, false] ); test_utf8_scalar!( test_utf8_array_lt_scalar, vec!["arrow", "datafusion", "flight", "parquet"], "flight", lt_utf8_scalar, - vec![true, true, false, false] + [true, true, false, false] ); test_utf8!( @@ -3851,14 +2513,14 @@ mod tests { vec!["arrow", "datafusion", "flight", "parquet"], vec!["flight", "flight", "flight", "flight"], lt_eq_utf8, - vec![true, true, true, false] + [true, true, true, false] ); test_utf8_scalar!( test_utf8_array_lt_eq_scalar, vec!["arrow", "datafusion", "flight", "parquet"], "flight", lt_eq_utf8_scalar, - vec![true, true, true, false] + [true, true, true, false] ); test_utf8!( @@ -3866,14 +2528,14 @@ mod tests { vec!["arrow", "datafusion", "flight", "parquet"], vec!["flight", "flight", "flight", "flight"], gt_utf8, - vec![false, false, false, true] + [false, false, false, true] ); test_utf8_scalar!( test_utf8_array_gt_scalar, vec!["arrow", "datafusion", "flight", "parquet"], "flight", gt_utf8_scalar, - vec![false, false, false, true] + [false, false, false, true] ); test_utf8!( @@ -3881,14 +2543,14 @@ mod tests { vec!["arrow", "datafusion", "flight", "parquet"], vec!["flight", "flight", "flight", "flight"], gt_eq_utf8, - vec![false, false, true, true] + [false, false, true, true] ); test_utf8_scalar!( test_utf8_array_gt_eq_scalar, vec!["arrow", "datafusion", "flight", "parquet"], "flight", gt_eq_utf8_scalar, - vec![false, false, true, true] + [false, false, true, true] ); #[test] @@ -4276,6 +2938,15 @@ mod tests { eq_dyn_binary_scalar(&large_array, scalar).unwrap(), expected ); + + let fsb_array = FixedSizeBinaryArray::try_from_iter( + vec![vec![0u8], vec![0u8], vec![0u8], vec![1u8]].into_iter(), + ) + .unwrap(); + let scalar = &[1u8]; + let expected = + BooleanArray::from(vec![Some(false), Some(false), Some(false), Some(true)]); + assert_eq!(eq_dyn_binary_scalar(&fsb_array, scalar).unwrap(), expected); } #[test] @@ -4293,6 +2964,15 @@ mod tests { neq_dyn_binary_scalar(&large_array, scalar).unwrap(), expected ); + + let fsb_array = FixedSizeBinaryArray::try_from_iter( + vec![vec![0u8], vec![0u8], vec![0u8], vec![1u8]].into_iter(), + ) + .unwrap(); + let scalar = &[1u8]; + let expected = + BooleanArray::from(vec![Some(true), Some(true), Some(true), Some(false)]); + assert_eq!(neq_dyn_binary_scalar(&fsb_array, scalar).unwrap(), expected); } #[test] @@ -4619,7 +3299,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_i8_array() { // Construct a value array let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); @@ -4641,7 +3320,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_u64_array() { let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]); let values = Arc::new(values) as ArrayRef; @@ -4662,10 +3340,9 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_utf8_array() { - let test1 = vec!["a", "a", "b", "c"]; - let test2 = vec!["a", "b", "b", "c"]; + let test1 = ["a", "a", "b", "c"]; + let test2 = ["a", "b", "b", "c"]; let dict_array1: DictionaryArray = test1 .iter() @@ -4690,7 +3367,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() @@ -4714,7 +3390,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_interval_array() { let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]); let values = Arc::new(values) as ArrayRef; @@ -4735,7 +3410,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_date_array() { let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]); let values = Arc::new(values) as ArrayRef; @@ -4756,7 +3430,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_bool_array() { let values = BooleanArray::from(vec![true, false]); let values = Arc::new(values) as ArrayRef; @@ -4777,7 +3450,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_gt_dyn_dictionary_i8_array() { // Construct a value array let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); @@ -4808,7 +3480,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_gt_dyn_dictionary_bool_array() { let values = BooleanArray::from(vec![true, false]); let values = Arc::new(values) as ArrayRef; @@ -4840,7 +3511,7 @@ mod tests { #[test] fn test_unary_cmp() { let a = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let values = vec![1_i32, 3]; + let values = [1_i32, 3]; let a_eq = unary_cmp(&a, |a| values.contains(&a)).unwrap(); assert_eq!( @@ -4850,7 +3521,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() { let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); let keys = Int8Array::from_iter_values([2_i8, 3, 4]); @@ -4885,7 +3555,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_i8_i8_array() { let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); let keys = Int8Array::from_iter_values([2_i8, 3, 4]); @@ -4945,20 +3614,13 @@ mod tests { #[test] fn test_eq_dyn_neq_dyn_float_nan() { - let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)] - .into_iter() - .map(Some) - .collect(); - let array2: Float16Array = vec![f16::NAN, f16::NAN, f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)] - .into_iter() - .map(Some) - .collect(); + let array1 = Float16Array::from(vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)]); + let array2 = Float16Array::from(vec![f16::NAN, f16::NAN, f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)]); let expected = BooleanArray::from( vec![Some(true), Some(false), Some(true), Some(true), Some(true)], ); assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(eq(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -4966,23 +3628,15 @@ mod tests { ); assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(neq(&array1, &array2).unwrap(), expected); - let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); - let array2: Float32Array = vec![f32::NAN, f32::NAN, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); + let array1 = Float32Array::from(vec![f32::NAN, 7.0, 8.0, 8.0, 10.0]); + let array2 = Float32Array::from(vec![f32::NAN, f32::NAN, 8.0, 8.0, 10.0]); let expected = BooleanArray::from( vec![Some(true), Some(false), Some(true), Some(true), Some(true)], ); assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(eq(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -4990,24 +3644,16 @@ mod tests { ); assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(neq(&array1, &array2).unwrap(), expected); - let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); - let array2: Float64Array = vec![f64::NAN, f64::NAN, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); + let array1 = Float64Array::from(vec![f64::NAN, 7.0, 8.0, 8.0, 10.0]); + let array2 = Float64Array::from(vec![f64::NAN, f64::NAN, 8.0, 8.0, 10.0]); let expected = BooleanArray::from( vec![Some(true), Some(false), Some(true), Some(true), Some(true)], ); assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(eq(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5015,27 +3661,19 @@ mod tests { ); assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(neq(&array1, &array2).unwrap(), expected); } #[test] fn test_lt_dyn_lt_eq_dyn_float_nan() { - let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN] - .into_iter() - .map(Some) - .collect(); - let array2: Float16Array = vec![f16::NAN, f16::NAN, f16::from_f32(8.0), f16::from_f32(9.0), f16::from_f32(10.0), f16::from_f32(1.0)] - .into_iter() - .map(Some) - .collect(); + let array1 = Float16Array::from(vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN]); + let array2 = Float16Array::from(vec![f16::NAN, f16::NAN, f16::from_f32(8.0), f16::from_f32(9.0), f16::from_f32(10.0), f16::from_f32(1.0)]); let expected = BooleanArray::from( vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], ); assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5043,24 +3681,16 @@ mod tests { ); assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt_eq(&array1, &array2).unwrap(), expected); - let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] - .into_iter() - .map(Some) - .collect(); - let array2: Float32Array = vec![f32::NAN, f32::NAN, 8.0, 9.0, 10.0, 1.0] - .into_iter() - .map(Some) - .collect(); + let array1 = Float32Array::from(vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN]); + let array2 = Float32Array::from(vec![f32::NAN, f32::NAN, 8.0, 9.0, 10.0, 1.0]); let expected = BooleanArray::from( vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], ); assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5068,7 +3698,6 @@ mod tests { ); assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt_eq(&array1, &array2).unwrap(), expected); let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] @@ -5085,7 +3714,6 @@ mod tests { ); assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5093,27 +3721,19 @@ mod tests { ); assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt_eq(&array1, &array2).unwrap(), expected); } #[test] fn test_gt_dyn_gt_eq_dyn_float_nan() { - let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN] - .into_iter() - .map(Some) - .collect(); - let array2: Float16Array = vec![f16::NAN, f16::NAN, f16::from_f32(8.0), f16::from_f32(9.0), f16::from_f32(10.0), f16::from_f32(1.0)] - .into_iter() - .map(Some) - .collect(); + let array1 = Float16Array::from(vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN]); + let array2 = Float16Array::from(vec![f16::NAN, f16::NAN, f16::from_f32(8.0), f16::from_f32(9.0), f16::from_f32(10.0), f16::from_f32(1.0)]); let expected = BooleanArray::from( vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], ); assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5121,24 +3741,16 @@ mod tests { ); assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt_eq(&array1, &array2).unwrap(), expected); - let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] - .into_iter() - .map(Some) - .collect(); - let array2: Float32Array = vec![f32::NAN, f32::NAN, 8.0, 9.0, 10.0, 1.0] - .into_iter() - .map(Some) - .collect(); + let array1 = Float32Array::from(vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN]); + let array2 = Float32Array::from(vec![f32::NAN, f32::NAN, 8.0, 9.0, 10.0, 1.0]); let expected = BooleanArray::from( vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], ); assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5146,24 +3758,16 @@ mod tests { ); assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt_eq(&array1, &array2).unwrap(), expected); - let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] - .into_iter() - .map(Some) - .collect(); - let array2: Float64Array = vec![f64::NAN, f64::NAN, 8.0, 9.0, 10.0, 1.0] - .into_iter() - .map(Some) - .collect(); + let array1 = Float64Array::from(vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN]); + let array2 = Float64Array::from(vec![f64::NAN, f64::NAN, 8.0, 9.0, 10.0, 1.0]); let expected = BooleanArray::from( vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], ); assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5171,79 +3775,40 @@ mod tests { ); assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt_eq(&array1, &array2).unwrap(), expected); } #[test] fn test_eq_dyn_scalar_neq_dyn_scalar_float_nan() { - let array: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)] - .into_iter() - .map(Some) - .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let array = Float16Array::from(vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)]); + let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(true), Some(true), Some(true), Some(true), Some(true)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(neq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - let array: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let array = Float32Array::from(vec![f32::NAN, 7.0, 8.0, 8.0, 10.0]); let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(true), Some(true), Some(true), Some(true), Some(true)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(neq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let array = Float64Array::from(vec![f64::NAN, 7.0, 8.0, 8.0, 10.0]); let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(eq_dyn_scalar(&array, f64::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(true), Some(true), Some(true), Some(true), Some(true)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); @@ -5252,73 +3817,36 @@ mod tests { #[test] fn test_lt_dyn_scalar_lt_eq_dyn_scalar_float_nan() { - let array: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)] - .into_iter() - .map(Some) - .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let array = Float16Array::from(vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)]); + let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_dyn_scalar(&array, f16::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_eq_dyn_scalar(&array, f16::NAN).unwrap(), expected); - let array: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let array = Float32Array::from(vec![f32::NAN, 7.0, 8.0, 8.0, 10.0]); + let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let array = Float64Array::from(vec![f64::NAN, 7.0, 8.0, 8.0, 10.0]); let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_dyn_scalar(&array, f64::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(true), Some(true), Some(true), Some(true)], ); @@ -5327,58 +3855,40 @@ mod tests { #[test] fn test_gt_dyn_scalar_gt_eq_dyn_scalar_float_nan() { - let array: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)] - .into_iter() - .map(Some) - .collect(); + let array = Float16Array::from(vec![ + f16::NAN, + f16::from_f32(7.0), + f16::from_f32(8.0), + f16::from_f32(8.0), + f16::from_f32(10.0), + ]); let expected = BooleanArray::from( vec![Some(false), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(gt_dyn_scalar(&array, f16::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(gt_eq_dyn_scalar(&array, f16::NAN).unwrap(), expected); - let array: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); + let array = Float32Array::from(vec![f32::NAN, 7.0, 8.0, 8.0, 10.0]); let expected = BooleanArray::from( vec![Some(false), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(gt_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(gt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] - .into_iter() - .map(Some) - .collect(); + let array = Float64Array::from(vec![f64::NAN, 7.0, 8.0, 8.0, 10.0]); let expected = BooleanArray::from( vec![Some(false), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(gt_dyn_scalar(&array, f64::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); @@ -5386,10 +3896,9 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_to_utf8_array() { - let test1 = vec!["a", "a", "b", "c"]; - let test2 = vec!["a", "b", "b", "d"]; + let test1 = ["a", "a", "b", "c"]; + let test2 = ["a", "b", "b", "d"]; let dict_array: DictionaryArray = test1 .iter() @@ -5427,10 +3936,9 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_utf8_array() { - let test1 = vec!["abc", "abc", "b", "cde"]; - let test2 = vec!["abc", "b", "b", "def"]; + let test1 = ["abc", "abc", "b", "cde"]; + let test2 = ["abc", "b", "b", "def"]; let dict_array: DictionaryArray = test1 .iter() @@ -5492,7 +4000,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_to_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() @@ -5533,7 +4040,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() @@ -5598,7 +4104,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dict_non_dict_float_nan() { let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)] .into_iter() @@ -5657,7 +4162,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_dict_non_dict_float_nan() { let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN] .into_iter() @@ -5715,7 +4219,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_gt_dyn_gt_eq_dyn_dict_non_dict_float_nan() { let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN] .into_iter() @@ -5773,7 +4276,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_to_boolean_array() { let test1 = vec![Some(true), None, Some(false)]; let test2 = vec![Some(true), None, None, Some(true)]; @@ -5782,7 +4284,7 @@ mod tests { let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2]); let dict_array = DictionaryArray::new(keys, Arc::new(values)); - let array: BooleanArray = test2.iter().collect(); + let array = BooleanArray::from(test2); let result = eq_dyn(&dict_array, &array); assert_eq!( @@ -5810,7 +4312,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_boolean_array() { let test1 = vec![Some(true), None, Some(false)]; let test2 = vec![Some(true), None, None, Some(true)]; @@ -5819,7 +4320,7 @@ mod tests { let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2]); let dict_array = DictionaryArray::new(keys, Arc::new(values)); - let array: BooleanArray = test2.iter().collect(); + let array = BooleanArray::from(test2); let result = lt_dyn(&dict_array, &array); assert_eq!( @@ -5871,7 +4372,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_decimal128() { let values = Decimal128Array::from_iter_values([0, 1, 2, 3, 4, 5]); let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]); @@ -5908,7 +4408,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_non_dict_decimal128() { let array1: Decimal128Array = Decimal128Array::from_iter_values([1, 2, 5, 4, 3, 0]); @@ -5944,7 +4443,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_decimal256() { let values = Decimal256Array::from_iter_values( [0, 1, 2, 3, 4, 5].into_iter().map(i256::from_i128), @@ -5985,7 +4483,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_non_dict_decimal256() { let array1: Decimal256Array = Decimal256Array::from_iter_values( [1, 2, 5, 4, 3, 0].into_iter().map(i256::from_i128), @@ -6291,7 +4788,6 @@ mod tests { } #[test] - #[cfg(not(feature = "simd"))] fn test_floating_zeros() { let a = Float32Array::from(vec![0.0_f32, -0.0]); let b = Float32Array::from(vec![-0.0_f32, 0.0]); @@ -6326,4 +4822,17 @@ mod tests { .to_string() .contains("Could not convert ToType with to_i128")); } + + #[test] + fn test_dictionary_nested_nulls() { + let keys = Int32Array::from(vec![0, 1, 2]); + let v1 = Arc::new(Int32Array::from(vec![Some(0), None, Some(2)])); + let a = DictionaryArray::new(keys.clone(), v1); + let v2 = Arc::new(Int32Array::from(vec![None, Some(0), Some(2)])); + let b = DictionaryArray::new(keys, v2); + + let r = eq_dyn(&a, &b).unwrap(); + assert_eq!(r.null_count(), 2); + assert!(r.is_valid(2)); + } } diff --git a/arrow-ord/src/lib.rs b/arrow-ord/src/lib.rs index 62338c022384..8fe4ecbc05aa 100644 --- a/arrow-ord/src/lib.rs +++ b/arrow-ord/src/lib.rs @@ -43,7 +43,10 @@ //! ``` //! +pub mod cmp; +#[doc(hidden)] pub mod comparison; pub mod ord; pub mod partition; +pub mod rank; pub mod sort; diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index a33ead8ab041..4d6e3bde9152 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -21,114 +21,59 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::ArrowNativeType; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::ArrowError; use std::cmp::Ordering; /// Compare the values at two arbitrary indices in two arrays. pub type DynComparator = Box Ordering + Send + Sync>; -fn compare_primitives( +fn compare_primitive( left: &dyn Array, right: &dyn Array, ) -> DynComparator where T::Native: ArrowNativeTypeOp, { - let left: PrimitiveArray = PrimitiveArray::from(left.to_data()); - let right: PrimitiveArray = PrimitiveArray::from(right.to_data()); + let left = left.as_primitive::().clone(); + let right = right.as_primitive::().clone(); Box::new(move |i, j| left.value(i).compare(right.value(j))) } fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left: BooleanArray = BooleanArray::from(left.to_data()); - let right: BooleanArray = BooleanArray::from(right.to_data()); + let left: BooleanArray = left.as_boolean().clone(); + let right: BooleanArray = right.as_boolean().clone(); Box::new(move |i, j| left.value(i).cmp(&right.value(j))) } -fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left: StringArray = StringArray::from(left.to_data()); - let right: StringArray = StringArray::from(right.to_data()); +fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left.as_bytes::().clone(); + let right = right.as_bytes::().clone(); - Box::new(move |i, j| left.value(i).cmp(right.value(j))) -} - -fn compare_dict_primitive(left: &dyn Array, right: &dyn Array) -> DynComparator -where - K: ArrowDictionaryKeyType, - V: ArrowPrimitiveType, - V::Native: ArrowNativeTypeOp, -{ - let left = left.as_dictionary::(); - let right = right.as_dictionary::(); - - let left_keys: PrimitiveArray = PrimitiveArray::from(left.keys().to_data()); - let right_keys: PrimitiveArray = PrimitiveArray::from(right.keys().to_data()); - let left_values: PrimitiveArray = left.values().to_data().into(); - let right_values: PrimitiveArray = right.values().to_data().into(); - - Box::new(move |i: usize, j: usize| { - let key_left = left_keys.value(i).as_usize(); - let key_right = right_keys.value(j).as_usize(); - let left = left_values.value(key_left); - let right = right_values.value(key_right); - left.compare(right) - }) -} - -fn compare_dict_string(left: &dyn Array, right: &dyn Array) -> DynComparator -where - T: ArrowDictionaryKeyType, -{ - let left = left.as_dictionary::(); - let right = right.as_dictionary::(); - - let left_keys: PrimitiveArray = PrimitiveArray::from(left.keys().to_data()); - let right_keys: PrimitiveArray = PrimitiveArray::from(right.keys().to_data()); - let left_values = StringArray::from(left.values().to_data()); - let right_values = StringArray::from(right.values().to_data()); - - Box::new(move |i: usize, j: usize| { - let key_left = left_keys.value(i).as_usize(); - let key_right = right_keys.value(j).as_usize(); - let left = left_values.value(key_left); - let right = right_values.value(key_right); - left.cmp(right) + Box::new(move |i, j| { + let l: &[u8] = left.value(i).as_ref(); + let r: &[u8] = right.value(j).as_ref(); + l.cmp(r) }) } -fn cmp_dict_primitive( - key_type: &DataType, +fn compare_dict( left: &dyn Array, right: &dyn Array, -) -> Result -where - VT: ArrowPrimitiveType, - VT::Native: ArrowNativeTypeOp, -{ - use DataType::*; - - Ok(match key_type { - UInt8 => compare_dict_primitive::(left, right), - UInt16 => compare_dict_primitive::(left, right), - UInt32 => compare_dict_primitive::(left, right), - UInt64 => compare_dict_primitive::(left, right), - Int8 => compare_dict_primitive::(left, right), - Int16 => compare_dict_primitive::(left, right), - Int32 => compare_dict_primitive::(left, right), - Int64 => compare_dict_primitive::(left, right), - t => { - return Err(ArrowError::InvalidArgumentError(format!( - "Dictionaries do not support keys of type {t:?}" - ))); - } - }) -} +) -> Result { + let left = left.as_dictionary::(); + let right = right.as_dictionary::(); + + let cmp = build_compare(left.values().as_ref(), right.values().as_ref())?; + let left_keys = left.keys().clone(); + let right_keys = right.keys().clone(); -macro_rules! cmp_dict_primitive_helper { - ($t:ty, $key_type_lhs:expr, $left:expr, $right:expr) => { - cmp_dict_primitive::<$t>($key_type_lhs, $left, $right)? - }; + // TODO: Handle value nulls (#2687) + Ok(Box::new(move |i, j| { + let l = left_keys.value(i).as_usize(); + let r = right_keys.value(j).as_usize(); + cmp(l, r) + })) } /// returns a comparison function that compares two values at two different positions @@ -145,7 +90,7 @@ macro_rules! cmp_dict_primitive_helper { /// let cmp = build_compare(&array1, &array2).unwrap(); /// /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) -/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1)); +/// assert_eq!(std::cmp::Ordering::Less, cmp(0, 1)); /// ``` // This is a factory of comparisons. // The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime. @@ -153,134 +98,47 @@ pub fn build_compare( left: &dyn Array, right: &dyn Array, ) -> Result { - use arrow_schema::{DataType::*, IntervalUnit::*, TimeUnit::*}; - Ok(match (left.data_type(), right.data_type()) { - (a, b) if a != b => { - return Err(ArrowError::InvalidArgumentError( - "Can't compare arrays of different types".to_string(), - )); - } - (Boolean, Boolean) => compare_boolean(left, right), - (UInt8, UInt8) => compare_primitives::(left, right), - (UInt16, UInt16) => compare_primitives::(left, right), - (UInt32, UInt32) => compare_primitives::(left, right), - (UInt64, UInt64) => compare_primitives::(left, right), - (Int8, Int8) => compare_primitives::(left, right), - (Int16, Int16) => compare_primitives::(left, right), - (Int32, Int32) => compare_primitives::(left, right), - (Int64, Int64) => compare_primitives::(left, right), - (Float16, Float16) => compare_primitives::(left, right), - (Float32, Float32) => compare_primitives::(left, right), - (Float64, Float64) => compare_primitives::(left, right), - (Decimal128(_, _), Decimal128(_, _)) => { - compare_primitives::(left, right) - } - (Decimal256(_, _), Decimal256(_, _)) => { - compare_primitives::(left, right) - } - (Date32, Date32) => compare_primitives::(left, right), - (Date64, Date64) => compare_primitives::(left, right), - (Time32(Second), Time32(Second)) => { - compare_primitives::(left, right) - } - (Time32(Millisecond), Time32(Millisecond)) => { - compare_primitives::(left, right) - } - (Time64(Microsecond), Time64(Microsecond)) => { - compare_primitives::(left, right) - } - (Time64(Nanosecond), Time64(Nanosecond)) => { - compare_primitives::(left, right) - } - (Timestamp(Second, _), Timestamp(Second, _)) => { - compare_primitives::(left, right) - } - (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => { - compare_primitives::(left, right) - } - (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => { - compare_primitives::(left, right) - } - (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => { - compare_primitives::(left, right) - } - (Interval(YearMonth), Interval(YearMonth)) => { - compare_primitives::(left, right) - } - (Interval(DayTime), Interval(DayTime)) => { - compare_primitives::(left, right) - } - (Interval(MonthDayNano), Interval(MonthDayNano)) => { - compare_primitives::(left, right) - } - (Duration(Second), Duration(Second)) => { - compare_primitives::(left, right) - } - (Duration(Millisecond), Duration(Millisecond)) => { - compare_primitives::(left, right) - } - (Duration(Microsecond), Duration(Microsecond)) => { - compare_primitives::(left, right) - } - (Duration(Nanosecond), Duration(Nanosecond)) => { - compare_primitives::(left, right) - } - (Utf8, Utf8) => compare_string(left, right), - (LargeUtf8, LargeUtf8) => compare_string(left, right), - ( - Dictionary(key_type_lhs, value_type_lhs), - Dictionary(key_type_rhs, value_type_rhs), - ) => { - if key_type_lhs != key_type_rhs || value_type_lhs != value_type_rhs { - return Err(ArrowError::InvalidArgumentError( - "Can't compare arrays of different types".to_string(), - )); - } - - let key_type_lhs = key_type_lhs.as_ref(); - downcast_primitive! { - value_type_lhs.as_ref() => (cmp_dict_primitive_helper, key_type_lhs, left, right), - Utf8 => match key_type_lhs { - UInt8 => compare_dict_string::(left, right), - UInt16 => compare_dict_string::(left, right), - UInt32 => compare_dict_string::(left, right), - UInt64 => compare_dict_string::(left, right), - Int8 => compare_dict_string::(left, right), - Int16 => compare_dict_string::(left, right), - Int32 => compare_dict_string::(left, right), - Int64 => compare_dict_string::(left, right), - lhs => { - return Err(ArrowError::InvalidArgumentError(format!( - "Dictionaries do not support keys of type {lhs:?}" - ))); - } - }, - t => { - return Err(ArrowError::InvalidArgumentError(format!( - "Dictionaries of value data type {t:?} are not supported" - ))); - } - } - } + use arrow_schema::DataType::*; + macro_rules! primitive_helper { + ($t:ty, $left:expr, $right:expr) => { + Ok(compare_primitive::<$t>($left, $right)) + }; + } + downcast_primitive! { + left.data_type(), right.data_type() => (primitive_helper, left, right), + (Boolean, Boolean) => Ok(compare_boolean(left, right)), + (Utf8, Utf8) => Ok(compare_bytes::(left, right)), + (LargeUtf8, LargeUtf8) => Ok(compare_bytes::(left, right)), + (Binary, Binary) => Ok(compare_bytes::(left, right)), + (LargeBinary, LargeBinary) => Ok(compare_bytes::(left, right)), (FixedSizeBinary(_), FixedSizeBinary(_)) => { - let left: FixedSizeBinaryArray = left.to_data().into(); - let right: FixedSizeBinaryArray = right.to_data().into(); - - Box::new(move |i, j| left.value(i).cmp(right.value(j))) - } - (lhs, _) => { - return Err(ArrowError::InvalidArgumentError(format!( - "The data type type {lhs:?} has no natural order" - ))); - } - }) + let left = left.as_fixed_size_binary().clone(); + let right = right.as_fixed_size_binary().clone(); + Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j)))) + }, + (Dictionary(l_key, _), Dictionary(r_key, _)) => { + macro_rules! dict_helper { + ($t:ty, $left:expr, $right:expr) => { + compare_dict::<$t>($left, $right) + }; + } + downcast_integer! { + l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right), + _ => unreachable!() + } + }, + (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs { + true => format!("The data type type {lhs:?} has no natural order"), + false => "Can't compare arrays of different types".to_string(), + })) + } } #[cfg(test)] pub mod tests { use super::*; use arrow_array::{FixedSizeBinaryArray, Float64Array, Int32Array}; - use arrow_buffer::i256; + use arrow_buffer::{i256, OffsetBuffer}; use half::f16; use std::cmp::Ordering; use std::sync::Arc; @@ -292,7 +150,7 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 1)); + assert_eq!(Ordering::Less, cmp(0, 1)); } #[test] @@ -304,7 +162,7 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 0)); } #[test] @@ -323,7 +181,7 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 0)); } #[test] @@ -332,7 +190,7 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 1)); + assert_eq!(Ordering::Less, cmp(0, 1)); } #[test] @@ -341,7 +199,7 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 1)); + assert_eq!(Ordering::Less, cmp(0, 1)); } #[test] @@ -350,8 +208,8 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 1)); - assert_eq!(Ordering::Equal, (cmp)(1, 1)); + assert_eq!(Ordering::Less, cmp(0, 1)); + assert_eq!(Ordering::Equal, cmp(1, 1)); } #[test] @@ -360,8 +218,8 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 1)); - assert_eq!(Ordering::Greater, (cmp)(1, 0)); + assert_eq!(Ordering::Less, cmp(0, 1)); + assert_eq!(Ordering::Greater, cmp(1, 0)); } #[test] @@ -373,8 +231,8 @@ pub mod tests { .unwrap(); let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(1, 0)); - assert_eq!(Ordering::Greater, (cmp)(0, 2)); + assert_eq!(Ordering::Less, cmp(1, 0)); + assert_eq!(Ordering::Greater, cmp(0, 2)); } #[test] @@ -390,8 +248,8 @@ pub mod tests { .unwrap(); let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(1, 0)); - assert_eq!(Ordering::Greater, (cmp)(0, 2)); + assert_eq!(Ordering::Less, cmp(1, 0)); + assert_eq!(Ordering::Greater, cmp(0, 2)); } #[test] @@ -401,9 +259,9 @@ pub mod tests { let cmp = build_compare(&array, &array).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 1)); - assert_eq!(Ordering::Equal, (cmp)(3, 4)); - assert_eq!(Ordering::Greater, (cmp)(2, 3)); + assert_eq!(Ordering::Less, cmp(0, 1)); + assert_eq!(Ordering::Equal, cmp(3, 4)); + assert_eq!(Ordering::Greater, cmp(2, 3)); } #[test] @@ -415,9 +273,9 @@ pub mod tests { let cmp = build_compare(&a1, &a2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Equal, (cmp)(0, 3)); - assert_eq!(Ordering::Greater, (cmp)(1, 3)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Equal, cmp(0, 3)); + assert_eq!(Ordering::Greater, cmp(1, 3)); } #[test] @@ -432,11 +290,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Less, (cmp)(0, 3)); - assert_eq!(Ordering::Equal, (cmp)(3, 3)); - assert_eq!(Ordering::Greater, (cmp)(3, 1)); - assert_eq!(Ordering::Greater, (cmp)(3, 2)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 3)); + assert_eq!(Ordering::Equal, cmp(3, 3)); + assert_eq!(Ordering::Greater, cmp(3, 1)); + assert_eq!(Ordering::Greater, cmp(3, 2)); } #[test] @@ -451,11 +309,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Less, (cmp)(0, 3)); - assert_eq!(Ordering::Equal, (cmp)(3, 3)); - assert_eq!(Ordering::Greater, (cmp)(3, 1)); - assert_eq!(Ordering::Greater, (cmp)(3, 2)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 3)); + assert_eq!(Ordering::Equal, cmp(3, 3)); + assert_eq!(Ordering::Greater, cmp(3, 1)); + assert_eq!(Ordering::Greater, cmp(3, 2)); } #[test] @@ -470,11 +328,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Less, (cmp)(0, 3)); - assert_eq!(Ordering::Equal, (cmp)(3, 3)); - assert_eq!(Ordering::Greater, (cmp)(3, 1)); - assert_eq!(Ordering::Greater, (cmp)(3, 2)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 3)); + assert_eq!(Ordering::Equal, cmp(3, 3)); + assert_eq!(Ordering::Greater, cmp(3, 1)); + assert_eq!(Ordering::Greater, cmp(3, 2)); } #[test] @@ -489,11 +347,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Less, (cmp)(0, 3)); - assert_eq!(Ordering::Equal, (cmp)(3, 3)); - assert_eq!(Ordering::Greater, (cmp)(3, 1)); - assert_eq!(Ordering::Greater, (cmp)(3, 2)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 3)); + assert_eq!(Ordering::Equal, cmp(3, 3)); + assert_eq!(Ordering::Greater, cmp(3, 1)); + assert_eq!(Ordering::Greater, cmp(3, 2)); } #[test] @@ -508,11 +366,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Less, (cmp)(0, 3)); - assert_eq!(Ordering::Equal, (cmp)(3, 3)); - assert_eq!(Ordering::Greater, (cmp)(3, 1)); - assert_eq!(Ordering::Greater, (cmp)(3, 2)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 3)); + assert_eq!(Ordering::Equal, cmp(3, 3)); + assert_eq!(Ordering::Greater, cmp(3, 1)); + assert_eq!(Ordering::Greater, cmp(3, 2)); } #[test] @@ -527,11 +385,11 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Less, (cmp)(0, 3)); - assert_eq!(Ordering::Equal, (cmp)(3, 3)); - assert_eq!(Ordering::Greater, (cmp)(3, 1)); - assert_eq!(Ordering::Greater, (cmp)(3, 2)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 3)); + assert_eq!(Ordering::Equal, cmp(3, 3)); + assert_eq!(Ordering::Greater, cmp(3, 1)); + assert_eq!(Ordering::Greater, cmp(3, 2)); } #[test] @@ -556,10 +414,28 @@ pub mod tests { let cmp = build_compare(&array1, &array2).unwrap(); - assert_eq!(Ordering::Less, (cmp)(0, 0)); - assert_eq!(Ordering::Less, (cmp)(0, 3)); - assert_eq!(Ordering::Equal, (cmp)(3, 3)); - assert_eq!(Ordering::Greater, (cmp)(3, 1)); - assert_eq!(Ordering::Greater, (cmp)(3, 2)); + assert_eq!(Ordering::Less, cmp(0, 0)); + assert_eq!(Ordering::Less, cmp(0, 3)); + assert_eq!(Ordering::Equal, cmp(3, 3)); + assert_eq!(Ordering::Greater, cmp(3, 1)); + assert_eq!(Ordering::Greater, cmp(3, 2)); + } + + fn test_bytes_impl() { + let offsets = OffsetBuffer::from_lengths([3, 3, 1]); + let a = GenericByteArray::::new(offsets, b"abcdefa".into(), None); + let cmp = build_compare(&a, &a).unwrap(); + + assert_eq!(Ordering::Less, cmp(0, 1)); + assert_eq!(Ordering::Greater, cmp(0, 2)); + assert_eq!(Ordering::Equal, cmp(1, 1)); + } + + #[test] + fn test_bytes() { + test_bytes_impl::(); + test_bytes_impl::(); + test_bytes_impl::(); + test_bytes_impl::(); } } diff --git a/arrow-ord/src/partition.rs b/arrow-ord/src/partition.rs index 26a030beb35e..80b25ee2afba 100644 --- a/arrow-ord/src/partition.rs +++ b/arrow-ord/src/partition.rs @@ -17,362 +17,302 @@ //! Defines partition kernel for `ArrayRef` -use crate::sort::{LexicographicalComparator, SortColumn}; -use arrow_schema::ArrowError; -use std::cmp::Ordering; use std::ops::Range; -/// Given a list of already sorted columns, find partition ranges that would partition -/// lexicographically equal values across columns. -/// -/// Here LexicographicalComparator is used in conjunction with binary -/// search so the columns *MUST* be pre-sorted already. -/// -/// The returned vec would be of size k where k is cardinality of the sorted values; Consecutive -/// values will be connected: (a, b) and (b, c), where start = 0 and end = n for the first and last -/// range. -pub fn lexicographical_partition_ranges( - columns: &[SortColumn], -) -> Result> + '_, ArrowError> { - LexicographicalPartitionIterator::try_new(columns) -} +use arrow_array::{Array, ArrayRef}; +use arrow_buffer::BooleanBuffer; +use arrow_schema::ArrowError; -struct LexicographicalPartitionIterator<'a> { - comparator: LexicographicalComparator<'a>, - num_rows: usize, - previous_partition_point: usize, - partition_point: usize, -} +use crate::cmp::distinct; +use crate::sort::SortColumn; -impl<'a> LexicographicalPartitionIterator<'a> { - fn try_new( - columns: &'a [SortColumn], - ) -> Result { - if columns.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "Sort requires at least one column".to_string(), - )); - } - let num_rows = columns[0].values.len(); - if columns.iter().any(|item| item.values.len() != num_rows) { - return Err(ArrowError::ComputeError( - "Lexical sort columns have different row counts".to_string(), - )); +/// A computed set of partitions, see [`partition`] +#[derive(Debug, Clone)] +pub struct Partitions(Option); + +impl Partitions { + /// Returns the range of each partition + /// + /// Consecutive ranges will be contiguous: i.e [`(a, b)` and `(b, c)`], and + /// `start = 0` and `end = self.len()` for the first and last range respectively + pub fn ranges(&self) -> Vec> { + let boundaries = match &self.0 { + Some(boundaries) => boundaries, + None => return vec![], }; - let comparator = LexicographicalComparator::try_new(columns)?; - Ok(LexicographicalPartitionIterator { - comparator, - num_rows, - previous_partition_point: 0, - partition_point: 0, - }) + let mut out = vec![]; + let mut current = 0; + for idx in boundaries.set_indices() { + let t = current; + current = idx + 1; + out.push(t..current) + } + let last = boundaries.len() + 1; + if current != last { + out.push(current..last) + } + out } -} -/// Returns the next partition point of the range `start..end` according to the given comparator. -/// The return value is the index of the first element of the second partition, -/// and is guaranteed to be between `start..=end` (inclusive). -/// -/// The values corresponding to those indices are assumed to be partitioned according to the given comparator. -/// -/// Exponential search is to remedy for the case when array size and cardinality are both large. -/// In these cases the partition point would be near the beginning of the range and -/// plain binary search would be doing some unnecessary iterations on each call. -/// -/// see -#[inline] -fn exponential_search_next_partition_point( - start: usize, - end: usize, - comparator: &LexicographicalComparator<'_>, -) -> usize { - let target = start; - let mut bound = 1; - while bound + start < end - && comparator.compare(bound + start, target) != Ordering::Greater - { - bound *= 2; + /// Returns the number of partitions + pub fn len(&self) -> usize { + match &self.0 { + Some(b) => b.count_set_bits() + 1, + None => 0, + } } - // invariant after while loop: - // (start + bound / 2) <= target < min(end, start + bound + 1) - // where <= and < are defined by the comparator; - // note here we have right = min(end, start + bound + 1) because (start + bound) might - // actually be considered and must be included. - partition_point(start + bound / 2, end.min(start + bound + 1), |idx| { - comparator.compare(idx, target) != Ordering::Greater - }) + /// Returns true if this contains no partitions + pub fn is_empty(&self) -> bool { + self.0.is_none() + } } -/// Returns the partition point of the range `start..end` according to the given predicate. -/// The return value is the index of the first element of the second partition, -/// and is guaranteed to be between `start..=end` (inclusive). +/// Given a list of lexicographically sorted columns, computes the [`Partitions`], +/// where a partition consists of the set of consecutive rows with equal values +/// +/// Returns an error if no columns are specified or all columns do not +/// have the same number of rows. +/// +/// # Example: +/// +/// For example, given columns `x`, `y` and `z`, calling +/// `lexicographical_partition_ranges(values, (x, y))` will divide the +/// rows into ranges where the values of `(x, y)` are equal: /// -/// The algorithm is similar to a binary search. +/// ```text +/// ┌ ─ ┬───┬ ─ ─┌───┐─ ─ ┬───┬ ─ ─ ┐ +/// │ 1 │ │ 1 │ │ A │ Range: 0..1 (x=1, y=1) +/// ├ ─ ┼───┼ ─ ─├───┤─ ─ ┼───┼ ─ ─ ┤ +/// │ 1 │ │ 2 │ │ B │ +/// │ ├───┤ ├───┤ ├───┤ │ +/// │ 1 │ │ 2 │ │ C │ Range: 1..4 (x=1, y=2) +/// │ ├───┤ ├───┤ ├───┤ │ +/// │ 1 │ │ 2 │ │ D │ +/// ├ ─ ┼───┼ ─ ─├───┤─ ─ ┼───┼ ─ ─ ┤ +/// │ 2 │ │ 1 │ │ E │ Range: 4..5 (x=2, y=1) +/// ├ ─ ┼───┼ ─ ─├───┤─ ─ ┼───┼ ─ ─ ┤ +/// │ 3 │ │ 1 │ │ F │ Range: 5..6 (x=3, y=1) +/// └ ─ ┴───┴ ─ ─└───┘─ ─ ┴───┴ ─ ─ ┘ /// -/// The values corresponding to those indices are assumed to be partitioned according to the given predicate. +/// x y z partition(&[x, y]) +/// ``` /// -/// See [`slice::partition_point`] -#[inline] -fn partition_point bool>(start: usize, end: usize, pred: P) -> usize { - let mut left = start; - let mut right = end; - let mut size = right - left; - while left < right { - let mid = left + size / 2; +/// # Example Code +/// +/// ``` +/// # use std::{sync::Arc, ops::Range}; +/// # use arrow_array::{RecordBatch, Int64Array, StringArray, ArrayRef}; +/// # use arrow_ord::sort::{SortColumn, SortOptions}; +/// # use arrow_ord::partition::partition; +/// let batch = RecordBatch::try_from_iter(vec![ +/// ("x", Arc::new(Int64Array::from(vec![1, 1, 1, 1, 2, 3])) as ArrayRef), +/// ("y", Arc::new(Int64Array::from(vec![1, 2, 2, 2, 1, 1])) as ArrayRef), +/// ("z", Arc::new(StringArray::from(vec!["A", "B", "C", "D", "E", "F"])) as ArrayRef), +/// ]).unwrap(); +/// +/// // Partition on first two columns +/// let ranges = partition(&batch.columns()[..2]).unwrap().ranges(); +/// +/// let expected = vec![ +/// (0..1), +/// (1..4), +/// (4..5), +/// (5..6), +/// ]; +/// +/// assert_eq!(ranges, expected); +/// ``` +pub fn partition(columns: &[ArrayRef]) -> Result { + if columns.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Partition requires at least one column".to_string(), + )); + } + let num_rows = columns[0].len(); + if columns.iter().any(|item| item.len() != num_rows) { + return Err(ArrowError::InvalidArgumentError( + "Partition columns have different row counts".to_string(), + )); + }; - let less = pred(mid); + match num_rows { + 0 => return Ok(Partitions(None)), + 1 => return Ok(Partitions(Some(BooleanBuffer::new_unset(0)))), + _ => {} + } - if less { - left = mid + 1; - } else { - right = mid; - } + let acc = find_boundaries(&columns[0])?; + let acc = columns + .iter() + .skip(1) + .try_fold(acc, |acc, c| find_boundaries(c.as_ref()).map(|b| &acc | &b))?; - size = right - left; - } - left + Ok(Partitions(Some(acc))) } -impl<'a> Iterator for LexicographicalPartitionIterator<'a> { - type Item = Range; +/// Returns a mask with bits set whenever the value or nullability changes +fn find_boundaries(v: &dyn Array) -> Result { + let slice_len = v.len() - 1; + let v1 = v.slice(0, slice_len); + let v2 = v.slice(1, slice_len); + Ok(distinct(&v1, &v2)?.values().clone()) +} - fn next(&mut self) -> Option { - if self.partition_point < self.num_rows { - // invariant: - // in the range [0..previous_partition_point] all values are <= the value at [previous_partition_point] - // so in order to save time we can do binary search on the range [previous_partition_point..num_rows] - // and find the index where any value is greater than the value at [previous_partition_point] - self.partition_point = exponential_search_next_partition_point( - self.partition_point, - self.num_rows, - &self.comparator, - ); - let start = self.previous_partition_point; - let end = self.partition_point; - self.previous_partition_point = self.partition_point; - Some(Range { start, end }) - } else { - None - } - } +/// Given a list of already sorted columns, find partition ranges that would partition +/// lexicographically equal values across columns. +/// +/// The returned vec would be of size k where k is cardinality of the sorted values; Consecutive +/// values will be connected: (a, b) and (b, c), where start = 0 and end = n for the first and last +/// range. +#[deprecated(note = "Use partition")] +pub fn lexicographical_partition_ranges( + columns: &[SortColumn], +) -> Result> + '_, ArrowError> { + let cols: Vec<_> = columns.iter().map(|x| x.values.clone()).collect(); + Ok(partition(&cols)?.ranges().into_iter()) } #[cfg(test)] mod tests { - use super::*; - use crate::sort::SortOptions; + use std::sync::Arc; + use arrow_array::*; use arrow_schema::DataType; - use std::sync::Arc; - #[test] - fn test_partition_point() { - let input = &[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4]; - { - let median = input[input.len() / 2]; - assert_eq!( - 9, - partition_point(0, input.len(), |i: usize| input[i].cmp(&median) - != Ordering::Greater) - ); - } - { - let search = input[9]; - assert_eq!( - 12, - partition_point(9, input.len(), |i: usize| input[i].cmp(&search) - != Ordering::Greater) - ); - } - { - let search = input[0]; - assert_eq!( - 3, - partition_point(0, 9, |i: usize| input[i].cmp(&search) - != Ordering::Greater) - ); - } - let input = &[1, 2, 2, 2, 2, 2, 2, 2, 9]; - { - let search = input[5]; - assert_eq!( - 8, - partition_point(5, 9, |i: usize| input[i].cmp(&search) - != Ordering::Greater) - ); - } - } + use super::*; #[test] - fn test_lexicographical_partition_ranges_empty() { - let input = vec![]; - assert!( - lexicographical_partition_ranges(&input).is_err(), - "lexicographical_partition_ranges should reject columns with empty rows" + fn test_partition_empty() { + let err = partition(&[]).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Partition requires at least one column" ); } #[test] - fn test_lexicographical_partition_ranges_unaligned_rows() { + fn test_partition_unaligned_rows() { let input = vec![ - SortColumn { - values: Arc::new(Int64Array::from(vec![None, Some(-1)])) as ArrayRef, - options: None, - }, - SortColumn { - values: Arc::new(StringArray::from(vec![Some("foo")])) as ArrayRef, - options: None, - }, + Arc::new(Int64Array::from(vec![None, Some(-1)])) as _, + Arc::new(StringArray::from(vec![Some("foo")])) as _, ]; - assert!( - lexicographical_partition_ranges(&input).is_err(), - "lexicographical_partition_ranges should reject columns with different row counts" - ); + let err = partition(&input).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Partition columns have different row counts" + ) } #[test] - fn test_lexicographical_partition_single_column() { - let input = vec![SortColumn { - values: Arc::new(Int64Array::from(vec![1, 2, 2, 2, 2, 2, 2, 2, 9])) - as ArrayRef, - options: Some(SortOptions { - descending: false, - nulls_first: true, - }), - }]; - let results = lexicographical_partition_ranges(&input).unwrap(); + fn test_partition_small() { + let results = partition(&[ + Arc::new(Int32Array::new(vec![].into(), None)) as _, + Arc::new(Int32Array::new(vec![].into(), None)) as _, + Arc::new(Int32Array::new(vec![].into(), None)) as _, + ]) + .unwrap(); + assert_eq!(results.len(), 0); + assert!(results.is_empty()); + + let results = partition(&[ + Arc::new(Int32Array::from(vec![1])) as _, + Arc::new(Int32Array::from(vec![1])) as _, + ]) + .unwrap() + .ranges(); + assert_eq!(results.len(), 1); + assert_eq!(results[0], 0..1); + } + + #[test] + fn test_partition_single_column() { + let a = Int64Array::from(vec![1, 2, 2, 2, 2, 2, 2, 2, 9]); + let input = vec![Arc::new(a) as _]; assert_eq!( - vec![(0_usize..1_usize), (1_usize..8_usize), (8_usize..9_usize)], - results.collect::>() + partition(&input).unwrap().ranges(), + vec![(0..1), (1..8), (8..9)], ); } #[test] - fn test_lexicographical_partition_all_equal_values() { - let input = vec![SortColumn { - values: Arc::new(Int64Array::from_value(1, 1000)) as ArrayRef, - options: Some(SortOptions { - descending: false, - nulls_first: true, - }), - }]; + fn test_partition_all_equal_values() { + let a = Int64Array::from_value(1, 1000); + let input = vec![Arc::new(a) as _]; + assert_eq!(partition(&input).unwrap().ranges(), vec![(0..1000)]); + } - let results = lexicographical_partition_ranges(&input).unwrap(); - assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); + #[test] + fn test_partition_all_null_values() { + let input = vec![ + new_null_array(&DataType::Int8, 1000), + new_null_array(&DataType::UInt16, 1000), + ]; + assert_eq!(partition(&input).unwrap().ranges(), vec![(0..1000)]); } #[test] - fn test_lexicographical_partition_all_null_values() { + fn test_partition_unique_column_1() { let input = vec![ - SortColumn { - values: new_null_array(&DataType::Int8, 1000), - options: Some(SortOptions { - descending: false, - nulls_first: true, - }), - }, - SortColumn { - values: new_null_array(&DataType::UInt16, 1000), - options: Some(SortOptions { - descending: false, - nulls_first: false, - }), - }, + Arc::new(Int64Array::from(vec![None, Some(-1)])) as _, + Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as _, ]; - let results = lexicographical_partition_ranges(&input).unwrap(); - assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); + assert_eq!(partition(&input).unwrap().ranges(), vec![(0..1), (1..2)],); } #[test] - fn test_lexicographical_partition_unique_column_1() { + fn test_partition_unique_column_2() { let input = vec![ - SortColumn { - values: Arc::new(Int64Array::from(vec![None, Some(-1)])) as ArrayRef, - options: Some(SortOptions { - descending: false, - nulls_first: true, - }), - }, - SortColumn { - values: Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) - as ArrayRef, - options: Some(SortOptions { - descending: true, - nulls_first: true, - }), - }, + Arc::new(Int64Array::from(vec![None, Some(-1), Some(-1)])) as _, + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + Some("apple"), + ])) as _, ]; - let results = lexicographical_partition_ranges(&input).unwrap(); assert_eq!( - vec![(0_usize..1_usize), (1_usize..2_usize)], - results.collect::>() + partition(&input).unwrap().ranges(), + vec![(0..1), (1..2), (2..3),], ); } #[test] - fn test_lexicographical_partition_unique_column_2() { + fn test_partition_non_unique_column_1() { let input = vec![ - SortColumn { - values: Arc::new(Int64Array::from(vec![None, Some(-1), Some(-1)])) - as ArrayRef, - options: Some(SortOptions { - descending: false, - nulls_first: true, - }), - }, - SortColumn { - values: Arc::new(StringArray::from(vec![ - Some("foo"), - Some("bar"), - Some("apple"), - ])) as ArrayRef, - options: Some(SortOptions { - descending: true, - nulls_first: true, - }), - }, + Arc::new(Int64Array::from(vec![None, Some(-1), Some(-1), Some(1)])) as _, + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + Some("bar"), + Some("bar"), + ])) as _, ]; - let results = lexicographical_partition_ranges(&input).unwrap(); assert_eq!( - vec![(0_usize..1_usize), (1_usize..2_usize), (2_usize..3_usize),], - results.collect::>() + partition(&input).unwrap().ranges(), + vec![(0..1), (1..3), (3..4),], ); } #[test] - fn test_lexicographical_partition_non_unique_column_1() { + fn test_partition_masked_nulls() { let input = vec![ - SortColumn { - values: Arc::new(Int64Array::from(vec![ - None, - Some(-1), - Some(-1), - Some(1), - ])) as ArrayRef, - options: Some(SortOptions { - descending: false, - nulls_first: true, - }), - }, - SortColumn { - values: Arc::new(StringArray::from(vec![ - Some("foo"), - Some("bar"), - Some("bar"), - Some("bar"), - ])) as ArrayRef, - options: Some(SortOptions { - descending: true, - nulls_first: true, - }), - }, + Arc::new(Int64Array::new(vec![1; 9].into(), None)) as _, + Arc::new(Int64Array::new( + vec![1, 1, 2, 2, 2, 3, 3, 3, 3].into(), + Some( + vec![false, true, true, true, true, false, false, true, false].into(), + ), + )) as _, + Arc::new(Int64Array::new( + vec![1, 1, 2, 2, 2, 2, 2, 3, 7].into(), + Some(vec![true, true, true, true, false, true, true, true, false].into()), + )) as _, ]; - let results = lexicographical_partition_ranges(&input).unwrap(); + assert_eq!( - vec![(0_usize..1_usize), (1_usize..3_usize), (3_usize..4_usize),], - results.collect::>() + partition(&input).unwrap().ranges(), + vec![(0..1), (1..2), (2..4), (4..5), (5..7), (7..8), (8..9)], ); } } diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs new file mode 100644 index 000000000000..1e79156a71a3 --- /dev/null +++ b/arrow-ord/src/rank.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::*; +use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, GenericByteArray}; +use arrow_buffer::NullBuffer; +use arrow_schema::{ArrowError, DataType, SortOptions}; +use std::cmp::Ordering; + +/// Assigns a rank to each value in `array` based on its position in the sorted order +/// +/// Where values are equal, they will be assigned the highest of their ranks, +/// leaving gaps in the overall rank assignment +/// +/// ``` +/// # use arrow_array::StringArray; +/// # use arrow_ord::rank::rank; +/// let array = StringArray::from(vec![Some("foo"), None, Some("foo"), None, Some("bar")]); +/// let ranks = rank(&array, None).unwrap(); +/// assert_eq!(ranks, &[5, 2, 5, 2, 3]); +/// ``` +pub fn rank( + array: &dyn Array, + options: Option, +) -> Result, ArrowError> { + let options = options.unwrap_or_default(); + let ranks = downcast_primitive_array! { + array => primitive_rank(array.values(), array.nulls(), options), + DataType::Utf8 => bytes_rank(array.as_bytes::(), options), + DataType::LargeUtf8 => bytes_rank(array.as_bytes::(), options), + DataType::Binary => bytes_rank(array.as_bytes::(), options), + DataType::LargeBinary => bytes_rank(array.as_bytes::(), options), + d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank"))) + }; + Ok(ranks) +} + +#[inline(never)] +fn primitive_rank( + values: &[T], + nulls: Option<&NullBuffer>, + options: SortOptions, +) -> Vec { + let len: u32 = values.len().try_into().unwrap(); + let to_sort = match nulls.filter(|n| n.null_count() > 0) { + Some(n) => n + .valid_indices() + .map(|idx| (values[idx], idx as u32)) + .collect(), + None => values.iter().copied().zip(0..len).collect(), + }; + rank_impl(values.len(), to_sort, options, T::compare, T::is_eq) +} + +#[inline(never)] +fn bytes_rank( + array: &GenericByteArray, + options: SortOptions, +) -> Vec { + let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) { + Some(n) => n + .valid_indices() + .map(|idx| (array.value(idx).as_ref(), idx as u32)) + .collect(), + None => (0..array.len()) + .map(|idx| (array.value(idx).as_ref(), idx as u32)) + .collect(), + }; + rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq) +} + +fn rank_impl( + len: usize, + mut valid: Vec<(T, u32)>, + options: SortOptions, + compare: C, + eq: E, +) -> Vec +where + T: Copy, + C: Fn(T, T) -> Ordering, + E: Fn(T, T) -> bool, +{ + // We can use an unstable sort as we combine equal values later + valid.sort_unstable_by(|a, b| compare(a.0, b.0)); + if options.descending { + valid.reverse(); + } + + let (mut valid_rank, null_rank) = match options.nulls_first { + true => (len as u32, (len - valid.len()) as u32), + false => (valid.len() as u32, len as u32), + }; + + let mut out: Vec<_> = vec![null_rank; len]; + if let Some(v) = valid.last() { + out[v.1 as usize] = valid_rank; + } + + let mut count = 1; // Number of values in rank + for w in valid.windows(2).rev() { + match eq(w[0].0, w[1].0) { + true => { + count += 1; + out[w[0].1 as usize] = valid_rank; + } + false => { + valid_rank -= count; + count = 1; + out[w[0].1 as usize] = valid_rank + } + } + } + + out +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::*; + + #[test] + fn test_primitive() { + let descending = SortOptions { + descending: true, + nulls_first: true, + }; + + let nulls_last = SortOptions { + descending: false, + nulls_first: false, + }; + + let nulls_last_descending = SortOptions { + descending: true, + nulls_first: false, + }; + + let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]); + let res = rank(&a, None).unwrap(); + assert_eq!(res, &[3, 3, 1, 5, 5, 6]); + + let res = rank(&a, Some(descending)).unwrap(); + assert_eq!(res, &[6, 6, 1, 4, 4, 2]); + + let res = rank(&a, Some(nulls_last)).unwrap(); + assert_eq!(res, &[2, 2, 6, 4, 4, 5]); + + let res = rank(&a, Some(nulls_last_descending)).unwrap(); + assert_eq!(res, &[5, 5, 6, 3, 3, 1]); + + // Test with non-zero null values + let nulls = NullBuffer::from(vec![true, true, false, true, false, false]); + let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls)); + let res = rank(&a, None).unwrap(); + assert_eq!(res, &[4, 6, 3, 6, 3, 3]); + } + + #[test] + fn test_bytes() { + let v = vec!["foo", "fo", "bar", "bar"]; + let values = StringArray::from(v.clone()); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[4, 3, 2, 2]); + + let values = LargeStringArray::from(v.clone()); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[4, 3, 2, 2]); + + let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]]; + let values = LargeBinaryArray::from(v.clone()); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[3, 1, 4, 3]); + + let values = BinaryArray::from(v); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[3, 1, 4, 3]); + } +} diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 1d96532598ca..a477d6c261b3 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -22,14 +22,15 @@ use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::{ArrowNativeType, MutableBuffer, NullBuffer}; -use arrow_data::ArrayData; +use arrow_buffer::BooleanBufferBuilder; +use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_data::ArrayDataBuilder; -use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; +use arrow_schema::{ArrowError, DataType}; use arrow_select::take::take; use std::cmp::Ordering; use std::sync::Arc; +use crate::rank::rank; pub use arrow_schema::SortOptions; /// Sort the `ArrayRef` using `SortOptions`. @@ -57,11 +58,74 @@ pub fn sort( values: &dyn Array, options: Option, ) -> Result { - if let DataType::RunEndEncoded(_, _) = values.data_type() { - return sort_run(values, options, None); + downcast_primitive_array!( + values => sort_native_type(values, options), + DataType::RunEndEncoded(_, _) => sort_run(values, options, None), + _ => { + let indices = sort_to_indices(values, options, None)?; + take(values, &indices, None) + } + ) +} + +fn sort_native_type( + primitive_values: &PrimitiveArray, + options: Option, +) -> Result +where + T: ArrowPrimitiveType, +{ + let sort_options = options.unwrap_or_default(); + + let mut mutable_buffer = vec![T::default_value(); primitive_values.len()]; + let mutable_slice = &mut mutable_buffer; + + let input_values = primitive_values.values().as_ref(); + + let nulls_count = primitive_values.null_count(); + let valid_count = primitive_values.len() - nulls_count; + + let null_bit_buffer = match nulls_count > 0 { + true => { + let mut validity_buffer = BooleanBufferBuilder::new(primitive_values.len()); + if sort_options.nulls_first { + validity_buffer.append_n(nulls_count, false); + validity_buffer.append_n(valid_count, true); + } else { + validity_buffer.append_n(valid_count, true); + validity_buffer.append_n(nulls_count, false); + } + Some(validity_buffer.finish().into()) + } + false => None, + }; + + if let Some(nulls) = primitive_values.nulls().filter(|n| n.null_count() > 0) { + let values_slice = match sort_options.nulls_first { + true => &mut mutable_slice[nulls_count..], + false => &mut mutable_slice[..valid_count], + }; + + for (write_index, index) in nulls.valid_indices().enumerate() { + values_slice[write_index] = primitive_values.value(index); + } + + values_slice.sort_unstable_by(|a, b| a.compare(*b)); + if sort_options.descending { + values_slice.reverse(); + } + } else { + mutable_slice.copy_from_slice(input_values); + mutable_slice.sort_unstable_by(|a, b| a.compare(*b)); + if sort_options.descending { + mutable_slice.reverse(); + } } - let indices = sort_to_indices(values, options, None)?; - take(values, &indices, None) + + Ok(Arc::new( + PrimitiveArray::::new(mutable_buffer.into(), null_bit_buffer) + .with_data_type(primitive_values.data_type().clone()), + )) } /// Sort the `ArrayRef` partially. @@ -117,13 +181,6 @@ where } } -fn cmp(l: T, r: T) -> Ordering -where - T: Ord, -{ - l.cmp(&r) -} - // partition indices into valid and null indices fn partition_validity(array: &dyn Array) -> (Vec, Vec) { match array.null_count() { @@ -140,223 +197,33 @@ fn partition_validity(array: &dyn Array) -> (Vec, Vec) { /// For floating point arrays any NaN values are considered to be greater than any other non-null value. /// `limit` is an option for [partial_sort]. pub fn sort_to_indices( - values: &dyn Array, + array: &dyn Array, options: Option, limit: Option, ) -> Result { let options = options.unwrap_or_default(); - let (v, n) = partition_validity(values); - - Ok(match values.data_type() { - DataType::Decimal128(_, _) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Decimal256(_, _) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Boolean => sort_boolean(values, v, n, &options, limit), - DataType::Int8 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Int16 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Int32 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Int64 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt8 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt16 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt32 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::UInt64 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Float16 => sort_primitive::( - values, - v, - n, - |x, y| x.total_cmp(&y), - &options, - limit, - ), - DataType::Float32 => sort_primitive::( - values, - v, - n, - |x, y| x.total_cmp(&y), - &options, - limit, - ), - DataType::Float64 => sort_primitive::( - values, - v, - n, - |x, y| x.total_cmp(&y), - &options, - limit, - ), - DataType::Date32 => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Date64 => { - sort_primitive::(values, v, n, cmp, &options, limit) + let (v, n) = partition_validity(array); + + Ok(downcast_primitive_array! { + array => sort_primitive(array, v, n, options, limit), + DataType::Boolean => sort_boolean(array.as_boolean(), v, n, options, limit), + DataType::Utf8 => sort_bytes(array.as_string::(), v, n, options, limit), + DataType::LargeUtf8 => sort_bytes(array.as_string::(), v, n, options, limit), + DataType::Binary => sort_bytes(array.as_binary::(), v, n, options, limit), + DataType::LargeBinary => sort_bytes(array.as_binary::(), v, n, options, limit), + DataType::FixedSizeBinary(_) => sort_fixed_size_binary(array.as_fixed_size_binary(), v, n, options, limit), + DataType::List(_) => sort_list(array.as_list::(), v, n, options, limit)?, + DataType::LargeList(_) => sort_list(array.as_list::(), v, n, options, limit)?, + DataType::FixedSizeList(_, _) => sort_fixed_size_list(array.as_fixed_size_list(), v, n, options, limit)?, + DataType::Dictionary(_, _) => downcast_dictionary_array!{ + array => sort_dictionary(array, v, n, options, limit)?, + _ => unreachable!() } - DataType::Time32(TimeUnit::Second) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Time32(TimeUnit::Millisecond) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Time64(TimeUnit::Microsecond) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Time64(TimeUnit::Nanosecond) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Timestamp(TimeUnit::Second, _) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Interval(IntervalUnit::YearMonth) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Interval(IntervalUnit::DayTime) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Duration(TimeUnit::Second) => { - sort_primitive::(values, v, n, cmp, &options, limit) - } - DataType::Duration(TimeUnit::Millisecond) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Duration(TimeUnit::Microsecond) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Duration(TimeUnit::Nanosecond) => { - sort_primitive::( - values, v, n, cmp, &options, limit, - ) - } - DataType::Utf8 => sort_string::(values, v, n, &options, limit), - DataType::LargeUtf8 => sort_string::(values, v, n, &options, limit), - DataType::List(field) | DataType::FixedSizeList(field, _) => { - match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options, limit), - DataType::Int16 => sort_list::(values, v, n, &options, limit), - DataType::Int32 => sort_list::(values, v, n, &options, limit), - DataType::Int64 => sort_list::(values, v, n, &options, limit), - DataType::UInt8 => sort_list::(values, v, n, &options, limit), - DataType::UInt16 => sort_list::(values, v, n, &options, limit), - DataType::UInt32 => sort_list::(values, v, n, &options, limit), - DataType::UInt64 => sort_list::(values, v, n, &options, limit), - DataType::Float16 => sort_list::(values, v, n, &options, limit), - DataType::Float32 => sort_list::(values, v, n, &options, limit), - DataType::Float64 => sort_list::(values, v, n, &options, limit), - t => { - return Err(ArrowError::ComputeError(format!( - "Sort not supported for list type {t:?}" - ))); - } - } - } - DataType::LargeList(field) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options, limit), - DataType::Int16 => sort_list::(values, v, n, &options, limit), - DataType::Int32 => sort_list::(values, v, n, &options, limit), - DataType::Int64 => sort_list::(values, v, n, &options, limit), - DataType::UInt8 => sort_list::(values, v, n, &options, limit), - DataType::UInt16 => sort_list::(values, v, n, &options, limit), - DataType::UInt32 => sort_list::(values, v, n, &options, limit), - DataType::UInt64 => sort_list::(values, v, n, &options, limit), - DataType::Float16 => sort_list::(values, v, n, &options, limit), - DataType::Float32 => sort_list::(values, v, n, &options, limit), - DataType::Float64 => sort_list::(values, v, n, &options, limit), - t => { - return Err(ArrowError::ComputeError(format!( - "Sort not supported for list type {t:?}" - ))); - } - }, - DataType::Dictionary(_, _) => { - let value_null_first = if options.descending { - // When sorting dictionary in descending order, we take inverse of of null ordering - // when sorting the values. Because if `nulls_first` is true, null must be in front - // of non-null value. As we take the sorted order of value array to sort dictionary - // keys, these null values will be treated as smallest ones and be sorted to the end - // of sorted result. So we set `nulls_first` to false when sorting dictionary value - // array to make them as largest ones, then null values will be put at the beginning - // of sorted dictionary result. - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { - descending: false, - nulls_first: value_null_first, - }); - downcast_dictionary_array!( - values => match values.values().data_type() { - dt if DataType::is_primitive(dt) => { - let dict_values = values.values(); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = sorted_rank(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp) - }, - DataType::Utf8 => { - let dict_values = values.values(); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = sorted_rank(&sorted_value_indices); - sort_string_dictionary::<_>(values, &value_indices_map, v, n, &options, limit) - }, - t => return Err(ArrowError::ComputeError(format!( - "Unsupported dictionary value type {t}" - ))), - }, - t => return Err(ArrowError::ComputeError(format!( - "Unsupported datatype {t}" - ))), - ) - } - DataType::Binary | DataType::FixedSizeBinary(_) => { - sort_binary::(values, v, n, &options, limit) - } - DataType::LargeBinary => sort_binary::(values, v, n, &options, limit), DataType::RunEndEncoded(run_ends_field, _) => match run_ends_field.data_type() { - DataType::Int16 => sort_run_to_indices::(values, &options, limit), - DataType::Int32 => sort_run_to_indices::(values, &options, limit), - DataType::Int64 => sort_run_to_indices::(values, &options, limit), + DataType::Int16 => sort_run_to_indices::(array, options, limit), + DataType::Int32 => sort_run_to_indices::(array, options, limit), + DataType::Int64 => sort_run_to_indices::(array, options, limit), dt => { return Err(ArrowError::ComputeError(format!( "Invalid run end data type: {dt}" @@ -371,238 +238,170 @@ pub fn sort_to_indices( }) } -/// Sort boolean values -/// -/// when a limit is present, the sort is pair-comparison based as k-select might be more efficient, -/// when the limit is absent, binary partition is used to speed up (which is linear). -/// -/// TODO maybe partition_validity call can be eliminated in this case -/// and [tri-color sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem) -/// can be used instead. fn sort_boolean( - values: &dyn Array, + values: &BooleanArray, value_indices: Vec, - mut null_indices: Vec, - options: &SortOptions, + null_indices: Vec, + options: SortOptions, limit: Option, ) -> UInt32Array { - let values = values - .as_any() - .downcast_ref::() - .expect("Unable to downcast to boolean array"); - let descending = options.descending; - - let valids_len = value_indices.len(); - let nulls_len = null_indices.len(); - - let mut len = values.len(); - let valids = if let Some(limit) = limit { - len = limit.min(len); - // create tuples that are used for sorting - let mut valids = value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .collect::>(); - - sort_valids(descending, &mut valids, &mut null_indices, len, cmp); - valids - } else { - // when limit is not present, we have a better way than sorting: we can just partition - // the vec into [false..., true...] or [true..., false...] when descending - // TODO when https://github.com/rust-lang/rust/issues/62543 is merged we can use partition_in_place - let (mut a, b): (Vec<_>, Vec<_>) = value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .partition(|(_, value)| *value == descending); - a.extend(b); - if descending { - null_indices.reverse(); - } - a - }; - - let nulls = null_indices; - - // collect results directly into a buffer instead of a vec to avoid another aligned allocation - let result_capacity = len * std::mem::size_of::(); - let mut result = MutableBuffer::new(result_capacity); - // sets len to capacity so we can access the whole buffer as a typed slice - result.resize(result_capacity, 0); - let result_slice: &mut [u32] = result.typed_data_mut(); - - if options.nulls_first { - let size = nulls_len.min(len); - result_slice[0..size].copy_from_slice(&nulls[0..size]); - if nulls_len < len { - insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); - } - } else { - // nulls last - let size = valids.len().min(len); - insert_valid_values(result_slice, 0, &valids[0..size]); - if len > size { - result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); - } - } - - let result_data = unsafe { - ArrayData::new_unchecked( - DataType::UInt32, - len, - Some(0), - None, - 0, - vec![result.into()], - vec![], - ) - }; + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + sort_impl(options, &mut valids, &null_indices, limit, |a, b| a.cmp(&b)).into() +} - UInt32Array::from(result_data) +fn sort_primitive( + values: &PrimitiveArray, + value_indices: Vec, + nulls: Vec, + options: SortOptions, + limit: Option, +) -> UInt32Array { + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + sort_impl(options, &mut valids, &nulls, limit, T::Native::compare).into() } -/// Sort primitive values -fn sort_primitive( - values: &dyn Array, +fn sort_bytes( + values: &GenericByteArray, value_indices: Vec, - null_indices: Vec, - cmp: F, - options: &SortOptions, + nulls: Vec, + options: SortOptions, limit: Option, -) -> UInt32Array -where - T: ArrowPrimitiveType, - T::Native: PartialOrd, - F: Fn(T::Native, T::Native) -> Ordering, -{ - // create tuples that are used for sorting - let valids = { - let values = values.as_primitive::(); - value_indices - .into_iter() - .map(|index| (index, values.value(index as usize))) - .collect::>() - }; - sort_primitive_inner(values.len(), null_indices, cmp, options, limit, valids) +) -> UInt32Array { + let mut valids = value_indices + .into_iter() + .map(|index| (index, values.value(index as usize).as_ref())) + .collect::>(); + + sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into() } -/// Given a list of indices that yield a sorted order, returns the ordered -/// rank of each index -/// -/// e.g. [2, 4, 3, 1, 0] -> [4, 3, 0, 2, 1] -fn sorted_rank(sorted_value_indices: &UInt32Array) -> Vec { - assert_eq!(sorted_value_indices.null_count(), 0); - let sorted_indices = sorted_value_indices.values(); - let mut out: Vec<_> = vec![0_u32; sorted_indices.len()]; - for (ix, val) in sorted_indices.iter().enumerate() { - out[*val as usize] = ix as u32; - } - out +fn sort_fixed_size_binary( + values: &FixedSizeBinaryArray, + value_indices: Vec, + nulls: Vec, + options: SortOptions, + limit: Option, +) -> UInt32Array { + let mut valids = value_indices + .iter() + .copied() + .map(|index| (index, values.value(index as usize))) + .collect::>(); + sort_impl(options, &mut valids, &nulls, limit, Ord::cmp).into() } -/// Sort dictionary encoded primitive values -fn sort_primitive_dictionary( - values: &DictionaryArray, - value_indices_map: &[u32], +fn sort_dictionary( + dict: &DictionaryArray, value_indices: Vec, null_indices: Vec, options: SortOptions, limit: Option, - cmp: F, -) -> UInt32Array -where - K: ArrowDictionaryKeyType, - F: Fn(u32, u32) -> Ordering, -{ - let keys: &PrimitiveArray = values.keys(); +) -> Result { + let keys: &PrimitiveArray = dict.keys(); + let rank = child_rank(dict.values().as_ref(), options)?; // create tuples that are used for sorting - let valids = value_indices + let mut valids = value_indices .into_iter() .map(|index| { let key: K::Native = keys.value(index as usize); - (index, value_indices_map[key.as_usize()]) + (index, rank[key.as_usize()]) }) .collect::>(); - sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, &options, limit, valids) + Ok(sort_impl(options, &mut valids, &null_indices, limit, |a, b| a.cmp(&b)).into()) } -// sort is instantiated a lot so we only compile this inner version for each native type -fn sort_primitive_inner( - value_len: usize, +fn sort_list( + array: &GenericListArray, + value_indices: Vec, null_indices: Vec, - cmp: F, - options: &SortOptions, + options: SortOptions, limit: Option, - mut valids: Vec<(u32, T)>, -) -> UInt32Array -where - T: ArrowNativeType, - T: PartialOrd, - F: Fn(T, T) -> Ordering, -{ - let mut nulls = null_indices; - - let valids_len = valids.len(); - let nulls_len = nulls.len(); - let mut len = value_len; +) -> Result { + let rank = child_rank(array.values().as_ref(), options)?; + let offsets = array.value_offsets(); + let mut valids = value_indices + .into_iter() + .map(|index| { + let end = offsets[index as usize + 1].as_usize(); + let start = offsets[index as usize].as_usize(); + (index, &rank[start..end]) + }) + .collect::>(); + Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into()) +} - if let Some(limit) = limit { - len = limit.min(len); - } +fn sort_fixed_size_list( + array: &FixedSizeListArray, + value_indices: Vec, + null_indices: Vec, + options: SortOptions, + limit: Option, +) -> Result { + let rank = child_rank(array.values().as_ref(), options)?; + let size = array.value_length() as usize; + let mut valids = value_indices + .into_iter() + .map(|index| { + let start = index as usize * size; + (index, &rank[start..start + size]) + }) + .collect::>(); + Ok(sort_impl(options, &mut valids, &null_indices, limit, Ord::cmp).into()) +} - sort_valids(options.descending, &mut valids, &mut nulls, len, cmp); +#[inline(never)] +fn sort_impl( + options: SortOptions, + valids: &mut [(u32, T)], + nulls: &[u32], + limit: Option, + mut cmp: impl FnMut(T, T) -> Ordering, +) -> Vec { + let v_limit = match (limit, options.nulls_first) { + (Some(l), true) => l.saturating_sub(nulls.len()).min(valids.len()), + _ => valids.len(), + }; - // collect results directly into a buffer instead of a vec to avoid another aligned allocation - let result_capacity = len * std::mem::size_of::(); - let mut result = MutableBuffer::new(result_capacity); - // sets len to capacity so we can access the whole buffer as a typed slice - result.resize(result_capacity, 0); - let result_slice: &mut [u32] = result.typed_data_mut(); + match options.descending { + false => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1, b.1)), + true => sort_unstable_by(valids, v_limit, |a, b| cmp(a.1, b.1).reverse()), + } - if options.nulls_first { - let size = nulls_len.min(len); - result_slice[0..size].copy_from_slice(&nulls[0..size]); - if nulls_len < len { - insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); + let len = valids.len() + nulls.len(); + let limit = limit.unwrap_or(len).min(len); + let mut out = Vec::with_capacity(len); + match options.nulls_first { + true => { + out.extend_from_slice(&nulls[..nulls.len().min(limit)]); + let remaining = limit - out.len(); + out.extend(valids.iter().map(|x| x.0).take(remaining)); } - } else { - // nulls last - let size = valids.len().min(len); - insert_valid_values(result_slice, 0, &valids[0..size]); - if len > size { - result_slice[valids_len..].copy_from_slice(&nulls[0..(len - valids_len)]); + false => { + out.extend(valids.iter().map(|x| x.0).take(limit)); + let remaining = limit - out.len(); + out.extend_from_slice(&nulls[..remaining]) } } - - let result_data = unsafe { - ArrayData::new_unchecked( - DataType::UInt32, - len, - Some(0), - None, - 0, - vec![result.into()], - vec![], - ) - }; - - UInt32Array::from(result_data) + out } -// insert valid and nan values in the correct order depending on the descending flag -fn insert_valid_values(result_slice: &mut [u32], offset: usize, valids: &[(u32, T)]) { - let valids_len = valids.len(); - // helper to append the index part of the valid tuples - let append_valids = move |dst_slice: &mut [u32]| { - debug_assert_eq!(dst_slice.len(), valids_len); - dst_slice - .iter_mut() - .zip(valids.iter()) - .for_each(|(dst, src)| *dst = src.0) - }; - - append_valids(&mut result_slice[offset..offset + valids.len()]); +/// Computes the rank for a set of child values +fn child_rank(values: &dyn Array, options: SortOptions) -> Result, ArrowError> { + // If parent sort order is descending we need to invert the value of nulls_first so that + // when the parent is sorted based on the produced ranks, nulls are still ordered correctly + let value_options = Some(SortOptions { + descending: false, + nulls_first: options.nulls_first != options.descending, + }); + rank(values, value_options) } // Sort run array and return sorted run array. @@ -693,7 +492,7 @@ fn sort_run_downcasted( // encoded back to run array. fn sort_run_to_indices( values: &dyn Array, - options: &SortOptions, + options: SortOptions, limit: Option, ) -> UInt32Array { let run_array = values.as_any().downcast_ref::>().unwrap(); @@ -708,7 +507,7 @@ fn sort_run_to_indices( let consume_runs = |run_length, logical_start| { result.extend(logical_start as u32..(logical_start + run_length) as u32); }; - sort_run_inner(run_array, Some(*options), output_len, consume_runs); + sort_run_inner(run_array, Some(options), output_len, consume_runs); UInt32Array::from(result) } @@ -790,223 +589,6 @@ where (values_indices, run_values) } -/// Sort strings -fn sort_string( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array { - let values = values - .as_any() - .downcast_ref::>() - .unwrap(); - - sort_string_helper( - values, - value_indices, - null_indices, - options, - limit, - |array, idx| array.value(idx as usize), - ) -} - -/// Sort dictionary encoded strings -fn sort_string_dictionary( - values: &DictionaryArray, - value_indices_map: &[u32], - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array { - let keys: &PrimitiveArray = values.keys(); - - // create tuples that are used for sorting - let valids = value_indices - .into_iter() - .map(|index| { - let key: T::Native = keys.value(index as usize); - (index, value_indices_map[key.as_usize()]) - }) - .collect::>(); - - sort_primitive_inner::<_, _>(keys.len(), null_indices, cmp, options, limit, valids) -} - -/// shared implementation between dictionary encoded and plain string arrays -#[inline] -fn sort_string_helper<'a, A: Array, F>( - values: &'a A, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, - value_fn: F, -) -> UInt32Array -where - F: Fn(&'a A, u32) -> &str, -{ - let mut valids = value_indices - .into_iter() - .map(|index| (index, value_fn(values, index))) - .collect::>(); - let mut nulls = null_indices; - let descending = options.descending; - let mut len = values.len(); - - if let Some(limit) = limit { - len = limit.min(len); - } - - sort_valids(descending, &mut valids, &mut nulls, len, cmp); - // collect the order of valid tuplies - let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - - if options.nulls_first { - nulls.append(&mut valid_indices); - nulls.truncate(len); - UInt32Array::from(nulls) - } else { - // no need to sort nulls as they are in the correct order already - valid_indices.append(&mut nulls); - valid_indices.truncate(len); - UInt32Array::from(valid_indices) - } -} - -fn sort_list( - values: &dyn Array, - value_indices: Vec, - null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array -where - S: OffsetSizeTrait, -{ - sort_list_inner::(values, value_indices, null_indices, options, limit) -} - -fn sort_list_inner( - values: &dyn Array, - value_indices: Vec, - mut null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array -where - S: OffsetSizeTrait, -{ - let mut valids: Vec<(u32, ArrayRef)> = values - .as_any() - .downcast_ref::() - .map_or_else( - || { - let values = as_generic_list_array::(values); - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - |values| { - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - ); - - let mut len = values.len(); - let descending = options.descending; - - if let Some(limit) = limit { - len = limit.min(len); - } - sort_valids_array(descending, &mut valids, &mut null_indices, len); - - let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - if options.nulls_first { - null_indices.append(&mut valid_indices); - null_indices.truncate(len); - UInt32Array::from(null_indices) - } else { - valid_indices.append(&mut null_indices); - valid_indices.truncate(len); - UInt32Array::from(valid_indices) - } -} - -fn sort_binary( - values: &dyn Array, - value_indices: Vec, - mut null_indices: Vec, - options: &SortOptions, - limit: Option, -) -> UInt32Array -where - S: OffsetSizeTrait, -{ - let mut valids: Vec<(u32, &[u8])> = values - .as_any() - .downcast_ref::() - .map_or_else( - || { - let values = as_generic_binary_array::(values); - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - |values| { - value_indices - .iter() - .copied() - .map(|index| (index, values.value(index as usize))) - .collect() - }, - ); - - let mut len = values.len(); - let descending = options.descending; - - if let Some(limit) = limit { - len = limit.min(len); - } - - sort_valids(descending, &mut valids, &mut null_indices, len, cmp); - - let mut valid_indices: Vec = valids.iter().map(|tuple| tuple.0).collect(); - if options.nulls_first { - null_indices.append(&mut valid_indices); - null_indices.truncate(len); - UInt32Array::from(null_indices) - } else { - valid_indices.append(&mut null_indices); - valid_indices.truncate(len); - UInt32Array::from(valid_indices) - } -} - -/// Compare two `Array`s based on the ordering defined in [build_compare] -fn cmp_array(a: &dyn Array, b: &dyn Array) -> Ordering { - let cmp_op = build_compare(a, b).unwrap(); - let length = a.len().max(b.len()); - - for i in 0..length { - let result = cmp_op(i, i); - if result != Ordering::Equal { - return result; - } - } - Ordering::Equal -} - /// One column to be used in lexicographical sort #[derive(Clone, Debug)] pub struct SortColumn { @@ -1125,23 +707,25 @@ pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) where F: FnMut(&T, &T) -> Ordering, { - let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less); - before.sort_unstable_by(is_less); + if let Some(n) = limit.checked_sub(1) { + let (before, _mid, _after) = v.select_nth_unstable_by(n, &mut is_less); + before.sort_unstable_by(is_less); + } } -type LexicographicalCompareItem<'a> = ( - Option<&'a NullBuffer>, // nulls - DynComparator, // comparator - SortOptions, // sort_option +type LexicographicalCompareItem = ( + Option, // nulls + DynComparator, // comparator + SortOptions, // sort_option ); /// A lexicographical comparator that wraps given array data (columns) and can lexicographically compare data /// at given two indices. The lifetime is the same at the data wrapped. -pub struct LexicographicalComparator<'a> { - compare_items: Vec>, +pub struct LexicographicalComparator { + compare_items: Vec, } -impl LexicographicalComparator<'_> { +impl LexicographicalComparator { /// lexicographically compare values at the wrapped columns with given indices. pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering { for (nulls, comparator, sort_option) in &self.compare_items { @@ -1190,14 +774,14 @@ impl LexicographicalComparator<'_> { /// results with two indices. pub fn try_new( columns: &[SortColumn], - ) -> Result, ArrowError> { + ) -> Result { let compare_items = columns .iter() .map(|column| { // flatten and convert build comparators let values = column.values.as_ref(); Ok(( - values.nulls(), + values.logical_nulls(), build_compare(values, values)?, column.options.unwrap_or_default(), )) @@ -1207,49 +791,12 @@ impl LexicographicalComparator<'_> { } } -fn sort_valids( - descending: bool, - valids: &mut [(u32, T)], - nulls: &mut [U], - len: usize, - mut cmp: impl FnMut(T, T) -> Ordering, -) where - T: ?Sized + Copy, -{ - let valids_len = valids.len(); - if !descending { - sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1)); - } else { - sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1).reverse()); - // reverse to keep a stable ordering - nulls.reverse(); - } -} - -fn sort_valids_array( - descending: bool, - valids: &mut [(u32, ArrayRef)], - nulls: &mut [T], - len: usize, -) { - let valids_len = valids.len(); - if !descending { - sort_unstable_by(valids, len.min(valids_len), |a, b| { - cmp_array(a.1.as_ref(), b.1.as_ref()) - }); - } else { - sort_unstable_by(valids, len.min(valids_len), |a, b| { - cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() - }); - // reverse to keep a stable ordering - nulls.reverse(); - } -} - #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::PrimitiveRunBuilder; + use arrow_array::builder::{ + FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder, + }; use arrow_buffer::i256; use half::f16; use rand::rngs::StdRng; @@ -1733,7 +1280,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 1, 4, 3, 5, 0], // [2, 4, 1, 3, 5, 0] + vec![2, 1, 4, 3, 0, 5], ); test_sort_to_indices_primitive_arrays::( @@ -1743,7 +1290,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 1, 4, 3, 5, 0], + vec![2, 1, 4, 3, 0, 5], ); test_sort_to_indices_primitive_arrays::( @@ -1753,7 +1300,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 1, 4, 3, 5, 0], + vec![2, 1, 4, 3, 0, 5], ); test_sort_to_indices_primitive_arrays::( @@ -1763,7 +1310,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 1, 4, 3, 5, 0], + vec![2, 1, 4, 3, 0, 5], ); test_sort_to_indices_primitive_arrays::( @@ -1780,7 +1327,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 1, 4, 3, 5, 0], + vec![2, 1, 4, 3, 0, 5], ); test_sort_to_indices_primitive_arrays::( @@ -1797,7 +1344,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 1, 4, 3, 5, 0], + vec![2, 1, 4, 3, 0, 5], ); test_sort_to_indices_primitive_arrays::( @@ -1807,7 +1354,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 1, 4, 3, 5, 0], + vec![2, 1, 4, 3, 0, 5], ); // descending, nulls first @@ -1818,7 +1365,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] + vec![0, 5, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] ); test_sort_to_indices_primitive_arrays::( @@ -1828,7 +1375,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] + vec![0, 5, 2, 1, 4, 3], // [5, 0, 2, 4, 1, 3] ); test_sort_to_indices_primitive_arrays::( @@ -1838,7 +1385,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 1, 4, 3], + vec![0, 5, 2, 1, 4, 3], ); test_sort_to_indices_primitive_arrays::( @@ -1848,7 +1395,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 1, 4, 3], + vec![0, 5, 2, 1, 4, 3], ); test_sort_to_indices_primitive_arrays::( @@ -1865,7 +1412,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 1, 4, 3], + vec![0, 5, 2, 1, 4, 3], ); test_sort_to_indices_primitive_arrays::( @@ -1875,7 +1422,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 1, 4, 3], + vec![0, 5, 2, 1, 4, 3], ); test_sort_to_indices_primitive_arrays::( @@ -1885,7 +1432,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 1, 4, 3], + vec![0, 5, 2, 1, 4, 3], ); // valid values less than limit with extra nulls @@ -1962,7 +1509,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 3, 1, 4, 5, 0], + vec![2, 3, 1, 4, 0, 5], ); // boolean, descending, nulls first @@ -1973,7 +1520,7 @@ mod tests { nulls_first: true, }), None, - vec![5, 0, 2, 3, 1, 4], + vec![0, 5, 2, 3, 1, 4], ); // boolean, descending, nulls first, limit @@ -1984,7 +1531,7 @@ mod tests { nulls_first: true, }), Some(3), - vec![5, 0, 2], + vec![0, 5, 2], ); // valid values less than limit with extra nulls @@ -2047,7 +1594,7 @@ mod tests { nulls_first: false, }), None, - vec![1, 5, 3, 2, 4, 6, 0], + vec![1, 5, 3, 2, 4, 0, 6], ); // decimal null_first and descending test_sort_to_indices_decimal128_array( @@ -2057,7 +1604,7 @@ mod tests { nulls_first: true, }), None, - vec![6, 0, 1, 5, 3, 2, 4], + vec![0, 6, 1, 5, 3, 2, 4], ); // decimal null_first test_sort_to_indices_decimal128_array( @@ -2094,7 +1641,7 @@ mod tests { nulls_first: true, }), Some(3), - vec![6, 0, 1], + vec![0, 6, 1], ); // limit null_first test_sort_to_indices_decimal128_array( @@ -2110,48 +1657,46 @@ mod tests { #[test] fn test_sort_indices_decimal256() { + let data = vec![ + None, + Some(i256::from_i128(5)), + Some(i256::from_i128(2)), + Some(i256::from_i128(3)), + Some(i256::from_i128(1)), + Some(i256::from_i128(4)), + None, + ]; + // decimal default test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), None, None, vec![0, 6, 4, 2, 3, 5, 1], ); // decimal descending test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: false, }), None, - vec![1, 5, 3, 2, 4, 6, 0], + vec![1, 5, 3, 2, 4, 0, 6], ); // decimal null_first and descending test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: true, }), None, - vec![6, 0, 1, 5, 3, 2, 4], + vec![0, 6, 1, 5, 3, 2, 4], ); // decimal null_first test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: false, nulls_first: true, @@ -2160,21 +1705,10 @@ mod tests { vec![0, 6, 4, 2, 3, 5, 1], ); // limit - test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - None, - Some(3), - vec![0, 6, 4], - ); + test_sort_to_indices_decimal256_array(data.clone(), None, Some(3), vec![0, 6, 4]); // limit descending test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: false, @@ -2184,23 +1718,17 @@ mod tests { ); // limit descending null_first test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: true, }), Some(3), - vec![6, 0, 1], + vec![0, 6, 1], ); // limit null_first test_sort_to_indices_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data, Some(SortOptions { descending: false, nulls_first: true, @@ -2212,14 +1740,15 @@ mod tests { #[test] fn test_sort_indices_decimal256_max_min() { + let data = vec![ + None, + Some(i256::MIN), + Some(i256::from_i128(1)), + Some(i256::MAX), + Some(i256::from_i128(-1)), + ]; test_sort_to_indices_decimal256_array( - vec![ - None, - Some(i256::MIN), - Some(i256::from_i128(1)), - Some(i256::MAX), - Some(i256::from_i128(-1)), - ], + data.clone(), Some(SortOptions { descending: false, nulls_first: true, @@ -2229,13 +1758,7 @@ mod tests { ); test_sort_to_indices_decimal256_array( - vec![ - None, - Some(i256::MIN), - Some(i256::from_i128(1)), - Some(i256::MAX), - Some(i256::from_i128(-1)), - ], + data.clone(), Some(SortOptions { descending: true, nulls_first: true, @@ -2245,13 +1768,7 @@ mod tests { ); test_sort_to_indices_decimal256_array( - vec![ - None, - Some(i256::MIN), - Some(i256::from_i128(1)), - Some(i256::MAX), - Some(i256::from_i128(-1)), - ], + data.clone(), Some(SortOptions { descending: false, nulls_first: true, @@ -2261,13 +1778,7 @@ mod tests { ); test_sort_to_indices_decimal256_array( - vec![ - None, - Some(i256::MIN), - Some(i256::from_i128(1)), - Some(i256::MAX), - Some(i256::from_i128(-1)), - ], + data.clone(), Some(SortOptions { descending: true, nulls_first: true, @@ -2357,124 +1868,109 @@ mod tests { #[test] fn test_sort_decimal256() { + let data = vec![ + None, + Some(i256::from_i128(5)), + Some(i256::from_i128(2)), + Some(i256::from_i128(3)), + Some(i256::from_i128(1)), + Some(i256::from_i128(4)), + None, + ]; // decimal default test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), None, None, - vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)] + [None, None, Some(1), Some(2), Some(3), Some(4), Some(5)] .iter() .map(|v| v.map(i256::from_i128)) .collect(), ); // decimal descending test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: false, }), None, - vec![Some(5), Some(4), Some(3), Some(2), Some(1), None, None] + [Some(5), Some(4), Some(3), Some(2), Some(1), None, None] .iter() .map(|v| v.map(i256::from_i128)) .collect(), ); // decimal null_first and descending test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: true, }), None, - vec![None, None, Some(5), Some(4), Some(3), Some(2), Some(1)] + [None, None, Some(5), Some(4), Some(3), Some(2), Some(1)] .iter() .map(|v| v.map(i256::from_i128)) .collect(), ); // decimal null_first test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: false, nulls_first: true, }), None, - vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)] + [None, None, Some(1), Some(2), Some(3), Some(4), Some(5)] .iter() .map(|v| v.map(i256::from_i128)) .collect(), ); // limit test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), None, Some(3), - vec![None, None, Some(1)] + [None, None, Some(1)] .iter() .map(|v| v.map(i256::from_i128)) .collect(), ); // limit descending test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: false, }), Some(3), - vec![Some(5), Some(4), Some(3)] + [Some(5), Some(4), Some(3)] .iter() .map(|v| v.map(i256::from_i128)) .collect(), ); // limit descending null_first test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data.clone(), Some(SortOptions { descending: true, nulls_first: true, }), Some(3), - vec![None, None, Some(5)] + [None, None, Some(5)] .iter() .map(|v| v.map(i256::from_i128)) .collect(), ); // limit null_first test_sort_decimal256_array( - vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), + data, Some(SortOptions { descending: false, nulls_first: true, }), Some(3), - vec![None, None, Some(1)] + [None, None, Some(1)] .iter() .map(|v| v.map(i256::from_i128)) .collect(), @@ -2915,7 +2411,7 @@ mod tests { nulls_first: false, }), None, - vec![2, 4, 1, 5, 3, 0], + vec![2, 4, 1, 5, 0, 3], ); test_sort_to_indices_string_arrays( @@ -2949,7 +2445,7 @@ mod tests { nulls_first: true, }), None, - vec![3, 0, 2, 4, 1, 5], + vec![0, 3, 2, 4, 1, 5], ); test_sort_to_indices_string_arrays( @@ -2966,7 +2462,7 @@ mod tests { nulls_first: true, }), Some(3), - vec![3, 0, 2], + vec![0, 3, 2], ); // valid values less than limit with extra nulls @@ -4465,4 +3961,57 @@ mod tests { vec![None, None, None, Some(5.1), Some(5.1), Some(3.0), Some(1.2)], ); } + + #[test] + fn test_lexicographic_comparator_null_dict_values() { + let values = Int32Array::new( + vec![1, 2, 3, 4].into(), + Some(NullBuffer::from(vec![true, false, false, true])), + ); + let keys = Int32Array::new( + vec![0, 1, 53, 3].into(), + Some(NullBuffer::from(vec![true, true, false, true])), + ); + // [1, NULL, NULL, 4] + let dict = DictionaryArray::new(keys, Arc::new(values)); + + let comparator = LexicographicalComparator::try_new(&[SortColumn { + values: Arc::new(dict), + options: None, + }]) + .unwrap(); + // 1.cmp(NULL) + assert_eq!(comparator.compare(0, 1), Ordering::Greater); + // NULL.cmp(NULL) + assert_eq!(comparator.compare(2, 1), Ordering::Equal); + // NULL.cmp(4) + assert_eq!(comparator.compare(2, 3), Ordering::Less); + } + + #[test] + fn sort_list_equal() { + let a = { + let mut builder = FixedSizeListBuilder::new(Int64Builder::new(), 2); + for value in [[1, 5], [0, 3], [1, 3]] { + builder.values().append_slice(&value); + builder.append(true); + } + builder.finish() + }; + + let sort_indices = sort_to_indices(&a, None, None).unwrap(); + assert_eq!(sort_indices.values(), &[1, 2, 0]); + + let a = { + let mut builder = ListBuilder::new(Int64Builder::new()); + for value in [[1, 5], [0, 3], [1, 3]] { + builder.values().append_slice(&value); + builder.append(true); + } + builder.finish() + }; + + let sort_indices = sort_to_indices(&a, None, None).unwrap(); + assert_eq!(sort_indices.values(), &[1, 2, 0]); + } } diff --git a/arrow-pyarrow-integration-testing/Cargo.toml b/arrow-pyarrow-integration-testing/Cargo.toml index 50987b03ca9e..8c60c086c29a 100644 --- a/arrow-pyarrow-integration-testing/Cargo.toml +++ b/arrow-pyarrow-integration-testing/Cargo.toml @@ -34,4 +34,4 @@ crate-type = ["cdylib"] [dependencies] arrow = { path = "../arrow", features = ["pyarrow"] } -pyo3 = { version = "0.19", features = ["extension-module"] } +pyo3 = { version = "0.20", features = ["extension-module"] } diff --git a/arrow-pyarrow-integration-testing/pyproject.toml b/arrow-pyarrow-integration-testing/pyproject.toml index d75f8de1ac4c..d85db24c2e18 100644 --- a/arrow-pyarrow-integration-testing/pyproject.toml +++ b/arrow-pyarrow-integration-testing/pyproject.toml @@ -16,7 +16,7 @@ # under the License. [build-system] -requires = ["maturin"] +requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" dependencies = ["pyarrow>=1"] diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 730409b3777e..a53447b53c31 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -21,6 +21,8 @@ use std::sync::Arc; use arrow::array::new_empty_array; +use arrow::record_batch::{RecordBatchIterator, RecordBatchReader}; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -49,7 +51,7 @@ fn double(array: &PyAny, py: Python) -> PyResult { .ok_or_else(|| ArrowError::ParseError("Expects an int64".to_string())) .map_err(to_py_err)?; - let array = kernels::arithmetic::add(array, array).map_err(to_py_err)?; + let array = kernels::numeric::add(array, array).map_err(to_py_err)?; // export array.to_data().to_pyarrow(py) @@ -140,6 +142,31 @@ fn round_trip_record_batch_reader( Ok(obj) } +#[pyfunction] +fn reader_return_errors(obj: PyArrowType) -> PyResult<()> { + // This makes sure we can correctly consume a RBR and return the error, + // ensuring the error can live beyond the lifetime of the RBR. + let batches = obj.0.collect::, ArrowError>>(); + match batches { + Ok(_) => Ok(()), + Err(err) => Err(PyValueError::new_err(err.to_string())), + } +} + +#[pyfunction] +fn boxed_reader_roundtrip( + obj: PyArrowType, +) -> PyArrowType> { + let schema = obj.0.schema(); + let batches = obj + .0 + .collect::, ArrowError>>() + .unwrap(); + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + let reader: Box = Box::new(reader); + PyArrowType(reader) +} + #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(double))?; @@ -153,5 +180,7 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> m.add_wrapped(wrap_pyfunction!(round_trip_array))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?; + m.add_wrapped(wrap_pyfunction!(reader_return_errors))?; + m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?; Ok(()) } diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index a7c6b34a4474..1748fd3ffb6b 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -393,6 +393,23 @@ def test_sparse_union_python(): del a del b +def test_tensor_array(): + tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3]) + inner = pa.array([float(x) for x in range(1, 7)] + [None] * 12, pa.float32()) + storage = pa.FixedSizeListArray.from_arrays(inner, 6) + f32_array = pa.ExtensionArray.from_storage(tensor_type, storage) + + # Round-tripping as an array gives back storage type, because arrow-rs has + # no notion of extension types. + b = rust.round_trip_array(f32_array) + assert b == f32_array.storage + + batch = pa.record_batch([f32_array], ["tensor"]) + b = rust.round_trip_record_batch(batch) + assert b == batch + + del b + def test_record_batch_reader(): """ Python -> Rust -> Python @@ -409,6 +426,33 @@ def test_record_batch_reader(): got_batches = list(b) assert got_batches == batches + # Also try the boxed reader variant + a = pa.RecordBatchReader.from_batches(schema, batches) + b = rust.boxed_reader_roundtrip(a) + assert b.schema == schema + got_batches = list(b) + assert got_batches == batches + +def test_record_batch_reader_error(): + schema = pa.schema([('ints', pa.list_(pa.int32()))]) + + def iter_batches(): + yield pa.record_batch([[[1], [2, 42]]], schema) + raise ValueError("test error") + + reader = pa.RecordBatchReader.from_batches(schema, iter_batches()) + + with pytest.raises(ValueError, match="test error"): + rust.reader_return_errors(reader) + + # Due to a long-standing oversight, PyArrow allows binary values in schema + # metadata that are not valid UTF-8. This is not allowed in Rust, but we + # make sure we error and not panic here. + schema = schema.with_metadata({"key": b"\xff"}) + reader = pa.RecordBatchReader.from_batches(schema, iter_batches()) + with pytest.raises(ValueError, match="invalid utf-8"): + rust.round_trip_record_batch_reader(reader) + def test_reject_other_classes(): # Arbitrary type that is not a PyArrow type not_pyarrow = ["hello"] diff --git a/arrow-row/src/dictionary.rs b/arrow-row/src/dictionary.rs deleted file mode 100644 index 6c3ee9e18ced..000000000000 --- a/arrow-row/src/dictionary.rs +++ /dev/null @@ -1,296 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::fixed::{FixedLengthEncoding, FromSlice}; -use crate::interner::{Interned, OrderPreservingInterner}; -use crate::{null_sentinel, Row, Rows}; -use arrow_array::builder::*; -use arrow_array::cast::*; -use arrow_array::types::*; -use arrow_array::*; -use arrow_buffer::{ArrowNativeType, MutableBuffer, ToByteSlice}; -use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::{ArrowError, DataType, SortOptions}; -use std::collections::hash_map::Entry; -use std::collections::HashMap; - -/// Computes the dictionary mapping for the given dictionary values -pub fn compute_dictionary_mapping( - interner: &mut OrderPreservingInterner, - values: &ArrayRef, -) -> Vec> { - downcast_primitive_array! { - values => interner - .intern(values.iter().map(|x| x.map(|x| x.encode()))), - DataType::Binary => { - let iter = as_generic_binary_array::(values).iter(); - interner.intern(iter) - } - DataType::LargeBinary => { - let iter = as_generic_binary_array::(values).iter(); - interner.intern(iter) - } - DataType::Utf8 => { - let iter = values.as_string::().iter().map(|x| x.map(|x| x.as_bytes())); - interner.intern(iter) - } - DataType::LargeUtf8 => { - let iter = values.as_string::().iter().map(|x| x.map(|x| x.as_bytes())); - interner.intern(iter) - } - _ => unreachable!(), - } -} - -/// Encode dictionary values not preserving the dictionary encoding -pub fn encode_dictionary_values( - data: &mut [u8], - offsets: &mut [usize], - column: &DictionaryArray, - values: &Rows, - null: &Row<'_>, -) { - for (offset, k) in offsets.iter_mut().skip(1).zip(column.keys()) { - let row = match k { - Some(k) => values.row(k.as_usize()).data, - None => null.data, - }; - let end_offset = *offset + row.len(); - data[*offset..end_offset].copy_from_slice(row); - *offset = end_offset; - } -} - -/// Dictionary types are encoded as -/// -/// - single `0_u8` if null -/// - the bytes of the corresponding normalized key including the null terminator -pub fn encode_dictionary( - data: &mut [u8], - offsets: &mut [usize], - column: &DictionaryArray, - normalized_keys: &[Option<&[u8]>], - opts: SortOptions, -) { - for (offset, k) in offsets.iter_mut().skip(1).zip(column.keys()) { - match k.and_then(|k| normalized_keys[k.as_usize()]) { - Some(normalized_key) => { - let end_offset = *offset + 1 + normalized_key.len(); - data[*offset] = 1; - data[*offset + 1..end_offset].copy_from_slice(normalized_key); - // Negate if descending - if opts.descending { - data[*offset..end_offset].iter_mut().for_each(|v| *v = !*v) - } - *offset = end_offset; - } - None => { - data[*offset] = null_sentinel(opts); - *offset += 1; - } - } - } -} - -macro_rules! decode_primitive_helper { - ($t:ty, $values: ident, $data_type:ident) => { - decode_primitive::<$t>(&$values, $data_type.clone()) - }; -} - -/// Decodes a string array from `rows` with the provided `options` -/// -/// # Safety -/// -/// `interner` must contain valid data for the provided `value_type` -pub unsafe fn decode_dictionary( - interner: &OrderPreservingInterner, - value_type: &DataType, - options: SortOptions, - rows: &mut [&[u8]], -) -> Result, ArrowError> { - let len = rows.len(); - let mut dictionary: HashMap = HashMap::with_capacity(len); - - let null_sentinel = null_sentinel(options); - - // If descending, the null terminator will have been negated - let null_terminator = match options.descending { - true => 0xFF, - false => 0_u8, - }; - - let mut null_builder = BooleanBufferBuilder::new(len); - let mut keys = BufferBuilder::::new(len); - let mut values = Vec::with_capacity(len); - let mut null_count = 0; - let mut key_scratch = Vec::new(); - - for row in rows { - if row[0] == null_sentinel { - null_builder.append(false); - null_count += 1; - *row = &row[1..]; - keys.append(K::Native::default()); - continue; - } - - let key_offset = row - .iter() - .skip(1) - .position(|x| *x == null_terminator) - .unwrap(); - - // Extract the normalized key including the null terminator - let key = &row[1..key_offset + 2]; - *row = &row[key_offset + 2..]; - - let interned = match options.descending { - true => { - // If options.descending the normalized key will have been - // negated we must first reverse this - key_scratch.clear(); - key_scratch.extend_from_slice(key); - key_scratch.iter_mut().for_each(|o| *o = !*o); - interner.lookup(&key_scratch).unwrap() - } - false => interner.lookup(key).unwrap(), - }; - - let k = match dictionary.entry(interned) { - Entry::Vacant(v) => { - let k = values.len(); - values.push(interner.value(interned)); - let key = K::Native::from_usize(k) - .ok_or(ArrowError::DictionaryKeyOverflowError)?; - *v.insert(key) - } - Entry::Occupied(o) => *o.get(), - }; - - keys.append(k); - null_builder.append(true); - } - - let child = downcast_primitive! { - value_type => (decode_primitive_helper, values, value_type), - DataType::Null => NullArray::new(values.len()).into_data(), - DataType::Boolean => decode_bool(&values), - DataType::Utf8 => decode_string::(&values), - DataType::LargeUtf8 => decode_string::(&values), - DataType::Binary => decode_binary::(&values), - DataType::LargeBinary => decode_binary::(&values), - _ => unreachable!(), - }; - - let data_type = - DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(value_type.clone())); - - let builder = ArrayDataBuilder::new(data_type) - .len(len) - .null_bit_buffer(Some(null_builder.into())) - .null_count(null_count) - .add_buffer(keys.finish()) - .add_child_data(child); - - Ok(DictionaryArray::from(builder.build_unchecked())) -} - -/// Decodes a binary array from dictionary values -/// -/// # Safety -/// -/// Values must be valid UTF-8 -fn decode_binary(values: &[&[u8]]) -> ArrayData { - let capacity = values.iter().map(|x| x.len()).sum(); - let mut builder = GenericBinaryBuilder::::with_capacity(values.len(), capacity); - for v in values { - builder.append_value(v) - } - builder.finish().into_data() -} - -/// Decodes a string array from dictionary values -/// -/// # Safety -/// -/// Values must be valid UTF-8 -unsafe fn decode_string(values: &[&[u8]]) -> ArrayData { - let d = match O::IS_LARGE { - true => DataType::LargeUtf8, - false => DataType::Utf8, - }; - - decode_binary::(values) - .into_builder() - .data_type(d) - .build_unchecked() -} - -/// Decodes a boolean array from dictionary values -fn decode_bool(values: &[&[u8]]) -> ArrayData { - let mut builder = BooleanBufferBuilder::new(values.len()); - for value in values { - builder.append(bool::decode([value[0]])) - } - - let builder = ArrayDataBuilder::new(DataType::Boolean) - .len(values.len()) - .add_buffer(builder.into()); - - // SAFETY: Buffers correct length - unsafe { builder.build_unchecked() } -} - -/// Decodes a fixed length type array from dictionary values -/// -/// # Safety -/// -/// `data_type` must be appropriate native type for `T` -unsafe fn decode_fixed( - values: &[&[u8]], - data_type: DataType, -) -> ArrayData { - let mut buffer = MutableBuffer::new(std::mem::size_of::() * values.len()); - - for value in values { - let value = T::Encoded::from_slice(value, false); - buffer.push(T::decode(value)) - } - - let builder = ArrayDataBuilder::new(data_type) - .len(values.len()) - .add_buffer(buffer.into()); - - // SAFETY: Buffers correct length - builder.build_unchecked() -} - -/// Decodes a `PrimitiveArray` from dictionary values -fn decode_primitive( - values: &[&[u8]], - data_type: DataType, -) -> ArrayData -where - T::Native: FixedLengthEncoding, -{ - assert!(PrimitiveArray::::is_compatible(&data_type)); - - // SAFETY: - // Validated data type above - unsafe { decode_fixed::(values, data_type) } -} diff --git a/arrow-row/src/interner.rs b/arrow-row/src/interner.rs deleted file mode 100644 index 1c71b6a55217..000000000000 --- a/arrow-row/src/interner.rs +++ /dev/null @@ -1,430 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use hashbrown::hash_map::RawEntryMut; -use hashbrown::HashMap; -use std::num::NonZeroU32; -use std::ops::Index; - -/// An interned value -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub struct Interned(NonZeroU32); // We use NonZeroU32 so that `Option` is 32 bits - -/// A byte array interner that generates normalized keys that are sorted with respect -/// to the interned values, e.g. `inter(a) < intern(b) => a < b` -#[derive(Debug, Default)] -pub struct OrderPreservingInterner { - /// Provides a lookup from [`Interned`] to the normalized key - keys: InternBuffer, - /// Provides a lookup from [`Interned`] to the normalized value - values: InternBuffer, - /// Key allocation data structure - bucket: Box, - - // A hash table used to perform faster re-keying, and detect duplicates - hasher: ahash::RandomState, - lookup: HashMap, -} - -impl OrderPreservingInterner { - /// Interns an iterator of values returning a list of [`Interned`] which can be - /// used with [`Self::normalized_key`] to retrieve the normalized keys with a - /// lifetime not tied to the mutable borrow passed to this method - pub fn intern(&mut self, input: I) -> Vec> - where - I: IntoIterator>, - V: AsRef<[u8]>, - { - let iter = input.into_iter(); - let capacity = iter.size_hint().0; - let mut out = Vec::with_capacity(capacity); - - // (index in output, hash value, value) - let mut to_intern: Vec<(usize, u64, V)> = Vec::with_capacity(capacity); - let mut to_intern_len = 0; - - for (idx, item) in iter.enumerate() { - let value: V = match item { - Some(value) => value, - None => { - out.push(None); - continue; - } - }; - - let v = value.as_ref(); - let hash = self.hasher.hash_one(v); - let entry = self - .lookup - .raw_entry_mut() - .from_hash(hash, |a| &self.values[*a] == v); - - match entry { - RawEntryMut::Occupied(o) => out.push(Some(*o.key())), - RawEntryMut::Vacant(_) => { - // Push placeholder - out.push(None); - to_intern_len += v.len(); - to_intern.push((idx, hash, value)); - } - }; - } - - to_intern.sort_unstable_by(|(_, _, a), (_, _, b)| a.as_ref().cmp(b.as_ref())); - - self.keys.offsets.reserve(to_intern.len()); - self.keys.values.reserve(to_intern.len()); // Approximation - self.values.offsets.reserve(to_intern.len()); - self.values.values.reserve(to_intern_len); - - for (idx, hash, value) in to_intern { - let val = value.as_ref(); - - let entry = self - .lookup - .raw_entry_mut() - .from_hash(hash, |a| &self.values[*a] == val); - - match entry { - RawEntryMut::Occupied(o) => { - out[idx] = Some(*o.key()); - } - RawEntryMut::Vacant(v) => { - let val = value.as_ref(); - self.bucket - .insert(&mut self.values, val, &mut self.keys.values); - self.keys.values.push(0); - let interned = self.keys.append(); - - let hasher = &mut self.hasher; - let values = &self.values; - v.insert_with_hasher(hash, interned, (), |key| { - hasher.hash_one(&values[*key]) - }); - out[idx] = Some(interned); - } - } - } - - out - } - - /// Returns a null-terminated byte array that can be compared against other normalized_key - /// returned by this instance, to establish ordering of the interned values - pub fn normalized_key(&self, key: Interned) -> &[u8] { - &self.keys[key] - } - - /// Converts a normalized key returned by [`Self::normalized_key`] to [`Interned`] - /// returning `None` if it cannot be found - pub fn lookup(&self, normalized_key: &[u8]) -> Option { - let len = normalized_key.len(); - if len <= 1 { - return None; - } - - let mut bucket = self.bucket.as_ref(); - if len > 2 { - for v in normalized_key.iter().take(len - 2) { - if *v == 255 { - bucket = bucket.next.as_ref()?; - } else { - let bucket_idx = v.checked_sub(1)?; - bucket = bucket.slots.get(bucket_idx as usize)?.child.as_ref()?; - } - } - } - - let slot_idx = normalized_key[len - 2].checked_sub(2)?; - Some(bucket.slots.get(slot_idx as usize)?.value) - } - - /// Returns the interned value for a given [`Interned`] - pub fn value(&self, key: Interned) -> &[u8] { - self.values.index(key) - } - - /// Returns the size of this instance in bytes including self - pub fn size(&self) -> usize { - std::mem::size_of::() - + self.keys.buffer_size() - + self.values.buffer_size() - + self.bucket.size() - + self.lookup.capacity() * std::mem::size_of::() - } -} - -/// A buffer of `[u8]` indexed by `[Interned]` -#[derive(Debug)] -struct InternBuffer { - /// Raw values - values: Vec, - /// The ith value is `&values[offsets[i]..offsets[i+1]]` - offsets: Vec, -} - -impl Default for InternBuffer { - fn default() -> Self { - Self { - values: Default::default(), - offsets: vec![0], - } - } -} - -impl InternBuffer { - /// Insert `data` returning the corresponding [`Interned`] - fn insert(&mut self, data: &[u8]) -> Interned { - self.values.extend_from_slice(data); - self.append() - } - - /// Appends the next value based on data written to `self.values` - /// returning the corresponding [`Interned`] - fn append(&mut self) -> Interned { - let idx: u32 = self.offsets.len().try_into().unwrap(); - let key = Interned(NonZeroU32::new(idx).unwrap()); - self.offsets.push(self.values.len()); - key - } - - /// Returns the byte size of the associated buffers - fn buffer_size(&self) -> usize { - self.values.capacity() + self.offsets.capacity() * std::mem::size_of::() - } -} - -impl Index for InternBuffer { - type Output = [u8]; - - fn index(&self, key: Interned) -> &Self::Output { - let index = key.0.get() as usize; - let end = self.offsets[index]; - let start = self.offsets[index - 1]; - // SAFETY: - // self.values is never reduced in size and values appended - // to self.offsets are always less than self.values at the time - unsafe { self.values.get_unchecked(start..end) } - } -} - -/// A slot corresponds to a single byte-value in the generated normalized key -/// -/// It may contain a value, if not the first slot, and may contain a child [`Bucket`] representing -/// the next byte in the generated normalized key -#[derive(Debug, Clone)] -struct Slot { - value: Interned, - /// Child values less than `self.value` if any - child: Option>, -} - -/// Bucket is the root of the data-structure used to allocate normalized keys -/// -/// In particular it needs to generate keys that -/// -/// * Contain no `0` bytes other than the null terminator -/// * Compare lexicographically in the same manner as the encoded `data` -/// -/// The data structure consists of 254 slots, each of which can store a value. -/// Additionally each slot may contain a child bucket, containing values smaller -/// than the value within the slot. -/// -/// Each bucket also may contain a child bucket, containing values greater than -/// all values in the current bucket -/// -/// # Allocation Strategy -/// -/// The contiguous slice of slots containing values is searched to find the insertion -/// point for the new value, according to the sort order. -/// -/// If the insertion position exceeds 254, the number of slots, the value is inserted -/// into the child bucket of the current bucket. -/// -/// If the insertion position already contains a value, the value is inserted into the -/// child bucket of that slot. -/// -/// If the slot is not occupied, the value is inserted into that slot. -/// -/// The final key consists of the slot indexes visited incremented by 1, -/// with the final value incremented by 2, followed by a null terminator. -/// -/// Consider the case of the integers `[8, 6, 5, 7]` inserted in that order -/// -/// ```ignore -/// 8: &[2, 0] -/// 6: &[1, 2, 0] -/// 5: &[1, 1, 2, 0] -/// 7: &[1, 3, 0] -/// ``` -/// -/// Note: this allocation strategy is optimised for interning values in sorted order -/// -#[derive(Debug, Clone)] -struct Bucket { - slots: Vec, - /// Bucket containing values larger than all of `slots` - next: Option>, -} - -impl Default for Bucket { - fn default() -> Self { - Self { - slots: Vec::with_capacity(254), - next: None, - } - } -} - -impl Bucket { - /// Insert `data` into this bucket or one of its children, appending the - /// normalized key to `out` as it is constructed - /// - /// # Panics - /// - /// Panics if the value already exists - fn insert(&mut self, values_buf: &mut InternBuffer, data: &[u8], out: &mut Vec) { - let slots_len = self.slots.len() as u8; - // We optimise the case of inserting a value directly after those already inserted - // as [`OrderPreservingInterner::intern`] sorts values prior to interning them - match self.slots.last() { - Some(slot) => { - if &values_buf[slot.value] < data { - if slots_len == 254 { - out.push(255); - self.next - .get_or_insert_with(Default::default) - .insert(values_buf, data, out) - } else { - out.push(slots_len + 2); - let value = values_buf.insert(data); - self.slots.push(Slot { value, child: None }); - } - } else { - // Find insertion point - match self - .slots - .binary_search_by(|slot| values_buf[slot.value].cmp(data)) - { - Ok(_) => unreachable!("value already exists"), - Err(idx) => { - out.push(idx as u8 + 1); - self.slots[idx] - .child - .get_or_insert_with(Default::default) - .insert(values_buf, data, out) - } - } - } - } - None => { - out.push(2); - let value = values_buf.insert(data); - self.slots.push(Slot { value, child: None }) - } - } - } - - /// Returns the size of this instance in bytes - fn size(&self) -> usize { - std::mem::size_of::() - + self.slots.capacity() * std::mem::size_of::() - + self.next.as_ref().map(|x| x.size()).unwrap_or_default() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rand::prelude::*; - - // Clippy isn't smart enough to understand dropping mutability - #[allow(clippy::needless_collect)] - fn test_intern_values(values: &[u64]) { - let mut interner = OrderPreservingInterner::default(); - - // Intern a single value at a time to check ordering - let interned: Vec<_> = values - .iter() - .flat_map(|v| interner.intern([Some(&v.to_be_bytes())])) - .map(Option::unwrap) - .collect(); - - for (value, interned) in values.iter().zip(&interned) { - assert_eq!(interner.value(*interned), &value.to_be_bytes()); - } - - let normalized_keys: Vec<_> = interned - .iter() - .map(|x| interner.normalized_key(*x)) - .collect(); - - for (interned, normalized) in interned.iter().zip(&normalized_keys) { - assert_eq!(*interned, interner.lookup(normalized).unwrap()); - } - - for (i, a) in normalized_keys.iter().enumerate() { - for (j, b) in normalized_keys.iter().enumerate() { - let interned_cmp = a.cmp(b); - let values_cmp = values[i].cmp(&values[j]); - assert_eq!( - interned_cmp, values_cmp, - "({:?} vs {:?}) vs ({} vs {})", - a, b, values[i], values[j] - ) - } - } - } - - #[test] - #[cfg_attr(miri, ignore)] - fn test_interner() { - test_intern_values(&[8, 6, 5, 7]); - - let mut values: Vec<_> = (0_u64..2000).collect(); - test_intern_values(&values); - - let mut rng = thread_rng(); - values.shuffle(&mut rng); - test_intern_values(&values); - } - - #[test] - fn test_intern_duplicates() { - // Unsorted with duplicates - let values = vec![0_u8, 1, 8, 4, 1, 0]; - let mut interner = OrderPreservingInterner::default(); - - let interned = interner.intern(values.iter().map(std::slice::from_ref).map(Some)); - let interned: Vec<_> = interned.into_iter().map(Option::unwrap).collect(); - - assert_eq!(interned[0], interned[5]); - assert_eq!(interned[1], interned[4]); - assert!( - interner.normalized_key(interned[0]) < interner.normalized_key(interned[1]) - ); - assert!( - interner.normalized_key(interned[1]) < interner.normalized_key(interned[2]) - ); - assert!( - interner.normalized_key(interned[1]) < interner.normalized_key(interned[3]) - ); - assert!( - interner.normalized_key(interned[3]) < interner.normalized_key(interned[2]) - ); - } -} diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index e8c5ff708d55..1fb4e1de7ac2 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -21,7 +21,7 @@ //! using [`memcmp`] under the hood, or used in [non-comparison sorts] such as [radix sort]. //! This makes the row format ideal for implementing efficient multi-column sorting, //! grouping, aggregation, windowing and more, as described in more detail -//! [here](https://arrow.apache.org/blog/2022/11/07/multi-column-sorts-in-arrow-rust-part-1/). +//! [in this blog post](https://arrow.apache.org/blog/2022/11/07/multi-column-sorts-in-arrow-rust-part-1/). //! //! For example, given three input [`Array`], [`RowConverter`] creates byte //! sequences that [compare] the same as when using [`lexsort`]. @@ -61,7 +61,7 @@ //! let arrays = vec![a1, a2]; //! //! // Convert arrays to rows -//! let mut converter = RowConverter::new(vec![ +//! let converter = RowConverter::new(vec![ //! SortField::new(DataType::Int32), //! SortField::new(DataType::Utf8), //! ]).unwrap(); @@ -109,7 +109,7 @@ //! .iter() //! .map(|a| SortField::new(a.data_type().clone())) //! .collect(); -//! let mut converter = RowConverter::new(fields).unwrap(); +//! let converter = RowConverter::new(fields).unwrap(); //! let rows = converter.convert_columns(arrays).unwrap(); //! let mut sort: Vec<_> = rows.iter().enumerate().collect(); //! sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); @@ -130,22 +130,16 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow_array::cast::*; +use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::*; use arrow_buffer::ArrowNativeType; use arrow_data::ArrayDataBuilder; use arrow_schema::*; -use crate::dictionary::{ - compute_dictionary_mapping, decode_dictionary, encode_dictionary, - encode_dictionary_values, -}; use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive}; -use crate::interner::OrderPreservingInterner; use crate::variable::{decode_binary, decode_string}; -mod dictionary; mod fixed; -mod interner; mod list; mod variable; @@ -232,13 +226,13 @@ mod variable; /// A non-null, non-empty byte array is encoded as `2_u8` followed by the byte array /// encoded using a block based scheme described below. /// -/// The byte array is broken up into 32-byte blocks, each block is written in turn +/// The byte array is broken up into fixed-width blocks, each block is written in turn /// to the output, followed by `0xFF_u8`. The final block is padded to 32-bytes /// with `0_u8` and written to the output, followed by the un-padded length in bytes -/// of this final block as a `u8`. +/// of this final block as a `u8`. The first 4 blocks have a length of 8, with subsequent +/// blocks using a length of 32, this is to reduce space amplification for small strings. /// -/// Note the following example encodings use a block size of 4 bytes, -/// as opposed to 32 bytes for brevity: +/// Note the following example encodings use a block size of 4 bytes for brevity: /// /// ```text /// ┌───┬───┬───┬───┬───┬───┐ @@ -271,53 +265,7 @@ mod variable; /// /// ## Dictionary Encoding /// -/// [`RowConverter`] needs to support converting dictionary encoded arrays with unsorted, and -/// potentially distinct dictionaries. One simple mechanism to avoid this would be to reverse -/// the dictionary encoding, and encode the array values directly, however, this would lose -/// the benefits of dictionary encoding to reduce memory and CPU consumption. -/// -/// As such the [`RowConverter`] creates an order-preserving mapping -/// for each dictionary encoded column, which allows new dictionary -/// values to be added whilst preserving the sort order. -/// -/// A null dictionary value is encoded as `0_u8`. -/// -/// A non-null dictionary value is encoded as `1_u8` followed by a null-terminated byte array -/// key determined by the order-preserving dictionary encoding -/// -/// ```text -/// ┌──────────┐ ┌─────┐ -/// │ "Bar" │ ───────────────▶│ 01 │ -/// └──────────┘ └─────┘ -/// ┌──────────┐ ┌─────┬─────┐ -/// │"Fabulous"│ ───────────────▶│ 01 │ 02 │ -/// └──────────┘ └─────┴─────┘ -/// ┌──────────┐ ┌─────┐ -/// │ "Soup" │ ───────────────▶│ 05 │ -/// └──────────┘ └─────┘ -/// ┌──────────┐ ┌─────┐ -/// │ "ZZ" │ ───────────────▶│ 07 │ -/// └──────────┘ └─────┘ -/// -/// Example Order Preserving Mapping -/// ``` -/// Using the map above, the corresponding row format will be -/// -/// ```text -/// ┌─────┬─────┬─────┬─────┐ -/// "Fabulous" │ 01 │ 03 │ 05 │ 00 │ -/// └─────┴─────┴─────┴─────┘ -/// -/// ┌─────┬─────┬─────┐ -/// "ZZ" │ 01 │ 07 │ 00 │ -/// └─────┴─────┴─────┘ -/// -/// ┌─────┐ -/// NULL │ 00 │ -/// └─────┘ -/// -/// Input Row Format -/// ``` +/// Dictionaries are hydrated to their underlying values /// /// ## Struct Encoding /// @@ -426,15 +374,9 @@ pub struct RowConverter { enum Codec { /// No additional codec state is necessary Stateless, - /// The interner used to encode dictionary values - /// - /// Used when preserving the dictionary encoding - Dictionary(OrderPreservingInterner), /// A row converter for the dictionary values /// and the encoding of a row containing only nulls - /// - /// Used when not preserving dictionary encoding - DictionaryValues(RowConverter, OwnedRow), + Dictionary(RowConverter, OwnedRow), /// A row converter for the child fields /// and the encoding of a row containing only nulls Struct(RowConverter, OwnedRow), @@ -445,25 +387,22 @@ enum Codec { impl Codec { fn new(sort_field: &SortField) -> Result { match &sort_field.data_type { - DataType::Dictionary(_, values) => match sort_field.preserve_dictionaries { - true => Ok(Self::Dictionary(Default::default())), - false => { - let sort_field = SortField::new_with_options( - values.as_ref().clone(), - sort_field.options, - ); + DataType::Dictionary(_, values) => { + let sort_field = SortField::new_with_options( + values.as_ref().clone(), + sort_field.options, + ); - let mut converter = RowConverter::new(vec![sort_field])?; - let null_array = new_null_array(values.as_ref(), 1); - let nulls = converter.convert_columns(&[null_array])?; + let converter = RowConverter::new(vec![sort_field])?; + let null_array = new_null_array(values.as_ref(), 1); + let nulls = converter.convert_columns(&[null_array])?; - let owned = OwnedRow { - data: nulls.buffer.into(), - config: nulls.config, - }; - Ok(Self::DictionaryValues(converter, owned)) - } - }, + let owned = OwnedRow { + data: nulls.buffer.into(), + config: nulls.config, + }; + Ok(Self::Dictionary(converter, owned)) + } d if !d.is_nested() => Ok(Self::Stateless), DataType::List(f) | DataType::LargeList(f) => { // The encoded contents will be inverted if descending is set to true @@ -490,7 +429,7 @@ impl Codec { }) .collect(); - let mut converter = RowConverter::new(sort_fields)?; + let converter = RowConverter::new(sort_fields)?; let nulls: Vec<_> = f.iter().map(|x| new_null_array(x.data_type(), 1)).collect(); @@ -509,32 +448,13 @@ impl Codec { } } - fn encoder(&mut self, array: &dyn Array) -> Result, ArrowError> { + fn encoder(&self, array: &dyn Array) -> Result, ArrowError> { match self { Codec::Stateless => Ok(Encoder::Stateless), - Codec::Dictionary(interner) => { - let values = downcast_dictionary_array! { - array => array.values(), - _ => unreachable!() - }; - - let mapping = compute_dictionary_mapping(interner, values) - .into_iter() - .map(|maybe_interned| { - maybe_interned.map(|interned| interner.normalized_key(interned)) - }) - .collect(); - - Ok(Encoder::Dictionary(mapping)) - } - Codec::DictionaryValues(converter, nulls) => { - let values = downcast_dictionary_array! { - array => array.values(), - _ => unreachable!() - }; - - let rows = converter.convert_columns(&[values.clone()])?; - Ok(Encoder::DictionaryValues(rows, nulls.row())) + Codec::Dictionary(converter, nulls) => { + let values = array.as_any_dictionary().values().clone(); + let rows = converter.convert_columns(&[values])?; + Ok(Encoder::Dictionary(rows, nulls.row())) } Codec::Struct(converter, null) => { let v = as_struct_array(array); @@ -556,10 +476,7 @@ impl Codec { fn size(&self) -> usize { match self { Codec::Stateless => 0, - Codec::Dictionary(interner) => interner.size(), - Codec::DictionaryValues(converter, nulls) => { - converter.size() + nulls.data.len() - } + Codec::Dictionary(converter, nulls) => converter.size() + nulls.data.len(), Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(), Codec::List(converter) => converter.size(), } @@ -570,10 +487,8 @@ impl Codec { enum Encoder<'a> { /// No additional encoder state is necessary Stateless, - /// The mapping from dictionary keys to normalized keys - Dictionary(Vec>), /// The encoding of the child array and the encoding of a null row - DictionaryValues(Rows, Row<'a>), + Dictionary(Rows, Row<'a>), /// The row encoding of the child arrays and the encoding of a null row /// /// It is necessary to encode to a temporary [`Rows`] to avoid serializing @@ -591,8 +506,6 @@ pub struct SortField { options: SortOptions, /// Data type data_type: DataType, - /// Preserve dictionaries - preserve_dictionaries: bool, } impl SortField { @@ -603,30 +516,7 @@ impl SortField { /// Create a new column with the given data type and [`SortOptions`] pub fn new_with_options(data_type: DataType, options: SortOptions) -> Self { - Self { - options, - data_type, - preserve_dictionaries: true, - } - } - - /// By default dictionaries are preserved as described on [`RowConverter`] - /// - /// However, this process requires maintaining and incrementally updating - /// an order-preserving mapping of dictionary values. This is relatively expensive - /// computationally but reduces the size of the encoded rows, minimising memory - /// usage and potentially yielding faster comparisons. - /// - /// Some applications may wish to instead trade-off space efficiency, for improved - /// encoding performance, by instead encoding dictionary values directly - /// - /// When `preserve_dictionaries` is true, fields will instead be encoded as their - /// underlying value, reversing any dictionary encoding - pub fn preserve_dictionaries(self, preserve_dictionaries: bool) -> Self { - Self { - preserve_dictionaries, - ..self - } + Self { options, data_type } } /// Return size of this instance in bytes. @@ -679,7 +569,53 @@ impl RowConverter { /// # Panics /// /// Panics if the schema of `columns` does not match that provided to [`RowConverter::new`] - pub fn convert_columns(&mut self, columns: &[ArrayRef]) -> Result { + pub fn convert_columns(&self, columns: &[ArrayRef]) -> Result { + let num_rows = columns.first().map(|x| x.len()).unwrap_or(0); + let mut rows = self.empty_rows(num_rows, 0); + self.append(&mut rows, columns)?; + Ok(rows) + } + + /// Convert [`ArrayRef`] columns appending to an existing [`Rows`] + /// + /// See [`Row`] for information on when [`Row`] can be compared + /// + /// # Panics + /// + /// Panics if + /// * The schema of `columns` does not match that provided to [`RowConverter::new`] + /// * The provided [`Rows`] were not created by this [`RowConverter`] + /// + /// ``` + /// # use std::sync::Arc; + /// # use std::collections::HashSet; + /// # use arrow_array::cast::AsArray; + /// # use arrow_array::StringArray; + /// # use arrow_row::{Row, RowConverter, SortField}; + /// # use arrow_schema::DataType; + /// # + /// let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); + /// let a1 = StringArray::from(vec!["hello", "world"]); + /// let a2 = StringArray::from(vec!["a", "a", "hello"]); + /// + /// let mut rows = converter.empty_rows(5, 128); + /// converter.append(&mut rows, &[Arc::new(a1)]).unwrap(); + /// converter.append(&mut rows, &[Arc::new(a2)]).unwrap(); + /// + /// let back = converter.convert_rows(&rows).unwrap(); + /// let values: Vec<_> = back[0].as_string::().iter().map(Option::unwrap).collect(); + /// assert_eq!(&values, &["hello", "world", "a", "a", "hello"]); + /// ``` + pub fn append( + &self, + rows: &mut Rows, + columns: &[ArrayRef], + ) -> Result<(), ArrowError> { + assert!( + Arc::ptr_eq(&rows.config.fields, &self.fields), + "rows were not produced by this RowConverter" + ); + if columns.len() != self.fields.len() { return Err(ArrowError::InvalidArgumentError(format!( "Incorrect number of arrays provided to RowConverter, expected {} got {}", @@ -690,7 +626,7 @@ impl RowConverter { let encoders = columns .iter() - .zip(&mut self.codecs) + .zip(&self.codecs) .zip(self.fields.iter()) .map(|((column, codec), field)| { if !column.data_type().equals_datatype(&field.data_type) { @@ -704,12 +640,35 @@ impl RowConverter { }) .collect::, _>>()?; - let config = RowConfig { - fields: Arc::clone(&self.fields), - // Don't need to validate UTF-8 as came from arrow array - validate_utf8: false, - }; - let mut rows = new_empty_rows(columns, &encoders, config); + let write_offset = rows.num_rows(); + let lengths = row_lengths(columns, &encoders); + + // We initialize the offsets shifted down by one row index. + // + // As the rows are appended to the offsets will be incremented to match + // + // For example, consider the case of 3 rows of length 3, 4, and 6 respectively. + // The offsets would be initialized to `0, 0, 3, 7` + // + // Writing the first row entirely would yield `0, 3, 3, 7` + // The second, `0, 3, 7, 7` + // The third, `0, 3, 7, 13` + // + // This would be the final offsets for reading + // + // In this way offsets tracks the position during writing whilst eventually serving + // as identifying the offsets of the written rows + rows.offsets.reserve(lengths.len()); + let mut cur_offset = rows.offsets[write_offset]; + for l in lengths { + rows.offsets.push(cur_offset); + cur_offset = cur_offset.checked_add(l).expect("overflow"); + } + + // Note this will not zero out any trailing data in `rows.buffer`, + // e.g. resulting from a call to `Rows::clear`, relying instead on the + // encoders not assuming a zero-initialized buffer + rows.buffer.resize(cur_offset, 0); for ((column, field), encoder) in columns.iter().zip(self.fields.iter()).zip(encoders) @@ -717,7 +676,7 @@ impl RowConverter { // We encode a column at a time to minimise dispatch overheads encode_column( &mut rows.buffer, - &mut rows.offsets, + &mut rows.offsets[write_offset..], column.as_ref(), field.options, &encoder, @@ -731,7 +690,7 @@ impl RowConverter { .for_each(|w| assert!(w[0] <= w[1], "offsets should be monotonic")); } - Ok(rows) + Ok(()) } /// Convert [`Rows`] columns into [`ArrayRef`] @@ -775,7 +734,7 @@ impl RowConverter { /// # use arrow_row::{Row, RowConverter, SortField}; /// # use arrow_schema::DataType; /// # - /// let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); + /// let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); /// let array = StringArray::from(vec!["hello", "world", "a", "a", "hello"]); /// /// // Convert to row format and deduplicate @@ -899,6 +858,7 @@ impl Rows { self.offsets.push(self.buffer.len()) } + /// Returns the row at index `row` pub fn row(&self, row: usize) -> Row<'_> { let end = self.offsets[row + 1]; let start = self.offsets[row]; @@ -908,10 +868,18 @@ impl Rows { } } + /// Sets the length of this [`Rows`] to 0 + pub fn clear(&mut self) { + self.offsets.truncate(1); + self.buffer.clear(); + } + + /// Returns the number of [`Row`] in this [`Rows`] pub fn num_rows(&self) -> usize { self.offsets.len() - 1 } + /// Returns an iterator over the [`Row`] in this [`Rows`] pub fn iter(&self) -> RowsIter<'_> { self.into_iter() } @@ -1021,7 +989,7 @@ impl<'a> Eq for Row<'a> {} impl<'a> PartialOrd for Row<'a> { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.data.partial_cmp(other.data) + Some(self.cmp(other)) } } @@ -1081,7 +1049,7 @@ impl Eq for OwnedRow {} impl PartialOrd for OwnedRow { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.row().partial_cmp(&other.row()) + Some(self.cmp(other)) } } @@ -1116,7 +1084,7 @@ fn null_sentinel(options: SortOptions) -> u8 { } /// Computes the length of each encoded [`Rows`] and returns an empty [`Rows`] -fn new_empty_rows(cols: &[ArrayRef], encoders: &[Encoder], config: RowConfig) -> Rows { +fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> Vec { use fixed::FixedLengthEncoding; let num_rows = cols.first().map(|x| x.len()).unwrap_or(0); @@ -1156,20 +1124,7 @@ fn new_empty_rows(cols: &[ArrayRef], encoders: &[Encoder], config: RowConfig) -> _ => unreachable!(), } } - Encoder::Dictionary(dict) => { - downcast_dictionary_array! { - array => { - for (v, length) in array.keys().iter().zip(lengths.iter_mut()) { - match v.and_then(|v| dict[v as usize]) { - Some(k) => *length += k.len() + 1, - None => *length += 1, - } - } - } - _ => unreachable!(), - } - } - Encoder::DictionaryValues(values, null) => { + Encoder::Dictionary(values, null) => { downcast_dictionary_array! { array => { for (v, length) in array.keys().iter().zip(lengths.iter_mut()) { @@ -1203,37 +1158,7 @@ fn new_empty_rows(cols: &[ArrayRef], encoders: &[Encoder], config: RowConfig) -> } } - let mut offsets = Vec::with_capacity(num_rows + 1); - offsets.push(0); - - // We initialize the offsets shifted down by one row index. - // - // As the rows are appended to the offsets will be incremented to match - // - // For example, consider the case of 3 rows of length 3, 4, and 6 respectively. - // The offsets would be initialized to `0, 0, 3, 7` - // - // Writing the first row entirely would yield `0, 3, 3, 7` - // The second, `0, 3, 7, 7` - // The third, `0, 3, 7, 13` - // - // This would be the final offsets for reading - // - // In this way offsets tracks the position during writing whilst eventually serving - // as identifying the offsets of the written rows - let mut cur_offset = 0_usize; - for l in lengths { - offsets.push(cur_offset); - cur_offset = cur_offset.checked_add(l).expect("overflow"); - } - - let buffer = vec![0_u8; cur_offset]; - - Rows { - buffer, - offsets, - config, - } + lengths } /// Encodes a column to the provided [`Rows`] incrementing the offsets as it progresses @@ -1275,13 +1200,7 @@ fn encode_column( _ => unreachable!(), } } - Encoder::Dictionary(dict) => { - downcast_dictionary_array! { - column => encode_dictionary(data, offsets, column, dict, opts), - _ => unreachable!() - } - } - Encoder::DictionaryValues(values, nulls) => { + Encoder::Dictionary(values, nulls) => { downcast_dictionary_array! { column => encode_dictionary_values(data, offsets, column, values, nulls), _ => unreachable!() @@ -1317,18 +1236,31 @@ fn encode_column( } } +/// Encode dictionary values not preserving the dictionary encoding +pub fn encode_dictionary_values( + data: &mut [u8], + offsets: &mut [usize], + column: &DictionaryArray, + values: &Rows, + null: &Row<'_>, +) { + for (offset, k) in offsets.iter_mut().skip(1).zip(column.keys()) { + let row = match k { + Some(k) => values.row(k.as_usize()).data, + None => null.data, + }; + let end_offset = *offset + row.len(); + data[*offset..end_offset].copy_from_slice(row); + *offset = end_offset; + } +} + macro_rules! decode_primitive_helper { ($t:ty, $rows:ident, $data_type:ident, $options:ident) => { Arc::new(decode_primitive::<$t>($rows, $data_type, $options)) }; } -macro_rules! decode_dictionary_helper { - ($t:ty, $interner:ident, $v:ident, $options:ident, $rows:ident) => { - Arc::new(decode_dictionary::<$t>($interner, $v, $options, $rows)?) - }; -} - /// Decodes a the provided `field` from `rows` /// /// # Safety @@ -1354,20 +1286,11 @@ unsafe fn decode_column( DataType::FixedSizeBinary(size) => Arc::new(decode_fixed_size_binary(rows, size, options)), DataType::Utf8 => Arc::new(decode_string::(rows, options, validate_utf8)), DataType::LargeUtf8 => Arc::new(decode_string::(rows, options, validate_utf8)), + DataType::Dictionary(_, _) => todo!(), _ => unreachable!() } } - Codec::Dictionary(interner) => { - let (k, v) = match &field.data_type { - DataType::Dictionary(k, v) => (k.as_ref(), v.as_ref()), - _ => unreachable!(), - }; - downcast_integer! { - k => (decode_dictionary_helper, interner, v, options, rows), - _ => unreachable!() - } - } - Codec::DictionaryValues(converter, _) => { + Codec::Dictionary(converter, _) => { let cols = converter.convert_raw(rows, validate_utf8)?; cols.into_iter().next().unwrap() } @@ -1439,7 +1362,7 @@ mod tests { ])) as ArrayRef, ]; - let mut converter = RowConverter::new(vec![ + let converter = RowConverter::new(vec![ SortField::new(DataType::Int16), SortField::new(DataType::Float32), ]) @@ -1481,9 +1404,10 @@ mod tests { #[test] fn test_decimal128() { - let mut converter = RowConverter::new(vec![SortField::new( - DataType::Decimal128(DECIMAL128_MAX_PRECISION, 7), - )]) + let converter = RowConverter::new(vec![SortField::new(DataType::Decimal128( + DECIMAL128_MAX_PRECISION, + 7, + ))]) .unwrap(); let col = Arc::new( Decimal128Array::from_iter([ @@ -1510,9 +1434,10 @@ mod tests { #[test] fn test_decimal256() { - let mut converter = RowConverter::new(vec![SortField::new( - DataType::Decimal256(DECIMAL256_MAX_PRECISION, 7), - )]) + let converter = RowConverter::new(vec![SortField::new(DataType::Decimal256( + DECIMAL256_MAX_PRECISION, + 7, + ))]) .unwrap(); let col = Arc::new( Decimal256Array::from_iter([ @@ -1541,7 +1466,7 @@ mod tests { #[test] fn test_bool() { - let mut converter = + let converter = RowConverter::new(vec![SortField::new(DataType::Boolean)]).unwrap(); let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)])) @@ -1555,7 +1480,7 @@ mod tests { let cols = converter.convert_rows(&rows).unwrap(); assert_eq!(&cols[0], &col); - let mut converter = RowConverter::new(vec![SortField::new_with_options( + let converter = RowConverter::new(vec![SortField::new_with_options( DataType::Boolean, SortOptions { descending: true, @@ -1578,7 +1503,7 @@ mod tests { .with_timezone("+01:00".to_string()); let d = a.data_type().clone(); - let mut converter = + let converter = RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap(); let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap(); let back = converter.convert_rows(&rows).unwrap(); @@ -1595,30 +1520,24 @@ mod tests { // Construct dictionary with a timezone let dict = a.finish(); let values = TimestampNanosecondArray::from(dict.values().to_data()); - let dict_with_tz = dict.with_values(&values.with_timezone("+02:00")); - let d = DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+02:00".into()), - )), - ); + let dict_with_tz = dict.with_values(Arc::new(values.with_timezone("+02:00"))); + let v = DataType::Timestamp(TimeUnit::Nanosecond, Some("+02:00".into())); + let d = DataType::Dictionary(Box::new(DataType::Int32), Box::new(v.clone())); assert_eq!(dict_with_tz.data_type(), &d); - let mut converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); let rows = converter .convert_columns(&[Arc::new(dict_with_tz) as _]) .unwrap(); let back = converter.convert_rows(&rows).unwrap(); assert_eq!(back.len(), 1); - assert_eq!(back[0].data_type(), &d); + assert_eq!(back[0].data_type(), &v); } #[test] fn test_null_encoding() { let col = Arc::new(NullArray::new(10)); - let mut converter = - RowConverter::new(vec![SortField::new(DataType::Null)]).unwrap(); + let converter = RowConverter::new(vec![SortField::new(DataType::Null)]).unwrap(); let rows = converter.convert_columns(&[col]).unwrap(); assert_eq!(rows.num_rows(), 10); assert_eq!(rows.row(1).data.len(), 0); @@ -1634,8 +1553,7 @@ mod tests { Some(""), ])) as ArrayRef; - let mut converter = - RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); + let converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); assert!(rows.row(1) < rows.row(0)); @@ -1650,17 +1568,23 @@ mod tests { None, Some(vec![0_u8; 0]), Some(vec![0_u8; 6]), + Some(vec![0_u8; variable::MINI_BLOCK_SIZE]), + Some(vec![0_u8; variable::MINI_BLOCK_SIZE + 1]), Some(vec![0_u8; variable::BLOCK_SIZE]), Some(vec![0_u8; variable::BLOCK_SIZE + 1]), Some(vec![1_u8; 6]), + Some(vec![1_u8; variable::MINI_BLOCK_SIZE]), + Some(vec![1_u8; variable::MINI_BLOCK_SIZE + 1]), Some(vec![1_u8; variable::BLOCK_SIZE]), Some(vec![1_u8; variable::BLOCK_SIZE + 1]), Some(vec![0xFF_u8; 6]), + Some(vec![0xFF_u8; variable::MINI_BLOCK_SIZE]), + Some(vec![0xFF_u8; variable::MINI_BLOCK_SIZE + 1]), Some(vec![0xFF_u8; variable::BLOCK_SIZE]), Some(vec![0xFF_u8; variable::BLOCK_SIZE + 1]), ])) as ArrayRef; - let mut converter = + let converter = RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap(); @@ -1680,7 +1604,7 @@ mod tests { let cols = converter.convert_rows(&rows).unwrap(); assert_eq!(&cols[0], &col); - let mut converter = RowConverter::new(vec![SortField::new_with_options( + let converter = RowConverter::new(vec![SortField::new_with_options( DataType::Binary, SortOptions { descending: true, @@ -1708,9 +1632,9 @@ mod tests { } /// If `exact` is false performs a logical comparison between a and dictionary-encoded b - fn dictionary_eq(exact: bool, a: &dyn Array, b: &dyn Array) { + fn dictionary_eq(a: &dyn Array, b: &dyn Array) { match b.data_type() { - DataType::Dictionary(_, v) if !exact => { + DataType::Dictionary(_, v) => { assert_eq!(a.data_type(), v.as_ref()); let b = arrow_cast::cast(b, v).unwrap(); assert_eq!(a, b.as_ref()) @@ -1721,11 +1645,6 @@ mod tests { #[test] fn test_string_dictionary() { - test_string_dictionary_impl(false); - test_string_dictionary_impl(true); - } - - fn test_string_dictionary_impl(preserve: bool) { let a = Arc::new(DictionaryArray::::from_iter([ Some("foo"), Some("hello"), @@ -1737,8 +1656,8 @@ mod tests { Some("hello"), ])) as ArrayRef; - let field = SortField::new(a.data_type().clone()).preserve_dictionaries(preserve); - let mut converter = RowConverter::new(vec![field]).unwrap(); + let field = SortField::new(a.data_type().clone()); + let converter = RowConverter::new(vec![field]).unwrap(); let rows_a = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); assert!(rows_a.row(3) < rows_a.row(5)); @@ -1751,7 +1670,7 @@ mod tests { assert_eq!(rows_a.row(1), rows_a.row(7)); let cols = converter.convert_rows(&rows_a).unwrap(); - dictionary_eq(preserve, &cols[0], &a); + dictionary_eq(&cols[0], &a); let b = Arc::new(DictionaryArray::::from_iter([ Some("hello"), @@ -1765,16 +1684,15 @@ mod tests { assert!(rows_b.row(2) < rows_a.row(0)); let cols = converter.convert_rows(&rows_b).unwrap(); - dictionary_eq(preserve, &cols[0], &b); + dictionary_eq(&cols[0], &b); - let mut converter = RowConverter::new(vec![SortField::new_with_options( + let converter = RowConverter::new(vec![SortField::new_with_options( a.data_type().clone(), SortOptions { descending: true, nulls_first: false, }, - ) - .preserve_dictionaries(preserve)]) + )]) .unwrap(); let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); @@ -1784,16 +1702,15 @@ mod tests { assert!(rows_c.row(3) > rows_c.row(0)); let cols = converter.convert_rows(&rows_c).unwrap(); - dictionary_eq(preserve, &cols[0], &a); + dictionary_eq(&cols[0], &a); - let mut converter = RowConverter::new(vec![SortField::new_with_options( + let converter = RowConverter::new(vec![SortField::new_with_options( a.data_type().clone(), SortOptions { descending: true, nulls_first: true, }, - ) - .preserve_dictionaries(preserve)]) + )]) .unwrap(); let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap(); @@ -1803,7 +1720,7 @@ mod tests { assert!(rows_c.row(3) < rows_c.row(0)); let cols = converter.convert_rows(&rows_c).unwrap(); - dictionary_eq(preserve, &cols[0], &a); + dictionary_eq(&cols[0], &a); } #[test] @@ -1816,7 +1733,7 @@ mod tests { let s1 = Arc::new(StructArray::from(vec![(a_f, a), (u_f, u)])) as ArrayRef; let sort_fields = vec![SortField::new(s1.data_type().clone())]; - let mut converter = RowConverter::new(sort_fields).unwrap(); + let converter = RowConverter::new(sort_fields).unwrap(); let r1 = converter.convert_columns(&[Arc::clone(&s1)]).unwrap(); for (a, b) in r1.iter().zip(r1.iter().skip(1)) { @@ -1865,16 +1782,14 @@ mod tests { let data_type = a.data_type().clone(); let columns = [Arc::new(a) as ArrayRef]; - for preserve in [true, false] { - let field = SortField::new(data_type.clone()).preserve_dictionaries(preserve); - let mut converter = RowConverter::new(vec![field]).unwrap(); - let rows = converter.convert_columns(&columns).unwrap(); - assert!(rows.row(0) < rows.row(1)); - assert!(rows.row(2) < rows.row(0)); - assert!(rows.row(3) < rows.row(2)); - assert!(rows.row(6) < rows.row(2)); - assert!(rows.row(3) < rows.row(6)); - } + let field = SortField::new(data_type.clone()); + let converter = RowConverter::new(vec![field]).unwrap(); + let rows = converter.convert_columns(&columns).unwrap(); + assert!(rows.row(0) < rows.row(1)); + assert!(rows.row(2) < rows.row(0)); + assert!(rows.row(3) < rows.row(2)); + assert!(rows.row(6) < rows.row(2)); + assert!(rows.row(3) < rows.row(6)); } #[test] @@ -1895,22 +1810,20 @@ mod tests { .unwrap(); let columns = [Arc::new(DictionaryArray::::from(data)) as ArrayRef]; - for preserve in [true, false] { - let field = SortField::new(data_type.clone()).preserve_dictionaries(preserve); - let mut converter = RowConverter::new(vec![field]).unwrap(); - let rows = converter.convert_columns(&columns).unwrap(); - - assert_eq!(rows.row(0), rows.row(1)); - assert_eq!(rows.row(3), rows.row(4)); - assert_eq!(rows.row(4), rows.row(5)); - assert!(rows.row(3) < rows.row(0)); - } + let field = SortField::new(data_type.clone()); + let converter = RowConverter::new(vec![field]).unwrap(); + let rows = converter.convert_columns(&columns).unwrap(); + + assert_eq!(rows.row(0), rows.row(1)); + assert_eq!(rows.row(3), rows.row(4)); + assert_eq!(rows.row(4), rows.row(5)); + assert!(rows.row(3) < rows.row(0)); } #[test] #[should_panic(expected = "Encountered non UTF-8 data")] fn test_invalid_utf8() { - let mut converter = + let converter = RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap(); let array = Arc::new(BinaryArray::from_iter_values([&[0xFF]])) as _; let rows = converter.convert_columns(&[array]).unwrap(); @@ -1927,8 +1840,7 @@ mod tests { #[should_panic(expected = "rows were not produced by this RowConverter")] fn test_different_converter() { let values = Arc::new(Int32Array::from_iter([Some(1), Some(-1)])); - let mut converter = - RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); let rows = converter.convert_columns(&[values]).unwrap(); let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); @@ -1959,7 +1871,7 @@ mod tests { let list = Arc::new(builder.finish()) as ArrayRef; let d = list.data_type().clone(); - let mut converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] @@ -1979,7 +1891,7 @@ mod tests { nulls_first: false, }; let field = SortField::new_with_options(d.clone(), options); - let mut converter = RowConverter::new(vec![field]).unwrap(); + let converter = RowConverter::new(vec![field]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] @@ -1999,7 +1911,7 @@ mod tests { nulls_first: false, }; let field = SortField::new_with_options(d.clone(), options); - let mut converter = RowConverter::new(vec![field]).unwrap(); + let converter = RowConverter::new(vec![field]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] @@ -2019,7 +1931,7 @@ mod tests { nulls_first: true, }; let field = SortField::new_with_options(d, options); - let mut converter = RowConverter::new(vec![field]).unwrap(); + let converter = RowConverter::new(vec![field]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] @@ -2083,7 +1995,7 @@ mod tests { nulls_first: true, }; let field = SortField::new_with_options(d.clone(), options); - let mut converter = RowConverter::new(vec![field]).unwrap(); + let converter = RowConverter::new(vec![field]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) > rows.row(1)); @@ -2102,7 +2014,7 @@ mod tests { nulls_first: true, }; let field = SortField::new_with_options(d.clone(), options); - let mut converter = RowConverter::new(vec![field]).unwrap(); + let converter = RowConverter::new(vec![field]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) > rows.row(1)); @@ -2121,7 +2033,7 @@ mod tests { nulls_first: false, }; let field = SortField::new_with_options(d, options); - let mut converter = RowConverter::new(vec![field]).unwrap(); + let converter = RowConverter::new(vec![field]).unwrap(); let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) < rows.row(1)); @@ -2148,35 +2060,6 @@ mod tests { test_nested_list::(); } - #[test] - fn test_dictionary_preserving() { - let mut dict = StringDictionaryBuilder::::new(); - dict.append_value("foo"); - dict.append_value("foo"); - dict.append_value("bar"); - dict.append_value("bar"); - dict.append_value("bar"); - dict.append_value("bar"); - - let array = Arc::new(dict.finish()) as ArrayRef; - let preserve = SortField::new(array.data_type().clone()); - let non_preserve = preserve.clone().preserve_dictionaries(false); - - let mut c1 = RowConverter::new(vec![preserve]).unwrap(); - let r1 = c1.convert_columns(&[array.clone()]).unwrap(); - - let mut c2 = RowConverter::new(vec![non_preserve]).unwrap(); - let r2 = c2.convert_columns(&[array.clone()]).unwrap(); - - for r in r1.iter() { - assert_eq!(r.data.len(), 3); - } - - for r in r2.iter() { - assert_eq!(r.data.len(), 34); - } - } - fn generate_primitive_array(len: usize, valid_percent: f64) -> PrimitiveArray where K: ArrowPrimitiveType, @@ -2332,21 +2215,15 @@ mod tests { }) .collect(); - let preserve: Vec<_> = (0..num_columns).map(|_| rng.gen_bool(0.5)).collect(); - let comparator = LexicographicalComparator::try_new(&sort_columns).unwrap(); let columns = options .into_iter() .zip(&arrays) - .zip(&preserve) - .map(|((o, a), p)| { - SortField::new_with_options(a.data_type().clone(), o) - .preserve_dictionaries(*p) - }) + .map(|(o, a)| SortField::new_with_options(a.data_type().clone(), o)) .collect(); - let mut converter = RowConverter::new(columns).unwrap(); + let converter = RowConverter::new(columns).unwrap(); let rows = converter.convert_columns(&arrays).unwrap(); for i in 0..len { @@ -2369,10 +2246,66 @@ mod tests { } let back = converter.convert_rows(&rows).unwrap(); - for ((actual, expected), preserve) in back.iter().zip(&arrays).zip(preserve) { + for (actual, expected) in back.iter().zip(&arrays) { actual.to_data().validate_full().unwrap(); - dictionary_eq(preserve, actual, expected) + dictionary_eq(actual, expected) } } } + + #[test] + fn test_clear() { + let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap(); + let mut rows = converter.empty_rows(3, 128); + + let first = Int32Array::from(vec![None, Some(2), Some(4)]); + let second = Int32Array::from(vec![Some(2), None, Some(4)]); + let arrays = vec![Arc::new(first) as ArrayRef, Arc::new(second) as ArrayRef]; + + for array in arrays.iter() { + rows.clear(); + converter.append(&mut rows, &[array.clone()]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(&back[0], array); + } + + let mut rows_expected = converter.empty_rows(3, 128); + converter.append(&mut rows_expected, &arrays[1..]).unwrap(); + + for (i, (actual, expected)) in rows.iter().zip(rows_expected.iter()).enumerate() { + assert_eq!( + actual, expected, + "For row {}: expected {:?}, actual: {:?}", + i, expected, actual + ); + } + } + + #[test] + fn test_append_codec_dictionary_binary() { + use DataType::*; + // Dictionary RowConverter + let converter = RowConverter::new(vec![SortField::new(Dictionary( + Box::new(Int32), + Box::new(Binary), + ))]) + .unwrap(); + let mut rows = converter.empty_rows(4, 128); + + let keys = Int32Array::from_iter_values([0, 1, 2, 3]); + let values = BinaryArray::from(vec![ + Some("a".as_bytes()), + Some(b"b"), + Some(b"c"), + Some(b"d"), + ]); + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + + rows.clear(); + let array = Arc::new(dict_array) as ArrayRef; + converter.append(&mut rows, &[array.clone()]).unwrap(); + let back = converter.convert_rows(&rows).unwrap(); + + dictionary_eq(&back[0], &array); + } } diff --git a/arrow-row/src/variable.rs b/arrow-row/src/variable.rs index e9f6160bf43c..6c9c4c43bca3 100644 --- a/arrow-row/src/variable.rs +++ b/arrow-row/src/variable.rs @@ -26,6 +26,14 @@ use arrow_schema::{DataType, SortOptions}; /// The block size of the variable length encoding pub const BLOCK_SIZE: usize = 32; +/// The first block is split into `MINI_BLOCK_COUNT` mini-blocks +/// +/// This helps to reduce the space amplification for small strings +pub const MINI_BLOCK_COUNT: usize = 4; + +/// The mini block size +pub const MINI_BLOCK_SIZE: usize = BLOCK_SIZE / MINI_BLOCK_COUNT; + /// The continuation token pub const BLOCK_CONTINUATION: u8 = 0xFF; @@ -45,7 +53,12 @@ pub fn encoded_len(a: Option<&[u8]>) -> usize { #[inline] pub fn padded_length(a: Option) -> usize { match a { - Some(a) => 1 + ceil(a, BLOCK_SIZE) * (BLOCK_SIZE + 1), + Some(a) if a <= BLOCK_SIZE => { + 1 + ceil(a, MINI_BLOCK_SIZE) * (MINI_BLOCK_SIZE + 1) + } + // Each miniblock ends with a 1 byte continuation, therefore add + // `(MINI_BLOCK_COUNT - 1)` additional bytes over non-miniblock size + Some(a) => MINI_BLOCK_COUNT + ceil(a, BLOCK_SIZE) * (BLOCK_SIZE + 1), None => 1, } } @@ -82,44 +95,23 @@ pub fn encode_one(out: &mut [u8], val: Option<&[u8]>, opts: SortOptions) -> usiz 1 } Some(val) => { - let block_count = ceil(val.len(), BLOCK_SIZE); - let end_offset = 1 + block_count * (BLOCK_SIZE + 1); - let to_write = &mut out[..end_offset]; - // Write `2_u8` to demarcate as non-empty, non-null string - to_write[0] = NON_EMPTY_SENTINEL; - - let chunks = val.chunks_exact(BLOCK_SIZE); - let remainder = chunks.remainder(); - for (input, output) in chunks - .clone() - .zip(to_write[1..].chunks_exact_mut(BLOCK_SIZE + 1)) - { - let input: &[u8; BLOCK_SIZE] = input.try_into().unwrap(); - let out_block: &mut [u8; BLOCK_SIZE] = - (&mut output[..BLOCK_SIZE]).try_into().unwrap(); - - *out_block = *input; - - // Indicate that there are further blocks to follow - output[BLOCK_SIZE] = BLOCK_CONTINUATION; - } + out[0] = NON_EMPTY_SENTINEL; - if !remainder.is_empty() { - let start_offset = 1 + (block_count - 1) * (BLOCK_SIZE + 1); - to_write[start_offset..start_offset + remainder.len()] - .copy_from_slice(remainder); - *to_write.last_mut().unwrap() = remainder.len() as u8; + let len = if val.len() <= BLOCK_SIZE { + 1 + encode_blocks::(&mut out[1..], val) } else { - // We must overwrite the continuation marker written by the loop above - *to_write.last_mut().unwrap() = BLOCK_SIZE as u8; - } + let (initial, rem) = val.split_at(BLOCK_SIZE); + let offset = encode_blocks::(&mut out[1..], initial); + out[offset] = BLOCK_CONTINUATION; + 1 + offset + encode_blocks::(&mut out[1 + offset..], rem) + }; if opts.descending { // Invert bits - to_write.iter_mut().for_each(|v| *v = !*v) + out[..len].iter_mut().for_each(|v| *v = !*v) } - end_offset + len } None => { out[0] = null_sentinel(opts); @@ -128,8 +120,37 @@ pub fn encode_one(out: &mut [u8], val: Option<&[u8]>, opts: SortOptions) -> usiz } } -/// Returns the number of bytes of encoded data -fn decoded_len(row: &[u8], options: SortOptions) -> usize { +/// Writes `val` in `SIZE` blocks with the appropriate continuation tokens +#[inline] +fn encode_blocks(out: &mut [u8], val: &[u8]) -> usize { + let block_count = ceil(val.len(), SIZE); + let end_offset = block_count * (SIZE + 1); + let to_write = &mut out[..end_offset]; + + let chunks = val.chunks_exact(SIZE); + let remainder = chunks.remainder(); + for (input, output) in chunks.clone().zip(to_write.chunks_exact_mut(SIZE + 1)) { + let input: &[u8; SIZE] = input.try_into().unwrap(); + let out_block: &mut [u8; SIZE] = (&mut output[..SIZE]).try_into().unwrap(); + + *out_block = *input; + + // Indicate that there are further blocks to follow + output[SIZE] = BLOCK_CONTINUATION; + } + + if !remainder.is_empty() { + let start_offset = (block_count - 1) * (SIZE + 1); + to_write[start_offset..start_offset + remainder.len()].copy_from_slice(remainder); + *to_write.last_mut().unwrap() = remainder.len() as u8; + } else { + // We must overwrite the continuation marker written by the loop above + *to_write.last_mut().unwrap() = SIZE as u8; + } + end_offset +} + +fn decode_blocks(row: &[u8], options: SortOptions, mut f: impl FnMut(&[u8])) -> usize { let (non_empty_sentinel, continuation) = match options.descending { true => (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION), false => (NON_EMPTY_SENTINEL, BLOCK_CONTINUATION), @@ -137,26 +158,44 @@ fn decoded_len(row: &[u8], options: SortOptions) -> usize { if row[0] != non_empty_sentinel { // Empty or null string - return 0; + return 1; } - let mut str_len = 0; + // Extracts the block length from the sentinel + let block_len = |sentinel: u8| match options.descending { + true => !sentinel as usize, + false => sentinel as usize, + }; + let mut idx = 1; + for _ in 0..MINI_BLOCK_COUNT { + let sentinel = row[idx + MINI_BLOCK_SIZE]; + if sentinel != continuation { + f(&row[idx..idx + block_len(sentinel)]); + return idx + MINI_BLOCK_SIZE + 1; + } + f(&row[idx..idx + MINI_BLOCK_SIZE]); + idx += MINI_BLOCK_SIZE + 1; + } + loop { let sentinel = row[idx + BLOCK_SIZE]; - if sentinel == continuation { - idx += BLOCK_SIZE + 1; - str_len += BLOCK_SIZE; - continue; + if sentinel != continuation { + f(&row[idx..idx + block_len(sentinel)]); + return idx + BLOCK_SIZE + 1; } - let block_len = match options.descending { - true => !sentinel, - false => sentinel, - }; - return str_len + block_len as usize; + f(&row[idx..idx + BLOCK_SIZE]); + idx += BLOCK_SIZE + 1; } } +/// Returns the number of bytes of encoded data +fn decoded_len(row: &[u8], options: SortOptions) -> usize { + let mut len = 0; + decode_blocks(row, options, |block| len += block.len()); + len +} + /// Decodes a binary array from `rows` with the provided `options` pub fn decode_binary( rows: &mut [&[u8]], @@ -176,22 +215,8 @@ pub fn decode_binary( let mut values = MutableBuffer::new(values_capacity); for row in rows { - let str_length = decoded_len(row, options); - let mut to_read = str_length; - let mut offset = 1; - while to_read >= BLOCK_SIZE { - to_read -= BLOCK_SIZE; - - values.extend_from_slice(&row[offset..offset + BLOCK_SIZE]); - offset += BLOCK_SIZE + 1; - } - - if to_read != 0 { - values.extend_from_slice(&row[offset..offset + to_read]); - offset += BLOCK_SIZE + 1; - } + let offset = decode_blocks(row, options, |b| values.extend_from_slice(b)); *row = &row[offset..]; - offsets.append(I::from_usize(values.len()).expect("offset overflow")) } diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index edd1dd09620e..4f8c8a18bd17 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -18,7 +18,7 @@ use std::fmt; use std::sync::Arc; -use crate::{FieldRef, Fields, UnionFields}; +use crate::{Field, FieldRef, Fields, UnionFields}; /// The set of datatypes that are supported by this implementation of Apache Arrow. /// @@ -576,6 +576,11 @@ impl DataType { _ => self == other, } } + + /// Create a List DataType default name is "item" + pub fn new_list(data_type: DataType, nullable: bool) -> Self { + DataType::List(Arc::new(Field::new("item", data_type, nullable))) + } } /// The maximum precision for [DataType::Decimal128] values diff --git a/arrow-schema/src/error.rs b/arrow-schema/src/error.rs index cd236c0871a6..8ea533db89af 100644 --- a/arrow-schema/src/error.rs +++ b/arrow-schema/src/error.rs @@ -35,7 +35,8 @@ pub enum ArrowError { DivideByZero, CsvError(String), JsonError(String), - IoError(String), + IoError(String, std::io::Error), + IpcError(String), InvalidArgumentError(String), ParquetError(String), /// Error during import or export to/from the C Data Interface @@ -53,7 +54,7 @@ impl ArrowError { impl From for ArrowError { fn from(error: std::io::Error) -> Self { - ArrowError::IoError(error.to_string()) + ArrowError::IoError(error.to_string(), error) } } @@ -65,7 +66,7 @@ impl From for ArrowError { impl From> for ArrowError { fn from(error: std::io::IntoInnerError) -> Self { - ArrowError::IoError(error.to_string()) + ArrowError::IoError(error.to_string(), error.into()) } } @@ -84,7 +85,8 @@ impl Display for ArrowError { ArrowError::DivideByZero => write!(f, "Divide by zero error"), ArrowError::CsvError(desc) => write!(f, "Csv error: {desc}"), ArrowError::JsonError(desc) => write!(f, "Json error: {desc}"), - ArrowError::IoError(desc) => write!(f, "Io error: {desc}"), + ArrowError::IoError(desc, _) => write!(f, "Io error: {desc}"), + ArrowError::IpcError(desc) => write!(f, "Ipc error: {desc}"), ArrowError::InvalidArgumentError(desc) => { write!(f, "Invalid argument error: {desc}") } diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs index cd3c207a56c5..a17dbe769f2e 100644 --- a/arrow-schema/src/ffi.rs +++ b/arrow-schema/src/ffi.rs @@ -833,7 +833,7 @@ mod tests { // Construct a map array from the above two let map_data_type = - DataType::Map(Arc::new(Field::new("entries", entry_struct, true)), true); + DataType::Map(Arc::new(Field::new("entries", entry_struct, false)), true); let arrow_schema = FFI_ArrowSchema::try_from(map_data_type).unwrap(); assert!(arrow_schema.map_keys_sorted()); diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index f38e1e26ad26..b50778c785fb 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -170,7 +170,7 @@ impl Field { /// Create a new [`Field`] with [`DataType::Struct`] /// - /// - `name`: the name of the [`DataType::List`] field + /// - `name`: the name of the [`DataType::Struct`] field /// - `fields`: the description of each struct element /// - `nullable`: if the [`DataType::Struct`] array is nullable pub fn new_struct( @@ -186,8 +186,6 @@ impl Field { /// - `name`: the name of the [`DataType::List`] field /// - `value`: the description of each list element /// - `nullable`: if the [`DataType::List`] array is nullable - /// - /// Uses "item" as the name of the child field, this can be overridden with [`Self::new`] pub fn new_list( name: impl Into, value: impl Into, @@ -463,7 +461,10 @@ impl Field { )); } }, - DataType::Null + DataType::Null => { + self.nullable = true; + self.data_type = from.data_type.clone(); + } | DataType::Boolean | DataType::Int8 | DataType::Int16 @@ -496,7 +497,9 @@ impl Field { | DataType::LargeUtf8 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { - if self.data_type != from.data_type { + if from.data_type == DataType::Null { + self.nullable = true; + } else if self.data_type != from.data_type { return Err(ArrowError::SchemaError( format!("Fail to merge schema field '{}' because the from data_type = {} does not equal {}", self.name, from.data_type, self.data_type) @@ -582,6 +585,21 @@ mod test { assert_eq!("Schema error: Fail to merge schema field 'c1' because the from data_type = Float32 does not equal Int64", result); } + #[test] + fn test_merge_with_null() { + let mut field1 = Field::new("c1", DataType::Null, true); + field1 + .try_merge(&Field::new("c1", DataType::Float32, false)) + .expect("should widen type to nullable float"); + assert_eq!(Field::new("c1", DataType::Float32, true), field1); + + let mut field2 = Field::new("c2", DataType::Utf8, false); + field2 + .try_merge(&Field::new("c2", DataType::Null, true)) + .expect("should widen type to nullable utf8"); + assert_eq!(Field::new("c2", DataType::Utf8, true), field2); + } + #[test] fn test_fields_with_dict_id() { let dict1 = Field::new_dict( diff --git a/arrow-select/Cargo.toml b/arrow-select/Cargo.toml index ff8a212c7b52..023788799c94 100644 --- a/arrow-select/Cargo.toml +++ b/arrow-select/Cargo.toml @@ -39,6 +39,7 @@ arrow-data = { workspace = true } arrow-schema = { workspace = true } arrow-array = { workspace = true } num = { version = "0.4", default-features = false, features = ["std"] } +ahash = { version = "0.8", default-features = false} [features] default = [] diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 0bf4c97ff827..a6dcca24eace 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -30,20 +30,20 @@ //! assert_eq!(arr.len(), 3); //! ``` +use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}; +use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::ArrowNativeType; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer}; use arrow_data::transform::{Capacities, MutableArrayData}; use arrow_schema::{ArrowError, DataType, SchemaRef}; +use std::sync::Arc; fn binary_capacity(arrays: &[&dyn Array]) -> Capacities { let mut item_capacity = 0; let mut bytes_capacity = 0; for array in arrays { - let a = array - .as_any() - .downcast_ref::>() - .unwrap(); + let a = array.as_bytes::(); // Guaranteed to always have at least one element let offsets = a.value_offsets(); @@ -54,6 +54,59 @@ fn binary_capacity(arrays: &[&dyn Array]) -> Capacities { Capacities::Binary(item_capacity, Some(bytes_capacity)) } +fn concat_dictionaries( + arrays: &[&dyn Array], +) -> Result { + let mut output_len = 0; + let dictionaries: Vec<_> = arrays + .iter() + .map(|x| x.as_dictionary::()) + .inspect(|d| output_len += d.len()) + .collect(); + + if !should_merge_dictionary_values::(&dictionaries, output_len) { + return concat_fallback(arrays, Capacities::Array(output_len)); + } + + let merged = merge_dictionary_values(&dictionaries, None)?; + + // Recompute keys + let mut key_values = Vec::with_capacity(output_len); + + let mut has_nulls = false; + for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) { + has_nulls |= d.null_count() != 0; + for key in d.keys().values() { + // Use get to safely handle nulls + key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default()) + } + } + + let nulls = has_nulls.then(|| { + let mut nulls = BooleanBufferBuilder::new(output_len); + for d in &dictionaries { + match d.nulls() { + Some(n) => nulls.append_buffer(n.inner()), + None => nulls.append_n(d.len(), true), + } + } + NullBuffer::new(nulls.finish()) + }); + + let keys = PrimitiveArray::::new(key_values.into(), nulls); + // Sanity check + assert_eq!(keys.len(), output_len); + + let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) }; + Ok(Arc::new(array)) +} + +macro_rules! dict_helper { + ($t:ty, $arrays:expr) => { + return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _) + }; +} + /// Concatenate multiple [Array] of the same type into a single [ArrayRef]. pub fn concat(arrays: &[&dyn Array]) -> Result { if arrays.is_empty() { @@ -78,9 +131,23 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { DataType::LargeUtf8 => binary_capacity::(arrays), DataType::Binary => binary_capacity::(arrays), DataType::LargeBinary => binary_capacity::(arrays), + DataType::Dictionary(k, _) => downcast_integer! { + k.as_ref() => (dict_helper, arrays), + _ => unreachable!("illegal dictionary key type {k}") + }, _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()), }; + concat_fallback(arrays, capacity) +} + +/// Concatenates arrays using MutableArrayData +/// +/// This will naively concatenate dictionaries +fn concat_fallback( + arrays: &[&dyn Array], + capacity: Capacities, +) -> Result { let array_data: Vec<_> = arrays.iter().map(|a| a.to_data()).collect::>(); let array_data = array_data.iter().collect(); let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity); @@ -92,29 +159,28 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { Ok(make_array(mutable.freeze())) } -/// Concatenates `batches` together into a single record batch. +/// Concatenates `batches` together into a single [`RecordBatch`]. +/// +/// The output batch has the specified `schemas`; The schema of the +/// input are ignored. +/// +/// Returns an error if the types of underlying arrays are different. pub fn concat_batches<'a>( schema: &SchemaRef, input_batches: impl IntoIterator, ) -> Result { + // When schema is empty, sum the number of the rows of all batches + if schema.fields().is_empty() { + let num_rows: usize = input_batches.into_iter().map(RecordBatch::num_rows).sum(); + let mut options = RecordBatchOptions::default(); + options.row_count = Some(num_rows); + return RecordBatch::try_new_with_options(schema.clone(), vec![], &options); + } + let batches: Vec<&RecordBatch> = input_batches.into_iter().collect(); if batches.is_empty() { return Ok(RecordBatch::new_empty(schema.clone())); } - if let Some((i, _)) = batches - .iter() - .enumerate() - .find(|&(_, batch)| batch.schema() != *schema) - { - return Err(ArrowError::InvalidArgumentError(format!( - "batches[{i}] schema is different with argument schema. - batches[{i}] schema: {:?}, - argument schema: {:?} - ", - batches[i].schema(), - *schema - ))); - } let field_num = schema.fields().len(); let mut arrays = Vec::with_capacity(field_num); for i in 0..field_num { @@ -132,6 +198,7 @@ pub fn concat_batches<'a>( #[cfg(test)] mod tests { use super::*; + use arrow_array::builder::StringDictionaryBuilder; use arrow_array::cast::AsArray; use arrow_schema::{Field, Schema}; use std::sync::Arc; @@ -142,6 +209,21 @@ mod tests { assert!(re.is_err()); } + #[test] + fn test_concat_batches_no_columns() { + // Test concat using empty schema / batches without columns + let schema = Arc::new(Schema::empty()); + + let mut options = RecordBatchOptions::default(); + options.row_count = Some(100); + let batch = + RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + // put in 2 batches of 100 rows each + let re = concat_batches(&schema, &[batch.clone(), batch]).unwrap(); + + assert_eq!(re.num_rows(), 200); + } + #[test] fn test_concat_one_element_vec() { let arr = Arc::new(PrimitiveArray::::from(vec![ @@ -315,10 +397,7 @@ mod tests { let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap(); - let expected = list1 - .into_iter() - .chain(list2.into_iter()) - .chain(list3.into_iter()); + let expected = list1.into_iter().chain(list2).chain(list3); let array_expected = ListArray::from_iter_primitive::(expected); assert_eq!(array_result.as_ref(), &array_expected as &dyn Array); @@ -448,29 +527,10 @@ mod tests { } fn collect_string_dictionary( - dictionary: &DictionaryArray, - ) -> Vec> { - let values = dictionary.values(); - let values = values.as_any().downcast_ref::().unwrap(); - - dictionary - .keys() - .iter() - .map(|key| key.map(|key| values.value(key as _).to_string())) - .collect() - } - - fn concat_dictionary( - input_1: DictionaryArray, - input_2: DictionaryArray, - ) -> Vec> { - let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap(); - let concat = concat - .as_any() - .downcast_ref::>() - .unwrap(); - - collect_string_dictionary(concat) + array: &DictionaryArray, + ) -> Vec> { + let concrete = array.downcast_dict::().unwrap(); + concrete.into_iter().collect() } #[test] @@ -489,11 +549,19 @@ mod tests { "E", ] .into_iter() - .map(|x| Some(x.to_string())) + .map(Some) .collect(); - let concat = concat_dictionary(input_1, input_2); - assert_eq!(concat, expected); + let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap(); + let dictionary = concat.as_dictionary::(); + let actual = collect_string_dictionary(dictionary); + assert_eq!(actual, expected); + + // Should have concatenated inputs together + assert_eq!( + dictionary.values().len(), + input_1.values().len() + input_2.values().len(), + ) } #[test] @@ -503,16 +571,45 @@ mod tests { .into_iter() .collect(); let input_2: DictionaryArray = vec![None].into_iter().collect(); - let expected = vec![ - Some("foo".to_string()), - Some("bar".to_string()), - None, - Some("fiz".to_string()), - None, - ]; + let expected = vec![Some("foo"), Some("bar"), None, Some("fiz"), None]; + + let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap(); + let dictionary = concat.as_dictionary::(); + let actual = collect_string_dictionary(dictionary); + assert_eq!(actual, expected); - let concat = concat_dictionary(input_1, input_2); - assert_eq!(concat, expected); + // Should have concatenated inputs together + assert_eq!( + dictionary.values().len(), + input_1.values().len() + input_2.values().len(), + ) + } + + #[test] + fn test_string_dictionary_merge() { + let mut builder = StringDictionaryBuilder::::new(); + for i in 0..20 { + builder.append(&i.to_string()).unwrap(); + } + let input_1 = builder.finish(); + + let mut builder = StringDictionaryBuilder::::new(); + for i in 0..30 { + builder.append(&i.to_string()).unwrap(); + } + let input_2 = builder.finish(); + + let expected: Vec<_> = (0..20).chain(0..30).map(|x| x.to_string()).collect(); + let expected: Vec<_> = expected.iter().map(|x| Some(x.as_str())).collect(); + + let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap(); + let dictionary = concat.as_dictionary::(); + let actual = collect_string_dictionary(dictionary); + assert_eq!(actual, expected); + + // Should have merged inputs together + // Not 30 as this is done on a best-effort basis + assert_eq!(dictionary.values().len(), 33) } #[test] @@ -536,7 +633,7 @@ mod tests { fn test_dictionary_concat_reuse() { let array: DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); - let copy: DictionaryArray = array.to_data().into(); + let copy: DictionaryArray = array.clone(); // dictionary is "a", "b", "c" assert_eq!( @@ -547,11 +644,7 @@ mod tests { // concatenate it with itself let combined = concat(&[© as _, &array as _]).unwrap(); - - let combined = combined - .as_any() - .downcast_ref::>() - .unwrap(); + let combined = combined.as_dictionary::(); assert_eq!( combined.values(), @@ -625,36 +718,45 @@ mod tests { } #[test] - fn concat_record_batches_of_different_schemas() { - let schema1 = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let schema2 = Arc::new(Schema::new(vec![ - Field::new("c", DataType::Int32, false), - Field::new("d", DataType::Utf8, false), - ])); + fn concat_record_batches_of_different_schemas_but_compatible_data() { + let schema1 = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // column names differ + let schema2 = + Arc::new(Schema::new(vec![Field::new("c", DataType::Int32, false)])); let batch1 = RecordBatch::try_new( schema1.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2])), - Arc::new(StringArray::from(vec!["a", "b"])), - ], + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .unwrap(); + let batch2 = + RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3, 4]))]) + .unwrap(); + // concat_batches simply uses the schema provided + let batch = concat_batches(&schema1, [&batch1, &batch2]).unwrap(); + assert_eq!(batch.schema().as_ref(), schema1.as_ref()); + assert_eq!(4, batch.num_rows()); + } + + #[test] + fn concat_record_batches_of_different_schemas_incompatible_data() { + let schema1 = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // column names differ + let schema2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])); + let batch1 = RecordBatch::try_new( + schema1.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2]))], ) .unwrap(); let batch2 = RecordBatch::try_new( schema2, - vec![ - Arc::new(Int32Array::from(vec![3, 4])), - Arc::new(StringArray::from(vec!["c", "d"])), - ], + vec![Arc::new(StringArray::from(vec!["foo", "bar"]))], ) .unwrap(); + let error = concat_batches(&schema1, [&batch1, &batch2]).unwrap_err(); - assert_eq!( - error.to_string(), - "Invalid argument error: batches[1] schema is different with argument schema.\n batches[1] schema: Schema { fields: [Field { name: \"c\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"d\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} },\n argument schema: Schema { fields: [Field { name: \"a\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }\n " - ); + assert_eq!(error.to_string(), "Invalid argument error: It is not possible to concatenate arrays of different data types."); } #[test] @@ -718,4 +820,16 @@ mod tests { assert_eq!(data.buffers()[1].len(), 200); assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64 } + + #[test] + fn concat_sparse_nulls() { + let values = StringArray::from_iter_values((0..100).map(|x| x.to_string())); + let keys = Int32Array::from(vec![1; 10]); + let dict_a = DictionaryArray::new(keys, Arc::new(values)); + let values = StringArray::new_null(0); + let keys = Int32Array::new_null(10); + let dict_b = DictionaryArray::new(keys, Arc::new(values)); + let array = concat(&[&dict_a, &dict_b]).unwrap(); + assert_eq!(array.null_count(), 10); + } } diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs new file mode 100644 index 000000000000..330196ae33f4 --- /dev/null +++ b/arrow-select/src/dictionary.rs @@ -0,0 +1,347 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::interleave::interleave; +use ahash::RandomState; +use arrow_array::builder::BooleanBufferBuilder; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + ArrowDictionaryKeyType, BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, + Utf8Type, +}; +use arrow_array::{Array, ArrayRef, DictionaryArray, GenericByteArray}; +use arrow_buffer::{ArrowNativeType, BooleanBuffer, ScalarBuffer}; +use arrow_schema::{ArrowError, DataType}; + +/// A best effort interner that maintains a fixed number of buckets +/// and interns keys based on their hash value +/// +/// Hash collisions will result in replacement +struct Interner<'a, V> { + state: RandomState, + buckets: Vec>, + shift: u32, +} + +impl<'a, V> Interner<'a, V> { + /// Capacity controls the number of unique buckets allocated within the Interner + /// + /// A larger capacity reduces the probability of hash collisions, and should be set + /// based on an approximation of the upper bound of unique values + fn new(capacity: usize) -> Self { + // Add additional buckets to help reduce collisions + let shift = (capacity as u64 + 128).leading_zeros(); + let num_buckets = (u64::MAX >> shift) as usize; + let buckets = (0..num_buckets.saturating_add(1)).map(|_| None).collect(); + Self { + // A fixed seed to ensure deterministic behaviour + state: RandomState::with_seeds(0, 0, 0, 0), + buckets, + shift, + } + } + + fn intern Result, E>( + &mut self, + new: &'a [u8], + f: F, + ) -> Result<&V, E> { + let hash = self.state.hash_one(new); + let bucket_idx = hash >> self.shift; + Ok(match &mut self.buckets[bucket_idx as usize] { + Some((current, v)) => { + if *current != new { + *v = f()?; + *current = new; + } + v + } + slot => &slot.insert((new, f()?)).1, + }) + } +} + +pub struct MergedDictionaries { + /// Provides `key_mappings[`array_idx`][`old_key`] -> new_key` + pub key_mappings: Vec>, + /// The new values + pub values: ArrayRef, +} + +/// Performs a cheap, pointer-based comparison of two byte array +/// +/// See [`ScalarBuffer::ptr_eq`] +fn bytes_ptr_eq(a: &dyn Array, b: &dyn Array) -> bool { + match (a.as_bytes_opt::(), b.as_bytes_opt::()) { + (Some(a), Some(b)) => { + let values_eq = + a.values().ptr_eq(b.values()) && a.offsets().ptr_eq(b.offsets()); + match (a.nulls(), b.nulls()) { + (Some(a), Some(b)) => values_eq && a.inner().ptr_eq(b.inner()), + (None, None) => values_eq, + _ => false, + } + } + _ => false, + } +} + +/// A type-erased function that compares two array for pointer equality +type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool; + +/// A weak heuristic of whether to merge dictionary values that aims to only +/// perform the expensive merge computation when it is likely to yield at least +/// some return over the naive approach used by MutableArrayData +/// +/// `len` is the total length of the merged output +pub fn should_merge_dictionary_values( + dictionaries: &[&DictionaryArray], + len: usize, +) -> bool { + use DataType::*; + let first_values = dictionaries[0].values().as_ref(); + let ptr_eq: Box = match first_values.data_type() { + Utf8 => Box::new(bytes_ptr_eq::), + LargeUtf8 => Box::new(bytes_ptr_eq::), + Binary => Box::new(bytes_ptr_eq::), + LargeBinary => Box::new(bytes_ptr_eq::), + _ => return false, + }; + + let mut single_dictionary = true; + let mut total_values = first_values.len(); + for dict in dictionaries.iter().skip(1) { + let values = dict.values().as_ref(); + total_values += values.len(); + if single_dictionary { + single_dictionary = ptr_eq(first_values, values) + } + } + + let overflow = K::Native::from_usize(total_values).is_none(); + let values_exceed_length = total_values >= len; + + !single_dictionary && (overflow || values_exceed_length) +} + +/// Given an array of dictionaries and an optional key mask compute a values array +/// containing referenced values, along with mappings from the [`DictionaryArray`] +/// keys to the new keys within this values array. Best-effort will be made to ensure +/// that the dictionary values are unique +/// +/// This method is meant to be very fast and the output dictionary values +/// may not be unique, unlike `GenericByteDictionaryBuilder` which is slower +/// but produces unique values +pub fn merge_dictionary_values( + dictionaries: &[&DictionaryArray], + masks: Option<&[BooleanBuffer]>, +) -> Result, ArrowError> { + let mut num_values = 0; + + let mut values_arrays = Vec::with_capacity(dictionaries.len()); + let mut value_slices = Vec::with_capacity(dictionaries.len()); + + for (idx, dictionary) in dictionaries.iter().enumerate() { + let mask = masks.and_then(|m| m.get(idx)); + let key_mask = match (dictionary.logical_nulls(), mask) { + (Some(n), None) => Some(n.into_inner()), + (None, Some(n)) => Some(n.clone()), + (Some(n), Some(m)) => Some(n.inner() & m), + (None, None) => None, + }; + let keys = dictionary.keys().values(); + let values = dictionary.values().as_ref(); + let values_mask = compute_values_mask(keys, key_mask.as_ref(), values.len()); + + let masked_values = get_masked_values(values, &values_mask); + num_values += masked_values.len(); + value_slices.push(masked_values); + values_arrays.push(values) + } + + // Map from value to new index + let mut interner = Interner::new(num_values); + // Interleave indices for new values array + let mut indices = Vec::with_capacity(num_values); + + // Compute the mapping for each dictionary + let key_mappings = dictionaries + .iter() + .enumerate() + .zip(value_slices) + .map(|((dictionary_idx, dictionary), values)| { + let zero = K::Native::from_usize(0).unwrap(); + let mut mapping = vec![zero; dictionary.values().len()]; + + for (value_idx, value) in values { + mapping[value_idx] = *interner.intern(value, || { + match K::Native::from_usize(indices.len()) { + Some(idx) => { + indices.push((dictionary_idx, value_idx)); + Ok(idx) + } + None => Err(ArrowError::DictionaryKeyOverflowError), + } + })?; + } + Ok(mapping) + }) + .collect::, ArrowError>>()?; + + Ok(MergedDictionaries { + key_mappings, + values: interleave(&values_arrays, &indices)?, + }) +} + +/// Return a mask identifying the values that are referenced by keys in `dictionary` +/// at the positions indicated by `selection` +fn compute_values_mask( + keys: &ScalarBuffer, + mask: Option<&BooleanBuffer>, + max_key: usize, +) -> BooleanBuffer { + let mut builder = BooleanBufferBuilder::new(max_key); + builder.advance(max_key); + + match mask { + Some(n) => n + .set_indices() + .for_each(|idx| builder.set_bit(keys[idx].as_usize(), true)), + None => keys + .iter() + .for_each(|k| builder.set_bit(k.as_usize(), true)), + } + builder.finish() +} + +/// Return a Vec containing for each set index in `mask`, the index and byte value of that index +fn get_masked_values<'a>( + array: &'a dyn Array, + mask: &BooleanBuffer, +) -> Vec<(usize, &'a [u8])> { + match array.data_type() { + DataType::Utf8 => masked_bytes(array.as_string::(), mask), + DataType::LargeUtf8 => masked_bytes(array.as_string::(), mask), + DataType::Binary => masked_bytes(array.as_binary::(), mask), + DataType::LargeBinary => masked_bytes(array.as_binary::(), mask), + _ => unimplemented!(), + } +} + +/// Compute [`get_masked_values`] for a [`GenericByteArray`] +/// +/// Note: this does not check the null mask and will return values contained in null slots +fn masked_bytes<'a, T: ByteArrayType>( + array: &'a GenericByteArray, + mask: &BooleanBuffer, +) -> Vec<(usize, &'a [u8])> { + let mut out = Vec::with_capacity(mask.count_set_bits()); + for idx in mask.set_indices() { + out.push((idx, array.value(idx).as_ref())) + } + out +} + +#[cfg(test)] +mod tests { + use crate::dictionary::merge_dictionary_values; + use arrow_array::cast::as_string_array; + use arrow_array::types::Int32Type; + use arrow_array::{DictionaryArray, Int32Array, StringArray}; + use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer}; + use std::sync::Arc; + + #[test] + fn test_merge_strings() { + let a = + DictionaryArray::::from_iter(["a", "b", "a", "b", "d", "c", "e"]); + let b = DictionaryArray::::from_iter(["c", "f", "c", "d", "a", "d"]); + let merged = merge_dictionary_values(&[&a, &b], None).unwrap(); + + let values = as_string_array(merged.values.as_ref()); + let actual: Vec<_> = values.iter().map(Option::unwrap).collect(); + assert_eq!(&actual, &["a", "b", "d", "c", "e", "f"]); + + assert_eq!(merged.key_mappings.len(), 2); + assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 3, 4]); + assert_eq!(&merged.key_mappings[1], &[3, 5, 2, 0]); + + let a_slice = a.slice(1, 4); + let merged = merge_dictionary_values(&[&a_slice, &b], None).unwrap(); + + let values = as_string_array(merged.values.as_ref()); + let actual: Vec<_> = values.iter().map(Option::unwrap).collect(); + assert_eq!(&actual, &["a", "b", "d", "c", "f"]); + + assert_eq!(merged.key_mappings.len(), 2); + assert_eq!(&merged.key_mappings[0], &[0, 1, 2, 0, 0]); + assert_eq!(&merged.key_mappings[1], &[3, 4, 2, 0]); + + // Mask out only ["b", "b", "d"] from a + let a_mask = + BooleanBuffer::from_iter([false, true, false, true, true, false, false]); + let b_mask = BooleanBuffer::new_set(b.len()); + let merged = merge_dictionary_values(&[&a, &b], Some(&[a_mask, b_mask])).unwrap(); + + let values = as_string_array(merged.values.as_ref()); + let actual: Vec<_> = values.iter().map(Option::unwrap).collect(); + assert_eq!(&actual, &["b", "d", "c", "f", "a"]); + + assert_eq!(merged.key_mappings.len(), 2); + assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 0, 0]); + assert_eq!(&merged.key_mappings[1], &[2, 3, 1, 4]); + } + + #[test] + fn test_merge_nulls() { + let buffer = Buffer::from("helloworldbingohelloworld"); + let offsets = OffsetBuffer::from_lengths([5, 5, 5, 5, 5]); + let nulls = NullBuffer::from(vec![true, false, true, true, true]); + let values = StringArray::new(offsets, buffer, Some(nulls)); + + let key_values = vec![1, 2, 3, 1, 8, 2, 3]; + let key_nulls = + NullBuffer::from(vec![true, true, false, true, false, true, true]); + let keys = Int32Array::new(key_values.into(), Some(key_nulls)); + let a = DictionaryArray::new(keys, Arc::new(values)); + // [NULL, "bingo", NULL, NULL, NULL, "bingo", "hello"] + + let b = DictionaryArray::new( + Int32Array::new_null(10), + Arc::new(StringArray::new_null(0)), + ); + + let merged = merge_dictionary_values(&[&a, &b], None).unwrap(); + let expected = StringArray::from(vec!["bingo", "hello"]); + assert_eq!(merged.values.as_ref(), &expected); + assert_eq!(merged.key_mappings.len(), 2); + assert_eq!(&merged.key_mappings[0], &[0, 0, 0, 1, 0]); + assert_eq!(&merged.key_mappings[1], &[]); + } + + #[test] + fn test_merge_keys_smaller() { + let values = StringArray::from_iter_values(["a", "b"]); + let keys = Int32Array::from_iter_values([1]); + let a = DictionaryArray::new(keys, Arc::new(values)); + + let merged = merge_dictionary_values(&[&a], None).unwrap(); + let expected = StringArray::from(vec!["b"]); + assert_eq!(merged.values.as_ref(), &expected); + } +} diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index c89491944a21..1afb8197bab6 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -187,8 +187,8 @@ pub fn filter_record_batch( .iter() .map(|a| filter_array(a, &filter)) .collect::, _>>()?; - - RecordBatch::try_new(record_batch.schema(), filtered_arrays) + let options = RecordBatchOptions::default().with_row_count(Some(filter.count())); + RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options) } /// A builder to construct [`FilterPredicate`] @@ -301,6 +301,11 @@ impl FilterPredicate { pub fn filter(&self, values: &dyn Array) -> Result { filter_array(values, self) } + + /// Number of rows being selected based on this [`FilterPredicate`] + pub fn count(&self) -> usize { + self.count + } } fn filter_array( @@ -321,16 +326,6 @@ fn filter_array( // actually filter _ => downcast_primitive_array! { values => Ok(Arc::new(filter_primitive(values, predicate))), - DataType::Decimal128(p, s) => { - let values = values.as_any().downcast_ref::().unwrap(); - let filtered = filter_primitive(values, predicate); - Ok(Arc::new(filtered.with_precision_and_scale(*p, *s).unwrap())) - } - DataType::Decimal256(p, s) => { - let values = values.as_any().downcast_ref::().unwrap(); - let filtered = filter_primitive(values, predicate); - Ok(Arc::new(filtered.with_precision_and_scale(*p, *s).unwrap())) - } DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(filter_boolean(values, predicate))) @@ -873,7 +868,7 @@ mod tests { #[test] fn test_filter_dictionary_array() { - let values = vec![Some("hello"), None, Some("world"), Some("!")]; + let values = [Some("hello"), None, Some("world"), Some("!")]; let a: Int8DictionaryArray = values.iter().copied().collect(); let b = BooleanArray::from(vec![false, true, true, false]); let c = filter(&a, &b).unwrap(); @@ -987,6 +982,21 @@ mod tests { assert_eq!(out.as_ref(), &a.slice(0, 2)); } + #[test] + fn test_filter_record_batch_no_columns() { + let pred = BooleanArray::from(vec![Some(true), Some(true), None]); + let options = RecordBatchOptions::default().with_row_count(Some(100)); + let record_batch = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &options, + ) + .unwrap(); + let out = filter_record_batch(&record_batch, &pred).unwrap(); + + assert_eq!(out.num_rows(), 2); + } + #[test] fn test_fast_path() { let a: PrimitiveArray = diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index c0d2026808af..a0f41666513b 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder}; +use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}; +use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder}; +use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{ + ArrowNativeType, MutableBuffer, NullBuffer, NullBufferBuilder, OffsetBuffer, +}; use arrow_data::transform::MutableArrayData; -use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType}; use std::sync::Arc; @@ -30,6 +33,12 @@ macro_rules! primitive_helper { }; } +macro_rules! dict_helper { + ($t:ty, $values:expr, $indices:expr) => { + Ok(Arc::new(interleave_dictionaries::<$t>($values, $indices)?) as _) + }; +} + /// /// Takes elements by index from a list of [`Array`], creating a new [`Array`] from those values. /// @@ -87,6 +96,10 @@ pub fn interleave( DataType::LargeUtf8 => interleave_bytes::(values, indices), DataType::Binary => interleave_bytes::(values, indices), DataType::LargeBinary => interleave_bytes::(values, indices), + DataType::Dictionary(k, _) => downcast_integer! { + k.as_ref() => (dict_helper, values, indices), + _ => unreachable!("illegal dictionary key type {k}") + }, _ => interleave_fallback(values, indices) } } @@ -97,10 +110,8 @@ pub fn interleave( struct Interleave<'a, T> { /// The input arrays downcast to T arrays: Vec<&'a T>, - /// The number of nulls in the interleaved output - null_count: usize, /// The null buffer of the interleaved output - nulls: Option, + nulls: Option, } impl<'a, T: Array + 'static> Interleave<'a, T> { @@ -114,22 +125,19 @@ impl<'a, T: Array + 'static> Interleave<'a, T> { }) .collect(); - let mut null_count = 0; - let nulls = has_nulls.then(|| { - let mut builder = BooleanBufferBuilder::new(indices.len()); - for (a, b) in indices { - let v = arrays[*a].is_valid(*b); - null_count += !v as usize; - builder.append(v) + let nulls = match has_nulls { + true => { + let mut builder = NullBufferBuilder::new(indices.len()); + for (a, b) in indices { + let v = arrays[*a].is_valid(*b); + builder.append(v) + } + builder.finish() } - builder.into() - }); + false => None, + }; - Self { - arrays, - null_count, - nulls, - } + Self { arrays, nulls } } } @@ -140,20 +148,14 @@ fn interleave_primitive( ) -> Result { let interleaved = Interleave::<'_, PrimitiveArray>::new(values, indices); - let mut values = BufferBuilder::::new(indices.len()); + let mut values = Vec::with_capacity(indices.len()); for (a, b) in indices { let v = interleaved.arrays[*a].value(*b); - values.append(v) + values.push(v) } - let builder = ArrayDataBuilder::new(data_type.clone()) - .len(indices.len()) - .add_buffer(values.finish()) - .null_bit_buffer(interleaved.nulls) - .null_count(interleaved.null_count); - - let data = unsafe { builder.build_unchecked() }; - Ok(Arc::new(PrimitiveArray::::from(data))) + let array = PrimitiveArray::::new(values.into(), interleaved.nulls); + Ok(Arc::new(array.with_data_type(data_type.clone()))) } fn interleave_bytes( @@ -177,15 +179,55 @@ fn interleave_bytes( values.extend_from_slice(interleaved.arrays[*a].value(*b).as_ref()); } - let builder = ArrayDataBuilder::new(T::DATA_TYPE) - .len(indices.len()) - .add_buffer(offsets.finish()) - .add_buffer(values.into()) - .null_bit_buffer(interleaved.nulls) - .null_count(interleaved.null_count); + // Safety: safe by construction + let array = unsafe { + let offsets = OffsetBuffer::new_unchecked(offsets.finish().into()); + GenericByteArray::::new_unchecked(offsets, values.into(), interleaved.nulls) + }; + Ok(Arc::new(array)) +} + +fn interleave_dictionaries( + arrays: &[&dyn Array], + indices: &[(usize, usize)], +) -> Result { + let dictionaries: Vec<_> = arrays.iter().map(|x| x.as_dictionary::()).collect(); + if !should_merge_dictionary_values::(&dictionaries, indices.len()) { + return interleave_fallback(arrays, indices); + } + + let masks: Vec<_> = dictionaries + .iter() + .enumerate() + .map(|(a_idx, dictionary)| { + let mut key_mask = BooleanBufferBuilder::new_from_buffer( + MutableBuffer::new_null(dictionary.len()), + dictionary.len(), + ); + + for (_, key_idx) in indices.iter().filter(|(a, _)| *a == a_idx) { + key_mask.set_bit(*key_idx, true); + } + key_mask.finish() + }) + .collect(); + + let merged = merge_dictionary_values(&dictionaries, Some(&masks))?; - let data = unsafe { builder.build_unchecked() }; - Ok(Arc::new(GenericByteArray::::from(data))) + // Recompute keys + let mut keys = PrimitiveBuilder::::with_capacity(indices.len()); + for (a, b) in indices { + let old_keys: &PrimitiveArray = dictionaries[*a].keys(); + match old_keys.is_valid(*b) { + true => { + let old_key = old_keys.values()[*b]; + keys.append_value(merged.key_mappings[*a][old_key.as_usize()]) + } + false => keys.append_null(), + } + } + let array = unsafe { DictionaryArray::new_unchecked(keys.finish(), merged.values) }; + Ok(Arc::new(array)) } /// Fallback implementation of interleave using [`MutableArrayData`] @@ -280,6 +322,32 @@ mod tests { ) } + #[test] + fn test_interleave_dictionary() { + let a = DictionaryArray::::from_iter(["a", "b", "c", "a", "b"]); + let b = DictionaryArray::::from_iter(["a", "c", "a", "c", "a"]); + + // Should not recompute dictionary + let values = + interleave(&[&a, &b], &[(0, 2), (0, 2), (0, 2), (1, 0), (1, 1), (0, 1)]) + .unwrap(); + let v = values.as_dictionary::(); + assert_eq!(v.values().len(), 5); + + let vc = v.downcast_dict::().unwrap(); + let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect(); + assert_eq!(&collected, &["c", "c", "c", "a", "c", "b"]); + + // Should recompute dictionary + let values = interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 1)]).unwrap(); + let v = values.as_dictionary::(); + assert_eq!(v.values().len(), 1); + + let vc = v.downcast_dict::().unwrap(); + let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect(); + assert_eq!(&collected, &["c", "c", "c"]); + } + #[test] fn test_lists() { // [[1, 2], null, [3]] @@ -323,4 +391,25 @@ mod tests { assert_eq!(v, &expected); } + + #[test] + fn interleave_sparse_nulls() { + let values = StringArray::from_iter_values((0..100).map(|x| x.to_string())); + let keys = Int32Array::from_iter_values(0..10); + let dict_a = DictionaryArray::new(keys, Arc::new(values)); + let values = StringArray::new_null(0); + let keys = Int32Array::new_null(10); + let dict_b = DictionaryArray::new(keys, Arc::new(values)); + + let indices = &[(0, 0), (0, 1), (0, 2), (1, 0)]; + let array = interleave(&[&dict_a, &dict_b], indices).unwrap(); + + let expected = DictionaryArray::::from_iter(vec![ + Some("0"), + Some("1"), + Some("2"), + None, + ]); + assert_eq!(array.as_ref(), &expected) + } } diff --git a/arrow-select/src/lib.rs b/arrow-select/src/lib.rs index c468e20a511e..82f57a6af42b 100644 --- a/arrow-select/src/lib.rs +++ b/arrow-select/src/lib.rs @@ -18,6 +18,7 @@ //! Arrow selection kernels pub mod concat; +mod dictionary; pub mod filter; pub mod interleave; pub mod nullif; diff --git a/arrow-select/src/nullif.rs b/arrow-select/src/nullif.rs index ab68e8c2f097..f0bcb73cccb9 100644 --- a/arrow-select/src/nullif.rs +++ b/arrow-select/src/nullif.rs @@ -18,7 +18,7 @@ use arrow_array::{make_array, Array, ArrayRef, BooleanArray}; use arrow_buffer::buffer::{bitwise_bin_op_helper, bitwise_unary_op_helper}; use arrow_buffer::{BooleanBuffer, NullBuffer}; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType}; /// Copies original array, setting validity bit to false if a secondary comparison /// boolean array is set to true @@ -35,7 +35,7 @@ pub fn nullif(left: &dyn Array, right: &BooleanArray) -> Result( +pub fn take( values: &dyn Array, - indices: &PrimitiveArray, + indices: &dyn Array, options: Option, ) -> Result { - take_impl(values, indices, options) + let options = options.unwrap_or_default(); + macro_rules! helper { + ($t:ty, $values:expr, $indices:expr, $options:expr) => {{ + let indices = indices.as_primitive::<$t>(); + if $options.check_bounds { + check_bounds($values.len(), indices)?; + } + let indices = indices.to_indices(); + take_impl($values, &indices) + }}; + } + downcast_integer! { + indices.data_type() => (helper, values, indices, options), + d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}"))) + } } +/// Verifies that the non-null values of `indices` are all `< len` +fn check_bounds( + len: usize, + indices: &PrimitiveArray, +) -> Result<(), ArrowError> { + if indices.null_count() > 0 { + indices.iter().flatten().try_for_each(|index| { + let ix = index.to_usize().ok_or_else(|| { + ArrowError::ComputeError("Cast to usize failed".to_string()) + })?; + if ix >= len { + return Err(ArrowError::ComputeError( + format!("Array index out of bounds, cannot get item at index {ix} from {len} entries")) + ); + } + Ok(()) + }) + } else { + indices.values().iter().try_for_each(|index| { + let ix = index.to_usize().ok_or_else(|| { + ArrowError::ComputeError("Cast to usize failed".to_string()) + })?; + if ix >= len { + return Err(ArrowError::ComputeError( + format!("Array index out of bounds, cannot get item at index {ix} from {len} entries")) + ); + } + Ok(()) + }) + } +} + +#[inline(never)] fn take_impl( values: &dyn Array, indices: &PrimitiveArray, - options: Option, ) -> Result { - let options = options.unwrap_or_default(); - if options.check_bounds { - let len = values.len(); - if indices.null_count() > 0 { - indices.iter().flatten().try_for_each(|index| { - let ix = index.to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - if ix >= len { - return Err(ArrowError::ComputeError( - format!("Array index out of bounds, cannot get item at index {ix} from {len} entries")) - ); - } - Ok(()) - })?; - } else { - indices.values().iter().try_for_each(|index| { - let ix = index.to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - if ix >= len { - return Err(ArrowError::ComputeError( - format!("Array index out of bounds, cannot get item at index {ix} from {len} entries")) - ); - } - Ok(()) - })? - } - } - downcast_primitive_array! { values => Ok(Arc::new(take_primitive(values, indices)?)), DataType::Boolean => { @@ -156,7 +172,7 @@ fn take_impl( let arrays = array .columns() .iter() - .map(|a| take_impl(a.as_ref(), indices, Some(options.clone()))) + .map(|a| take_impl(a.as_ref(), indices)) .collect::, _>>()?; let fields: Vec<(FieldRef, ArrayRef)> = fields.iter().cloned().zip(arrays).collect(); @@ -207,6 +223,21 @@ fn take_impl( Ok(new_null_array(&DataType::Null, indices.len())) } } + DataType::Union(fields, UnionMode::Sparse) => { + let mut field_type_ids = Vec::with_capacity(fields.len()); + let mut children = Vec::with_capacity(fields.len()); + let values = values.as_any().downcast_ref::().unwrap(); + let type_ids = take_native(values.type_ids(), indices).into_inner(); + for (type_id, field) in fields.iter() { + let values = values.child(type_id); + let values = take_impl(values, indices)?; + let field = (**field).clone(); + children.push((field, values)); + field_type_ids.push(type_id); + } + let array = UnionArray::try_new(field_type_ids.as_slice(), type_ids, None, children)?; + Ok(Arc::new(array)) + } t => unimplemented!("Take not supported for data type {:?}", t) } } @@ -331,94 +362,70 @@ fn take_bytes( let data_len = indices.len(); let bytes_offset = (data_len + 1) * std::mem::size_of::(); - let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset); + let mut offsets = MutableBuffer::new(bytes_offset); + offsets.push(T::Offset::default()); - let offsets = offsets_buffer.typed_data_mut(); let mut values = MutableBuffer::new(0); - let mut length_so_far = T::Offset::from_usize(0).unwrap(); - offsets[0] = length_so_far; let nulls; if array.null_count() == 0 && indices.null_count() == 0 { - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - let s = array.value(index); - - let s: &[u8] = s.as_ref(); - length_so_far += T::Offset::from_usize(s.len()).unwrap(); + offsets.extend(indices.values().iter().map(|index| { + let s: &[u8] = array.value(index.as_usize()).as_ref(); values.extend_from_slice(s); - *offset = length_so_far; - } + T::Offset::usize_as(values.len()) + })); nulls = None } else if indices.null_count() == 0 { let num_bytes = bit_util::ceil(data_len, 8); let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); let null_slice = null_buf.as_slice_mut(); - - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - + offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { + let index = index.as_usize(); if array.is_valid(index) { let s: &[u8] = array.value(index).as_ref(); - - length_so_far += T::Offset::from_usize(s.len()).unwrap(); values.extend_from_slice(s.as_ref()); } else { bit_util::unset_bit(null_slice, i); } - *offset = length_so_far; - } + T::Offset::usize_as(values.len()) + })); nulls = Some(null_buf.into()); } else if array.null_count() == 0 { - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { + offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { if indices.is_valid(i) { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - let s: &[u8] = array.value(index).as_ref(); - - length_so_far += T::Offset::from_usize(s.len()).unwrap(); + let s: &[u8] = array.value(index.as_usize()).as_ref(); values.extend_from_slice(s); } - *offset = length_so_far; - } + T::Offset::usize_as(values.len()) + })); nulls = indices.nulls().map(|b| b.inner().sliced()); } else { let num_bytes = bit_util::ceil(data_len, 8); let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); let null_slice = null_buf.as_slice_mut(); - - for (i, offset) in offsets.iter_mut().skip(1).enumerate() { - let index = indices.value(i).to_usize().ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - if array.is_valid(index) && indices.is_valid(i) { + offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { + // check index is valid before using index. The value in + // NULL index slots may not be within bounds of array + let index = index.as_usize(); + if indices.is_valid(i) && array.is_valid(index) { let s: &[u8] = array.value(index).as_ref(); - - length_so_far += T::Offset::from_usize(s.len()).unwrap(); values.extend_from_slice(s); } else { // set null bit bit_util::unset_bit(null_slice, i); } - *offset = length_so_far; - } - + T::Offset::usize_as(values.len()) + })); nulls = Some(null_buf.into()) } + T::Offset::from_usize(values.len()).expect("offset overflow"); + let array_data = ArrayData::builder(T::DATA_TYPE) .len(data_len) - .add_buffer(offsets_buffer.into()) + .add_buffer(offsets.into()) .add_buffer(values.into()) .null_bit_buffer(nulls); @@ -447,7 +454,7 @@ where let (list_indices, offsets, null_buf) = take_value_indices_from_list::(values, indices)?; - let taken = take_impl::(values.values().as_ref(), &list_indices, None)?; + let taken = take_impl::(values.values().as_ref(), &list_indices)?; let value_offsets = Buffer::from_vec(offsets); // create a new list with taken data and computed null information let list_data = ArrayDataBuilder::new(values.data_type().clone()) @@ -473,7 +480,7 @@ fn take_fixed_size_list( length: ::Native, ) -> Result { let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?; - let taken = take_impl::(values.values().as_ref(), &list_indices, None)?; + let taken = take_impl::(values.values().as_ref(), &list_indices)?; // determine null count and null buffer, which are a function of `values` and `indices` let num_bytes = bit_util::ceil(indices.len(), 8); @@ -700,6 +707,65 @@ where Ok(PrimitiveArray::::from(values)) } +/// To avoid generating take implementations for every index type, instead we +/// only generate for UInt32 and UInt64 and coerce inputs to these types +trait ToIndices { + type T: ArrowPrimitiveType; + + fn to_indices(&self) -> PrimitiveArray; +} + +macro_rules! to_indices_reinterpret { + ($t:ty, $o:ty) => { + impl ToIndices for PrimitiveArray<$t> { + type T = $o; + + fn to_indices(&self) -> PrimitiveArray<$o> { + let cast = + ScalarBuffer::new(self.values().inner().clone(), 0, self.len()); + PrimitiveArray::new(cast, self.nulls().cloned()) + } + } + }; +} + +macro_rules! to_indices_identity { + ($t:ty) => { + impl ToIndices for PrimitiveArray<$t> { + type T = $t; + + fn to_indices(&self) -> PrimitiveArray<$t> { + self.clone() + } + } + }; +} + +macro_rules! to_indices_widening { + ($t:ty, $o:ty) => { + impl ToIndices for PrimitiveArray<$t> { + type T = UInt32Type; + + fn to_indices(&self) -> PrimitiveArray<$o> { + let cast = self.values().iter().copied().map(|x| x as _).collect(); + PrimitiveArray::new(cast, self.nulls().cloned()) + } + } + }; +} + +to_indices_widening!(UInt8Type, UInt32Type); +to_indices_widening!(Int8Type, UInt32Type); + +to_indices_widening!(UInt16Type, UInt32Type); +to_indices_widening!(Int16Type, UInt32Type); + +to_indices_identity!(UInt32Type); +to_indices_reinterpret!(Int32Type, UInt32Type); + +to_indices_identity!(UInt64Type); +to_indices_reinterpret!(Int64Type, UInt64Type); + #[cfg(test)] mod tests { use super::*; @@ -791,7 +857,7 @@ mod tests { { let output = PrimitiveArray::::from(data); let expected = PrimitiveArray::::from(expected_data); - let output = take_impl(&output, index, options).unwrap(); + let output = take(&output, index, options).unwrap(); let output = output.as_any().downcast_ref::>().unwrap(); assert_eq!(output, &expected) } @@ -1102,7 +1168,7 @@ mod tests { 1_639_715_368_000_000_000, ]) .with_timezone("UTC".to_string()); - let result = take_impl(&input, &index, None).unwrap(); + let result = take(&input, &index, None).unwrap(); match result.data_type() { DataType::Timestamp(TimeUnit::Nanosecond, tz) => { assert_eq!(tz.clone(), Some("UTC".into())) @@ -1937,6 +2003,7 @@ mod tests { #[test] fn test_take_null_indices() { + // Build indices with values that are out of bounds, but masked by null mask let indices = Int32Array::new( vec![1, 2, 400, 400].into(), Some(NullBuffer::from(vec![true, true, false, false])), @@ -1949,4 +2016,53 @@ mod tests { .collect::>(); assert_eq!(&values, &[Some(23), Some(4), None, None]) } + + #[test] + fn test_take_bytes_null_indices() { + let indices = Int32Array::new( + vec![0, 1, 400, 400].into(), + Some(NullBuffer::from_iter(vec![true, true, false, false])), + ); + let values = StringArray::from(vec![Some("foo"), None]); + let r = take(&values, &indices, None).unwrap(); + let values = r.as_string::().iter().collect::>(); + assert_eq!(&values, &[Some("foo"), None, None, None]) + } + + #[test] + fn test_take_union() { + let structs = create_test_struct(vec![ + Some((Some(true), Some(42))), + Some((Some(false), Some(28))), + Some((Some(false), Some(19))), + Some((Some(true), Some(31))), + None, + ]); + let strings = + StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]); + let type_ids = Buffer::from_slice_ref(vec![1i8; 5]); + + let children: Vec<(Field, Arc)> = vec![ + ( + Field::new("f1", structs.data_type().clone(), true), + Arc::new(structs), + ), + ( + Field::new("f2", strings.data_type().clone(), true), + Arc::new(strings), + ), + ]; + let array = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap(); + + let indices = vec![0, 3, 1, 0, 2, 4]; + let index = UInt32Array::from(indices.clone()); + let actual = take(&array, &index, None).unwrap(); + let actual = actual.as_any().downcast_ref::().unwrap(); + let strings = actual.child(1); + let strings = strings.as_any().downcast_ref::().unwrap(); + + let actual = strings.iter().collect::>(); + let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")]; + assert_eq!(expected, actual); + } } diff --git a/arrow-string/Cargo.toml b/arrow-string/Cargo.toml index 0f88ffbac923..1ae7af8bdf41 100644 --- a/arrow-string/Cargo.toml +++ b/arrow-string/Cargo.toml @@ -40,11 +40,5 @@ arrow-schema = { workspace = true } arrow-array = { workspace = true } arrow-select = { workspace = true } regex = { version = "1.7.0", default-features = false, features = ["std", "unicode", "perf"] } -regex-syntax = { version = "0.7.1", default-features = false, features = ["unicode"] } +regex-syntax = { version = "0.8.0", default-features = false, features = ["unicode"] } num = { version = "0.4", default-features = false, features = ["std"] } - -[package.metadata.docs.rs] -all-features = true - -[features] -dyn_cmp_dict = [] diff --git a/arrow-string/src/length.rs b/arrow-string/src/length.rs index 90efdd7b67cc..ab5fbb0c6425 100644 --- a/arrow-string/src/length.rs +++ b/arrow-string/src/length.rs @@ -17,161 +17,74 @@ //! Defines kernel for length of string arrays and binary arrays -use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::Buffer; -use arrow_data::ArrayData; +use arrow_array::{cast::AsArray, types::*}; +use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer}; use arrow_schema::{ArrowError, DataType}; use std::sync::Arc; -macro_rules! unary_offsets { - ($array: expr, $data_type: expr, $op: expr) => {{ - let slice = $array.value_offsets(); - - let lengths = slice.windows(2).map(|offset| $op(offset[1] - offset[0])); - - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` come from a slice iterator with a known size. - let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; - - let null_bit_buffer = $array.nulls().map(|b| b.inner().sliced()); - - let data = unsafe { - ArrayData::new_unchecked( - $data_type, - $array.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ) - }; - make_array(data) - }}; +fn length_impl( + offsets: &OffsetBuffer, + nulls: Option<&NullBuffer>, +) -> ArrayRef { + let v: Vec<_> = offsets + .windows(2) + .map(|w| w[1].sub_wrapping(w[0])) + .collect(); + Arc::new(PrimitiveArray::

::new(v.into(), nulls.cloned())) } -macro_rules! kernel_dict { - ($array: ident, $kernel: expr, $kt: ident, $($t: ident: $gt: ident), *) => { - match $kt.as_ref() { - $(&DataType::$t => { - let dict = $array - .as_any() - .downcast_ref::>() - .unwrap_or_else(|| { - panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}", - stringify!($gt), $array.data_type()) - }); - let values = $kernel(dict.values())?; - let result = DictionaryArray::try_new(dict.keys().clone(), values)?; - Ok(Arc::new(result)) - }, - )* - t => panic!("Unsupported dictionary key type: {}", t) - } - } -} - -fn length_list(array: &dyn Array) -> ArrayRef -where - O: OffsetSizeTrait, - T: ArrowPrimitiveType, - T::Native: OffsetSizeTrait, -{ - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - unary_offsets!(array, T::DATA_TYPE, |x| x) -} - -fn length_binary(array: &dyn Array) -> ArrayRef -where - O: OffsetSizeTrait, - T: ArrowPrimitiveType, - T::Native: OffsetSizeTrait, -{ - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - unary_offsets!(array, T::DATA_TYPE, |x| x) -} - -fn length_string(array: &dyn Array) -> ArrayRef -where - O: OffsetSizeTrait, - T: ArrowPrimitiveType, - T::Native: OffsetSizeTrait, -{ - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - unary_offsets!(array, T::DATA_TYPE, |x| x) -} - -fn bit_length_binary(array: &dyn Array) -> ArrayRef -where - O: OffsetSizeTrait, - T: ArrowPrimitiveType, - T::Native: OffsetSizeTrait, -{ - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let bits_in_bytes = O::from_usize(8).unwrap(); - unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes) -} - -fn bit_length_string(array: &dyn Array) -> ArrayRef -where - O: OffsetSizeTrait, - T: ArrowPrimitiveType, - T::Native: OffsetSizeTrait, -{ - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let bits_in_bytes = O::from_usize(8).unwrap(); - unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes) +fn bit_length_impl( + offsets: &OffsetBuffer, + nulls: Option<&NullBuffer>, +) -> ArrayRef { + let bits = P::Native::usize_as(8); + let c = |w: &[P::Native]| w[1].sub_wrapping(w[0]).mul_wrapping(bits); + let v: Vec<_> = offsets.windows(2).map(c).collect(); + Arc::new(PrimitiveArray::

::new(v.into(), nulls.cloned())) } /// Returns an array of Int32/Int64 denoting the length of each value in the array. /// For list array, length is the number of elements in each list. /// For string array and binary array, length is the number of bytes of each value. /// -/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray and BinaryArray/LargeBinaryArray, +/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray, BinaryArray/LargeBinaryArray, and FixedSizeListArray, /// or DictionaryArray with above Arrays as values /// * length of null is null. pub fn length(array: &dyn Array) -> Result { + if let Some(d) = array.as_any_dictionary_opt() { + let lengths = length(d.values().as_ref())?; + return Ok(d.with_values(lengths)); + } + match array.data_type() { - DataType::Dictionary(kt, _) => { - kernel_dict!( - array, - |a| { length(a) }, - kt, - Int8: Int8Type, - Int16: Int16Type, - Int32: Int32Type, - Int64: Int64Type, - UInt8: UInt8Type, - UInt16: UInt16Type, - UInt32: UInt32Type, - UInt64: UInt64Type - ) + DataType::List(_) => { + let list = array.as_list::(); + Ok(length_impl::(list.offsets(), list.nulls())) + } + DataType::LargeList(_) => { + let list = array.as_list::(); + Ok(length_impl::(list.offsets(), list.nulls())) + } + DataType::Utf8 => { + let list = array.as_string::(); + Ok(length_impl::(list.offsets(), list.nulls())) + } + DataType::LargeUtf8 => { + let list = array.as_string::(); + Ok(length_impl::(list.offsets(), list.nulls())) + } + DataType::Binary => { + let list = array.as_binary::(); + Ok(length_impl::(list.offsets(), list.nulls())) + } + DataType::LargeBinary => { + let list = array.as_binary::(); + Ok(length_impl::(list.offsets(), list.nulls())) } - DataType::List(_) => Ok(length_list::(array)), - DataType::LargeList(_) => Ok(length_list::(array)), - DataType::Utf8 => Ok(length_string::(array)), - DataType::LargeUtf8 => Ok(length_string::(array)), - DataType::Binary => Ok(length_binary::(array)), - DataType::LargeBinary => Ok(length_binary::(array)), + DataType::FixedSizeBinary(len) | DataType::FixedSizeList(_, len) => Ok(Arc::new( + Int32Array::new(vec![*len; array.len()].into(), array.nulls().cloned()), + )), other => Err(ArrowError::ComputeError(format!( "length not supported for {other:?}" ))), @@ -185,26 +98,40 @@ pub fn length(array: &dyn Array) -> Result { /// * bit_length of null is null. /// * bit_length is in number of bits pub fn bit_length(array: &dyn Array) -> Result { + if let Some(d) = array.as_any_dictionary_opt() { + let lengths = bit_length(d.values().as_ref())?; + return Ok(d.with_values(lengths)); + } + match array.data_type() { - DataType::Dictionary(kt, _) => { - kernel_dict!( - array, - |a| { bit_length(a) }, - kt, - Int8: Int8Type, - Int16: Int16Type, - Int32: Int32Type, - Int64: Int64Type, - UInt8: UInt8Type, - UInt16: UInt16Type, - UInt32: UInt32Type, - UInt64: UInt64Type - ) + DataType::List(_) => { + let list = array.as_list::(); + Ok(bit_length_impl::(list.offsets(), list.nulls())) + } + DataType::LargeList(_) => { + let list = array.as_list::(); + Ok(bit_length_impl::(list.offsets(), list.nulls())) + } + DataType::Utf8 => { + let list = array.as_string::(); + Ok(bit_length_impl::(list.offsets(), list.nulls())) + } + DataType::LargeUtf8 => { + let list = array.as_string::(); + Ok(bit_length_impl::(list.offsets(), list.nulls())) } - DataType::Utf8 => Ok(bit_length_string::(array)), - DataType::LargeUtf8 => Ok(bit_length_string::(array)), - DataType::Binary => Ok(bit_length_binary::(array)), - DataType::LargeBinary => Ok(bit_length_binary::(array)), + DataType::Binary => { + let list = array.as_binary::(); + Ok(bit_length_impl::(list.offsets(), list.nulls())) + } + DataType::LargeBinary => { + let list = array.as_binary::(); + Ok(bit_length_impl::(list.offsets(), list.nulls())) + } + DataType::FixedSizeBinary(len) => Ok(Arc::new(Int32Array::new( + vec![*len * 8; array.len()].into(), + array.nulls().cloned(), + ))), other => Err(ArrowError::ComputeError(format!( "bit_length not supported for {other:?}" ))), @@ -215,19 +142,15 @@ pub fn bit_length(array: &dyn Array) -> Result { mod tests { use super::*; use arrow_array::cast::AsArray; - - fn double_vec(v: Vec) -> Vec { - [&v[..], &v[..]].concat() - } + use arrow_buffer::{Buffer, NullBuffer}; + use arrow_data::ArrayData; + use arrow_schema::Field; fn length_cases_string() -> Vec<(Vec<&'static str>, usize, Vec)> { // a large array - let mut values = vec!["one", "on", "o", ""]; - let mut expected = vec![3, 2, 1, 0]; - for _ in 0..10 { - values = double_vec(values); - expected = double_vec(expected); - } + let values = ["one", "on", "o", ""]; + let values = values.into_iter().cycle().take(4096).collect(); + let expected = [3, 2, 1, 0].into_iter().cycle().take(4096).collect(); vec![ (vec!["hello", " ", "world"], 3, vec![5, 1, 5]), @@ -261,7 +184,6 @@ mod tests { } #[test] - #[cfg_attr(miri, ignore)] // running forever fn length_test_string() { length_cases_string() .into_iter() @@ -277,7 +199,6 @@ mod tests { } #[test] - #[cfg_attr(miri, ignore)] // running forever fn length_test_large_string() { length_cases_string() .into_iter() @@ -448,12 +369,9 @@ mod tests { fn bit_length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { // a large array - let mut values = vec!["one", "on", "o", ""]; - let mut expected = vec![24, 16, 8, 0]; - for _ in 0..10 { - values = double_vec(values); - expected = double_vec(expected); - } + let values = ["one", "on", "o", ""]; + let values = values.into_iter().cycle().take(4096).collect(); + let expected = [24, 16, 8, 0].into_iter().cycle().take(4096).collect(); vec![ (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]), @@ -464,7 +382,6 @@ mod tests { } #[test] - #[cfg_attr(miri, ignore)] // error: this test uses too much memory to run on CI fn bit_length_test_string() { bit_length_cases() .into_iter() @@ -480,7 +397,6 @@ mod tests { } #[test] - #[cfg_attr(miri, ignore)] // error: this test uses too much memory to run on CI fn bit_length_test_large_string() { bit_length_cases() .into_iter() @@ -696,4 +612,44 @@ mod tests { assert_eq!(expected[i], actual[i],); } } + + #[test] + fn test_fixed_size_list_length() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8])) + .build() + .unwrap(); + let list_data_type = DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Int32, false)), + 3, + ); + let nulls = NullBuffer::from(vec![true, false, true]); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .nulls(Some(nulls)) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + let lengths = length(&list_array).unwrap(); + let lengths = lengths.as_primitive::(); + + assert_eq!(lengths.len(), 3); + assert_eq!(lengths.value(0), 3); + assert!(lengths.is_null(1)); + assert_eq!(lengths.value(2), 3); + } + + #[test] + fn test_fixed_size_binary() { + let array = FixedSizeBinaryArray::new(4, [0; 16].into(), None); + let result = length(&array).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![4; 4])); + + let result = bit_length(&array).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![32; 4])); + } } diff --git a/arrow-string/src/lib.rs b/arrow-string/src/lib.rs index 4bd4d282656c..4444b37a7742 100644 --- a/arrow-string/src/lib.rs +++ b/arrow-string/src/lib.rs @@ -20,5 +20,6 @@ pub mod concat_elements; pub mod length; pub mod like; +mod predicate; pub mod regexp; pub mod substring; diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index 6b4aea7e8e64..4478c4e4f7ef 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -15,227 +15,37 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::builder::BooleanBufferBuilder; -use arrow_array::cast::*; +use crate::predicate::Predicate; +use arrow_array::cast::AsArray; use arrow_array::*; -use arrow_buffer::NullBuffer; -use arrow_data::ArrayDataBuilder; use arrow_schema::*; use arrow_select::take::take; -use regex::Regex; -use std::collections::HashMap; - -/// Helper function to perform boolean lambda function on values from two array accessors, this -/// version does not attempt to use SIMD. -/// -/// Duplicated from `arrow_ord::comparison` -fn compare_op( - left: T, - right: S, - op: F, -) -> Result -where - F: Fn(T::Item, S::Item) -> bool, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - Ok(BooleanArray::from_binary(left, right, op)) -} - -/// Helper function to perform boolean lambda function on values from array accessor, this -/// version does not attempt to use SIMD. -/// -/// Duplicated from `arrow_ord::comparison` -fn compare_op_scalar( - left: T, - op: F, -) -> Result -where - F: Fn(T::Item) -> bool, -{ - Ok(BooleanArray::from_unary(left, op)) -} - -macro_rules! dyn_function { - ($sql:tt, $fn_name:tt, $fn_utf8:tt, $fn_dict:tt) => { -#[doc = concat!("Perform SQL `", $sql ,"` operation on [`StringArray`] /")] -/// [`LargeStringArray`], or [`DictionaryArray`] with values -/// [`StringArray`]/[`LargeStringArray`]. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn $fn_name(left: &dyn Array, right: &dyn Array) -> Result { - match (left.data_type(), right.data_type()) { - (DataType::Utf8, DataType::Utf8) => { - let left = left.as_string::(); - let right = right.as_string::(); - $fn_utf8(left, right) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - let left = left.as_string::(); - let right = right.as_string::(); - $fn_utf8(left, right) - } - #[cfg(feature = "dyn_cmp_dict")] - (DataType::Dictionary(_, _), DataType::Dictionary(_, _)) => { - downcast_dictionary_array!( - left => { - let right = as_dictionary_array(right); - $fn_dict(left, right) - } - t => Err(ArrowError::ComputeError(format!( - "Should be DictionaryArray but got: {}", t - ))) - ) - } - _ => { - Err(ArrowError::ComputeError(format!( - "{} only supports Utf8, LargeUtf8 or DictionaryArray (with feature `dyn_cmp_dict`) with Utf8 or LargeUtf8 values", - stringify!($fn_name) - ))) - } - } -} - - } -} -dyn_function!("left LIKE right", like_dyn, like_utf8, like_dict); -dyn_function!("left NOT LIKE right", nlike_dyn, nlike_utf8, nlike_dict); -dyn_function!("left ILIKE right", ilike_dyn, ilike_utf8, ilike_dict); -dyn_function!("left NOT ILIKE right", nilike_dyn, nilike_utf8, nilike_dict); -dyn_function!( - "STARTSWITH(left, right)", - starts_with_dyn, - starts_with_utf8, - starts_with_dict -); -dyn_function!( - "ENDSWITH(left, right)", - ends_with_dyn, - ends_with_utf8, - ends_with_dict -); -dyn_function!( - "CONTAINS(left, right)", - contains_dyn, - contains_utf8, - contains_dict -); - -macro_rules! scalar_dyn_function { - ($sql:tt, $fn_name:tt, $fn_scalar:tt) => { -#[doc = concat!("Perform SQL `", $sql ,"` operation on [`StringArray`] /")] -/// [`LargeStringArray`], or [`DictionaryArray`] with values -/// [`StringArray`]/[`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn $fn_name( - left: &dyn Array, - right: &str, -) -> Result { - match left.data_type() { - DataType::Utf8 => { - let left = left.as_string::(); - $fn_scalar(left, right) - } - DataType::LargeUtf8 => { - let left = left.as_string::(); - $fn_scalar(left, right) - } - DataType::Dictionary(_, _) => { - downcast_dictionary_array!( - left => { - let dict_comparison = $fn_name(left.values().as_ref(), right)?; - // TODO: Use take_boolean (#2967) - let array = take(&dict_comparison, left.keys(), None)?; - Ok(BooleanArray::from(array.to_data())) - } - t => Err(ArrowError::ComputeError(format!( - "Should be DictionaryArray but got: {}", t - ))) - ) - } - _ => { - Err(ArrowError::ComputeError(format!( - "{} only supports Utf8, LargeUtf8 or DictionaryArray with Utf8 or LargeUtf8 values", - stringify!($fn_name) - ))) - } - } +use std::sync::Arc; + +#[derive(Debug)] +enum Op { + Like(bool), + ILike(bool), + Contains, + StartsWith, + EndsWith, } - } -} -scalar_dyn_function!("left LIKE right", like_utf8_scalar_dyn, like_scalar); -scalar_dyn_function!("left NOT LIKE right", nlike_utf8_scalar_dyn, nlike_scalar); -scalar_dyn_function!("left ILIKE right", ilike_utf8_scalar_dyn, ilike_scalar); -scalar_dyn_function!( - "left NOT ILIKE right", - nilike_utf8_scalar_dyn, - nilike_scalar -); -scalar_dyn_function!( - "STARTSWITH(left, right)", - starts_with_utf8_scalar_dyn, - starts_with_scalar -); -scalar_dyn_function!( - "ENDSWITH(left, right)", - ends_with_utf8_scalar_dyn, - ends_with_scalar -); -scalar_dyn_function!( - "CONTAINS(left, right)", - contains_utf8_scalar_dyn, - contains_scalar -); - -macro_rules! dict_function { - ($sql:tt, $fn_name:tt, $fn_impl:tt) => { - -#[doc = concat!("Perform SQL `", $sql ,"` operation on [`DictionaryArray`] with values")] -/// [`StringArray`]/[`LargeStringArray`]. -/// -/// See the documentation on [`like_utf8`] for more details. -#[cfg(feature = "dyn_cmp_dict")] -fn $fn_name( - left: &DictionaryArray, - right: &DictionaryArray, -) -> Result { - match (left.value_type(), right.value_type()) { - (DataType::Utf8, DataType::Utf8) => { - let left = left.downcast_dict::>().unwrap(); - let right = right.downcast_dict::>().unwrap(); - $fn_impl(left, right) +impl std::fmt::Display for Op { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Op::Like(false) => write!(f, "LIKE"), + Op::Like(true) => write!(f, "NLIKE"), + Op::ILike(false) => write!(f, "ILIKE"), + Op::ILike(true) => write!(f, "NILIKE"), + Op::Contains => write!(f, "CONTAINS"), + Op::StartsWith => write!(f, "STARTS_WITH"), + Op::EndsWith => write!(f, "ENDS_WITH"), } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - let left = left.downcast_dict::>().unwrap(); - let right = right.downcast_dict::>().unwrap(); - - $fn_impl(left, right) - } - _ => Err(ArrowError::ComputeError(format!( - "{} only supports DictionaryArray with Utf8 or LargeUtf8 values", - stringify!($fn_name) - ))), - } -} } } -dict_function!("left LIKE right", like_dict, like); -dict_function!("left NOT LIKE right", nlike_dict, nlike); -dict_function!("left ILIKE right", ilike_dict, ilike); -dict_function!("left NOT ILIKE right", nilike_dict, nilike); -dict_function!("STARTSWITH(left, right)", starts_with_dict, starts_with); -dict_function!("ENDSWITH(left, right)", ends_with_dict, ends_with); -dict_function!("CONTAINS(left, right)", contains_dict, contains); - -/// Perform SQL `left LIKE right` operation on [`StringArray`] / [`LargeStringArray`]. +/// Perform SQL `left LIKE right` /// /// There are two wildcards supported with the LIKE operator: /// @@ -244,487 +54,334 @@ dict_function!("CONTAINS(left, right)", contains_dict, contains); /// /// For example: /// ``` -/// use arrow_array::{StringArray, BooleanArray}; -/// use arrow_string::like::like_utf8; -/// +/// # use arrow_array::{StringArray, BooleanArray}; +/// # use arrow_string::like::like; +/// # /// let strings = StringArray::from(vec!["Arrow", "Arrow", "Arrow", "Ar"]); /// let patterns = StringArray::from(vec!["A%", "B%", "A.", "A_"]); /// -/// let result = like_utf8(&strings, &patterns).unwrap(); +/// let result = like(&strings, &patterns).unwrap(); /// assert_eq!(result, BooleanArray::from(vec![true, false, false, true])); /// ``` -pub fn like_utf8( - left: &GenericStringArray, - right: &GenericStringArray, -) -> Result { - like(left, right) -} - -#[inline] -fn like<'a, S: ArrayAccessor>( - left: S, - right: S, -) -> Result { - regex_like(left, right, false, |re_pattern| { - Regex::new(&format!("^{re_pattern}$")).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from LIKE pattern: {e}" - )) - }) - }) +pub fn like(left: &dyn Datum, right: &dyn Datum) -> Result { + like_op(Op::Like(false), left, right) } -#[inline] -fn like_scalar_op<'a, F: Fn(bool) -> bool, L: ArrayAccessor>( - left: L, - right: &str, - op: F, -) -> Result { - if !right.contains(is_like_pattern) { - // fast path, can use equals - Ok(BooleanArray::from_unary(left, |item| op(item == right))) - } else if right.ends_with('%') - && !right.ends_with("\\%") - && !right[..right.len() - 1].contains(is_like_pattern) - { - // fast path, can use starts_with - let starts_with = &right[..right.len() - 1]; - - Ok(BooleanArray::from_unary(left, |item| { - op(item.starts_with(starts_with)) - })) - } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { - // fast path, can use ends_with - let ends_with = &right[1..]; - - Ok(BooleanArray::from_unary(left, |item| { - op(item.ends_with(ends_with)) - })) - } else if right.starts_with('%') - && right.ends_with('%') - && !right.ends_with("\\%") - && !right[1..right.len() - 1].contains(is_like_pattern) - { - let contains = &right[1..right.len() - 1]; - - Ok(BooleanArray::from_unary(left, |item| { - op(item.contains(contains)) - })) - } else { - let re_pattern = replace_like_wildcards(right)?; - let re = Regex::new(&format!("^{re_pattern}$")).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from LIKE pattern: {e}" - )) - })?; - - Ok(BooleanArray::from_unary(left, |item| op(re.is_match(item)))) - } -} - -#[inline] -fn like_scalar<'a, L: ArrayAccessor>( - left: L, - right: &str, -) -> Result { - like_scalar_op(left, right, |x| x) -} - -/// Perform SQL `left LIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. +/// Perform SQL `left ILIKE right` /// -/// See the documentation on [`like_utf8`] for more details. -pub fn like_utf8_scalar( - left: &GenericStringArray, - right: &str, -) -> Result { - like_scalar(left, right) -} - -/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does: +/// This is a case-insensitive version of [`like`] /// -/// 1. Replace like wildcards for regex expressions as the pattern will be evaluated using regex match: `%` => `.*` and `_` => `.` -/// 2. Escape regex meta characters to match them and not be evaluated as regex special chars. For example: `.` => `\\.` -/// 3. Replace escaped like wildcards removing the escape characters to be able to match it as a regex. For example: `\\%` => `%` -fn replace_like_wildcards(pattern: &str) -> Result { - let mut result = String::new(); - let pattern = String::from(pattern); - let mut chars_iter = pattern.chars().peekable(); - while let Some(c) = chars_iter.next() { - if c == '\\' { - let next = chars_iter.peek(); - match next { - Some(next) if is_like_pattern(*next) => { - result.push(*next); - // Skipping the next char as it is already appended - chars_iter.next(); - } - _ => { - result.push('\\'); - result.push('\\'); - } - } - } else if regex_syntax::is_meta_character(c) { - result.push('\\'); - result.push(c); - } else if c == '%' { - result.push_str(".*"); - } else if c == '_' { - result.push('.'); - } else { - result.push(c); - } - } - Ok(result) +/// Note: this only implements loose matching as defined by the Unicode standard. For example, +/// the `ff` ligature is not equivalent to `FF` and `ß` is not equivalent to `SS` +pub fn ilike(left: &dyn Datum, right: &dyn Datum) -> Result { + like_op(Op::ILike(false), left, right) } -/// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`]. +/// Perform SQL `left NOT LIKE right` /// -/// See the documentation on [`like_utf8`] for more details. -pub fn nlike_utf8( - left: &GenericStringArray, - right: &GenericStringArray, -) -> Result { - nlike(left, right) -} - -#[inline] -fn nlike<'a, S: ArrayAccessor>( - left: S, - right: S, -) -> Result { - regex_like(left, right, true, |re_pattern| { - Regex::new(&format!("^{re_pattern}$")).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from LIKE pattern: {e}" - )) - }) - }) +/// See the documentation on [`like`] for more details +pub fn nlike(left: &dyn Datum, right: &dyn Datum) -> Result { + like_op(Op::Like(true), left, right) } -#[inline] -fn nlike_scalar<'a, L: ArrayAccessor>( - left: L, - right: &str, -) -> Result { - like_scalar_op(left, right, |x| !x) +/// Perform SQL `left NOT ILIKE right` +/// +/// See the documentation on [`ilike`] for more details +pub fn nilike(left: &dyn Datum, right: &dyn Datum) -> Result { + like_op(Op::ILike(true), left, right) } -/// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn nlike_utf8_scalar( - left: &GenericStringArray, - right: &str, +/// Perform SQL `STARTSWITH(left, right)` +pub fn starts_with( + left: &dyn Datum, + right: &dyn Datum, ) -> Result { - nlike_scalar(left, right) + like_op(Op::StartsWith, left, right) } -/// Perform SQL `left ILIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`]. -/// -/// Case insensitive version of [`like_utf8`] -/// -/// Note: this only implements loose matching as defined by the Unicode standard. For example, -/// the `ff` ligature is not equivalent to `FF` and `ß` is not equivalent to `SS` -pub fn ilike_utf8( - left: &GenericStringArray, - right: &GenericStringArray, +/// Perform SQL `ENDSWITH(left, right)` +pub fn ends_with( + left: &dyn Datum, + right: &dyn Datum, ) -> Result { - ilike(left, right) + like_op(Op::EndsWith, left, right) } -#[inline] -fn ilike<'a, S: ArrayAccessor>( - left: S, - right: S, -) -> Result { - regex_like(left, right, false, |re_pattern| { - Regex::new(&format!("(?i)^{re_pattern}$")).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from ILIKE pattern: {e}" - )) - }) - }) +/// Perform SQL `CONTAINS(left, right)` +pub fn contains(left: &dyn Datum, right: &dyn Datum) -> Result { + like_op(Op::Contains, left, right) } -#[inline] -fn ilike_scalar_op bool>( - left: &GenericStringArray, - right: &str, - op: F, -) -> Result { - // If not ASCII faster to use case insensitive regex than using to_uppercase - if right.is_ascii() && left.is_ascii() { - if !right.contains(is_like_pattern) { - return Ok(BooleanArray::from_unary(left, |item| { - op(item.eq_ignore_ascii_case(right)) - })); - } else if right.ends_with('%') - && !right.ends_with("\\%") - && !right[..right.len() - 1].contains(is_like_pattern) - { - // fast path, can use starts_with - let start_str = &right[..right.len() - 1]; - return Ok(BooleanArray::from_unary(left, |item| { - let end = item.len().min(start_str.len()); - let result = item.is_char_boundary(end) - && start_str.eq_ignore_ascii_case(&item[..end]); - op(result) - })); - } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { - // fast path, can use ends_with - let ends_str = &right[1..]; - return Ok(BooleanArray::from_unary(left, |item| { - let start = item.len().saturating_sub(ends_str.len()); - let result = item.is_char_boundary(start) - && ends_str.eq_ignore_ascii_case(&item[start..]); - op(result) - })); - } +fn like_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + use arrow_schema::DataType::*; + let (l, l_s) = lhs.get(); + let (r, r_s) = rhs.get(); + + if l.len() != r.len() && !l_s && !r_s { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot compare arrays of different lengths, got {} vs {}", + l.len(), + r.len() + ))); } - let re_pattern = replace_like_wildcards(right)?; - let re = Regex::new(&format!("(?i)^{re_pattern}$")).map_err(|e| { - ArrowError::ComputeError(format!("Unable to build regex from ILIKE pattern: {e}")) - })?; + let l_v = l.as_any_dictionary_opt(); + let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); - Ok(BooleanArray::from_unary(left, |item| op(re.is_match(item)))) -} + let r_v = r.as_any_dictionary_opt(); + let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); -#[inline] -fn ilike_scalar( - left: &GenericStringArray, - right: &str, -) -> Result { - ilike_scalar_op(left, right, |x| x) + match (l.data_type(), r.data_type()) { + (Utf8, Utf8) => { + apply::(op, l.as_string(), l_s, l_v, r.as_string(), r_s, r_v) + } + (LargeUtf8, LargeUtf8) => { + apply::(op, l.as_string(), l_s, l_v, r.as_string(), r_s, r_v) + } + (l_t, r_t) => Err(ArrowError::InvalidArgumentError(format!( + "Invalid string operation: {l_t} {op} {r_t}" + ))), + } } -/// Perform SQL `left ILIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`ilike_utf8`] for more details. -pub fn ilike_utf8_scalar( - left: &GenericStringArray, - right: &str, +fn apply( + op: Op, + l: &GenericStringArray, + l_s: bool, + l_v: Option<&dyn AnyDictionaryArray>, + r: &GenericStringArray, + r_s: bool, + r_v: Option<&dyn AnyDictionaryArray>, ) -> Result { - ilike_scalar(left, right) + let l_len = l_v.map(|l| l.len()).unwrap_or(l.len()); + if r_s { + let idx = match r_v { + Some(dict) if dict.null_count() != 0 => { + return Ok(BooleanArray::new_null(l_len)) + } + Some(dict) => dict.normalized_keys()[0], + None => 0, + }; + if r.is_null(idx) { + return Ok(BooleanArray::new_null(l_len)); + } + op_scalar(op, l, l_v, r.value(idx)) + } else { + match (l_s, l_v, r_v) { + (true, None, None) => { + let v = l.is_valid(0).then(|| l.value(0)); + op_binary(op, std::iter::repeat(v), r.iter()) + } + (true, Some(l_v), None) => { + let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]); + let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx))); + op_binary(op, std::iter::repeat(v), r.iter()) + } + (true, None, Some(r_v)) => { + let v = l.is_valid(0).then(|| l.value(0)); + op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v)) + } + (true, Some(l_v), Some(r_v)) => { + let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]); + let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx))); + op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v)) + } + (false, None, None) => op_binary(op, l.iter(), r.iter()), + (false, Some(l_v), None) => op_binary(op, vectored_iter(l, l_v), r.iter()), + (false, None, Some(r_v)) => op_binary(op, l.iter(), vectored_iter(r, r_v)), + (false, Some(l_v), Some(r_v)) => { + op_binary(op, vectored_iter(l, l_v), vectored_iter(r, r_v)) + } + } + } } -/// Perform SQL `left NOT ILIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`]. -/// -/// See the documentation on [`ilike_utf8`] for more details. -pub fn nilike_utf8( - left: &GenericStringArray, - right: &GenericStringArray, +#[inline(never)] +fn op_scalar( + op: Op, + l: &GenericStringArray, + l_v: Option<&dyn AnyDictionaryArray>, + r: &str, ) -> Result { - nilike(left, right) -} + let r = match op { + Op::Like(neg) => Predicate::like(r)?.evaluate_array(l, neg), + Op::ILike(neg) => Predicate::ilike(r, l.is_ascii())?.evaluate_array(l, neg), + Op::Contains => Predicate::Contains(r).evaluate_array(l, false), + Op::StartsWith => Predicate::StartsWith(r).evaluate_array(l, false), + Op::EndsWith => Predicate::EndsWith(r).evaluate_array(l, false), + }; -#[inline] -fn nilike<'a, S: ArrayAccessor>( - left: S, - right: S, -) -> Result { - regex_like(left, right, true, |re_pattern| { - Regex::new(&format!("(?i)^{re_pattern}$")).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from ILIKE pattern: {e}" - )) - }) + Ok(match l_v { + Some(v) => take(&r, v.keys(), None)?.as_boolean().clone(), + None => r, }) } -#[inline] -fn nilike_scalar( - left: &GenericStringArray, - right: &str, -) -> Result { - ilike_scalar_op(left, right, |x| !x) +fn vectored_iter<'a, O: OffsetSizeTrait>( + a: &'a GenericStringArray, + a_v: &'a dyn AnyDictionaryArray, +) -> impl Iterator> + 'a { + let nulls = a_v.nulls(); + let keys = a_v.normalized_keys(); + keys.into_iter().enumerate().map(move |(idx, key)| { + if nulls.map(|n| n.is_null(idx)).unwrap_or_default() || a.is_null(key) { + return None; + } + Some(a.value(key)) + }) } -/// Perform SQL `left NOT ILIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`ilike_utf8`] for more details. -pub fn nilike_utf8_scalar( - left: &GenericStringArray, - right: &str, +#[inline(never)] +fn op_binary<'a>( + op: Op, + l: impl Iterator>, + r: impl Iterator>, ) -> Result { - nilike_scalar(left, right) -} - -fn is_like_pattern(c: char) -> bool { - c == '%' || c == '_' -} - -/// Evaluate regex `op(left)` matching `right` on [`StringArray`] / [`LargeStringArray`] -/// -/// If `negate_regex` is true, the regex expression will be negated. (for example, with `not like`) -fn regex_like<'a, S: ArrayAccessor, F>( - left: S, - right: S, - negate_regex: bool, - op: F, -) -> Result -where - F: Fn(&str) -> Result, -{ - let mut map = HashMap::new(); - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); + match op { + Op::Like(neg) => binary_predicate(l, r, neg, Predicate::like), + Op::ILike(neg) => binary_predicate(l, r, neg, |s| Predicate::ilike(s, false)), + Op::Contains => Ok(l.zip(r).map(|(l, r)| Some(l?.contains(r?))).collect()), + Op::StartsWith => Ok(l.zip(r).map(|(l, r)| Some(l?.starts_with(r?))).collect()), + Op::EndsWith => Ok(l.zip(r).map(|(l, r)| Some(l?.ends_with(r?))).collect()), } - - let nulls = NullBuffer::union(left.nulls(), right.nulls()); - - let mut result = BooleanBufferBuilder::new(left.len()); - for i in 0..left.len() { - let haystack = left.value(i); - let pat = right.value(i); - let re = if let Some(ref regex) = map.get(pat) { - regex - } else { - let re_pattern = replace_like_wildcards(pat)?; - let re = op(&re_pattern)?; - map.insert(pat, re); - map.get(pat).unwrap() - }; - - result.append(if negate_regex { - !re.is_match(haystack) - } else { - re.is_match(haystack) - }); - } - - let data = unsafe { - ArrayDataBuilder::new(DataType::Boolean) - .len(left.len()) - .nulls(nulls) - .buffers(vec![result.into()]) - .build_unchecked() - }; - Ok(BooleanArray::from(data)) -} - -/// Perform SQL `STARTSWITH(left, right)` operation on [`StringArray`] / [`LargeStringArray`]. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn starts_with_utf8( - left: &GenericStringArray, - right: &GenericStringArray, -) -> Result { - starts_with(left, right) -} - -#[inline] -fn starts_with<'a, S: ArrayAccessor>( - left: S, - right: S, -) -> Result { - compare_op(left, right, |l, r| l.starts_with(r)) -} - -#[inline] -fn starts_with_scalar<'a, L: ArrayAccessor>( - left: L, - right: &str, -) -> Result { - compare_op_scalar(left, |item| item.starts_with(right)) } -/// Perform SQL `STARTSWITH(left, right)` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn starts_with_utf8_scalar( - left: &GenericStringArray, - right: &str, +fn binary_predicate<'a>( + l: impl Iterator>, + r: impl Iterator>, + neg: bool, + f: impl Fn(&'a str) -> Result, ArrowError>, ) -> Result { - starts_with_scalar(left, right) + let mut previous = None; + l.zip(r) + .map(|(l, r)| match (l, r) { + (Some(l), Some(r)) => { + let p: &Predicate = match previous { + Some((expr, ref predicate)) if expr == r => predicate, + _ => &previous.insert((r, f(r)?)).1, + }; + Ok(Some(p.evaluate(l) != neg)) + } + _ => Ok(None), + }) + .collect() } -/// Perform SQL `ENDSWITH(left, right)` operation on [`StringArray`] / [`LargeStringArray`]. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn ends_with_utf8( - left: &GenericStringArray, - right: &GenericStringArray, -) -> Result { - ends_with(left, right) -} +// Deprecated kernels -#[inline] -fn ends_with<'a, S: ArrayAccessor>( - left: S, - right: S, -) -> Result { - compare_op(left, right, |l, r| l.ends_with(r)) +fn make_scalar(data_type: &DataType, scalar: &str) -> Result { + match data_type { + DataType::Utf8 => Ok(Arc::new(StringArray::from_iter_values([scalar]))), + DataType::LargeUtf8 => Ok(Arc::new(LargeStringArray::from_iter_values([scalar]))), + DataType::Dictionary(_, v) => make_scalar(v.as_ref(), scalar), + d => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported string scalar data type {d:?}", + ))), + } } -#[inline] -fn ends_with_scalar<'a, L: ArrayAccessor>( - left: L, - right: &str, -) -> Result { - compare_op_scalar(left, |item| item.ends_with(right)) -} +macro_rules! legacy_kernels { + ($fn_datum:ident, $fn_array:ident, $fn_scalar:ident, $fn_array_dyn:ident, $fn_scalar_dyn:ident, $deprecation:expr) => { + #[doc(hidden)] + #[deprecated(note = $deprecation)] + pub fn $fn_array( + left: &GenericStringArray, + right: &GenericStringArray, + ) -> Result { + $fn_datum(left, right) + } -/// Perform SQL `ENDSWITH(left, right)` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn ends_with_utf8_scalar( - left: &GenericStringArray, - right: &str, -) -> Result { - ends_with_scalar(left, right) -} + #[doc(hidden)] + #[deprecated(note = $deprecation)] + pub fn $fn_scalar( + left: &GenericStringArray, + right: &str, + ) -> Result { + let scalar = GenericStringArray::::from_iter_values([right]); + $fn_datum(left, &Scalar::new(&scalar)) + } -/// Perform SQL `CONTAINS(left, right)` operation on [`StringArray`] / [`LargeStringArray`]. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn contains_utf8( - left: &GenericStringArray, - right: &GenericStringArray, -) -> Result { - contains(left, right) -} + #[doc(hidden)] + #[deprecated(note = $deprecation)] + pub fn $fn_array_dyn( + left: &dyn Array, + right: &dyn Array, + ) -> Result { + $fn_datum(&left, &right) + } -#[inline] -fn contains<'a, S: ArrayAccessor>( - left: S, - right: S, -) -> Result { - compare_op(left, right, |l, r| l.contains(r)) + #[doc(hidden)] + #[deprecated(note = $deprecation)] + pub fn $fn_scalar_dyn( + left: &dyn Array, + right: &str, + ) -> Result { + let scalar = make_scalar(left.data_type(), right)?; + $fn_datum(&left, &Scalar::new(&scalar)) + } + }; } -#[inline] -fn contains_scalar<'a, L: ArrayAccessor>( - left: L, - right: &str, -) -> Result { - compare_op_scalar(left, |item| item.contains(right)) -} +legacy_kernels!( + like, + like_utf8, + like_utf8_scalar, + like_dyn, + like_utf8_scalar_dyn, + "Use arrow_string::like::like" +); +legacy_kernels!( + ilike, + ilike_utf8, + ilike_utf8_scalar, + ilike_dyn, + ilike_utf8_scalar_dyn, + "Use arrow_string::like::ilike" +); +legacy_kernels!( + nlike, + nlike_utf8, + nlike_utf8_scalar, + nlike_dyn, + nlike_utf8_scalar_dyn, + "Use arrow_string::like::nlike" +); +legacy_kernels!( + nilike, + nilike_utf8, + nilike_utf8_scalar, + nilike_dyn, + nilike_utf8_scalar_dyn, + "Use arrow_string::like::nilike" +); +legacy_kernels!( + contains, + contains_utf8, + contains_utf8_scalar, + contains_dyn, + contains_utf8_scalar_dyn, + "Use arrow_string::like::contains" +); +legacy_kernels!( + starts_with, + starts_with_utf8, + starts_with_utf8_scalar, + starts_with_dyn, + starts_with_utf8_scalar_dyn, + "Use arrow_string::like::starts_with" +); -/// Perform SQL `CONTAINS(left, right)` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn contains_utf8_scalar( - left: &GenericStringArray, - right: &str, -) -> Result { - contains_scalar(left, right) -} +legacy_kernels!( + ends_with, + ends_with_utf8, + ends_with_utf8_scalar, + ends_with_dyn, + ends_with_utf8_scalar_dyn, + "Use arrow_string::like::ends_with" +); #[cfg(test)] +#[allow(deprecated)] mod tests { use super::*; use arrow_array::types::Int8Type; @@ -733,15 +390,11 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { + let expected = BooleanArray::from($expected); let left = StringArray::from($left); let right = StringArray::from($right); let res = $op(&left, &right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!(v, expected[i]); - } + assert_eq!(res, expected); } }; } @@ -749,17 +402,12 @@ mod tests { macro_rules! test_dict_utf8 { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] - #[cfg(feature = "dyn_cmp_dict")] fn $test_name() { + let expected = BooleanArray::from($expected); let left: DictionaryArray = $left.into_iter().collect(); let right: DictionaryArray = $right.into_iter().collect(); let res = $op(&left, &right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!(v, expected[i]); - } + assert_eq!(res, expected); } }; } @@ -768,37 +416,15 @@ mod tests { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test] fn $test_name() { + let expected = BooleanArray::from($expected); + let left = StringArray::from($left); let res = $op(&left, $right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!( - v, - expected[i], - "unexpected result when comparing {} at position {} to {} ", - left.value(i), - i, - $right - ); - } + assert_eq!(res, expected); let left = LargeStringArray::from($left); let res = $op(&left, $right).unwrap(); - let expected = $expected; - assert_eq!(expected.len(), res.len()); - for i in 0..res.len() { - let v = res.value(i); - assert_eq!( - v, - expected[i], - "unexpected result when comparing {} at position {} to {} ", - left.value(i), - i, - $right - ); - } + assert_eq!(res, expected); } }; ($test_name:ident, $test_name_dyn:ident, $left:expr, $right:expr, $op:expr, $op_dyn:expr, $expected:expr) => { @@ -950,7 +576,7 @@ mod tests { test_utf8!( test_utf8_scalar_ilike_regex, vec!["%%%"], - vec![r#"\%_\%"#], + vec![r"\%_\%"], ilike_utf8, vec![true] ); @@ -958,39 +584,11 @@ mod tests { test_dict_utf8!( test_utf8_scalar_ilike_regex_dict, vec!["%%%"], - vec![r#"\%_\%"#], + vec![r"\%_\%"], ilike_dyn, vec![true] ); - #[test] - fn test_replace_like_wildcards() { - let a_eq = "_%"; - let expected = "..*"; - assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); - } - - #[test] - fn test_replace_like_wildcards_leave_like_meta_chars() { - let a_eq = "\\%\\_"; - let expected = "%_"; - assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); - } - - #[test] - fn test_replace_like_wildcards_with_multiple_escape_chars() { - let a_eq = "\\\\%"; - let expected = "\\\\%"; - assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); - } - - #[test] - fn test_replace_like_wildcards_escape_regex_meta_char() { - let a_eq = "."; - let expected = "\\."; - assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); - } - test_utf8!( test_utf8_array_nlike, vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow"], @@ -1368,6 +966,7 @@ mod tests { Some("Air"), None, Some("Air"), + Some("bbbbb\nAir"), ]; let dict_array: DictionaryArray = data.into_iter().collect(); @@ -1380,7 +979,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(false), ]), ); @@ -1392,7 +992,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(false), ]), ); @@ -1404,7 +1005,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1416,7 +1018,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1428,7 +1031,8 @@ mod tests { Some(true), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1440,7 +1044,8 @@ mod tests { Some(true), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1452,7 +1057,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1464,7 +1070,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1476,7 +1083,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1488,7 +1096,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); } @@ -1502,6 +1111,7 @@ mod tests { Some("Air"), None, Some("Air"), + Some("bbbbb\nAir"), ]; let dict_array: DictionaryArray = data.into_iter().collect(); @@ -1514,7 +1124,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(true), ]), ); @@ -1526,7 +1137,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(true), ]), ); @@ -1538,7 +1150,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1550,7 +1163,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1562,7 +1176,8 @@ mod tests { Some(false), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1574,7 +1189,8 @@ mod tests { Some(false), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1586,7 +1202,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1598,7 +1215,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1610,7 +1228,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1622,7 +1241,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); } @@ -1636,6 +1256,7 @@ mod tests { Some("Air"), None, Some("Air"), + Some("bbbbb\nAir"), ]; let dict_array: DictionaryArray = data.into_iter().collect(); @@ -1648,7 +1269,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(false), ]), ); @@ -1660,7 +1282,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(false), ]), ); @@ -1672,7 +1295,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1684,7 +1308,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1696,7 +1321,8 @@ mod tests { Some(true), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1708,7 +1334,8 @@ mod tests { Some(true), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1720,7 +1347,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1732,7 +1360,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1744,7 +1373,8 @@ mod tests { Some(true), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1756,7 +1386,8 @@ mod tests { Some(true), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); } @@ -1770,6 +1401,7 @@ mod tests { Some("Air"), None, Some("Air"), + Some("bbbbb\nAir"), ]; let dict_array: DictionaryArray = data.into_iter().collect(); @@ -1782,7 +1414,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(true), ]), ); @@ -1794,7 +1427,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(true), ]), ); @@ -1806,7 +1440,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1818,7 +1453,8 @@ mod tests { Some(false), Some(true), None, - Some(true) + Some(true), + Some(true), ]), ); @@ -1830,7 +1466,8 @@ mod tests { Some(false), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1842,7 +1479,8 @@ mod tests { Some(false), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1854,7 +1492,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1866,7 +1505,8 @@ mod tests { Some(true), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1878,7 +1518,8 @@ mod tests { Some(false), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); @@ -1890,8 +1531,40 @@ mod tests { Some(false), Some(false), None, - Some(false) + Some(false), + Some(false), ]), ); } + + #[test] + fn like_scalar_null() { + let a = StringArray::new_scalar("a"); + let b = Scalar::new(StringArray::new_null(1)); + let r = like(&a, &b).unwrap(); + assert_eq!(r.len(), 1); + assert_eq!(r.null_count(), 1); + assert!(r.is_null(0)); + + let a = StringArray::from_iter_values(["a"]); + let b = Scalar::new(StringArray::new_null(1)); + let r = like(&a, &b).unwrap(); + assert_eq!(r.len(), 1); + assert_eq!(r.null_count(), 1); + assert!(r.is_null(0)); + + let a = StringArray::from_iter_values(["a"]); + let b = StringArray::new_null(1); + let r = like(&a, &b).unwrap(); + assert_eq!(r.len(), 1); + assert_eq!(r.null_count(), 1); + assert!(r.is_null(0)); + + let a = StringArray::new_scalar("a"); + let b = StringArray::new_null(1); + let r = like(&a, &b).unwrap(); + assert_eq!(r.len(), 1); + assert_eq!(r.null_count(), 1); + assert!(r.is_null(0)); + } } diff --git a/arrow-string/src/predicate.rs b/arrow-string/src/predicate.rs new file mode 100644 index 000000000000..162e3c75027d --- /dev/null +++ b/arrow-string/src/predicate.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{BooleanArray, GenericStringArray, OffsetSizeTrait}; +use arrow_schema::ArrowError; +use regex::{Regex, RegexBuilder}; + +/// A string based predicate +pub enum Predicate<'a> { + Eq(&'a str), + Contains(&'a str), + StartsWith(&'a str), + EndsWith(&'a str), + + /// Equality ignoring ASCII case + IEqAscii(&'a str), + /// Starts with ignoring ASCII case + IStartsWithAscii(&'a str), + /// Ends with ignoring ASCII case + IEndsWithAscii(&'a str), + + Regex(Regex), +} + +impl<'a> Predicate<'a> { + /// Create a predicate for the given like pattern + pub fn like(pattern: &'a str) -> Result { + if !pattern.contains(is_like_pattern) { + Ok(Self::Eq(pattern)) + } else if pattern.ends_with('%') + && !pattern.ends_with("\\%") + && !pattern[..pattern.len() - 1].contains(is_like_pattern) + { + Ok(Self::StartsWith(&pattern[..pattern.len() - 1])) + } else if pattern.starts_with('%') && !pattern[1..].contains(is_like_pattern) { + Ok(Self::EndsWith(&pattern[1..])) + } else if pattern.starts_with('%') + && pattern.ends_with('%') + && !pattern.ends_with("\\%") + && !pattern[1..pattern.len() - 1].contains(is_like_pattern) + { + Ok(Self::Contains(&pattern[1..pattern.len() - 1])) + } else { + Ok(Self::Regex(regex_like(pattern, false)?)) + } + } + + /// Create a predicate for the given ilike pattern + pub fn ilike(pattern: &'a str, is_ascii: bool) -> Result { + if is_ascii && pattern.is_ascii() { + if !pattern.contains(is_like_pattern) { + return Ok(Self::IEqAscii(pattern)); + } else if pattern.ends_with('%') + && !pattern.ends_with("\\%") + && !pattern[..pattern.len() - 1].contains(is_like_pattern) + { + return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1])); + } else if pattern.starts_with('%') && !pattern[1..].contains(is_like_pattern) + { + return Ok(Self::IEndsWithAscii(&pattern[1..])); + } + } + Ok(Self::Regex(regex_like(pattern, true)?)) + } + + /// Evaluate this predicate against the given haystack + pub fn evaluate(&self, haystack: &str) -> bool { + match self { + Predicate::Eq(v) => *v == haystack, + Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v), + Predicate::Contains(v) => haystack.contains(v), + Predicate::StartsWith(v) => haystack.starts_with(v), + Predicate::IStartsWithAscii(v) => starts_with_ignore_ascii_case(haystack, v), + Predicate::EndsWith(v) => haystack.ends_with(v), + Predicate::IEndsWithAscii(v) => ends_with_ignore_ascii_case(haystack, v), + Predicate::Regex(v) => v.is_match(haystack), + } + } + + /// Evaluate this predicate against the elements of `array` + /// + /// If `negate` is true the result of the predicate will be negated + #[inline(never)] + pub fn evaluate_array( + &self, + array: &GenericStringArray, + negate: bool, + ) -> BooleanArray { + match self { + Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| { + (haystack.len() == v.len() && haystack == *v) != negate + }), + Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| { + haystack.eq_ignore_ascii_case(v) != negate + }), + Predicate::Contains(v) => { + BooleanArray::from_unary(array, |haystack| haystack.contains(v) != negate) + } + Predicate::StartsWith(v) => BooleanArray::from_unary(array, |haystack| { + haystack.starts_with(v) != negate + }), + Predicate::IStartsWithAscii(v) => { + BooleanArray::from_unary(array, |haystack| { + starts_with_ignore_ascii_case(haystack, v) != negate + }) + } + Predicate::EndsWith(v) => BooleanArray::from_unary(array, |haystack| { + haystack.ends_with(v) != negate + }), + Predicate::IEndsWithAscii(v) => BooleanArray::from_unary(array, |haystack| { + ends_with_ignore_ascii_case(haystack, v) != negate + }), + Predicate::Regex(v) => { + BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate) + } + } + } +} + +fn starts_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool { + let end = haystack.len().min(needle.len()); + haystack.is_char_boundary(end) && needle.eq_ignore_ascii_case(&haystack[..end]) +} + +fn ends_with_ignore_ascii_case(haystack: &str, needle: &str) -> bool { + let start = haystack.len().saturating_sub(needle.len()); + haystack.is_char_boundary(start) && needle.eq_ignore_ascii_case(&haystack[start..]) +} + +/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does: +/// +/// 1. Replace like wildcards for regex expressions as the pattern will be evaluated using regex match: `%` => `.*` and `_` => `.` +/// 2. Escape regex meta characters to match them and not be evaluated as regex special chars. For example: `.` => `\\.` +/// 3. Replace escaped like wildcards removing the escape characters to be able to match it as a regex. For example: `\\%` => `%` +fn regex_like(pattern: &str, case_insensitive: bool) -> Result { + let mut result = String::with_capacity(pattern.len() * 2); + result.push('^'); + let mut chars_iter = pattern.chars().peekable(); + while let Some(c) = chars_iter.next() { + if c == '\\' { + let next = chars_iter.peek(); + match next { + Some(next) if is_like_pattern(*next) => { + result.push(*next); + // Skipping the next char as it is already appended + chars_iter.next(); + } + _ => { + result.push('\\'); + result.push('\\'); + } + } + } else if regex_syntax::is_meta_character(c) { + result.push('\\'); + result.push(c); + } else if c == '%' { + result.push_str(".*"); + } else if c == '_' { + result.push('.'); + } else { + result.push(c); + } + } + result.push('$'); + RegexBuilder::new(&result) + .case_insensitive(case_insensitive) + .dot_matches_new_line(true) + .build() + .map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Unable to build regex from LIKE pattern: {e}" + )) + }) +} + +fn is_like_pattern(c: char) -> bool { + c == '%' || c == '_' +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_replace_like_wildcards() { + let a_eq = "_%"; + let expected = "^..*$"; + let r = regex_like(a_eq, false).unwrap(); + assert_eq!(r.to_string(), expected); + } + + #[test] + fn test_replace_like_wildcards_leave_like_meta_chars() { + let a_eq = "\\%\\_"; + let expected = "^%_$"; + let r = regex_like(a_eq, false).unwrap(); + assert_eq!(r.to_string(), expected); + } + + #[test] + fn test_replace_like_wildcards_with_multiple_escape_chars() { + let a_eq = "\\\\%"; + let expected = "^\\\\%$"; + let r = regex_like(a_eq, false).unwrap(); + assert_eq!(r.to_string(), expected); + } + + #[test] + fn test_replace_like_wildcards_escape_regex_meta_char() { + let a_eq = "."; + let expected = "^\\.$"; + let r = regex_like(a_eq, false).unwrap(); + assert_eq!(r.to_string(), expected); + } +} diff --git a/arrow-string/src/regexp.rs b/arrow-string/src/regexp.rs index e28564bdae95..af4d66f97fd0 100644 --- a/arrow-string/src/regexp.rs +++ b/arrow-string/src/regexp.rs @@ -398,7 +398,7 @@ mod tests { vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"], vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], regexp_is_match_utf8, - vec![true, false, true, false, false, true] + [true, false, true, false, false, true] ); test_flag_utf8!( test_utf8_array_regexp_is_match_insensitive, @@ -406,7 +406,7 @@ mod tests { vec!["^ar", "^AR", "ow$", "OW$", "foo", ""], vec!["i"; 6], regexp_is_match_utf8, - vec![true, true, true, true, false, true] + [true, true, true, true, false, true] ); test_flag_utf8_scalar!( @@ -414,14 +414,14 @@ mod tests { vec!["arrow", "ARROW", "parquet", "PARQUET"], "^ar", regexp_is_match_utf8_scalar, - vec![true, false, false, false] + [true, false, false, false] ); test_flag_utf8_scalar!( test_utf8_array_regexp_is_match_empty_scalar, vec!["arrow", "ARROW", "parquet", "PARQUET"], "", regexp_is_match_utf8_scalar, - vec![true, true, true, true] + [true, true, true, true] ); test_flag_utf8_scalar!( test_utf8_array_regexp_is_match_insensitive_scalar, @@ -429,6 +429,6 @@ mod tests { "^ar", "i", regexp_is_match_utf8_scalar, - vec![true, true, false, false] + [true, true, false, false] ); } diff --git a/arrow-string/src/substring.rs b/arrow-string/src/substring.rs index 1075d106911e..dc0dfdcbb4ad 100644 --- a/arrow-string/src/substring.rs +++ b/arrow-string/src/substring.rs @@ -347,8 +347,7 @@ fn fixed_size_binary_substring( // build value buffer let num_of_elements = array.len(); - let values = array.value_data(); - let data = values.as_slice(); + let data = array.value_data(); let mut new_values = MutableBuffer::new(num_of_elements * (new_len as usize)); (0..num_of_elements) .map(|idx| { diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index bc126a2f4c2d..37f03a05b3fa 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -31,7 +31,7 @@ include = [ "Cargo.toml", ] edition = { workspace = true } -rust-version = { workspace = true } +rust-version = "1.70.0" [lib] name = "arrow" @@ -60,10 +60,10 @@ arrow-select = { workspace = true } arrow-string = { workspace = true } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } -pyo3 = { version = "0.19", default-features = false, optional = true } +pyo3 = { version = "0.20", default-features = false, optional = true } [package.metadata.docs.rs] -features = ["prettyprint", "ipc_compression", "dyn_cmp_dict", "ffi", "pyarrow"] +features = ["prettyprint", "ipc_compression", "ffi", "pyarrow"] [features] default = ["csv", "ipc", "json"] @@ -71,7 +71,7 @@ ipc_compression = ["ipc", "arrow-ipc/lz4", "arrow-ipc/zstd"] csv = ["arrow-csv"] ipc = ["arrow-ipc"] json = ["arrow-json"] -simd = ["arrow-array/simd", "arrow-ord/simd", "arrow-arith/simd"] +simd = ["arrow-array/simd", "arrow-arith/simd"] prettyprint = ["arrow-cast/prettyprint"] # The test utils feature enables code used in benchmarks and tests but # not the core arrow code itself. Be aware that `rand` must be kept as @@ -85,13 +85,10 @@ pyarrow = ["pyo3", "ffi"] force_validate = ["arrow-data/force_validate"] # Enable ffi support ffi = ["arrow-schema/ffi", "arrow-data/ffi"] -# Enable dyn-comparison of dictionary arrays with other arrays -# Note: this does not impact comparison against scalars -dyn_cmp_dict = ["arrow-string/dyn_cmp_dict", "arrow-ord/dyn_cmp_dict"] chrono-tz = ["arrow-array/chrono-tz"] [dev-dependencies] -chrono = { version = "0.4.23", default-features = false, features = ["clock"] } +chrono = { workspace = true } criterion = { version = "0.5", default-features = false } half = { version = "2.1", default-features = false } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } @@ -295,3 +292,7 @@ required-features = ["chrono-tz", "prettyprint"] [[test]] name = "timezone" required-features = ["chrono-tz"] + +[[test]] +name = "arithmetic" +required-features = ["chrono-tz"] diff --git a/arrow/README.md b/arrow/README.md index fb2119e3bc15..6a91bc951cc1 100644 --- a/arrow/README.md +++ b/arrow/README.md @@ -54,7 +54,6 @@ The `arrow` crate provides the following features which may be enabled in your ` - `chrono-tz` - support of parsing timezone using [chrono-tz](https://docs.rs/chrono-tz/0.6.0/chrono_tz/) - `ffi` - bindings for the Arrow C [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) - `pyarrow` - bindings for pyo3 to call arrow-rs from python -- `dyn_cmp_dict` - enables comparison of dictionary arrays within dyn comparison kernels ## Arrow Feature Status diff --git a/arrow/benches/arithmetic_kernels.rs b/arrow/benches/arithmetic_kernels.rs index 4ed197783b07..e982b0eb4b5f 100644 --- a/arrow/benches/arithmetic_kernels.rs +++ b/arrow/benches/arithmetic_kernels.rs @@ -15,65 +15,61 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; -use rand::Rng; +use criterion::*; extern crate arrow; +use arrow::compute::kernels::numeric::*; use arrow::datatypes::Float32Type; use arrow::util::bench_util::*; -use arrow::{compute::kernels::arithmetic::*, util::test_util::seedable_rng}; +use arrow_array::Scalar; fn add_benchmark(c: &mut Criterion) { const BATCH_SIZE: usize = 64 * 1024; for null_density in [0., 0.1, 0.5, 0.9, 1.0] { let arr_a = create_primitive_array::(BATCH_SIZE, null_density); let arr_b = create_primitive_array::(BATCH_SIZE, null_density); - let scalar = seedable_rng().gen(); + let scalar_a = create_primitive_array::(1, 0.); + let scalar = Scalar::new(&scalar_a); c.bench_function(&format!("add({null_density})"), |b| { - b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(add_wrapping(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("add_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(add_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("add_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(add_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(add_wrapping(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("subtract({null_density})"), |b| { - b.iter(|| criterion::black_box(subtract(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(sub_wrapping(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("subtract_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(subtract_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(sub(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("subtract_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(subtract_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(sub_wrapping(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("multiply({null_density})"), |b| { - b.iter(|| criterion::black_box(multiply(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(mul_wrapping(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("multiply_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(multiply_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(mul(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("multiply_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(multiply_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(mul_wrapping(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("divide({null_density})"), |b| { - b.iter(|| criterion::black_box(divide(&arr_a, &arr_b).unwrap())) - }); - c.bench_function(&format!("divide_checked({null_density})"), |b| { - b.iter(|| criterion::black_box(divide_checked(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(div(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("divide_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(divide_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(div(&arr_a, &scalar).unwrap())) }); c.bench_function(&format!("modulo({null_density})"), |b| { - b.iter(|| criterion::black_box(modulus(&arr_a, &arr_b).unwrap())) + b.iter(|| criterion::black_box(rem(&arr_a, &arr_b).unwrap())) }); c.bench_function(&format!("modulo_scalar({null_density})"), |b| { - b.iter(|| criterion::black_box(modulus_scalar(&arr_a, scalar).unwrap())) + b.iter(|| criterion::black_box(rem(&arr_a, &scalar).unwrap())) }); } } diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs index 73db3ffed368..02de70c5d79d 100644 --- a/arrow/benches/comparison_kernels.rs +++ b/arrow/benches/comparison_kernels.rs @@ -21,78 +21,30 @@ use criterion::Criterion; extern crate arrow; -use arrow::compute::*; -use arrow::datatypes::{ArrowNativeTypeOp, ArrowNumericType, IntervalMonthDayNanoType}; +use arrow::compute::kernels::cmp::*; +use arrow::datatypes::IntervalMonthDayNanoType; use arrow::util::bench_util::*; use arrow::{array::*, datatypes::Float32Type, datatypes::Int32Type}; +use arrow_array::Scalar; +use arrow_string::like::*; +use arrow_string::regexp::regexp_is_match_utf8_scalar; const SIZE: usize = 65536; -fn bench_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_neq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - neq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_lt(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - lt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_lt_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - lt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_gt(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - gt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_gt_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - gt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - fn bench_like_utf8_scalar(arr_a: &StringArray, value_b: &str) { - like_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); + like(arr_a, &StringArray::new_scalar(value_b)).unwrap(); } fn bench_nlike_utf8_scalar(arr_a: &StringArray, value_b: &str) { - nlike_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)) - .unwrap(); + nlike(arr_a, &StringArray::new_scalar(value_b)).unwrap(); } fn bench_ilike_utf8_scalar(arr_a: &StringArray, value_b: &str) { - ilike_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)) - .unwrap(); + ilike(arr_a, &StringArray::new_scalar(value_b)).unwrap(); } fn bench_nilike_utf8_scalar(arr_a: &StringArray, value_b: &str) { - nilike_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)) - .unwrap(); + nilike(arr_a, &StringArray::new_scalar(value_b)).unwrap(); } fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) { @@ -104,27 +56,6 @@ fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) { .unwrap(); } -#[cfg(not(feature = "dyn_cmp_dict"))] -fn dyn_cmp_dict_benchmarks(_c: &mut Criterion) {} - -#[cfg(feature = "dyn_cmp_dict")] -fn dyn_cmp_dict_benchmarks(c: &mut Criterion) { - let strings = create_string_array::(20, 0.); - let dict_arr_a = create_dict_from_values::(SIZE, 0., &strings); - let dict_arr_b = create_dict_from_values::(SIZE, 0., &strings); - - c.bench_function("eq dictionary[10] string[4])", |b| { - b.iter(|| { - cmp_dict_utf8::<_, i32, _>( - criterion::black_box(&dict_arr_a), - criterion::black_box(&dict_arr_b), - |a, b| a == b, - ) - .unwrap() - }) - }); -} - fn add_benchmark(c: &mut Criterion) { let arr_a = create_primitive_array_with_seed::(SIZE, 0.0, 42); let arr_b = create_primitive_array_with_seed::(SIZE, 0.0, 43); @@ -135,105 +66,79 @@ fn add_benchmark(c: &mut Criterion) { create_primitive_array_with_seed::(SIZE, 0.0, 43); let arr_string = create_string_array::(SIZE, 0.0); + let scalar = Float32Array::from(vec![1.0]); - c.bench_function("eq Float32", |b| b.iter(|| bench_eq(&arr_a, &arr_b))); + c.bench_function("eq Float32", |b| b.iter(|| eq(&arr_a, &arr_b))); c.bench_function("eq scalar Float32", |b| { - b.iter(|| { - eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("neq Float32", |b| b.iter(|| bench_neq(&arr_a, &arr_b))); + c.bench_function("neq Float32", |b| b.iter(|| neq(&arr_a, &arr_b))); c.bench_function("neq scalar Float32", |b| { - b.iter(|| { - neq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| neq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("lt Float32", |b| b.iter(|| bench_lt(&arr_a, &arr_b))); + c.bench_function("lt Float32", |b| b.iter(|| lt(&arr_a, &arr_b))); c.bench_function("lt scalar Float32", |b| { - b.iter(|| { - lt_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| lt(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("lt_eq Float32", |b| b.iter(|| bench_lt_eq(&arr_a, &arr_b))); + c.bench_function("lt_eq Float32", |b| b.iter(|| lt_eq(&arr_a, &arr_b))); c.bench_function("lt_eq scalar Float32", |b| { - b.iter(|| { - lt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| lt_eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("gt Float32", |b| b.iter(|| bench_gt(&arr_a, &arr_b))); + c.bench_function("gt Float32", |b| b.iter(|| gt(&arr_a, &arr_b))); c.bench_function("gt scalar Float32", |b| { - b.iter(|| { - gt_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| gt(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("gt_eq Float32", |b| b.iter(|| bench_gt_eq(&arr_a, &arr_b))); + c.bench_function("gt_eq Float32", |b| b.iter(|| gt_eq(&arr_a, &arr_b))); c.bench_function("gt_eq scalar Float32", |b| { - b.iter(|| { - gt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| gt_eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); let arr_a = create_primitive_array_with_seed::(SIZE, 0.0, 42); let arr_b = create_primitive_array_with_seed::(SIZE, 0.0, 43); + let scalar = Int32Array::new_scalar(1); - c.bench_function("eq Int32", |b| b.iter(|| bench_eq(&arr_a, &arr_b))); + c.bench_function("eq Int32", |b| b.iter(|| eq(&arr_a, &arr_b))); c.bench_function("eq scalar Int32", |b| { - b.iter(|| { - eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| eq(&arr_a, &scalar).unwrap()) }); - c.bench_function("neq Int32", |b| b.iter(|| bench_neq(&arr_a, &arr_b))); + c.bench_function("neq Int32", |b| b.iter(|| neq(&arr_a, &arr_b))); c.bench_function("neq scalar Int32", |b| { - b.iter(|| { - neq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| neq(&arr_a, &scalar).unwrap()) }); - c.bench_function("lt Int32", |b| b.iter(|| bench_lt(&arr_a, &arr_b))); + c.bench_function("lt Int32", |b| b.iter(|| lt(&arr_a, &arr_b))); c.bench_function("lt scalar Int32", |b| { - b.iter(|| { - lt_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| lt(&arr_a, &scalar).unwrap()) }); - c.bench_function("lt_eq Int32", |b| b.iter(|| bench_lt_eq(&arr_a, &arr_b))); + c.bench_function("lt_eq Int32", |b| b.iter(|| lt_eq(&arr_a, &arr_b))); c.bench_function("lt_eq scalar Int32", |b| { - b.iter(|| { - lt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| lt_eq(&arr_a, &scalar).unwrap()) }); - c.bench_function("gt Int32", |b| b.iter(|| bench_gt(&arr_a, &arr_b))); + c.bench_function("gt Int32", |b| b.iter(|| gt(&arr_a, &arr_b))); c.bench_function("gt scalar Int32", |b| { - b.iter(|| { - gt_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| gt(&arr_a, &scalar).unwrap()) }); - c.bench_function("gt_eq Int32", |b| b.iter(|| bench_gt_eq(&arr_a, &arr_b))); + c.bench_function("gt_eq Int32", |b| b.iter(|| gt_eq(&arr_a, &arr_b))); c.bench_function("gt_eq scalar Int32", |b| { - b.iter(|| { - gt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| gt_eq(&arr_a, &scalar).unwrap()) }); c.bench_function("eq MonthDayNano", |b| { - b.iter(|| bench_eq(&arr_month_day_nano_a, &arr_month_day_nano_b)) + b.iter(|| eq(&arr_month_day_nano_a, &arr_month_day_nano_b)) }); + let scalar = IntervalMonthDayNanoArray::new_scalar(123); + c.bench_function("eq scalar MonthDayNano", |b| { - b.iter(|| { - eq_scalar( - criterion::black_box(&arr_month_day_nano_a), - criterion::black_box(123), - ) - .unwrap() - }) + b.iter(|| eq(&arr_month_day_nano_b, &scalar).unwrap()) }); c.bench_function("like_utf8 scalar equals", |b| { @@ -326,25 +231,32 @@ fn add_benchmark(c: &mut Criterion) { let strings = create_string_array::(20, 0.); let dict_arr_a = create_dict_from_values::(SIZE, 0., &strings); + let scalar = StringArray::from(vec!["test"]); c.bench_function("eq_dyn_utf8_scalar dictionary[10] string[4])", |b| { - b.iter(|| eq_dyn_utf8_scalar(&dict_arr_a, "test")) + b.iter(|| eq(&dict_arr_a, &Scalar::new(&scalar))) }); c.bench_function( "gt_eq_dyn_utf8_scalar scalar dictionary[10] string[4])", - |b| b.iter(|| gt_eq_dyn_utf8_scalar(&dict_arr_a, "test")), + |b| b.iter(|| gt_eq(&dict_arr_a, &Scalar::new(&scalar))), ); c.bench_function("like_utf8_scalar_dyn dictionary[10] string[4])", |b| { - b.iter(|| like_utf8_scalar_dyn(&dict_arr_a, "test")) + b.iter(|| like(&dict_arr_a, &StringArray::new_scalar("test"))) }); c.bench_function("ilike_utf8_scalar_dyn dictionary[10] string[4])", |b| { - b.iter(|| ilike_utf8_scalar_dyn(&dict_arr_a, "test")) + b.iter(|| ilike(&dict_arr_a, &StringArray::new_scalar("test"))) }); - dyn_cmp_dict_benchmarks(c); + let strings = create_string_array::(20, 0.); + let dict_arr_a = create_dict_from_values::(SIZE, 0., &strings); + let dict_arr_b = create_dict_from_values::(SIZE, 0., &strings); + + c.bench_function("eq dictionary[10] string[4])", |b| { + b.iter(|| eq(&dict_arr_a, &dict_arr_b).unwrap()) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/concatenate_kernel.rs b/arrow/benches/concatenate_kernel.rs index 3fff2abd179c..2f5b654394e4 100644 --- a/arrow/benches/concatenate_kernel.rs +++ b/arrow/benches/concatenate_kernel.rs @@ -60,6 +60,28 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("concat str nulls 1024", |b| { b.iter(|| bench_concat(&v1, &v2)) }); + + let v1 = create_string_array_with_len::(10, 0.0, 20); + let v1 = create_dict_from_values::(1024, 0.0, &v1); + let v2 = create_string_array_with_len::(10, 0.0, 20); + let v2 = create_dict_from_values::(1024, 0.0, &v2); + c.bench_function("concat str_dict 1024", |b| { + b.iter(|| bench_concat(&v1, &v2)) + }); + + let v1 = create_string_array_with_len::(1024, 0.0, 20); + let v1 = create_sparse_dict_from_values::(1024, 0.0, &v1, 10..20); + let v2 = create_string_array_with_len::(1024, 0.0, 20); + let v2 = create_sparse_dict_from_values::(1024, 0.0, &v2, 30..40); + c.bench_function("concat str_dict_sparse 1024", |b| { + b.iter(|| bench_concat(&v1, &v2)) + }); + + let v1 = create_string_array::(1024, 0.5); + let v2 = create_string_array::(1024, 0.5); + c.bench_function("concat str nulls 1024", |b| { + b.iter(|| bench_concat(&v1, &v2)) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/csv_reader.rs b/arrow/benches/csv_reader.rs index c2491a5a0b04..5a91dfe0a6ff 100644 --- a/arrow/benches/csv_reader.rs +++ b/arrow/benches/csv_reader.rs @@ -18,15 +18,18 @@ extern crate arrow; extern crate criterion; +use std::io::Cursor; +use std::sync::Arc; + use criterion::*; +use rand::Rng; use arrow::array::*; use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use arrow::util::bench_util::{create_primitive_array, create_string_array_with_len}; -use std::io::Cursor; -use std::sync::Arc; +use arrow::util::test_util::seedable_rng; fn do_bench(c: &mut Criterion, name: &str, cols: Vec) { let batch = RecordBatch::try_from_iter(cols.into_iter().map(|a| ("col", a))).unwrap(); @@ -42,7 +45,7 @@ fn do_bench(c: &mut Criterion, name: &str, cols: Vec) { let cursor = Cursor::new(buf.as_slice()); let reader = csv::ReaderBuilder::new(batch.schema()) .with_batch_size(batch_size) - .has_header(true) + .with_header(true) .build_buffered(cursor) .unwrap(); @@ -55,18 +58,49 @@ fn do_bench(c: &mut Criterion, name: &str, cols: Vec) { } fn criterion_benchmark(c: &mut Criterion) { - let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; + let mut rng = seedable_rng(); + + let values = Int32Array::from_iter_values((0..4096).map(|_| rng.gen_range(0..1024))); + let cols = vec![Arc::new(values) as ArrayRef]; + do_bench(c, "4096 i32_small(0)", cols); + + let values = Int32Array::from_iter_values((0..4096).map(|_| rng.gen())); + let cols = vec![Arc::new(values) as ArrayRef]; + do_bench(c, "4096 i32(0)", cols); + + let values = UInt64Array::from_iter_values((0..4096).map(|_| rng.gen_range(0..1024))); + let cols = vec![Arc::new(values) as ArrayRef]; + do_bench(c, "4096 u64_small(0)", cols); + + let values = UInt64Array::from_iter_values((0..4096).map(|_| rng.gen())); + let cols = vec![Arc::new(values) as ArrayRef]; do_bench(c, "4096 u64(0)", cols); - let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; + let values = + Int64Array::from_iter_values((0..4096).map(|_| rng.gen_range(0..1024) - 512)); + let cols = vec![Arc::new(values) as ArrayRef]; + do_bench(c, "4096 i64_small(0)", cols); + + let values = Int64Array::from_iter_values((0..4096).map(|_| rng.gen())); + let cols = vec![Arc::new(values) as ArrayRef]; do_bench(c, "4096 i64(0)", cols); - let cols = - vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; + let cols = vec![Arc::new(Float32Array::from_iter_values( + (0..4096).map(|_| rng.gen_range(0..1024000) as f32 / 1000.), + )) as _]; + do_bench(c, "4096 f32_small(0)", cols); + + let values = Float32Array::from_iter_values((0..4096).map(|_| rng.gen())); + let cols = vec![Arc::new(values) as ArrayRef]; do_bench(c, "4096 f32(0)", cols); - let cols = - vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; + let cols = vec![Arc::new(Float64Array::from_iter_values( + (0..4096).map(|_| rng.gen_range(0..1024000) as f64 / 1000.), + )) as _]; + do_bench(c, "4096 f64_small(0)", cols); + + let values = Float64Array::from_iter_values((0..4096).map(|_| rng.gen())); + let cols = vec![Arc::new(values) as ArrayRef]; do_bench(c, "4096 f64(0)", cols); let cols = diff --git a/arrow/benches/equal.rs b/arrow/benches/equal.rs index 2f4e2fada9e9..4e99bf3071c9 100644 --- a/arrow/benches/equal.rs +++ b/arrow/benches/equal.rs @@ -20,7 +20,6 @@ #[macro_use] extern crate criterion; -use arrow::compute::eq_utf8_scalar; use criterion::Criterion; extern crate arrow; @@ -32,10 +31,6 @@ fn bench_equal>(arr_a: &A) { criterion::black_box(arr_a == arr_a); } -fn bench_equal_utf8_scalar(arr_a: &GenericStringArray, right: &str) { - criterion::black_box(eq_utf8_scalar(arr_a, right).unwrap()); -} - fn add_benchmark(c: &mut Criterion) { let arr_a = create_primitive_array::(512, 0.0); c.bench_function("equal_512", |b| b.iter(|| bench_equal(&arr_a))); @@ -49,11 +44,6 @@ fn add_benchmark(c: &mut Criterion) { let arr_a = create_string_array::(512, 0.0); c.bench_function("equal_string_512", |b| b.iter(|| bench_equal(&arr_a))); - let arr_a = create_string_array::(512, 0.0); - c.bench_function("equal_string_scalar_empty_512", |b| { - b.iter(|| bench_equal_utf8_scalar(&arr_a, "")) - }); - let arr_a_nulls = create_string_array::(512, 0.5); c.bench_function("equal_string_nulls_512", |b| { b.iter(|| bench_equal(&arr_a_nulls)) diff --git a/arrow/benches/interleave_kernels.rs b/arrow/benches/interleave_kernels.rs index 2bb430e40b0f..454d9140809c 100644 --- a/arrow/benches/interleave_kernels.rs +++ b/arrow/benches/interleave_kernels.rs @@ -37,14 +37,21 @@ fn do_bench( base: &dyn Array, slices: &[Range], ) { - let mut rng = seedable_rng(); - let arrays: Vec<_> = slices .iter() .map(|r| base.slice(r.start, r.end - r.start)) .collect(); let values: Vec<_> = arrays.iter().map(|x| x.as_ref()).collect(); + bench_values( + c, + &format!("interleave {prefix} {len} {slices:?}"), + len, + &values, + ); +} +fn bench_values(c: &mut Criterion, name: &str, len: usize, values: &[&dyn Array]) { + let mut rng = seedable_rng(); let indices: Vec<_> = (0..len) .map(|_| { let array_idx = rng.gen_range(0..values.len()); @@ -53,8 +60,8 @@ fn do_bench( }) .collect(); - c.bench_function(&format!("interleave {prefix} {len} {slices:?}"), |b| { - b.iter(|| criterion::black_box(interleave(&values, &indices).unwrap())) + c.bench_function(name, |b| { + b.iter(|| criterion::black_box(interleave(values, &indices).unwrap())) }); } @@ -63,12 +70,20 @@ fn add_benchmark(c: &mut Criterion) { let i32_opt = create_primitive_array::(1024, 0.5); let string = create_string_array_with_len::(1024, 0., 20); let string_opt = create_string_array_with_len::(1024, 0.5, 20); + let values = create_string_array_with_len::(10, 0.0, 20); + let dict = create_dict_from_values::(1024, 0.0, &values); + + let values = create_string_array_with_len::(1024, 0.0, 20); + let sparse_dict = + create_sparse_dict_from_values::(1024, 0.0, &values, 10..20); let cases: &[(&str, &dyn Array)] = &[ ("i32(0.0)", &i32), ("i32(0.5)", &i32_opt), ("str(20, 0.0)", &string), ("str(20, 0.5)", &string_opt), + ("dict(20, 0.0)", &dict), + ("dict_sparse(20, 0.0)", &sparse_dict), ]; for (prefix, base) in cases { @@ -83,6 +98,15 @@ fn add_benchmark(c: &mut Criterion) { do_bench(c, prefix, *len, *base, slice); } } + + for len in [100, 1024, 2048] { + bench_values( + c, + &format!("interleave dict_distinct {len}"), + 100, + &[&dict, &sparse_dict], + ); + } } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/lexsort.rs b/arrow/benches/lexsort.rs index 30dab9a74667..25b2279be8d6 100644 --- a/arrow/benches/lexsort.rs +++ b/arrow/benches/lexsort.rs @@ -100,7 +100,7 @@ fn do_bench(c: &mut Criterion, columns: &[Column], len: usize) { .iter() .map(|a| SortField::new(a.data_type().clone())) .collect(); - let mut converter = RowConverter::new(fields).unwrap(); + let converter = RowConverter::new(fields).unwrap(); let rows = converter.convert_columns(&arrays).unwrap(); let mut sort: Vec<_> = rows.iter().enumerate().collect(); sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); diff --git a/arrow/benches/partition_kernels.rs b/arrow/benches/partition_kernels.rs index ae55fbdad22c..85cafbe47a11 100644 --- a/arrow/benches/partition_kernels.rs +++ b/arrow/benches/partition_kernels.rs @@ -20,13 +20,13 @@ extern crate criterion; use criterion::Criterion; use std::sync::Arc; extern crate arrow; -use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::{lexsort, SortColumn}; use arrow::util::bench_util::*; use arrow::{ array::*, datatypes::{ArrowPrimitiveType, Float64Type, UInt8Type}, }; +use arrow_ord::partition::partition; use rand::distributions::{Distribution, Standard}; use std::iter; @@ -40,19 +40,7 @@ where } fn bench_partition(sorted_columns: &[ArrayRef]) { - let columns = sorted_columns - .iter() - .map(|arr| SortColumn { - values: arr.clone(), - options: None, - }) - .collect::>(); - - criterion::black_box( - lexicographical_partition_ranges(&columns) - .unwrap() - .collect::>(), - ); + criterion::black_box(partition(sorted_columns).unwrap().ranges()); } fn create_sorted_low_cardinality_data(length: usize) -> Vec { @@ -109,37 +97,34 @@ fn create_sorted_data(pow: u32, with_nulls: bool) -> Vec { fn add_benchmark(c: &mut Criterion) { let sorted_columns = create_sorted_data(10, false); - c.bench_function("lexicographical_partition_ranges(u8) 2^10", |b| { + c.bench_function("partition(u8) 2^10", |b| { b.iter(|| bench_partition(&sorted_columns)) }); let sorted_columns = create_sorted_data(12, false); - c.bench_function("lexicographical_partition_ranges(u8) 2^12", |b| { + c.bench_function("partition(u8) 2^12", |b| { b.iter(|| bench_partition(&sorted_columns)) }); let sorted_columns = create_sorted_data(10, true); - c.bench_function( - "lexicographical_partition_ranges(u8) 2^10 with nulls", - |b| b.iter(|| bench_partition(&sorted_columns)), - ); + c.bench_function("partition(u8) 2^10 with nulls", |b| { + b.iter(|| bench_partition(&sorted_columns)) + }); let sorted_columns = create_sorted_data(12, true); - c.bench_function( - "lexicographical_partition_ranges(u8) 2^12 with nulls", - |b| b.iter(|| bench_partition(&sorted_columns)), - ); + c.bench_function("partition(u8) 2^12 with nulls", |b| { + b.iter(|| bench_partition(&sorted_columns)) + }); let sorted_columns = create_sorted_float_data(10, false); - c.bench_function("lexicographical_partition_ranges(f64) 2^10", |b| { + c.bench_function("partition(f64) 2^10", |b| { b.iter(|| bench_partition(&sorted_columns)) }); let sorted_columns = create_sorted_low_cardinality_data(1024); - c.bench_function( - "lexicographical_partition_ranges(low cardinality) 1024", - |b| b.iter(|| bench_partition(&sorted_columns)), - ); + c.bench_function("partition(low cardinality) 1024", |b| { + b.iter(|| bench_partition(&sorted_columns)) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/row_format.rs b/arrow/benches/row_format.rs index 12ce71764f7e..bde117e3ec3e 100644 --- a/arrow/benches/row_format.rs +++ b/arrow/benches/row_format.rs @@ -23,35 +23,28 @@ use arrow::array::ArrayRef; use arrow::datatypes::{Int64Type, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow::util::bench_util::{ - create_primitive_array, create_string_array_with_len, create_string_dict_array, + create_dict_from_values, create_primitive_array, create_string_array_with_len, + create_string_dict_array, }; use arrow_array::types::Int32Type; use arrow_array::Array; use criterion::{black_box, Criterion}; use std::sync::Arc; -fn do_bench( - c: &mut Criterion, - name: &str, - cols: Vec, - preserve_dictionaries: bool, -) { +fn do_bench(c: &mut Criterion, name: &str, cols: Vec) { let fields: Vec<_> = cols .iter() - .map(|x| { - SortField::new(x.data_type().clone()) - .preserve_dictionaries(preserve_dictionaries) - }) + .map(|x| SortField::new(x.data_type().clone())) .collect(); c.bench_function(&format!("convert_columns {name}"), |b| { b.iter(|| { - let mut converter = RowConverter::new(fields.clone()).unwrap(); + let converter = RowConverter::new(fields.clone()).unwrap(); black_box(converter.convert_columns(&cols).unwrap()) }); }); - let mut converter = RowConverter::new(fields).unwrap(); + let converter = RowConverter::new(fields).unwrap(); let rows = converter.convert_columns(&cols).unwrap(); // using a pre-prepared row converter should be faster than the first time c.bench_function(&format!("convert_columns_prepared {name}"), |b| { @@ -65,46 +58,57 @@ fn do_bench( fn row_bench(c: &mut Criterion) { let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; - do_bench(c, "4096 u64(0)", cols, true); + do_bench(c, "4096 u64(0)", cols); let cols = vec![Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef]; - do_bench(c, "4096 i64(0)", cols, true); + do_bench(c, "4096 i64(0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0., 10)) as ArrayRef]; - do_bench(c, "4096 string(10, 0)", cols, true); + do_bench(c, "4096 string(10, 0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0., 30)) as ArrayRef]; - do_bench(c, "4096 string(30, 0)", cols, true); + do_bench(c, "4096 string(30, 0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0., 100)) as ArrayRef]; - do_bench(c, "4096 string(100, 0)", cols, true); + do_bench(c, "4096 string(100, 0)", cols); let cols = vec![Arc::new(create_string_array_with_len::(4096, 0.5, 100)) as ArrayRef]; - do_bench(c, "4096 string(100, 0.5)", cols, true); + do_bench(c, "4096 string(100, 0.5)", cols); let cols = vec![Arc::new(create_string_dict_array::(4096, 0., 10)) as ArrayRef]; - do_bench(c, "4096 string_dictionary(10, 0)", cols, true); + do_bench(c, "4096 string_dictionary(10, 0)", cols); let cols = vec![Arc::new(create_string_dict_array::(4096, 0., 30)) as ArrayRef]; - do_bench(c, "4096 string_dictionary(30, 0)", cols, true); + do_bench(c, "4096 string_dictionary(30, 0)", cols); let cols = vec![Arc::new(create_string_dict_array::(4096, 0., 100)) as ArrayRef]; - do_bench(c, "4096 string_dictionary(100, 0)", cols.clone(), true); - let name = "4096 string_dictionary_non_preserving(100, 0)"; - do_bench(c, name, cols, false); + do_bench(c, "4096 string_dictionary(100, 0)", cols.clone()); let cols = vec![Arc::new(create_string_dict_array::(4096, 0.5, 100)) as ArrayRef]; - do_bench(c, "4096 string_dictionary(100, 0.5)", cols.clone(), true); - let name = "4096 string_dictionary_non_preserving(100, 0.5)"; - do_bench(c, name, cols, false); + do_bench(c, "4096 string_dictionary(100, 0.5)", cols.clone()); + + let values = create_string_array_with_len::(10, 0., 10); + let dict = create_dict_from_values::(4096, 0., &values); + let cols = vec![Arc::new(dict) as ArrayRef]; + do_bench(c, "4096 string_dictionary_low_cardinality(10, 0)", cols); + + let values = create_string_array_with_len::(10, 0., 30); + let dict = create_dict_from_values::(4096, 0., &values); + let cols = vec![Arc::new(dict) as ArrayRef]; + do_bench(c, "4096 string_dictionary_low_cardinality(30, 0)", cols); + + let values = create_string_array_with_len::(10, 0., 100); + let dict = create_dict_from_values::(4096, 0., &values); + let cols = vec![Arc::new(dict) as ArrayRef]; + do_bench(c, "4096 string_dictionary_low_cardinality(100, 0)", cols); let cols = vec![ Arc::new(create_string_array_with_len::(4096, 0.5, 20)) as ArrayRef, @@ -116,7 +120,6 @@ fn row_bench(c: &mut Criterion) { c, "4096 string(20, 0.5), string(30, 0), string(100, 0), i64(0)", cols, - false, ); let cols = vec![ @@ -125,7 +128,7 @@ fn row_bench(c: &mut Criterion) { Arc::new(create_string_dict_array::(4096, 0., 100)) as ArrayRef, Arc::new(create_primitive_array::(4096, 0.)) as ArrayRef, ]; - do_bench(c, "4096 4096 string_dictionary(20, 0.5), string_dictionary(30, 0), string_dictionary(100, 0), i64(0)", cols, false); + do_bench(c, "4096 4096 string_dictionary(20, 0.5), string_dictionary(30, 0), string_dictionary(100, 0), i64(0)", cols); } criterion_group!(benches, row_bench); diff --git a/arrow/benches/sort_kernel.rs b/arrow/benches/sort_kernel.rs index 43a9a84d9a74..63e10e0528ba 100644 --- a/arrow/benches/sort_kernel.rs +++ b/arrow/benches/sort_kernel.rs @@ -17,17 +17,17 @@ #[macro_use] extern crate criterion; -use criterion::Criterion; +use criterion::{black_box, Criterion}; use std::sync::Arc; extern crate arrow; -use arrow::compute::kernels::sort::{lexsort, SortColumn}; -use arrow::compute::{sort_limit, sort_to_indices}; +use arrow::compute::{lexsort, sort, sort_to_indices, SortColumn}; use arrow::datatypes::{Int16Type, Int32Type}; use arrow::util::bench_util::*; use arrow::{array::*, datatypes::Float32Type}; +use arrow_ord::rank::rank; fn create_f32_array(size: usize, with_nulls: bool) -> ArrayRef { let null_density = if with_nulls { 0.5 } else { 0.0 }; @@ -42,7 +42,11 @@ fn create_bool_array(size: usize, with_nulls: bool) -> ArrayRef { Arc::new(array) } -fn bench_sort(array_a: &ArrayRef, array_b: &ArrayRef, limit: Option) { +fn bench_sort(array: &dyn Array) { + black_box(sort(array, None).unwrap()); +} + +fn bench_lexsort(array_a: &ArrayRef, array_b: &ArrayRef, limit: Option) { let columns = vec![ SortColumn { values: array_a.clone(), @@ -54,118 +58,182 @@ fn bench_sort(array_a: &ArrayRef, array_b: &ArrayRef, limit: Option) { }, ]; - criterion::black_box(lexsort(&columns, limit).unwrap()); -} - -fn bench_sort_to_indices(array: &ArrayRef, limit: Option) { - criterion::black_box(sort_to_indices(array, None, limit).unwrap()); + black_box(lexsort(&columns, limit).unwrap()); } -fn bench_sort_run(array: &ArrayRef, limit: Option) { - criterion::black_box(sort_limit(array, None, limit).unwrap()); +fn bench_sort_to_indices(array: &dyn Array, limit: Option) { + black_box(sort_to_indices(array, None, limit).unwrap()); } fn add_benchmark(c: &mut Criterion) { - let arr_a = create_f32_array(2u64.pow(10) as usize, false); - let arr_b = create_f32_array(2u64.pow(10) as usize, false); - - c.bench_function("sort 2^10", |b| b.iter(|| bench_sort(&arr_a, &arr_b, None))); + let arr = create_primitive_array::(2usize.pow(10), 0.0); + c.bench_function("sort i32 2^10", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 to indices 2^10", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); - let arr_a = create_f32_array(2u64.pow(12) as usize, false); - let arr_b = create_f32_array(2u64.pow(12) as usize, false); + let arr = create_primitive_array::(2usize.pow(12), 0.0); + c.bench_function("sort i32 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); - c.bench_function("sort 2^12", |b| b.iter(|| bench_sort(&arr_a, &arr_b, None))); + let arr = create_primitive_array::(2usize.pow(10), 0.5); + c.bench_function("sort i32 nulls 2^10", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 nulls to indices 2^10", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); - let arr_a = create_f32_array(2u64.pow(10) as usize, true); - let arr_b = create_f32_array(2u64.pow(10) as usize, true); + let arr = create_primitive_array::(2usize.pow(12), 0.5); + c.bench_function("sort i32 nulls 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort i32 nulls to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); - c.bench_function("sort nulls 2^10", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, None)) + let arr = create_f32_array(2_usize.pow(12), false); + c.bench_function("sort f32 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort f32 to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) }); - let arr_a = create_f32_array(2u64.pow(12) as usize, true); - let arr_b = create_f32_array(2u64.pow(12) as usize, true); + let arr = create_f32_array(2usize.pow(12), true); + c.bench_function("sort f32 nulls 2^12", |b| b.iter(|| bench_sort(&arr))); + c.bench_function("sort f32 nulls to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) + }); - c.bench_function("sort nulls 2^12", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, None)) + let arr = create_string_array_with_len::(2usize.pow(12), 0.0, 10); + c.bench_function("sort string[10] to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) }); - let arr_a = create_bool_array(2u64.pow(12) as usize, false); - let arr_b = create_bool_array(2u64.pow(12) as usize, false); - c.bench_function("bool sort 2^12", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, None)) + let arr = create_string_array_with_len::(2usize.pow(12), 0.5, 10); + c.bench_function("sort string[10] nulls to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) }); - let arr_a = create_bool_array(2u64.pow(12) as usize, true); - let arr_b = create_bool_array(2u64.pow(12) as usize, true); - c.bench_function("bool sort nulls 2^12", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, None)) + let arr = create_string_dict_array::(2usize.pow(12), 0.0, 10); + c.bench_function("sort string[10] dict to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) }); - let dict_arr = Arc::new(create_string_dict_array::( - 2u64.pow(12) as usize, - 0.0, - 1, - )) as ArrayRef; - c.bench_function("dict string 2^12", |b| { - b.iter(|| bench_sort_to_indices(&dict_arr, None)) + let arr = create_string_dict_array::(2usize.pow(12), 0.5, 10); + c.bench_function("sort string[10] dict nulls to indices 2^12", |b| { + b.iter(|| bench_sort_to_indices(&arr, None)) }); - let run_encoded_array = Arc::new(create_primitive_run_array::( - 2u64.pow(12) as usize, - 2u64.pow(10) as usize, - )) as ArrayRef; + let run_encoded_array = create_primitive_run_array::( + 2usize.pow(12), + 2usize.pow(10), + ); + + c.bench_function("sort primitive run 2^12", |b| { + b.iter(|| bench_sort(&run_encoded_array)) + }); c.bench_function("sort primitive run to indices 2^12", |b| { b.iter(|| bench_sort_to_indices(&run_encoded_array, None)) }); - c.bench_function("sort primitive run to run 2^12", |b| { - b.iter(|| bench_sort_run(&run_encoded_array, None)) - }); - - // with limit - { - let arr_a = create_f32_array(2u64.pow(12) as usize, false); - let arr_b = create_f32_array(2u64.pow(12) as usize, false); - c.bench_function("sort 2^12 limit 10", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) - }); - - let arr_a = create_f32_array(2u64.pow(12) as usize, false); - let arr_b = create_f32_array(2u64.pow(12) as usize, false); - c.bench_function("sort 2^12 limit 100", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(100))) - }); - - let arr_a = create_f32_array(2u64.pow(12) as usize, false); - let arr_b = create_f32_array(2u64.pow(12) as usize, false); - c.bench_function("sort 2^12 limit 1000", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) - }); - - let arr_a = create_f32_array(2u64.pow(12) as usize, false); - let arr_b = create_f32_array(2u64.pow(12) as usize, false); - c.bench_function("sort 2^12 limit 2^12", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(2u64.pow(12) as usize))) - }); - - let arr_a = create_f32_array(2u64.pow(12) as usize, true); - let arr_b = create_f32_array(2u64.pow(12) as usize, true); - - c.bench_function("sort nulls 2^12 limit 10", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(10))) - }); - c.bench_function("sort nulls 2^12 limit 100", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(100))) - }); - c.bench_function("sort nulls 2^12 limit 1000", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(1000))) - }); - c.bench_function("sort nulls 2^12 limit 2^12", |b| { - b.iter(|| bench_sort(&arr_a, &arr_b, Some(2u64.pow(12) as usize))) - }); - } + let arr_a = create_f32_array(2usize.pow(10), false); + let arr_b = create_f32_array(2usize.pow(10), false); + + c.bench_function("lexsort (f32, f32) 2^10", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, None)) + }); + + let arr_a = create_f32_array(2usize.pow(12), false); + let arr_b = create_f32_array(2usize.pow(12), false); + + c.bench_function("lexsort (f32, f32) 2^12", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, None)) + }); + + let arr_a = create_f32_array(2usize.pow(10), true); + let arr_b = create_f32_array(2usize.pow(10), true); + + c.bench_function("lexsort (f32, f32) nulls 2^10", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, None)) + }); + + let arr_a = create_f32_array(2usize.pow(12), true); + let arr_b = create_f32_array(2usize.pow(12), true); + + c.bench_function("lexsort (f32, f32) nulls 2^12", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, None)) + }); + + let arr_a = create_bool_array(2usize.pow(12), false); + let arr_b = create_bool_array(2usize.pow(12), false); + c.bench_function("lexsort (bool, bool) 2^12", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, None)) + }); + + let arr_a = create_bool_array(2usize.pow(12), true); + let arr_b = create_bool_array(2usize.pow(12), true); + c.bench_function("lexsort (bool, bool) nulls 2^12", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, None)) + }); + + let arr_a = create_f32_array(2usize.pow(12), false); + let arr_b = create_f32_array(2usize.pow(12), false); + c.bench_function("lexsort (f32, f32) 2^12 limit 10", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(10))) + }); + + let arr_a = create_f32_array(2usize.pow(12), false); + let arr_b = create_f32_array(2usize.pow(12), false); + c.bench_function("lexsort (f32, f32) 2^12 limit 100", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(100))) + }); + + let arr_a = create_f32_array(2usize.pow(12), false); + let arr_b = create_f32_array(2usize.pow(12), false); + c.bench_function("lexsort (f32, f32) 2^12 limit 1000", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(1000))) + }); + + let arr_a = create_f32_array(2usize.pow(12), false); + let arr_b = create_f32_array(2usize.pow(12), false); + c.bench_function("lexsort (f32, f32) 2^12 limit 2^12", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(2usize.pow(12)))) + }); + + let arr_a = create_f32_array(2usize.pow(12), true); + let arr_b = create_f32_array(2usize.pow(12), true); + + c.bench_function("lexsort (f32, f32) nulls 2^12 limit 10", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(10))) + }); + c.bench_function("lexsort (f32, f32) nulls 2^12 limit 100", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(100))) + }); + c.bench_function("lexsort (f32, f32) nulls 2^12 limit 1000", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(1000))) + }); + c.bench_function("lexsort (f32, f32) nulls 2^12 limit 2^12", |b| { + b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(2usize.pow(12)))) + }); + + let arr = create_f32_array(2usize.pow(12), false); + c.bench_function("rank f32 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); + + let arr = create_f32_array(2usize.pow(12), true); + c.bench_function("rank f32 nulls 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); + + let arr = create_string_array_with_len::(2usize.pow(12), 0.0, 10); + c.bench_function("rank string[10] 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); + + let arr = create_string_array_with_len::(2usize.pow(12), 0.5, 10); + c.bench_function("rank string[10] nulls 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/examples/dynamic_types.rs b/arrow/examples/dynamic_types.rs index 8ec473c76d56..21edb235aaa7 100644 --- a/arrow/examples/dynamic_types.rs +++ b/arrow/examples/dynamic_types.rs @@ -23,7 +23,6 @@ extern crate arrow; use arrow::array::*; use arrow::datatypes::*; use arrow::error::Result; -use arrow::record_batch::*; #[cfg(feature = "prettyprint")] use arrow::util::pretty::print_batches; diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index ff3a170c698a..fa01f4c4c15b 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -23,10 +23,10 @@ mod ffi; // --------------------- Array & ArrayData --------------------- -pub use arrow_array::array::*; pub use arrow_array::builder::*; pub use arrow_array::cast::*; pub use arrow_array::iterator::*; +pub use arrow_array::*; pub use arrow_data::{ layout, ArrayData, ArrayDataBuilder, ArrayDataRef, BufferSpec, DataTypeLayout, }; diff --git a/arrow/src/compute/kernels/mod.rs b/arrow/src/compute/kernels.rs similarity index 89% rename from arrow/src/compute/kernels/mod.rs rename to arrow/src/compute/kernels.rs index d9c948c607bd..35ad80e009cc 100644 --- a/arrow/src/compute/kernels/mod.rs +++ b/arrow/src/compute/kernels.rs @@ -17,12 +17,12 @@ //! Computation kernels on Arrow Arrays -pub mod limit; - -pub use arrow_arith::{aggregate, arithmetic, arity, bitwise, boolean, temporal}; +pub use arrow_arith::{ + aggregate, arithmetic, arity, bitwise, boolean, numeric, temporal, +}; pub use arrow_cast::cast; pub use arrow_cast::parse as cast_utils; -pub use arrow_ord::{partition, sort}; +pub use arrow_ord::{cmp, partition, rank, sort}; pub use arrow_select::{concat, filter, interleave, nullif, take, window, zip}; pub use arrow_string::{concat_elements, length, regexp, substring}; diff --git a/arrow/src/compute/kernels/limit.rs b/arrow/src/compute/kernels/limit.rs deleted file mode 100644 index 097b8e949443..000000000000 --- a/arrow/src/compute/kernels/limit.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines miscellaneous array kernels. - -use crate::array::ArrayRef; - -/// Returns the array, taking only the number of elements specified -/// -/// Limit performs a zero-copy slice of the array, and is a convenience method on slice -/// where: -/// * it performs a bounds-check on the array -/// * it slices from offset 0 -#[deprecated(note = "Use Array::slice")] -pub fn limit(array: &ArrayRef, num_elements: usize) -> ArrayRef { - let lim = num_elements.min(array.len()); - array.slice(0, lim) -} - -#[cfg(test)] -#[allow(deprecated)] -mod tests { - use super::*; - use crate::array::*; - use crate::buffer::Buffer; - use crate::datatypes::{DataType, Field}; - use crate::util::bit_util; - - use std::sync::Arc; - - #[test] - fn test_limit_array() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![5, 6, 7, 8, 9])); - let b = limit(&a, 3); - let c = b.as_ref().as_any().downcast_ref::().unwrap(); - assert_eq!(3, c.len()); - assert_eq!(5, c.value(0)); - assert_eq!(6, c.value(1)); - assert_eq!(7, c.value(2)); - } - - #[test] - fn test_limit_string_array() { - let a: ArrayRef = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"])); - let b = limit(&a, 2); - let c = b.as_ref().as_any().downcast_ref::().unwrap(); - assert_eq!(2, c.len()); - assert_eq!("hello", c.value(0)); - assert_eq!(" ", c.value(1)); - } - - #[test] - fn test_limit_array_with_null() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(5)])); - let b = limit(&a, 1); - let c = b.as_ref().as_any().downcast_ref::().unwrap(); - assert_eq!(1, c.len()); - assert!(c.is_null(0)); - } - - #[test] - fn test_limit_array_with_limit_too_large() { - let a = Int32Array::from(vec![5, 6, 7, 8, 9]); - let a_ref: ArrayRef = Arc::new(a); - let b = limit(&a_ref, 6); - let c = b.as_ref().as_any().downcast_ref::().unwrap(); - - assert_eq!(5, c.len()); - assert_eq!(5, c.value(0)); - assert_eq!(6, c.value(1)); - assert_eq!(7, c.value(2)); - assert_eq!(8, c.value(3)); - assert_eq!(9, c.value(4)); - } - - #[test] - fn test_list_array_limit() { - // adapted from crate::array::test::test_list_array_slice - // Construct a value array - let value_data = ArrayData::builder(DataType::Int32) - .len(10) - .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) - .build() - .unwrap(); - - // Construct a buffer for value offsets, for the nested array: - // [[0, 1], null, [2, 3], null, [4, 5], null, [6, 7, 8], null, [9]] - let value_offsets = Buffer::from_slice_ref([0, 2, 2, 4, 4, 6, 6, 9, 9, 10]); - // 01010101 00000001 - let mut null_bits: [u8; 2] = [0; 2]; - bit_util::set_bit(&mut null_bits, 0); - bit_util::set_bit(&mut null_bits, 2); - bit_util::set_bit(&mut null_bits, 4); - bit_util::set_bit(&mut null_bits, 6); - bit_util::set_bit(&mut null_bits, 8); - - // Construct a list array from the above two - let list_data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); - let list_data = ArrayData::builder(list_data_type) - .len(9) - .add_buffer(value_offsets) - .add_child_data(value_data) - .null_bit_buffer(Some(Buffer::from(null_bits))) - .build() - .unwrap(); - let list_array: ArrayRef = Arc::new(ListArray::from(list_data)); - - let limit_array = limit(&list_array, 6); - assert_eq!(6, limit_array.len()); - assert_eq!(0, limit_array.offset()); - assert_eq!(3, limit_array.null_count()); - - // Check offset and length for each non-null value. - let limit_array: &ListArray = - limit_array.as_any().downcast_ref::().unwrap(); - - for i in 0..limit_array.len() { - let offset = limit_array.value_offsets()[i]; - let length = limit_array.value_length(i); - if i % 2 == 0 { - assert_eq!(2, length); - assert_eq!(i as i32, offset); - } else { - assert_eq!(0, length); - } - } - } - - #[test] - fn test_struct_array_limit() { - // adapted from crate::array::test::test_struct_array_slice - let boolean_data = ArrayData::builder(DataType::Boolean) - .len(5) - .add_buffer(Buffer::from([0b00010000])) - .null_bit_buffer(Some(Buffer::from([0b00010001]))) - .build() - .unwrap(); - let int_data = ArrayData::builder(DataType::Int32) - .len(5) - .add_buffer(Buffer::from_slice_ref([0, 28, 42, 0, 0])) - .null_bit_buffer(Some(Buffer::from([0b00000110]))) - .build() - .unwrap(); - - let field_types = vec![ - Field::new("a", DataType::Boolean, true), - Field::new("b", DataType::Int32, true), - ]; - let struct_array_data = ArrayData::builder(DataType::Struct(field_types.into())) - .len(5) - .add_child_data(boolean_data.clone()) - .add_child_data(int_data.clone()) - .null_bit_buffer(Some(Buffer::from([0b00010111]))) - .build() - .unwrap(); - let struct_array = StructArray::from(struct_array_data); - - assert_eq!(5, struct_array.len()); - assert_eq!(1, struct_array.null_count()); - assert_eq!(boolean_data, struct_array.column(0).to_data()); - assert_eq!(int_data, struct_array.column(1).to_data()); - - let array: ArrayRef = Arc::new(struct_array); - - let sliced_array = limit(&array, 3); - let sliced_array = sliced_array.as_any().downcast_ref::().unwrap(); - assert_eq!(3, sliced_array.len()); - assert_eq!(0, sliced_array.offset()); - assert_eq!(0, sliced_array.null_count()); - assert!(sliced_array.is_valid(0)); - assert!(sliced_array.is_valid(1)); - assert!(sliced_array.is_valid(2)); - - let sliced_c0 = sliced_array.column(0); - let sliced_c0 = sliced_c0.as_any().downcast_ref::().unwrap(); - assert_eq!(3, sliced_c0.len()); - assert_eq!(0, sliced_c0.offset()); - assert_eq!(2, sliced_c0.null_count()); - assert!(sliced_c0.is_valid(0)); - assert!(sliced_c0.is_null(1)); - assert!(sliced_c0.is_null(2)); - assert!(!sliced_c0.value(0)); - - let sliced_c1 = sliced_array.column(1); - let sliced_c1 = sliced_c1.as_any().downcast_ref::().unwrap(); - assert_eq!(3, sliced_c1.len()); - assert_eq!(0, sliced_c1.offset()); - assert_eq!(1, sliced_c1.null_count()); - assert!(sliced_c1.is_null(0)); - assert_eq!(28, sliced_c1.value(1)); - assert_eq!(42, sliced_c1.value(2)); - } -} diff --git a/arrow/src/compute/mod.rs b/arrow/src/compute/mod.rs index c9fd525e85a4..47a9d149aadb 100644 --- a/arrow/src/compute/mod.rs +++ b/arrow/src/compute/mod.rs @@ -28,9 +28,9 @@ pub use self::kernels::comparison::*; pub use self::kernels::concat::*; pub use self::kernels::filter::*; pub use self::kernels::interleave::*; -pub use self::kernels::limit::*; pub use self::kernels::nullif::*; pub use self::kernels::partition::*; +pub use self::kernels::rank::*; pub use self::kernels::regexp::*; pub use self::kernels::sort::*; pub use self::kernels::take::*; diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index 840e98ab0ded..bc5b7d500b18 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -18,9 +18,9 @@ //! Defines the logical data types of Arrow arrays. //! //! The most important things you might be looking for are: -//! * [`Schema`](crate::datatypes::Schema) to describe a schema. -//! * [`Field`](crate::datatypes::Field) to describe one field within a schema. -//! * [`DataType`](crate::datatypes::DataType) to describe the type of a field. +//! * [`Schema`] to describe a schema. +//! * [`Field`] to describe one field within a schema. +//! * [`DataType`] to describe the type of a field. pub use arrow_array::types::*; pub use arrow_array::{ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType}; diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 12aa1309c552..7fbbaa7a3907 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -31,10 +31,11 @@ //! # use std::sync::Arc; //! # use arrow::array::{Int32Array, Array, ArrayData, make_array}; //! # use arrow::error::Result; -//! # use arrow::compute::kernels::arithmetic; +//! # use arrow_arith::numeric::add; //! # use arrow::ffi::{to_ffi, from_ffi}; //! # fn main() -> Result<()> { //! // create an array natively +//! //! let array = Int32Array::from(vec![Some(1), None, Some(3)]); //! let data = array.into_data(); //! @@ -46,10 +47,10 @@ //! let array = Int32Array::from(data); //! //! // perform some operation -//! let array = arithmetic::add(&array, &array)?; +//! let array = add(&array, &array)?; //! //! // verify -//! assert_eq!(array, Int32Array::from(vec![Some(2), None, Some(6)])); +//! assert_eq!(array.as_ref(), &Int32Array::from(vec![Some(2), None, Some(6)])); //! # //! # Ok(()) //! # } @@ -105,6 +106,8 @@ To export an array, create an `ArrowArray` using [ArrowArray::try_new]. use std::{mem::size_of, ptr::NonNull, sync::Arc}; +pub use arrow_data::ffi::FFI_ArrowArray; +pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags}; use arrow_schema::UnionMode; use crate::array::{layout, ArrayData}; @@ -113,9 +116,6 @@ use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -pub use arrow_data::ffi::FFI_ArrowArray; -pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags}; - // returns the number of bits that buffer `i` (in the C data interface) is expected to have. // This is set by the Arrow specification fn bit_width(data_type: &DataType, i: usize) -> Result { @@ -412,7 +412,16 @@ impl<'a> ArrowArray<'a> { #[cfg(test)] mod tests { - use super::*; + use std::collections::HashMap; + use std::convert::TryFrom; + use std::mem::ManuallyDrop; + use std::ptr::addr_of_mut; + + use arrow_array::builder::UnionBuilder; + use arrow_array::cast::AsArray; + use arrow_array::types::{Float64Type, Int32Type}; + use arrow_array::{StructArray, UnionArray}; + use crate::array::{ make_array, Array, ArrayData, BooleanArray, Decimal128Array, DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, @@ -421,14 +430,8 @@ mod tests { }; use crate::compute::kernels; use crate::datatypes::{Field, Int8Type}; - use arrow_array::builder::UnionBuilder; - use arrow_array::cast::AsArray; - use arrow_array::types::{Float64Type, Int32Type}; - use arrow_array::{StructArray, UnionArray}; - use std::collections::HashMap; - use std::convert::TryFrom; - use std::mem::ManuallyDrop; - use std::ptr::addr_of_mut; + + use super::*; #[test] fn test_round_trip() { @@ -440,10 +443,10 @@ mod tests { // (simulate consumer) import it let array = Int32Array::from(from_ffi(array, &schema).unwrap()); - let array = kernels::arithmetic::add(&array, &array).unwrap(); + let array = kernels::numeric::add(&array, &array).unwrap(); // verify - assert_eq!(array, Int32Array::from(vec![2, 4, 6])); + assert_eq!(array.as_ref(), &Int32Array::from(vec![2, 4, 6])); } #[test] @@ -491,10 +494,10 @@ mod tests { let array = array.as_any().downcast_ref::().unwrap(); assert_eq!(array, &Int32Array::from(vec![Some(2), None])); - let array = kernels::arithmetic::add(array, array).unwrap(); + let array = kernels::numeric::add(array, array).unwrap(); // verify - assert_eq!(array, Int32Array::from(vec![Some(4), None])); + assert_eq!(array.as_ref(), &Int32Array::from(vec![Some(4), None])); // (drop/release) Ok(()) @@ -946,10 +949,10 @@ mod tests { // perform some operation let array = array.as_any().downcast_ref::().unwrap(); - let array = kernels::arithmetic::add(array, array).unwrap(); + let array = kernels::numeric::add(array, array).unwrap(); // verify - assert_eq!(array, Int32Array::from(vec![2, 4, 6])); + assert_eq!(array.as_ref(), &Int32Array::from(vec![2, 4, 6])); Ok(()) } diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs index 5fb1c107350a..865a8d0e0a29 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow/src/ffi_stream.rs @@ -54,6 +54,7 @@ //! } //! ``` +use std::ffi::CStr; use std::ptr::addr_of; use std::{ convert::TryFrom, @@ -119,8 +120,8 @@ unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) { } struct StreamPrivateData { - batch_reader: Box, - last_error: String, + batch_reader: Box, + last_error: Option, } // The callback used to get array schema @@ -142,8 +143,12 @@ unsafe extern "C" fn get_next( // The callback used to get the error from last operation on the `FFI_ArrowArrayStream` unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char { let mut ffi_stream = ExportedArrayStream { stream }; - let last_error = ffi_stream.get_last_error(); - CString::new(last_error.as_str()).unwrap().into_raw() + // The consumer should not take ownership of this string, we should return + // a const pointer to it. + match ffi_stream.get_last_error() { + Some(err_string) => err_string.as_ptr(), + None => std::ptr::null(), + } } impl Drop for FFI_ArrowArrayStream { @@ -157,10 +162,10 @@ impl Drop for FFI_ArrowArrayStream { impl FFI_ArrowArrayStream { /// Creates a new [`FFI_ArrowArrayStream`]. - pub fn new(batch_reader: Box) -> Self { + pub fn new(batch_reader: Box) -> Self { let private_data = Box::new(StreamPrivateData { batch_reader, - last_error: String::new(), + last_error: None, }); Self { @@ -194,7 +199,7 @@ impl ExportedArrayStream { } pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 { - let mut private_data = self.get_private_data(); + let private_data = self.get_private_data(); let reader = &private_data.batch_reader; let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref()); @@ -206,14 +211,17 @@ impl ExportedArrayStream { 0 } Err(ref err) => { - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()) + .expect("Error string has a null byte in it."), + ); get_error_code(err) } } } pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 { - let mut private_data = self.get_private_data(); + let private_data = self.get_private_data(); let reader = &mut private_data.batch_reader; match reader.next() { @@ -231,15 +239,18 @@ impl ExportedArrayStream { 0 } else { let err = &next_batch.unwrap_err(); - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()) + .expect("Error string has a null byte in it."), + ); get_error_code(err) } } } } - pub fn get_last_error(&mut self) -> &String { - &self.get_private_data().last_error + pub fn get_last_error(&mut self) -> Option<&CString> { + self.get_private_data().last_error.as_ref() } } @@ -247,7 +258,7 @@ fn get_error_code(err: &ArrowError) -> i32 { match err { ArrowError::NotYetImplemented(_) => ENOSYS, ArrowError::MemoryError(_) => ENOMEM, - ArrowError::IoError(_) => EIO, + ArrowError::IoError(_, _) => EIO, _ => EINVAL, } } @@ -270,7 +281,7 @@ fn get_stream_schema(stream_ptr: *mut FFI_ArrowArrayStream) -> Result let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, &mut schema) }; if ret_code == 0 { - let schema = Schema::try_from(&schema).unwrap(); + let schema = Schema::try_from(&schema)?; Ok(Arc::new(schema)) } else { Err(ArrowError::CDataInterface(format!( @@ -312,19 +323,15 @@ impl ArrowArrayStreamReader { /// Get the last error from `ArrowArrayStreamReader` fn get_stream_last_error(&mut self) -> Option { - self.stream.get_last_error?; - - let error_str = unsafe { - let c_str = - self.stream.get_last_error.unwrap()(&mut self.stream) as *mut c_char; - CString::from_raw(c_str).into_string() - }; + let get_last_error = self.stream.get_last_error?; - if let Err(err) = error_str { - Some(err.to_string()) - } else { - Some(error_str.unwrap()) + let error_str = unsafe { get_last_error(&mut self.stream) }; + if error_str.is_null() { + return None; } + + let error_str = unsafe { CStr::from_ptr(error_str) }; + Some(error_str.to_string_lossy().to_string()) } } @@ -371,7 +378,7 @@ impl RecordBatchReader for ArrowArrayStreamReader { /// Assumes that the pointer represents valid C Stream Interfaces, both in memory /// representation and lifetime via the `release` mechanism. pub unsafe fn export_reader_into_raw( - reader: Box, + reader: Box, out_stream: *mut FFI_ArrowArrayStream, ) { let stream = FFI_ArrowArrayStream::new(reader); @@ -381,6 +388,8 @@ pub unsafe fn export_reader_into_raw( #[cfg(test)] mod tests { + use arrow_schema::DataType; + use super::*; use crate::array::Int32Array; @@ -388,13 +397,13 @@ mod tests { struct TestRecordBatchReader { schema: SchemaRef, - iter: Box>>, + iter: Box> + Send>, } impl TestRecordBatchReader { pub fn new( schema: SchemaRef, - iter: Box>>, + iter: Box> + Send>, ) -> Box { Box::new(TestRecordBatchReader { schema, iter }) } @@ -503,4 +512,32 @@ mod tests { _test_round_trip_import(vec![array.clone(), array.clone(), array]) } + + #[test] + fn test_error_import() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let iter = + Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter()); + + let reader = TestRecordBatchReader::new(schema.clone(), iter); + + // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader` + let stream = FFI_ArrowArrayStream::new(reader); + let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap(); + + let imported_schema = stream_reader.schema(); + assert_eq!(imported_schema, schema); + + let mut produced_batches = vec![]; + for batch in stream_reader { + produced_batches.push(batch); + } + + // The results should outlive the lifetime of the stream itself. + assert_eq!(produced_batches.len(), 1); + assert!(produced_batches[0].is_err()); + + Ok(()) + } } diff --git a/arrow/src/lib.rs b/arrow/src/lib.rs index bf39bae530b9..f4d0585fa6b5 100644 --- a/arrow/src/lib.rs +++ b/arrow/src/lib.rs @@ -160,7 +160,7 @@ //! //! # Compute Kernels //! -//! The [`compute`](compute) module provides optimised implementations of many common operations, +//! The [`compute`] module provides optimised implementations of many common operations, //! for example the `parse_strings` operation above could also be implemented as follows: //! //! ``` @@ -184,11 +184,11 @@ //! //! This module also implements many common vertical operations: //! -//! * All mathematical binary operators, such as [`subtract`](compute::kernels::arithmetic::subtract) +//! * All mathematical binary operators, such as [`sub`](compute::kernels::numeric::sub) //! * All boolean binary operators such as [`equality`](compute::kernels::comparison::eq) //! * [`cast`](compute::kernels::cast::cast) //! * [`filter`](compute::kernels::filter::filter) -//! * [`take`](compute::kernels::take::take) and [`limit`](compute::kernels::limit::limit) +//! * [`take`](compute::kernels::take::take) //! * [`sort`](compute::kernels::sort::sort) //! * some string operators such as [`substring`](compute::kernels::substring::substring) and [`length`](compute::kernels::length::length) //! @@ -375,7 +375,8 @@ pub mod pyarrow; pub mod record_batch { pub use arrow_array::{ - RecordBatch, RecordBatchOptions, RecordBatchReader, RecordBatchWriter, + RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, + RecordBatchWriter, }; } pub use arrow_array::temporal_conversions; diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 54a247d53e6d..ab0ea8ef8d74 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -15,22 +15,58 @@ // specific language governing permissions and limitations // under the License. -//! Pass Arrow objects from and to Python, using Arrow's +//! Pass Arrow objects from and to PyArrow, using Arrow's //! [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) //! and [pyo3](https://docs.rs/pyo3/latest/pyo3/). //! For underlying implementation, see the [ffi] module. +//! +//! One can use these to write Python functions that take and return PyArrow +//! objects, with automatic conversion to corresponding arrow-rs types. +//! +//! ```ignore +//! #[pyfunction] +//! fn double_array(array: PyArrowType) -> PyResult> { +//! let array = array.0; // Extract from PyArrowType wrapper +//! let array: Arc = make_array(array); // Convert ArrayData to ArrayRef +//! let array: &Int32Array = array.as_any().downcast_ref() +//! .ok_or_else(|| PyValueError::new_err("expected int32 array"))?; +//! let array: Int32Array = array.iter().map(|x| x.map(|x| x * 2)).collect(); +//! Ok(PyArrowType(array.into_data())) +//! } +//! ``` +//! +//! | pyarrow type | arrow-rs type | +//! |-----------------------------|--------------------------------------------------------------------| +//! | `pyarrow.DataType` | [DataType] | +//! | `pyarrow.Field` | [Field] | +//! | `pyarrow.Schema` | [Schema] | +//! | `pyarrow.Array` | [ArrayData] | +//! | `pyarrow.RecordBatch` | [RecordBatch] | +//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box` (1) | +//! +//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either +//! [ArrowArrayStreamReader] or `Box` can be exported +//! as `pyarrow.RecordBatchReader`. (`Box` is typically +//! easier to create.) +//! +//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't +//! have these same concepts. A chunked table is instead represented with +//! `Vec`. A `pyarrow.Table` can be imported to Rust by calling +//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader) +//! and then importing the reader as a [ArrowArrayStreamReader]. use std::convert::{From, TryFrom}; use std::ptr::{addr_of, addr_of_mut}; use std::sync::Arc; +use arrow_array::{RecordBatchIterator, RecordBatchReader}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyTuple}; +use pyo3::types::{PyList, PyTuple}; -use crate::array::{make_array, Array, ArrayData}; +use crate::array::{make_array, ArrayData}; use crate::datatypes::{DataType, Field, Schema}; use crate::error::ArrowError; use crate::ffi; @@ -234,28 +270,16 @@ impl FromPyArrow for RecordBatch { impl ToPyArrow for RecordBatch { fn to_pyarrow(&self, py: Python) -> PyResult { - let mut py_arrays = vec![]; - - let schema = self.schema(); - let columns = self.columns().iter(); - - for array in columns { - py_arrays.push(array.to_data().to_pyarrow(py)?); - } - - let py_schema = schema.to_pyarrow(py)?; - - let module = py.import("pyarrow")?; - let class = module.getattr("RecordBatch")?; - let args = (py_arrays,); - let kwargs = PyDict::new(py); - kwargs.set_item("schema", py_schema)?; - let record = class.call_method("from_arrays", args, Some(kwargs))?; - - Ok(PyObject::from(record)) + // Workaround apache/arrow#37669 by returning RecordBatchIterator + let reader = + RecordBatchIterator::new(vec![Ok(self.clone())], self.schema().clone()); + let reader: Box = Box::new(reader); + let py_reader = reader.into_pyarrow(py)?; + py_reader.call_method0(py, "read_next_batch") } } +/// Supports conversion from `pyarrow.RecordBatchReader` to [ArrowArrayStreamReader]. impl FromPyArrow for ArrowArrayStreamReader { fn from_pyarrow(value: &PyAny) -> PyResult { validate_class("RecordBatchReader", value)?; @@ -277,10 +301,13 @@ impl FromPyArrow for ArrowArrayStreamReader { } } -impl IntoPyArrow for ArrowArrayStreamReader { +/// Convert a [`RecordBatchReader`] into a `pyarrow.RecordBatchReader`. +impl IntoPyArrow for Box { + // We can't implement `ToPyArrow` for `T: RecordBatchReader + Send` because + // there is already a blanket implementation for `T: ToPyArrow`. fn into_pyarrow(self, py: Python) -> PyResult { let mut stream = FFI_ArrowArrayStream::empty(); - unsafe { export_reader_into_raw(Box::new(self), &mut stream) }; + unsafe { export_reader_into_raw(self, &mut stream) }; let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream; let module = py.import("pyarrow")?; @@ -292,18 +319,27 @@ impl IntoPyArrow for ArrowArrayStreamReader { } } -/// A newtype wrapper around a `T: PyArrowConvert` that implements -/// [`FromPyObject`] and [`IntoPy`] allowing usage with pyo3 macros +/// Convert a [`ArrowArrayStreamReader`] into a `pyarrow.RecordBatchReader`. +impl IntoPyArrow for ArrowArrayStreamReader { + fn into_pyarrow(self, py: Python) -> PyResult { + let boxed: Box = Box::new(self); + boxed.into_pyarrow(py) + } +} + +/// A newtype wrapper. When wrapped around a type `T: FromPyArrow`, it +/// implements `FromPyObject` for the PyArrow objects. When wrapped around a +/// `T: IntoPyArrow`, it implements `IntoPy` for the wrapped type. #[derive(Debug)] -pub struct PyArrowType(pub T); +pub struct PyArrowType(pub T); -impl<'source, T: FromPyArrow + IntoPyArrow> FromPyObject<'source> for PyArrowType { +impl<'source, T: FromPyArrow> FromPyObject<'source> for PyArrowType { fn extract(value: &'source PyAny) -> PyResult { Ok(Self(T::from_pyarrow(value)?)) } } -impl IntoPy for PyArrowType { +impl IntoPy for PyArrowType { fn into_py(self, py: Python) -> PyObject { match self.0.into_pyarrow(py) { Ok(obj) => obj, @@ -312,7 +348,7 @@ impl IntoPy for PyArrowType { } } -impl From for PyArrowType { +impl From for PyArrowType { fn from(s: T) -> Self { Self(s) } diff --git a/arrow/src/util/bench_util.rs b/arrow/src/util/bench_util.rs index 9bdc24783736..5e5f4c6ee118 100644 --- a/arrow/src/util/bench_util.rs +++ b/arrow/src/util/bench_util.rs @@ -29,6 +29,7 @@ use rand::{ distributions::{Alphanumeric, Distribution, Standard}, prelude::StdRng, }; +use std::ops::Range; /// Creates an random (but fixed-seeded) array of a given size and null density pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray @@ -268,6 +269,24 @@ pub fn create_dict_from_values( null_density: f32, values: &dyn Array, ) -> DictionaryArray +where + K: ArrowDictionaryKeyType, + Standard: Distribution, + K::Native: SampleUniform, +{ + let min_key = K::Native::from_usize(0).unwrap(); + let max_key = K::Native::from_usize(values.len()).unwrap(); + create_sparse_dict_from_values(size, null_density, values, min_key..max_key) +} + +/// Creates a random (but fixed-seeded) dictionary array of a given size and null density +/// with the provided values array and key range +pub fn create_sparse_dict_from_values( + size: usize, + null_density: f32, + values: &dyn Array, + key_range: Range, +) -> DictionaryArray where K: ArrowDictionaryKeyType, Standard: Distribution, @@ -279,9 +298,9 @@ where Box::new(values.data_type().clone()), ); - let min_key = K::Native::from_usize(0).unwrap(); - let max_key = K::Native::from_usize(values.len()).unwrap(); - let keys: Buffer = (0..size).map(|_| rng.gen_range(min_key..max_key)).collect(); + let keys: Buffer = (0..size) + .map(|_| rng.gen_range(key_range.clone())) + .collect(); let nulls: Option = (null_density != 0.) .then(|| (0..size).map(|_| rng.gen_bool(null_density as _)).collect()); diff --git a/arrow/tests/arithmetic.rs b/arrow/tests/arithmetic.rs new file mode 100644 index 000000000000..982420902cc3 --- /dev/null +++ b/arrow/tests/arithmetic.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_arith::numeric::{add, sub}; +use arrow_arith::temporal::hour; +use arrow_array::cast::AsArray; +use arrow_array::temporal_conversions::as_datetime_with_timezone; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use chrono::{DateTime, TimeZone}; + +#[test] +fn test_temporal_array_timestamp_hour_with_timezone_using_chrono_tz() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]) + .with_timezone("Asia/Kolkata".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(15, b.value(0)); +} + +#[test] +fn test_temporal_array_timestamp_hour_with_dst_timezone_using_chrono_tz() { + // + // 1635577147 converts to 2021-10-30 17:59:07 in time zone Australia/Sydney (AEDT) + // The offset (difference to UTC) is +11:00. Note that daylight savings is in effect on 2021-10-30. + // When daylight savings is not in effect, Australia/Sydney has an offset difference of +10:00. + + let a = TimestampMillisecondArray::from(vec![Some(1635577147000)]) + .with_timezone("Australia/Sydney".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(17, b.value(0)); +} + +fn test_timestamp_with_timezone_impl(tz_str: &str) { + let tz: Tz = tz_str.parse().unwrap(); + + let transform_array = |x: &dyn Array| -> Vec> { + x.as_primitive::() + .values() + .into_iter() + .map(|x| as_datetime_with_timezone::(*x, tz).unwrap()) + .collect() + }; + + let values = vec![ + tz.with_ymd_and_hms(1970, 1, 28, 23, 0, 0) + .unwrap() + .naive_utc(), + tz.with_ymd_and_hms(1970, 1, 1, 0, 0, 0) + .unwrap() + .naive_utc(), + tz.with_ymd_and_hms(2010, 4, 1, 4, 0, 20) + .unwrap() + .naive_utc(), + tz.with_ymd_and_hms(1960, 1, 30, 4, 23, 20) + .unwrap() + .naive_utc(), + tz.with_ymd_and_hms(2023, 3, 25, 14, 0, 0) + .unwrap() + .naive_utc(), + ] + .into_iter() + .map(|x| T::make_value(x).unwrap()) + .collect(); + + let a = PrimitiveArray::::new(values, None).with_timezone(tz_str); + + // IntervalYearMonth + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(0, 1), + IntervalYearMonthType::make_value(5, 34), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(7, -4), + IntervalYearMonthType::make_value(0, 1), + ]); + let r1 = add(&a, &b).unwrap(); + assert_eq!( + &transform_array(r1.as_ref()), + &[ + tz.with_ymd_and_hms(1970, 2, 28, 23, 0, 0).unwrap(), + tz.with_ymd_and_hms(1977, 11, 1, 0, 0, 0).unwrap(), + tz.with_ymd_and_hms(2008, 8, 1, 4, 0, 20).unwrap(), + tz.with_ymd_and_hms(1966, 9, 30, 4, 23, 20).unwrap(), + tz.with_ymd_and_hms(2023, 4, 25, 14, 0, 0).unwrap(), + ] + ); + + let r2 = sub(&r1, &b).unwrap(); + assert_eq!(r2.as_ref(), &a); + + // IntervalDayTime + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(0, 0), + IntervalDayTimeType::make_value(5, 454000), + IntervalDayTimeType::make_value(-34, 0), + IntervalDayTimeType::make_value(7, -4000), + IntervalDayTimeType::make_value(1, 0), + ]); + let r3 = add(&a, &b).unwrap(); + assert_eq!( + &transform_array(r3.as_ref()), + &[ + tz.with_ymd_and_hms(1970, 1, 28, 23, 0, 0).unwrap(), + tz.with_ymd_and_hms(1970, 1, 6, 0, 7, 34).unwrap(), + tz.with_ymd_and_hms(2010, 2, 26, 4, 0, 20).unwrap(), + tz.with_ymd_and_hms(1960, 2, 6, 4, 23, 16).unwrap(), + tz.with_ymd_and_hms(2023, 3, 26, 14, 0, 0).unwrap(), + ] + ); + + let r4 = sub(&r3, &b).unwrap(); + assert_eq!(r4.as_ref(), &a); + + // IntervalMonthDayNano + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(1, 0, 0), + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000), + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000), + IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000), + IntervalMonthDayNanoType::make_value(1, 0, 0), + ]); + let r5 = add(&a, &b).unwrap(); + assert_eq!( + &transform_array(r5.as_ref()), + &[ + tz.with_ymd_and_hms(1970, 2, 28, 23, 0, 0).unwrap(), + tz.with_ymd_and_hms(1998, 10, 4, 23, 59, 17).unwrap(), + tz.with_ymd_and_hms(1960, 9, 29, 4, 0, 33).unwrap(), + tz.with_ymd_and_hms(1960, 7, 2, 4, 31, 33).unwrap(), + tz.with_ymd_and_hms(2023, 4, 25, 14, 0, 0).unwrap(), + ] + ); + + let r6 = sub(&r5, &b).unwrap(); + assert_eq!( + &transform_array(r6.as_ref()), + &[ + tz.with_ymd_and_hms(1970, 1, 28, 23, 0, 0).unwrap(), + tz.with_ymd_and_hms(1970, 1, 2, 0, 0, 0).unwrap(), + tz.with_ymd_and_hms(2010, 4, 2, 4, 0, 20).unwrap(), + tz.with_ymd_and_hms(1960, 1, 31, 4, 23, 20).unwrap(), + tz.with_ymd_and_hms(2023, 3, 25, 14, 0, 0).unwrap(), + ] + ); +} + +#[test] +fn test_timestamp_with_offset_timezone() { + let timezones = ["+00:00", "+01:00", "-01:00", "+03:30"]; + for timezone in timezones { + test_timestamp_with_timezone_impl::(timezone); + test_timestamp_with_timezone_impl::(timezone); + test_timestamp_with_timezone_impl::(timezone); + test_timestamp_with_timezone_impl::(timezone); + } +} + +#[test] +fn test_timestamp_with_timezone() { + let timezones = [ + "Europe/Paris", + "Europe/London", + "Africa/Bamako", + "America/Dominica", + "Asia/Seoul", + "Asia/Shanghai", + ]; + for timezone in timezones { + test_timestamp_with_timezone_impl::(timezone); + test_timestamp_with_timezone_impl::(timezone); + test_timestamp_with_timezone_impl::(timezone); + test_timestamp_with_timezone_impl::(timezone); + } +} diff --git a/arrow/tests/array_equal.rs b/arrow/tests/array_equal.rs index 83a280db67b8..317287c102f2 100644 --- a/arrow/tests/array_equal.rs +++ b/arrow/tests/array_equal.rs @@ -399,7 +399,7 @@ fn test_empty_offsets_list_equal() { true, )))) .len(0) - .add_buffer(Buffer::from(vec![0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) + .add_buffer(Buffer::from([0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) .add_child_data(Int32Array::from(vec![1, 2, -1, -2, 3, 4, -3, -4]).into_data()) .null_bit_buffer(Some(Buffer::from(vec![0b00001001]))) .build() @@ -437,7 +437,7 @@ fn test_list_null() { true, )))) .len(6) - .add_buffer(Buffer::from(vec![0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) + .add_buffer(Buffer::from([0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) .add_child_data(c_values.into_data()) .null_bit_buffer(Some(Buffer::from(vec![0b00001001]))) .build() @@ -460,7 +460,7 @@ fn test_list_null() { true, )))) .len(6) - .add_buffer(Buffer::from(vec![0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) + .add_buffer(Buffer::from([0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) .add_child_data(d_values.into_data()) .null_bit_buffer(Some(Buffer::from(vec![0b00001001]))) .build() @@ -1295,3 +1295,25 @@ fn test_struct_equal_slice() { test_equal(&a, &b, true); } + +#[test] +fn test_list_excess_children_equal() { + let mut a = ListBuilder::new(FixedSizeBinaryBuilder::new(5)); + a.values().append_value(b"11111").unwrap(); // Masked value + a.append_null(); + a.values().append_value(b"22222").unwrap(); + a.values().append_null(); + a.append(true); + let a = a.finish(); + + let mut b = ListBuilder::new(FixedSizeBinaryBuilder::new(5)); + b.append_null(); + b.values().append_value(b"22222").unwrap(); + b.values().append_null(); + b.append(true); + let b = b.finish(); + + assert_eq!(a.value_offsets(), &[0, 1, 3]); + assert_eq!(b.value_offsets(), &[0, 0, 2]); + assert_eq!(a, b); +} diff --git a/arrow/tests/array_transform.rs b/arrow/tests/array_transform.rs index ebbadc00aecd..15141eb208e4 100644 --- a/arrow/tests/array_transform.rs +++ b/arrow/tests/array_transform.rs @@ -19,7 +19,7 @@ use arrow::array::{ Array, ArrayRef, BooleanArray, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Int16Array, Int32Array, Int64Array, Int64Builder, ListArray, ListBuilder, MapBuilder, NullArray, StringArray, StringBuilder, - StringDictionaryBuilder, StructArray, UInt8Array, + StringDictionaryBuilder, StructArray, UInt8Array, UnionArray, }; use arrow::datatypes::Int16Type; use arrow_buffer::Buffer; @@ -488,6 +488,63 @@ fn test_struct_many() { assert_eq!(array, expected) } +#[test] +fn test_union_dense() { + // Input data + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + Some("mark"), + Some("doe"), + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])); + let offsets = Buffer::from_slice_ref([0, 0, 1, 1, 2, 2, 3, 4i32]); + let type_ids = Buffer::from_slice_ref([42, 84, 42, 84, 84, 42, 84, 84i8]); + + let array = UnionArray::try_new( + &[84, 42], + type_ids, + Some(offsets), + vec![ + (Field::new("int", DataType::Int32, false), ints), + (Field::new("string", DataType::Utf8, false), strings), + ], + ) + .unwrap() + .into_data(); + let arrays = vec![&array]; + let mut mutable = MutableArrayData::new(arrays, false, 0); + + // Slice it by `MutableArrayData` + mutable.extend(0, 4, 7); + let data = mutable.freeze(); + let array = UnionArray::from(data); + + // Expected data + let strings: ArrayRef = Arc::new(StringArray::from(vec![Some("doe")])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(3), Some(4)])); + let offsets = Buffer::from_slice_ref([0, 0, 1i32]); + let type_ids = Buffer::from_slice_ref([84, 42, 84i8]); + + let expected = UnionArray::try_new( + &[84, 42], + type_ids, + Some(offsets), + vec![ + (Field::new("int", DataType::Int32, false), ints), + (Field::new("string", DataType::Utf8, false), strings), + ], + ) + .unwrap(); + + assert_eq!(array.to_data(), expected.to_data()); +} + #[test] fn test_binary_fixed_sized_offsets() { let array = FixedSizeBinaryArray::try_from_iter( diff --git a/arrow/tests/array_validation.rs b/arrow/tests/array_validation.rs index 0d3652a0473a..fa80db1860cd 100644 --- a/arrow/tests/array_validation.rs +++ b/arrow/tests/array_validation.rs @@ -56,7 +56,9 @@ fn test_bad_number_of_buffers() { } #[test] -#[should_panic(expected = "integer overflow computing min buffer size")] +#[should_panic( + expected = "Need at least 18446744073709551615 bytes in buffers[0] in array of type Int64, but got 8" +)] fn test_fixed_width_overflow() { let buffer = Buffer::from_slice_ref([0i32, 2i32]); ArrayData::try_new(DataType::Int64, usize::MAX, None, 0, vec![buffer], vec![]) diff --git a/arrow/tests/csv.rs b/arrow/tests/csv.rs index 3ee319101757..a79b6b44c2d3 100644 --- a/arrow/tests/csv.rs +++ b/arrow/tests/csv.rs @@ -53,48 +53,6 @@ fn test_export_csv_timestamps() { } drop(writer); - let left = "c1,c2 -2019-04-18T20:54:47.378000000+10:00,2019-04-18T10:54:47.378000000 -2021-10-30T17:59:07.000000000+11:00,2021-10-30T06:59:07.000000000\n"; - let right = String::from_utf8(sw).unwrap(); - assert_eq!(left, right); -} - -#[test] -fn test_export_csv_timestamps_using_rfc3339() { - let schema = Schema::new(vec![ - Field::new( - "c1", - DataType::Timestamp(TimeUnit::Millisecond, Some("Australia/Sydney".into())), - true, - ), - Field::new("c2", DataType::Timestamp(TimeUnit::Millisecond, None), true), - ]); - - let c1 = TimestampMillisecondArray::from( - // 1555584887 converts to 2019-04-18, 20:54:47 in time zone Australia/Sydney (AEST). - // The offset (difference to UTC) is +10:00. - // 1635577147 converts to 2021-10-30 17:59:07 in time zone Australia/Sydney (AEDT) - // The offset (difference to UTC) is +11:00. Note that daylight savings is in effect on 2021-10-30. - // - vec![Some(1555584887378), Some(1635577147000)], - ) - .with_timezone("Australia/Sydney"); - let c2 = - TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); - - let mut sw = Vec::new(); - let mut writer = arrow_csv::WriterBuilder::new() - .with_rfc3339() - .build(&mut sw); - let batches = vec![&batch]; - for batch in batches { - writer.write(batch).unwrap(); - } - drop(writer); - let left = "c1,c2 2019-04-18T20:54:47.378+10:00,2019-04-18T10:54:47.378 2021-10-30T17:59:07+11:00,2021-10-30T06:59:07\n"; diff --git a/dev/release/README.md b/dev/release/README.md index 30b3a4a8a569..177f33bcbb4d 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -258,6 +258,7 @@ Rust Arrow Crates: (cd arrow-ipc && cargo publish) (cd arrow-csv && cargo publish) (cd arrow-json && cargo publish) +(cd arrow-avro && cargo publish) (cd arrow-ord && cargo publish) (cd arrow-arith && cargo publish) (cd arrow-string && cargo publish) diff --git a/dev/release/update_change_log.sh b/dev/release/update_change_log.sh index 6b4b0a56c4bc..c1627ebb8cf2 100755 --- a/dev/release/update_change_log.sh +++ b/dev/release/update_change_log.sh @@ -29,8 +29,8 @@ set -e -SINCE_TAG="42.0.0" -FUTURE_RELEASE="43.0.0" +SINCE_TAG="47.0.0" +FUTURE_RELEASE="48.0.0" SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" diff --git a/object_store/CHANGELOG-old.md b/object_store/CHANGELOG-old.md index 3880205bc05e..a0ced7c8d21e 100644 --- a/object_store/CHANGELOG-old.md +++ b/object_store/CHANGELOG-old.md @@ -19,6 +19,53 @@ # Historical Changelog +## [object_store_0.7.0](https://github.com/apache/arrow-rs/tree/object_store_0.7.0) (2023-08-15) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.6.1...object_store_0.7.0) + +**Breaking changes:** + +- Add range and ObjectMeta to GetResult \(\#4352\) \(\#4495\) [\#4677](https://github.com/apache/arrow-rs/pull/4677) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Add AzureConfigKey::ContainerName [\#4629](https://github.com/apache/arrow-rs/issues/4629) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: multipart ranges for HTTP [\#4612](https://github.com/apache/arrow-rs/issues/4612) +- Make object\_store::multipart public [\#4569](https://github.com/apache/arrow-rs/issues/4569) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Export `ClientConfigKey` and make the `HttpBuilder` more consistent with other builders [\#4515](https://github.com/apache/arrow-rs/issues/4515) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store/InMemory: Make `clone()` non-async [\#4496](https://github.com/apache/arrow-rs/issues/4496) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Add Range to GetResult::File [\#4352](https://github.com/apache/arrow-rs/issues/4352) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support copy\_if\_not\_exists for Cloudflare R2 \(S3 API\) [\#4190](https://github.com/apache/arrow-rs/issues/4190) + +**Fixed bugs:** + +- object\_store documentation is broken [\#4683](https://github.com/apache/arrow-rs/issues/4683) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Exports are not sufficient for configuring some object stores, for example minio running locally [\#4530](https://github.com/apache/arrow-rs/issues/4530) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Uploading empty file to S3 results in "411 Length Required" [\#4514](https://github.com/apache/arrow-rs/issues/4514) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- GCP doesn't fetch public objects [\#4417](https://github.com/apache/arrow-rs/issues/4417) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Closed issues:** + +- \[object\_store\] when Create a AmazonS3 instance work with MinIO without set endpoint got error MissingRegion [\#4617](https://github.com/apache/arrow-rs/issues/4617) +- AWS Profile credentials no longer working in object\_store 0.6.1 [\#4556](https://github.com/apache/arrow-rs/issues/4556) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Merged pull requests:** + +- Add AzureConfigKey::ContainerName \(\#4629\) [\#4686](https://github.com/apache/arrow-rs/pull/4686) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix MSRV CI [\#4671](https://github.com/apache/arrow-rs/pull/4671) ([tustvold](https://github.com/tustvold)) +- Use Config System for Object Store Integration Tests [\#4628](https://github.com/apache/arrow-rs/pull/4628) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Prepare arrow 45 [\#4590](https://github.com/apache/arrow-rs/pull/4590) ([tustvold](https://github.com/tustvold)) +- Add Support for Microsoft Fabric / OneLake [\#4573](https://github.com/apache/arrow-rs/pull/4573) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([vmuddassir-msft](https://github.com/vmuddassir-msft)) +- Cleanup multipart upload trait [\#4572](https://github.com/apache/arrow-rs/pull/4572) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Make object\_store::multipart public [\#4570](https://github.com/apache/arrow-rs/pull/4570) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([yjshen](https://github.com/yjshen)) +- Handle empty S3 payloads \(\#4514\) [\#4518](https://github.com/apache/arrow-rs/pull/4518) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- object\_store: Export `ClientConfigKey` and add `HttpBuilder::with_config` [\#4516](https://github.com/apache/arrow-rs/pull/4516) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([thehabbos007](https://github.com/thehabbos007)) +- object\_store: Implement `ObjectStore` for `Arc` [\#4502](https://github.com/apache/arrow-rs/pull/4502) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) +- object\_store/InMemory: Add `fork()` fn and deprecate `clone()` fn [\#4499](https://github.com/apache/arrow-rs/pull/4499) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([Turbo87](https://github.com/Turbo87)) +- Bump actions/deploy-pages from 1 to 2 [\#4449](https://github.com/apache/arrow-rs/pull/4449) ([dependabot[bot]](https://github.com/apps/dependabot)) +- gcp: Exclude authorization header when bearer empty [\#4418](https://github.com/apache/arrow-rs/pull/4418) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([vrongmeal](https://github.com/vrongmeal)) +- Support copy\_if\_not\_exists for Cloudflare R2 \(\#4190\) [\#4239](https://github.com/apache/arrow-rs/pull/4239) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + ## [object_store_0.6.0](https://github.com/apache/arrow-rs/tree/object_store_0.6.0) (2023-05-18) [Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.5.6...object_store_0.6.0) diff --git a/object_store/CHANGELOG.md b/object_store/CHANGELOG.md index fe25e23fb768..1f069ce41eac 100644 --- a/object_store/CHANGELOG.md +++ b/object_store/CHANGELOG.md @@ -19,30 +19,49 @@ # Changelog -## [object_store_0.6.1](https://github.com/apache/arrow-rs/tree/object_store_0.6.1) (2023-06-02) +## [object_store_0.7.1](https://github.com/apache/arrow-rs/tree/object_store_0.7.1) (2023-09-26) -[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.6.0...object_store_0.6.1) +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.7.0...object_store_0.7.1) **Implemented enhancements:** -- Support multipart upload in R2 [\#4304](https://github.com/apache/arrow-rs/issues/4304) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Automatically Cleanup LocalFileSystem Temporary Files [\#4778](https://github.com/apache/arrow-rs/issues/4778) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object-store: Expose an async reader API for object store [\#4762](https://github.com/apache/arrow-rs/issues/4762) +- Improve proxy support by using reqwest::Proxy as configuration [\#4713](https://github.com/apache/arrow-rs/issues/4713) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] **Fixed bugs:** -- Default ObjectStore::get\_range Doesn't Apply Range to GetResult::File [\#4350](https://github.com/apache/arrow-rs/issues/4350) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object-store: http shouldn't perform range requests unless `accept-ranges: bytes` header is present [\#4839](https://github.com/apache/arrow-rs/issues/4839) +- object-store: http-store fails when url doesn't have last-modified header on 0.7.0 [\#4831](https://github.com/apache/arrow-rs/issues/4831) +- object-store fails to compile for `wasm32-unknown-unknown` with `http` feature [\#4776](https://github.com/apache/arrow-rs/issues/4776) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object-store: could not find `header` in `client` for `http` feature [\#4775](https://github.com/apache/arrow-rs/issues/4775) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- LocalFileSystem Copy and Rename Don't Create Intermediate Directories [\#4760](https://github.com/apache/arrow-rs/issues/4760) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- LocalFileSystem Copy is not Atomic [\#4758](https://github.com/apache/arrow-rs/issues/4758) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] **Closed issues:** -- \[object\_store - AmazonS3Builder\] incorrect metadata\_endpoint set in `from_env` in an ECS environment [\#4283](https://github.com/apache/arrow-rs/issues/4283) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store Azure Government Cloud functionality? [\#4853](https://github.com/apache/arrow-rs/issues/4853) **Merged pull requests:** -- Fix ObjectStore::get\_range for GetResult::File \(\#4350\) [\#4351](https://github.com/apache/arrow-rs/pull/4351) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) -- Don't exclude FIFO files from LocalFileSystem [\#4345](https://github.com/apache/arrow-rs/pull/4345) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) -- Fix support for ECS IAM credentials [\#4310](https://github.com/apache/arrow-rs/pull/4310) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) -- feat: use exactly equal parts in multipart upload [\#4305](https://github.com/apache/arrow-rs/pull/4305) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) -- Set ECS specific metadata endpoint [\#4288](https://github.com/apache/arrow-rs/pull/4288) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jfuechsl](https://github.com/jfuechsl)) -- Prepare 40.0.0 release [\#4245](https://github.com/apache/arrow-rs/pull/4245) ([tustvold](https://github.com/tustvold)) -- feat: support bulk deletes in object\_store [\#4060](https://github.com/apache/arrow-rs/pull/4060) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([wjones127](https://github.com/wjones127)) +- Add ObjectStore BufReader \(\#4762\) [\#4857](https://github.com/apache/arrow-rs/pull/4857) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Allow overriding azure endpoint [\#4854](https://github.com/apache/arrow-rs/pull/4854) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Minor: Improve object\_store docs.rs landing page [\#4849](https://github.com/apache/arrow-rs/pull/4849) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Error if Remote Ignores HTTP Range Header [\#4841](https://github.com/apache/arrow-rs/pull/4841) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([universalmind303](https://github.com/universalmind303)) +- Perform HEAD request for HttpStore::head [\#4837](https://github.com/apache/arrow-rs/pull/4837) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- fix: object store http header last modified [\#4834](https://github.com/apache/arrow-rs/pull/4834) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([universalmind303](https://github.com/universalmind303)) +- Prepare arrow 47.0.0 [\#4827](https://github.com/apache/arrow-rs/pull/4827) ([tustvold](https://github.com/tustvold)) +- ObjectStore Wasm32 Fixes \(\#4775\) \(\#4776\) [\#4796](https://github.com/apache/arrow-rs/pull/4796) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Best effort cleanup of staged upload files \(\#4778\) [\#4792](https://github.com/apache/arrow-rs/pull/4792) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Relaxing type bounds on coalesce\_ranges and collect\_bytes [\#4787](https://github.com/apache/arrow-rs/pull/4787) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([sumerman](https://github.com/sumerman)) +- Update object\_store chrono deprecations [\#4786](https://github.com/apache/arrow-rs/pull/4786) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Make coalesce\_ranges and collect\_bytes available for crate users [\#4784](https://github.com/apache/arrow-rs/pull/4784) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([sumerman](https://github.com/sumerman)) +- Bump actions/checkout from 3 to 4 [\#4767](https://github.com/apache/arrow-rs/pull/4767) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Make ObjectStore::copy Atomic and Automatically Create Parent Directories \(\#4758\) \(\#4760\) [\#4759](https://github.com/apache/arrow-rs/pull/4759) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Update nix requirement from 0.26.1 to 0.27.1 in /object\_store [\#4744](https://github.com/apache/arrow-rs/pull/4744) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([viirya](https://github.com/viirya)) +- Add `with_proxy_ca_certificate` and `with_proxy_excludes` [\#4714](https://github.com/apache/arrow-rs/pull/4714) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([gordonwang0](https://github.com/gordonwang0)) +- Update object\_store Dependencies and Configure Dependabot [\#4700](https://github.com/apache/arrow-rs/pull/4700) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) + + \* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index 5e2009d07013..7928648d170f 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -17,13 +17,14 @@ [package] name = "object_store" -version = "0.6.1" +version = "0.7.1" edition = "2021" license = "MIT/Apache-2.0" readme = "README.md" description = "A generic object store interface for uniformly interacting with AWS S3, Google Cloud Storage, Azure Blob Storage and local files." keywords = ["object", "storage", "cloud"] repository = "https://github.com/apache/arrow-rs/tree/master/object_store" +rust-version = "1.62.1" [package.metadata.docs.rs] all-features = true @@ -31,10 +32,10 @@ all-features = true [dependencies] # In alphabetical order async-trait = "0.1.53" bytes = "1.0" -chrono = { version = "0.4.23", default-features = false, features = ["clock"] } +chrono = { version = "0.4.31", default-features = false, features = ["clock"] } futures = "0.3" humantime = "2.1" -itertools = "0.10.1" +itertools = "0.11.0" parking_lot = { version = "0.12" } percent-encoding = "2.1" snafu = "0.7" @@ -45,12 +46,12 @@ walkdir = "2" # Cloud storage support base64 = { version = "0.21", default-features = false, features = ["std"], optional = true } hyper = { version = "0.14", default-features = false, optional = true } -quick-xml = { version = "0.28.0", features = ["serialize", "overlapped-lists"], optional = true } +quick-xml = { version = "0.30.0", features = ["serialize", "overlapped-lists"], optional = true } serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } serde_json = { version = "1.0", default-features = false, optional = true } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"], optional = true } -ring = { version = "0.16", default-features = false, features = ["std"], optional = true } +ring = { version = "0.17", default-features = false, features = ["std"], optional = true } rustls-pemfile = { version = "1.0", default-features = false, optional = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] @@ -60,7 +61,7 @@ tokio = { version = "1.25.0", features = ["sync", "macros", "rt", "time", "io-ut tokio = { version = "1.25.0", features = ["sync", "macros", "rt", "time", "io-util"] } [target.'cfg(target_family="unix")'.dev-dependencies] -nix = "0.26.1" +nix = { version = "0.27.1", features = ["fs"] } [features] cloud = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "base64", "rand", "ring"] @@ -70,7 +71,6 @@ aws = ["cloud"] http = ["cloud"] [dev-dependencies] # In alphabetical order -dotenv = "0.15.0" tempfile = "3.1.0" futures-test = "0.3" rand = "0.8" diff --git a/object_store/README.md b/object_store/README.md index 5b47a65c124f..fd09ec7205af 100644 --- a/object_store/README.md +++ b/object_store/README.md @@ -39,7 +39,7 @@ See [docs.rs](https://docs.rs/object_store) for usage instructions ## Support for `wasm32-unknown-unknown` target -It's possible to build `object_store` for the `wasm32-unknown-unknown` target, however the cloud storage features `aws`, `azure`, and `gcp` are not supported. +It's possible to build `object_store` for the `wasm32-unknown-unknown` target, however the cloud storage features `aws`, `azure`, `gcp`, and `http` are not supported. ``` cargo build -p object_store --target wasm32-unknown-unknown diff --git a/object_store/dev/release/update_change_log.sh b/object_store/dev/release/update_change_log.sh index 3e9f8bdba859..aeec3caf4f57 100755 --- a/object_store/dev/release/update_change_log.sh +++ b/object_store/dev/release/update_change_log.sh @@ -29,8 +29,8 @@ set -e -SINCE_TAG="object_store_0.6.0" -FUTURE_RELEASE="object_store_0.6.1" +SINCE_TAG="object_store_0.7.0" +FUTURE_RELEASE="object_store_0.7.1" SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 0c2493651000..8199510d0489 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -17,13 +17,16 @@ use crate::aws::checksum::Checksum; use crate::aws::credential::{AwsCredential, CredentialExt}; -use crate::aws::{AwsCredentialProvider, STORE, STRICT_PATH_ENCODE_SET}; +use crate::aws::{ + AwsCredentialProvider, S3CopyIfNotExists, STORE, STRICT_ENCODE_SET, + STRICT_PATH_ENCODE_SET, +}; use crate::client::get::GetClient; use crate::client::list::ListClient; use crate::client::list_response::ListResponse; use crate::client::retry::RetryExt; use crate::client::GetOptionsExt; -use crate::multipart::UploadPart; +use crate::multipart::PartId; use crate::path::DELIMITER; use crate::{ ClientOptions, GetOptions, ListResult, MultipartId, Path, Result, RetryConfig, @@ -33,11 +36,15 @@ use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::{Buf, Bytes}; use itertools::Itertools; -use percent_encoding::{utf8_percent_encode, PercentEncode}; +use percent_encoding::{percent_encode, utf8_percent_encode, PercentEncode}; use quick_xml::events::{self as xml_events}; -use reqwest::{header::CONTENT_TYPE, Client as ReqwestClient, Method, Response}; +use reqwest::{ + header::{CONTENT_LENGTH, CONTENT_TYPE}, + Client as ReqwestClient, Method, Response, StatusCode, +}; use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu}; +use std::collections::HashMap; use std::sync::Arc; /// A specialized `Error` for object store-related errors @@ -202,11 +209,13 @@ pub struct S3Config { pub retry_config: RetryConfig, pub client_options: ClientOptions, pub sign_payload: bool, + pub skip_signature: bool, pub checksum: Option, + pub copy_if_not_exists: Option, } impl S3Config { - fn path_url(&self, path: &Path) -> String { + pub(crate) fn path_url(&self, path: &Path) -> String { format!("{}/{}", self.bucket_endpoint, encode_path(path)) } } @@ -217,6 +226,8 @@ pub(crate) struct S3Client { client: ReqwestClient, } +const TAGGING_HEADER: &str = "x-amz-tagging"; + impl S3Client { pub fn new(config: S3Config) -> Result { let client = config.client_options.client()?; @@ -228,42 +239,62 @@ impl S3Client { &self.config } - async fn get_credential(&self) -> Result> { - self.config.credentials.get_credential().await + async fn get_credential(&self) -> Result>> { + Ok(match self.config.skip_signature { + false => Some(self.config.credentials.get_credential().await?), + true => None, + }) } /// Make an S3 PUT request pub async fn put_request( &self, path: &Path, - bytes: Option, + bytes: Bytes, query: &T, + tags: Option<&HashMap>, ) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); let mut builder = self.client.request(Method::PUT, url); let mut payload_sha256 = None; - if let Some(bytes) = bytes { - if let Some(checksum) = self.config().checksum { - let digest = checksum.digest(&bytes); - builder = builder - .header(checksum.header_name(), BASE64_STANDARD.encode(&digest)); - if checksum == Checksum::SHA256 { - payload_sha256 = Some(digest); - } + if let Some(checksum) = self.config().checksum { + let digest = checksum.digest(&bytes); + builder = + builder.header(checksum.header_name(), BASE64_STANDARD.encode(&digest)); + if checksum == Checksum::SHA256 { + payload_sha256 = Some(digest); } - builder = builder.body(bytes); } + builder = match bytes.is_empty() { + true => builder.header(CONTENT_LENGTH, 0), // Handle empty uploads (#4514) + false => builder.body(bytes), + }; + if let Some(value) = self.config().client_options.get_content_type(path) { builder = builder.header(CONTENT_TYPE, value); } + if let Some(tags) = tags { + let tags = tags + .iter() + .map(|(key, value)| { + let key = + percent_encode(key.as_bytes(), &STRICT_ENCODE_SET).to_string(); + let value = + percent_encode(value.as_bytes(), &STRICT_ENCODE_SET).to_string(); + format!("{key}={value}") + }) + .join("&"); + builder = builder.header(TAGGING_HEADER, tags); + } + let response = builder .query(query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -291,7 +322,7 @@ impl S3Client { .request(Method::DELETE, url) .query(query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -382,7 +413,7 @@ impl S3Client { .header(CONTENT_TYPE, "application/xml") .body(body) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -419,16 +450,39 @@ impl S3Client { } /// Make an S3 Copy request - pub async fn copy_request(&self, from: &Path, to: &Path) -> Result<()> { + pub async fn copy_request( + &self, + from: &Path, + to: &Path, + overwrite: bool, + ) -> Result<()> { let credential = self.get_credential().await?; let url = self.config.path_url(to); let source = format!("{}/{}", self.config.bucket, encode_path(from)); - self.client + let mut builder = self + .client .request(Method::PUT, url) - .header("x-amz-copy-source", source) + .header("x-amz-copy-source", source); + + if !overwrite { + match &self.config.copy_if_not_exists { + Some(S3CopyIfNotExists::Header(k, v)) => { + builder = builder.header(k, v); + } + None => { + return Err(crate::Error::NotSupported { + source: "S3 does not support copy-if-not-exists" + .to_string() + .into(), + }) + } + } + } + + builder .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -436,8 +490,16 @@ impl S3Client { ) .send_retry(&self.config.retry_config) .await - .context(CopyRequestSnafu { - path: from.as_ref(), + .map_err(|source| match source.status() { + Some(StatusCode::PRECONDITION_FAILED) => crate::Error::AlreadyExists { + source: Box::new(source), + path: to.to_string(), + }, + _ => Error::CopyRequest { + source, + path: from.to_string(), + } + .into(), })?; Ok(()) @@ -451,7 +513,7 @@ impl S3Client { .client .request(Method::POST, url) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -474,7 +536,7 @@ impl S3Client { &self, location: &Path, upload_id: &str, - parts: Vec, + parts: Vec, ) -> Result<()> { let parts = parts .into_iter() @@ -496,7 +558,7 @@ impl S3Client { .query(&[("uploadId", upload_id)]) .body(body) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -515,15 +577,10 @@ impl GetClient for S3Client { const STORE: &'static str = STORE; /// Make an S3 GET request - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -533,7 +590,7 @@ impl GetClient for S3Client { let response = builder .with_get_options(options) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -587,7 +644,7 @@ impl ListClient for S3Client { .request(Method::GET, &url) .query(&query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, diff --git a/object_store/src/aws/copy.rs b/object_store/src/aws/copy.rs new file mode 100644 index 000000000000..da4e2809be1a --- /dev/null +++ b/object_store/src/aws/copy.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::config::Parse; + +/// Configure how to provide [`ObjectStore::copy_if_not_exists`] for +/// [`AmazonS3`]. +/// +/// [`ObjectStore::copy_if_not_exists`]: crate::ObjectStore::copy_if_not_exists +/// [`AmazonS3`]: super::AmazonS3 +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum S3CopyIfNotExists { + /// Some S3-compatible stores, such as Cloudflare R2, support copy if not exists + /// semantics through custom headers. + /// + /// If set, [`ObjectStore::copy_if_not_exists`] will perform a normal copy operation + /// with the provided header pair, and expect the store to fail with `412 Precondition Failed` + /// if the destination file already exists + /// + /// Encoded as `header::` ignoring whitespace + /// + /// For example `header: cf-copy-destination-if-none-match: *`, would set + /// the header `cf-copy-destination-if-none-match` to `*` + /// + /// [`ObjectStore::copy_if_not_exists`]: crate::ObjectStore::copy_if_not_exists + Header(String, String), +} + +impl std::fmt::Display for S3CopyIfNotExists { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Header(k, v) => write!(f, "header: {}: {}", k, v), + } + } +} + +impl S3CopyIfNotExists { + fn from_str(s: &str) -> Option { + let (variant, value) = s.split_once(':')?; + match variant.trim() { + "header" => { + let (k, v) = value.split_once(':')?; + Some(Self::Header(k.trim().to_string(), v.trim().to_string())) + } + _ => None, + } + } +} + +impl Parse for S3CopyIfNotExists { + fn parse(v: &str) -> crate::Result { + Self::from_str(v).ok_or_else(|| crate::Error::Generic { + store: "Config", + source: format!("Failed to parse \"{v}\" as S3CopyIfNotExists").into(), + }) + } +} diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index be0ffa578d13..e0c5de5fe784 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -30,7 +30,7 @@ use reqwest::{Client, Method, Request, RequestBuilder, StatusCode}; use serde::Deserialize; use std::collections::BTreeMap; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use tracing::warn; use url::Url; @@ -89,6 +89,7 @@ const DATE_HEADER: &str = "x-amz-date"; const HASH_HEADER: &str = "x-amz-content-sha256"; const TOKEN_HEADER: &str = "x-amz-security-token"; const AUTH_HEADER: &str = "authorization"; +const ALGORITHM: &str = "AWS4-HMAC-SHA256"; impl<'a> AwsAuthorizer<'a> { /// Create a new [`AwsAuthorizer`] @@ -154,21 +155,110 @@ impl<'a> AwsAuthorizer<'a> { let header_digest = HeaderValue::from_str(&digest).unwrap(); request.headers_mut().insert(HASH_HEADER, header_digest); - // Each path segment must be URI-encoded twice (except for Amazon S3 which only gets URI-encoded once). + let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); + + let scope = self.scope(date); + + let string_to_sign = self.string_to_sign( + date, + &scope, + request.method(), + request.url(), + &canonical_headers, + &signed_headers, + &digest, + ); + + // sign the string + let signature = + self.credential + .sign(&string_to_sign, date, self.region, self.service); + + // build the actual auth header + let authorisation = format!( + "{} Credential={}/{}, SignedHeaders={}, Signature={}", + ALGORITHM, self.credential.key_id, scope, signed_headers, signature + ); + + let authorization_val = HeaderValue::from_str(&authorisation).unwrap(); + request.headers_mut().insert(AUTH_HEADER, authorization_val); + } + + pub(crate) fn sign(&self, method: Method, url: &mut Url, expires_in: Duration) { + let date = self.date.unwrap_or_else(Utc::now); + let scope = self.scope(date); + + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + url.query_pairs_mut() + .append_pair("X-Amz-Algorithm", ALGORITHM) + .append_pair( + "X-Amz-Credential", + &format!("{}/{}", self.credential.key_id, scope), + ) + .append_pair("X-Amz-Date", &date.format("%Y%m%dT%H%M%SZ").to_string()) + .append_pair("X-Amz-Expires", &expires_in.as_secs().to_string()) + .append_pair("X-Amz-SignedHeaders", "host"); + + // For S3, you must include the X-Amz-Security-Token query parameter in the URL if + // using credentials sourced from the STS service. + if let Some(ref token) = self.credential.token { + url.query_pairs_mut() + .append_pair("X-Amz-Security-Token", token); + } + + // We don't have a payload; the user is going to send the payload directly themselves. + let digest = UNSIGNED_PAYLOAD; + + let host = &url[url::Position::BeforeHost..url::Position::AfterPort].to_string(); + let mut headers = HeaderMap::new(); + let host_val = HeaderValue::from_str(host).unwrap(); + headers.insert("host", host_val); + + let (signed_headers, canonical_headers) = canonicalize_headers(&headers); + + let string_to_sign = self.string_to_sign( + date, + &scope, + &method, + url, + &canonical_headers, + &signed_headers, + digest, + ); + + let signature = + self.credential + .sign(&string_to_sign, date, self.region, self.service); + + url.query_pairs_mut() + .append_pair("X-Amz-Signature", &signature); + } + + #[allow(clippy::too_many_arguments)] + fn string_to_sign( + &self, + date: DateTime, + scope: &str, + request_method: &Method, + url: &Url, + canonical_headers: &str, + signed_headers: &str, + digest: &str, + ) -> String { + // Each path segment must be URI-encoded twice (except for Amazon S3 which only gets + // URI-encoded once). // see https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html let canonical_uri = match self.service { - "s3" => request.url().path().to_string(), - _ => utf8_percent_encode(request.url().path(), &STRICT_PATH_ENCODE_SET) - .to_string(), + "s3" => url.path().to_string(), + _ => utf8_percent_encode(url.path(), &STRICT_PATH_ENCODE_SET).to_string(), }; - let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); - let canonical_query = canonicalize_query(request.url()); + let canonical_query = canonicalize_query(url); // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html let canonical_request = format!( "{}\n{}\n{}\n{}\n{}\n{}", - request.method().as_str(), + request_method.as_str(), canonical_uri, canonical_query, canonical_headers, @@ -177,33 +267,23 @@ impl<'a> AwsAuthorizer<'a> { ); let hashed_canonical_request = hex_digest(canonical_request.as_bytes()); - let scope = format!( - "{}/{}/{}/aws4_request", - date.format("%Y%m%d"), - self.region, - self.service - ); - let string_to_sign = format!( - "AWS4-HMAC-SHA256\n{}\n{}\n{}", + format!( + "{}\n{}\n{}\n{}", + ALGORITHM, date.format("%Y%m%dT%H%M%SZ"), scope, hashed_canonical_request - ); - - // sign the string - let signature = - self.credential - .sign(&string_to_sign, date, self.region, self.service); - - // build the actual auth header - let authorisation = format!( - "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", - self.credential.key_id, scope, signed_headers, signature - ); + ) + } - let authorization_val = HeaderValue::from_str(&authorisation).unwrap(); - request.headers_mut().insert(AUTH_HEADER, authorization_val); + fn scope(&self, date: DateTime) -> String { + format!( + "{}/{}/{}/aws4_request", + date.format("%Y%m%d"), + self.region, + self.service + ) } } @@ -211,7 +291,7 @@ pub trait CredentialExt { /// Sign a request fn with_aws_sigv4( self, - credential: &AwsCredential, + credential: Option<&AwsCredential>, region: &str, service: &str, sign_payload: bool, @@ -222,20 +302,25 @@ pub trait CredentialExt { impl CredentialExt for RequestBuilder { fn with_aws_sigv4( self, - credential: &AwsCredential, + credential: Option<&AwsCredential>, region: &str, service: &str, sign_payload: bool, payload_sha256: Option<&[u8]>, ) -> Self { - let (client, request) = self.build_split(); - let mut request = request.expect("request valid"); + match credential { + Some(credential) => { + let (client, request) = self.build_split(); + let mut request = request.expect("request valid"); - AwsAuthorizer::new(credential, service, region) - .with_sign_payload(sign_payload) - .authorize(&mut request, payload_sha256); + AwsAuthorizer::new(credential, service, region) + .with_sign_payload(sign_payload) + .authorize(&mut request, payload_sha256); - Self::from_parts(client, request) + Self::from_parts(client, request) + } + None => self, + } } } @@ -667,7 +752,46 @@ mod tests { }; authorizer.authorize(&mut request, None); - assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699") + assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699"); + } + + #[test] + fn signed_get_url() { + // Values from https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z") + .unwrap() + .with_timezone(&Utc); + + let authorizer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "s3", + region: "us-east-1", + sign_payload: false, + }; + + let mut url = + Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400)); + + assert_eq!( + url, + Url::parse( + "https://examplebucket.s3.amazonaws.com/test.txt?\ + X-Amz-Algorithm=AWS4-HMAC-SHA256&\ + X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\ + X-Amz-Date=20130524T000000Z&\ + X-Amz-Expires=86400&\ + X-Amz-SignedHeaders=host&\ + X-Amz-Signature=aeeed9bbccd4d02ee5c0109b86d86835f995330da4c265957d157751f604d404" + ).unwrap() + ); } #[test] diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 8a486f986792..4de29786431c 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -36,15 +36,14 @@ use bytes::Bytes; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; +use reqwest::Method; use serde::{Deserialize, Serialize}; use snafu::{ensure, OptionExt, ResultExt, Snafu}; -use std::str::FromStr; -use std::sync::Arc; +use std::{str::FromStr, sync::Arc, time::Duration}; use tokio::io::AsyncWrite; use tracing::info; use url::Url; -pub use crate::aws::checksum::Checksum; use crate::aws::client::{S3Client, S3Config}; use crate::aws::credential::{ InstanceCredentialProvider, TaskCredentialProvider, WebIdentityProvider, @@ -56,16 +55,21 @@ use crate::client::{ TokenCredentialProvider, }; use crate::config::ConfigValue; -use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; +use crate::multipart::{PartId, PutPart, WriteMultiPart}; +use crate::signer::Signer; use crate::{ ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, Path, Result, RetryConfig, + ObjectStore, Path, PutOptions, Result, RetryConfig, }; mod checksum; mod client; +mod copy; mod credential; +pub use checksum::Checksum; +pub use copy::S3CopyIfNotExists; + // http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html // // Do not URI-encode any of the unreserved characters that RFC 3986 defines: @@ -206,12 +210,92 @@ impl AmazonS3 { pub fn credentials(&self) -> &AwsCredentialProvider { &self.client.config().credentials } + + /// Create a full URL to the resource specified by `path` with this instance's configuration. + fn path_url(&self, path: &Path) -> String { + self.client.config().path_url(path) + } +} + +#[async_trait] +impl Signer for AmazonS3 { + /// Create a URL containing the relevant [AWS SigV4] query parameters that authorize a request + /// via `method` to the resource at `path` valid for the duration specified in `expires_in`. + /// + /// [AWS SigV4]: https://docs.aws.amazon.com/IAM/latest/UserGuide/create-signed-request.html + /// + /// # Example + /// + /// This example returns a URL that will enable a user to upload a file to + /// "some-folder/some-file.txt" in the next hour. + /// + /// ``` + /// # async fn example() -> Result<(), Box> { + /// # use object_store::{aws::AmazonS3Builder, path::Path, signer::Signer}; + /// # use reqwest::Method; + /// # use std::time::Duration; + /// # + /// let region = "us-east-1"; + /// let s3 = AmazonS3Builder::new() + /// .with_region(region) + /// .with_bucket_name("my-bucket") + /// .with_access_key_id("my-access-key-id") + /// .with_secret_access_key("my-secret-access-key") + /// .build()?; + /// + /// let url = s3.signed_url( + /// Method::PUT, + /// &Path::from("some-folder/some-file.txt"), + /// Duration::from_secs(60 * 60) + /// ).await?; + /// # Ok(()) + /// # } + /// ``` + async fn signed_url( + &self, + method: Method, + path: &Path, + expires_in: Duration, + ) -> Result { + let credential = self.credentials().get_credential().await?; + let authorizer = + AwsAuthorizer::new(&credential, "s3", &self.client.config().region); + + let path_url = self.path_url(path); + let mut url = + Url::parse(&path_url).context(UnableToParseUrlSnafu { url: path_url })?; + + authorizer.sign(method, &mut url, expires_in); + + Ok(url) + } } #[async_trait] impl ObjectStore for AmazonS3 { async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.client.put_request(location, Some(bytes), &()).await?; + self.client + .put_request(location, bytes, &(), None) + .await?; + Ok(()) + } + + async fn put_opts( + &self, + location: &Path, + bytes: Bytes, + options: PutOptions, + ) -> Result<()> { + if options.tags.is_empty() { + self.client + .put_request(location, bytes, &(), None) + .await?; + } else { + self.client + .put_request(location, bytes, &(), Some(&options.tags)) + .await?; + } + Ok(()) } @@ -227,7 +311,7 @@ impl ObjectStore for AmazonS3 { client: Arc::clone(&self.client), }; - Ok((id, Box::new(CloudMultiPartUpload::new(upload, 8)))) + Ok((id, Box::new(WriteMultiPart::new(upload, 8)))) } async fn abort_multipart( @@ -244,10 +328,6 @@ impl ObjectStore for AmazonS3 { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location, &()).await } @@ -272,19 +352,16 @@ impl ObjectStore for AmazonS3 { .boxed() } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.client.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.client.list(prefix) } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - self.client.list_with_offset(prefix, offset).await + ) -> BoxStream<'_, Result> { + self.client.list_with_offset(prefix, offset) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -292,12 +369,11 @@ impl ObjectStore for AmazonS3 { } async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - self.client.copy_request(from, to).await + self.client.copy_request(from, to, true).await } - async fn copy_if_not_exists(&self, _source: &Path, _dest: &Path) -> Result<()> { - // Will need dynamodb_lock - Err(crate::Error::NotImplemented) + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, false).await } } @@ -308,12 +384,8 @@ struct S3MultiPartUpload { } #[async_trait] -impl CloudMultiPartUploadImpl for S3MultiPartUpload { - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result { +impl PutPart for S3MultiPartUpload { + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { use reqwest::header::ETAG; let part = (part_idx + 1).to_string(); @@ -321,31 +393,22 @@ impl CloudMultiPartUploadImpl for S3MultiPartUpload { .client .put_request( &self.location, - Some(buf.into()), + buf.into(), &[("partNumber", &part), ("uploadId", &self.upload_id)], + None, ) .await?; - let etag = response - .headers() - .get(ETAG) - .context(MissingEtagSnafu) - .map_err(crate::Error::from)?; + let etag = response.headers().get(ETAG).context(MissingEtagSnafu)?; - let etag = etag - .to_str() - .context(BadHeaderSnafu) - .map_err(crate::Error::from)?; + let etag = etag.to_str().context(BadHeaderSnafu)?; - Ok(UploadPart { + Ok(PartId { content_id: etag.to_string(), }) } - async fn complete( - &self, - completed_parts: Vec, - ) -> Result<(), std::io::Error> { + async fn complete(&self, completed_parts: Vec) -> Result<()> { self.client .complete_multipart(&self.location, &self.upload_id, completed_parts) .await?; @@ -404,6 +467,10 @@ pub struct AmazonS3Builder { client_options: ClientOptions, /// Credentials credentials: Option, + /// Skip signing requests + skip_signature: ConfigValue, + /// Copy if not exists + copy_if_not_exists: Option>, } /// Configuration keys for [`AmazonS3Builder`] @@ -535,6 +602,14 @@ pub enum AmazonS3ConfigKey { /// ContainerCredentialsRelativeUri, + /// Configure how to provide [`ObjectStore::copy_if_not_exists`] + /// + /// See [`S3CopyIfNotExists`] + CopyIfNotExists, + + /// Skip signing request + SkipSignature, + /// Client options Client(ClientConfigKey), } @@ -557,6 +632,8 @@ impl AsRef for AmazonS3ConfigKey { Self::ContainerCredentialsRelativeUri => { "aws_container_credentials_relative_uri" } + Self::SkipSignature => "aws_skip_signature", + Self::CopyIfNotExists => "copy_if_not_exists", Self::Client(opt) => opt.as_ref(), } } @@ -590,6 +667,8 @@ impl FromStr for AmazonS3ConfigKey { "aws_container_credentials_relative_uri" => { Ok(Self::ContainerCredentialsRelativeUri) } + "aws_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), + "copy_if_not_exists" => Ok(Self::CopyIfNotExists), // Backwards compatibility "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), _ => match s.parse() { @@ -700,6 +779,10 @@ impl AmazonS3Builder { AmazonS3ConfigKey::Client(key) => { self.client_options = self.client_options.with_config(key, value) } + AmazonS3ConfigKey::SkipSignature => self.skip_signature.parse(value), + AmazonS3ConfigKey::CopyIfNotExists => { + self.copy_if_not_exists = Some(ConfigValue::Deferred(value.into())) + } }; self } @@ -767,6 +850,10 @@ impl AmazonS3Builder { AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { self.container_credentials_relative_uri.clone() } + AmazonS3ConfigKey::SkipSignature => Some(self.skip_signature.to_string()), + AmazonS3ConfigKey::CopyIfNotExists => { + self.copy_if_not_exists.as_ref().map(ToString::to_string) + } } } @@ -918,6 +1005,14 @@ impl AmazonS3Builder { self } + /// If enabled, [`AmazonS3`] will not fetch credentials and will not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny authorized requests + pub fn with_skip_signature(mut self, skip_signature: bool) -> Self { + self.skip_signature = skip_signature.into(); + self + } + /// Sets the [checksum algorithm] which has to be used for object integrity check during upload. /// /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html @@ -943,12 +1038,35 @@ impl AmazonS3Builder { self } + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate( + mut self, + proxy_ca_certificate: impl Into, + ) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + /// Sets the client options, overriding any already set pub fn with_client_options(mut self, options: ClientOptions) -> Self { self.client_options = options; self } + /// Configure how to provide [`ObjectStore::copy_if_not_exists`] + pub fn with_copy_if_not_exists(mut self, config: S3CopyIfNotExists) -> Self { + self.copy_if_not_exists = Some(config.into()); + self + } + /// Create a [`AmazonS3`] instance from the provided values, /// consuming `self`. pub fn build(mut self) -> Result { @@ -959,6 +1077,7 @@ impl AmazonS3Builder { let bucket = self.bucket_name.context(MissingBucketNameSnafu)?; let region = self.region.context(MissingRegionSnafu)?; let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?; + let copy_if_not_exists = self.copy_if_not_exists.map(|x| x.get()).transpose()?; let credentials = if let Some(credentials) = self.credentials { credentials @@ -1030,8 +1149,7 @@ impl AmazonS3Builder { Arc::new(TokenCredentialProvider::new( token, - // The instance metadata endpoint is access over HTTP - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; @@ -1063,7 +1181,9 @@ impl AmazonS3Builder { retry_config: self.retry_config, client_options: self.client_options, sign_payload: !self.unsigned_payload.get()?, + skip_signature: self.skip_signature.get()?, checksum, + copy_if_not_exists, }; let client = Arc::new(S3Client::new(config)?); @@ -1076,158 +1196,15 @@ impl AmazonS3Builder { mod tests { use super::*; use crate::tests::{ - get_nonexistent_object, get_opts, list_uses_directories_correctly, - list_with_delimiter, put_get_delete_list_opts, rename_and_copy, stream_get, + copy_if_not_exists, get_nonexistent_object, get_opts, + list_uses_directories_correctly, list_with_delimiter, put_get_delete_list_opts, + rename_and_copy, stream_get, }; use bytes::Bytes; use std::collections::HashMap; - use std::env; const NON_EXISTENT_NAME: &str = "nonexistentname"; - // Helper macro to skip tests if TEST_INTEGRATION and the AWS - // environment variables are not set. Returns a configured - // AmazonS3Builder - macro_rules! maybe_skip_integration { - () => {{ - dotenv::dotenv().ok(); - - let required_vars = [ - "OBJECT_STORE_AWS_DEFAULT_REGION", - "OBJECT_STORE_BUCKET", - "OBJECT_STORE_AWS_ACCESS_KEY_ID", - "OBJECT_STORE_AWS_SECRET_ACCESS_KEY", - ]; - let unset_vars: Vec<_> = required_vars - .iter() - .filter_map(|&name| match env::var(name) { - Ok(_) => None, - Err(_) => Some(name), - }) - .collect(); - let unset_var_names = unset_vars.join(", "); - - let force = env::var("TEST_INTEGRATION"); - - if force.is_ok() && !unset_var_names.is_empty() { - panic!( - "TEST_INTEGRATION is set, \ - but variable(s) {} need to be set", - unset_var_names - ); - } else if force.is_err() { - eprintln!( - "skipping AWS integration test - set {}TEST_INTEGRATION to run", - if unset_var_names.is_empty() { - String::new() - } else { - format!("{} and ", unset_var_names) - } - ); - return; - } else { - let config = AmazonS3Builder::new() - .with_access_key_id( - env::var("OBJECT_STORE_AWS_ACCESS_KEY_ID") - .expect("already checked OBJECT_STORE_AWS_ACCESS_KEY_ID"), - ) - .with_secret_access_key( - env::var("OBJECT_STORE_AWS_SECRET_ACCESS_KEY") - .expect("already checked OBJECT_STORE_AWS_SECRET_ACCESS_KEY"), - ) - .with_region( - env::var("OBJECT_STORE_AWS_DEFAULT_REGION") - .expect("already checked OBJECT_STORE_AWS_DEFAULT_REGION"), - ) - .with_bucket_name( - env::var("OBJECT_STORE_BUCKET") - .expect("already checked OBJECT_STORE_BUCKET"), - ) - .with_allow_http(true); - - let config = if let Ok(endpoint) = env::var("OBJECT_STORE_AWS_ENDPOINT") { - config.with_endpoint(endpoint) - } else { - config - }; - - let config = if let Ok(token) = env::var("OBJECT_STORE_AWS_SESSION_TOKEN") - { - config.with_token(token) - } else { - config - }; - - let config = if let Ok(virtual_hosted_style_request) = - env::var("OBJECT_STORE_VIRTUAL_HOSTED_STYLE_REQUEST") - { - config.with_virtual_hosted_style_request( - virtual_hosted_style_request.trim().parse().unwrap(), - ) - } else { - config - }; - - config - } - }}; - } - - #[test] - fn s3_test_config_from_env() { - let aws_access_key_id = env::var("AWS_ACCESS_KEY_ID") - .unwrap_or_else(|_| "object_store:fake_access_key_id".into()); - let aws_secret_access_key = env::var("AWS_SECRET_ACCESS_KEY") - .unwrap_or_else(|_| "object_store:fake_secret_key".into()); - - let aws_default_region = env::var("AWS_DEFAULT_REGION") - .unwrap_or_else(|_| "object_store:fake_default_region".into()); - - let aws_endpoint = env::var("AWS_ENDPOINT") - .unwrap_or_else(|_| "object_store:fake_endpoint".into()); - let aws_session_token = env::var("AWS_SESSION_TOKEN") - .unwrap_or_else(|_| "object_store:fake_session_token".into()); - - let container_creds_relative_uri = - env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") - .unwrap_or_else(|_| "/object_store/fake_credentials_uri".into()); - - // required - env::set_var("AWS_ACCESS_KEY_ID", &aws_access_key_id); - env::set_var("AWS_SECRET_ACCESS_KEY", &aws_secret_access_key); - env::set_var("AWS_DEFAULT_REGION", &aws_default_region); - - // optional - env::set_var("AWS_ENDPOINT", &aws_endpoint); - env::set_var("AWS_SESSION_TOKEN", &aws_session_token); - env::set_var( - "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", - &container_creds_relative_uri, - ); - env::set_var("AWS_UNSIGNED_PAYLOAD", "true"); - env::set_var("AWS_CHECKSUM_ALGORITHM", "sha256"); - - let builder = AmazonS3Builder::from_env(); - assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str()); - assert_eq!( - builder.secret_access_key.unwrap(), - aws_secret_access_key.as_str() - ); - assert_eq!(builder.region.unwrap(), aws_default_region); - - assert_eq!(builder.endpoint.unwrap(), aws_endpoint); - assert_eq!(builder.token.unwrap(), aws_session_token); - assert_eq!( - builder.container_credentials_relative_uri.unwrap(), - container_creds_relative_uri - ); - assert_eq!( - builder.checksum_algorithm.unwrap().get().unwrap(), - Checksum::SHA256 - ); - assert!(builder.unsigned_payload.get().unwrap()); - } - #[test] fn s3_test_config_from_map() { let aws_access_key_id = "object_store:fake_access_key_id".to_string(); @@ -1318,8 +1295,11 @@ mod tests { #[tokio::test] async fn s3_test() { - let config = maybe_skip_integration!(); + crate::test_util::maybe_skip_integration!(); + let config = AmazonS3Builder::from_env(); + let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); + let test_not_exists = config.copy_if_not_exists.is_some(); let integration = config.build().unwrap(); // Localstack doesn't support listing with spaces https://github.com/localstack/localstack/issues/6328 @@ -1329,15 +1309,19 @@ mod tests { list_with_delimiter(&integration).await; rename_and_copy(&integration).await; stream_get(&integration).await; + if test_not_exists { + copy_if_not_exists(&integration).await; + } // run integration test with unsigned payload enabled - let config = maybe_skip_integration!().with_unsigned_payload(true); + let config = AmazonS3Builder::from_env().with_unsigned_payload(true); let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); let integration = config.build().unwrap(); put_get_delete_list_opts(&integration, is_local).await; // run integration test with checksum set to sha256 - let config = maybe_skip_integration!().with_checksum_algorithm(Checksum::SHA256); + let config = + AmazonS3Builder::from_env().with_checksum_algorithm(Checksum::SHA256); let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); let integration = config.build().unwrap(); put_get_delete_list_opts(&integration, is_local).await; @@ -1345,8 +1329,8 @@ mod tests { #[tokio::test] async fn s3_test_get_nonexistent_location() { - let config = maybe_skip_integration!(); - let integration = config.build().unwrap(); + crate::test_util::maybe_skip_integration!(); + let integration = AmazonS3Builder::from_env().build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1358,7 +1342,8 @@ mod tests { #[tokio::test] async fn s3_test_get_nonexistent_bucket() { - let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + crate::test_util::maybe_skip_integration!(); + let config = AmazonS3Builder::from_env().with_bucket_name(NON_EXISTENT_NAME); let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1369,8 +1354,8 @@ mod tests { #[tokio::test] async fn s3_test_put_nonexistent_bucket() { - let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); - + crate::test_util::maybe_skip_integration!(); + let config = AmazonS3Builder::from_env().with_bucket_name(NON_EXISTENT_NAME); let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1382,8 +1367,8 @@ mod tests { #[tokio::test] async fn s3_test_delete_nonexistent_location() { - let config = maybe_skip_integration!(); - let integration = config.build().unwrap(); + crate::test_util::maybe_skip_integration!(); + let integration = AmazonS3Builder::from_env().build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1392,7 +1377,8 @@ mod tests { #[tokio::test] async fn s3_test_delete_nonexistent_bucket() { - let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + crate::test_util::maybe_skip_integration!(); + let config = AmazonS3Builder::from_env().with_bucket_name(NON_EXISTENT_NAME); let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1555,4 +1541,30 @@ mod s3_resolve_bucket_region_tests { assert!(result.is_err()); } + + #[tokio::test] + #[ignore = "Tests shouldn't call use remote services by default"] + async fn test_disable_creds() { + // https://registry.opendata.aws/daylight-osm/ + let v1 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_access_key_id("local") + .with_secret_access_key("development") + .build() + .unwrap(); + + let prefix = Path::from("release"); + + v1.list_with_delimiter(Some(&prefix)).await.unwrap_err(); + + let v2 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_skip_signature(true) + .build() + .unwrap(); + + v2.list_with_delimiter(Some(&prefix)).await.unwrap(); + } } diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index 5ed6f2443f32..f65388b61a80 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -264,15 +264,10 @@ impl GetClient for AzureClient { /// Make an Azure GET request /// /// - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -372,7 +367,7 @@ struct ListResultInternal { } fn to_list_result(value: ListResultInternal, prefix: Option<&str>) -> Result { - let prefix = prefix.map(Path::from).unwrap_or_else(Path::default); + let prefix = prefix.map(Path::from).unwrap_or_default(); let common_prefixes = value .blobs .blob_prefix @@ -387,7 +382,7 @@ fn to_list_result(value: ListResultInternal, prefix: Option<&str>) -> Result 0 && obj.location.as_ref().len() > prefix.as_ref().len() { diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index fd75389249b0..8dc61365fa6e 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -234,11 +234,9 @@ fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> Str fn canonicalize_header(headers: &HeaderMap) -> String { let mut names = headers .iter() - .filter_map(|(k, _)| { - (k.as_str().starts_with("x-ms")) - // TODO remove unwraps - .then(|| (k.as_str(), headers.get(k).unwrap().to_str().unwrap())) - }) + .filter(|&(k, _)| (k.as_str().starts_with("x-ms"))) + // TODO remove unwraps + .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap())) .collect::>(); names.sort_unstable(); diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index d2735038321b..8f9a54ea80e9 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -28,7 +28,7 @@ //! after 7 days. use self::client::{BlockId, BlockList}; use crate::{ - multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, + multipart::{PartId, PutPart, WriteMultiPart}, path::Path, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, RetryConfig, @@ -42,7 +42,6 @@ use percent_encoding::percent_decode_str; use serde::{Deserialize, Serialize}; use snafu::{OptionExt, ResultExt, Snafu}; use std::fmt::{Debug, Formatter}; -use std::io; use std::str::FromStr; use std::sync::Arc; use tokio::io::AsyncWrite; @@ -173,7 +172,7 @@ impl std::fmt::Display for MicrosoftAzure { impl ObjectStore for MicrosoftAzure { async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { self.client - .put_request(location, Some(bytes), false, &()) + .put_request(location, bytes, false, &()) .await?; Ok(()) } @@ -186,7 +185,7 @@ impl ObjectStore for MicrosoftAzure { client: Arc::clone(&self.client), location: location.to_owned(), }; - Ok((String::new(), Box::new(CloudMultiPartUpload::new(inner, 8)))) + Ok((String::new(), Box::new(WriteMultiPart::new(inner, 8)))) } async fn abort_multipart( @@ -203,19 +202,12 @@ impl ObjectStore for MicrosoftAzure { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location, &()).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.client.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.client.list(prefix) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -243,12 +235,8 @@ struct AzureMultiPartUpload { } #[async_trait] -impl CloudMultiPartUploadImpl for AzureMultiPartUpload { - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result { +impl PutPart for AzureMultiPartUpload { + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { let content_id = format!("{part_idx:20}"); let block_id: BlockId = content_id.clone().into(); @@ -264,10 +252,10 @@ impl CloudMultiPartUploadImpl for AzureMultiPartUpload { ) .await?; - Ok(UploadPart { content_id }) + Ok(PartId { content_id }) } - async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { + async fn complete(&self, completed_parts: Vec) -> Result<()> { let blocks = completed_parts .into_iter() .map(|part| BlockId::from(part.content_id)) @@ -330,6 +318,8 @@ pub struct MicrosoftAzureBuilder { url: Option, /// When set to true, azurite storage emulator has to be used use_emulator: ConfigValue, + /// Storage endpoint + endpoint: Option, /// Msi endpoint for acquiring managed identity token msi_endpoint: Option, /// Object id for use with managed identity authentication @@ -346,6 +336,10 @@ pub struct MicrosoftAzureBuilder { client_options: ClientOptions, /// Credentials credentials: Option, + /// When set to true, fabric url scheme will be used + /// + /// i.e. https://{account_name}.dfs.fabric.microsoft.com + use_fabric_endpoint: ConfigValue, } /// Configuration keys for [`MicrosoftAzureBuilder`] @@ -435,6 +429,21 @@ pub enum AzureConfigKey { /// - `use_emulator` UseEmulator, + /// Override the endpoint used to communicate with blob storage + /// + /// Supported keys: + /// - `azure_storage_endpoint` + /// - `azure_endpoint` + /// - `endpoint` + Endpoint, + + /// Use object store with url scheme account.dfs.fabric.microsoft.com + /// + /// Supported keys: + /// - `azure_use_fabric_endpoint` + /// - `use_fabric_endpoint` + UseFabricEndpoint, + /// Endpoint to request a imds managed identity token /// /// Supported keys: @@ -472,6 +481,13 @@ pub enum AzureConfigKey { /// - `use_azure_cli` UseAzureCli, + /// Container name + /// + /// Supported keys: + /// - `azure_container_name` + /// - `container_name` + ContainerName, + /// Client options Client(ClientConfigKey), } @@ -487,11 +503,14 @@ impl AsRef for AzureConfigKey { Self::SasKey => "azure_storage_sas_key", Self::Token => "azure_storage_token", Self::UseEmulator => "azure_storage_use_emulator", + Self::UseFabricEndpoint => "azure_use_fabric_endpoint", + Self::Endpoint => "azure_storage_endpoint", Self::MsiEndpoint => "azure_msi_endpoint", Self::ObjectId => "azure_object_id", Self::MsiResourceId => "azure_msi_resource_id", Self::FederatedTokenFile => "azure_federated_token_file", Self::UseAzureCli => "azure_use_azure_cli", + Self::ContainerName => "azure_container_name", Self::Client(key) => key.as_ref(), } } @@ -527,6 +546,9 @@ impl FromStr for AzureConfigKey { | "sas_token" => Ok(Self::SasKey), "azure_storage_token" | "bearer_token" | "token" => Ok(Self::Token), "azure_storage_use_emulator" | "use_emulator" => Ok(Self::UseEmulator), + "azure_storage_endpoint" | "azure_endpoint" | "endpoint" => { + Ok(Self::Endpoint) + } "azure_msi_endpoint" | "azure_identity_endpoint" | "identity_endpoint" @@ -536,7 +558,11 @@ impl FromStr for AzureConfigKey { "azure_federated_token_file" | "federated_token_file" => { Ok(Self::FederatedTokenFile) } + "azure_use_fabric_endpoint" | "use_fabric_endpoint" => { + Ok(Self::UseFabricEndpoint) + } "azure_use_azure_cli" | "use_azure_cli" => Ok(Self::UseAzureCli), + "azure_container_name" | "container_name" => Ok(Self::ContainerName), // Backwards compatibility "azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), _ => match s.parse() { @@ -605,11 +631,16 @@ impl MicrosoftAzureBuilder { /// /// - `abfs[s]:///` (according to [fsspec](https://github.com/fsspec/adlfs)) /// - `abfs[s]://@.dfs.core.windows.net/` + /// - `abfs[s]://@.dfs.fabric.microsoft.com/` /// - `az:///` (according to [fsspec](https://github.com/fsspec/adlfs)) /// - `adl:///` (according to [fsspec](https://github.com/fsspec/adlfs)) /// - `azure:///` (custom) /// - `https://.dfs.core.windows.net` /// - `https://.blob.core.windows.net` + /// - `https://.dfs.fabric.microsoft.com` + /// - `https://.dfs.fabric.microsoft.com/` + /// - `https://.blob.fabric.microsoft.com` + /// - `https://.blob.fabric.microsoft.com/` /// /// Note: Settings derived from the URL will override any others set on this builder /// @@ -644,9 +675,12 @@ impl MicrosoftAzureBuilder { } AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value), AzureConfigKey::UseEmulator => self.use_emulator.parse(value), + AzureConfigKey::Endpoint => self.endpoint = Some(value.into()), + AzureConfigKey::UseFabricEndpoint => self.use_fabric_endpoint.parse(value), AzureConfigKey::Client(key) => { self.client_options = self.client_options.with_config(key, value) } + AzureConfigKey::ContainerName => self.container_name = Some(value.into()), }; self } @@ -697,12 +731,17 @@ impl MicrosoftAzureBuilder { AzureConfigKey::SasKey => self.sas_key.clone(), AzureConfigKey::Token => self.bearer_token.clone(), AzureConfigKey::UseEmulator => Some(self.use_emulator.to_string()), + AzureConfigKey::UseFabricEndpoint => { + Some(self.use_fabric_endpoint.to_string()) + } + AzureConfigKey::Endpoint => self.endpoint.clone(), AzureConfigKey::MsiEndpoint => self.msi_endpoint.clone(), AzureConfigKey::ObjectId => self.object_id.clone(), AzureConfigKey::MsiResourceId => self.msi_resource_id.clone(), AzureConfigKey::FederatedTokenFile => self.federated_token_file.clone(), AzureConfigKey::UseAzureCli => Some(self.use_azure_cli.to_string()), AzureConfigKey::Client(key) => self.client_options.get_config_value(key), + AzureConfigKey::ContainerName => self.container_name.clone(), } } @@ -729,6 +768,10 @@ impl MicrosoftAzureBuilder { } else if let Some(a) = host.strip_suffix(".dfs.core.windows.net") { self.container_name = Some(validate(parsed.username())?); self.account_name = Some(validate(a)?); + } else if let Some(a) = host.strip_suffix(".dfs.fabric.microsoft.com") { + self.container_name = Some(validate(parsed.username())?); + self.account_name = Some(validate(a)?); + self.use_fabric_endpoint = true.into(); } else { return Err(UrlNotRecognisedSnafu { url }.build().into()); } @@ -738,6 +781,21 @@ impl MicrosoftAzureBuilder { | Some((a, "blob.core.windows.net")) => { self.account_name = Some(validate(a)?); } + Some((a, "dfs.fabric.microsoft.com")) + | Some((a, "blob.fabric.microsoft.com")) => { + self.account_name = Some(validate(a)?); + // Attempt to infer the container name from the URL + // - https://onelake.dfs.fabric.microsoft.com///Files/test.csv + // - https://onelake.dfs.fabric.microsoft.com//.// + // + // See + if let Some(workspace) = parsed.path_segments().unwrap().next() { + if !workspace.is_empty() { + self.container_name = Some(workspace.to_string()) + } + } + self.use_fabric_endpoint = true.into(); + } _ => return Err(UrlNotRecognisedSnafu { url }.build().into()), }, scheme => return Err(UnknownUrlSchemeSnafu { scheme }.build().into()), @@ -824,6 +882,24 @@ impl MicrosoftAzureBuilder { self } + /// Override the endpoint used to communicate with blob storage + /// + /// Defaults to `https://{account}.blob.core.windows.net` + pub fn with_endpoint(mut self, endpoint: String) -> Self { + self.endpoint = Some(endpoint); + self + } + + /// Set if Microsoft Fabric url scheme should be used (defaults to false) + /// When disabled the url scheme used is `https://{account}.blob.core.windows.net` + /// When enabled the url scheme used is `https://{account}.dfs.fabric.microsoft.com` + /// + /// Note: [`Self::with_endpoint`] will take precedence over this option + pub fn with_use_fabric_endpoint(mut self, use_fabric_endpoint: bool) -> Self { + self.use_fabric_endpoint = use_fabric_endpoint.into(); + self + } + /// Sets what protocol is allowed. If `allow_http` is : /// * false (default): Only HTTPS are allowed /// * true: HTTP and HTTPS are allowed @@ -852,6 +928,23 @@ impl MicrosoftAzureBuilder { self } + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate( + mut self, + proxy_ca_certificate: impl Into, + ) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + /// Sets the client options, overriding any already set pub fn with_client_options(mut self, options: ClientOptions) -> Self { self.client_options = options; @@ -890,6 +983,7 @@ impl MicrosoftAzureBuilder { } let container = self.container_name.ok_or(Error::MissingContainerName {})?; + let static_creds = |credential: AzureCredential| -> AzureCredentialProvider { Arc::new(StaticCredentialProvider::new(credential)) }; @@ -911,7 +1005,16 @@ impl MicrosoftAzureBuilder { (true, url, credential, account_name) } else { let account_name = self.account_name.ok_or(Error::MissingAccount {})?; - let account_url = format!("https://{}.blob.core.windows.net", &account_name); + let account_url = match self.endpoint { + Some(account_url) => account_url, + None => match self.use_fabric_endpoint.get()? { + true => { + format!("https://{}.blob.fabric.microsoft.com", &account_name) + } + false => format!("https://{}.blob.core.windows.net", &account_name), + }, + }; + let url = Url::parse(&account_url) .context(UnableToParseUrlSnafu { url: account_url })?; @@ -964,7 +1067,7 @@ impl MicrosoftAzureBuilder { ); Arc::new(TokenCredentialProvider::new( msi_credential, - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; @@ -1026,107 +1129,16 @@ mod tests { use super::*; use crate::tests::{ copy_if_not_exists, get_opts, list_uses_directories_correctly, - list_with_delimiter, put_get_delete_list, put_get_delete_list_opts, - rename_and_copy, stream_get, + list_with_delimiter, put_get_delete_list_opts, rename_and_copy, stream_get, }; use std::collections::HashMap; - use std::env; - - // Helper macro to skip tests if TEST_INTEGRATION and the Azure environment - // variables are not set. - macro_rules! maybe_skip_integration { - () => {{ - dotenv::dotenv().ok(); - - let use_emulator = std::env::var("AZURE_USE_EMULATOR").is_ok(); - - let mut required_vars = vec!["OBJECT_STORE_BUCKET"]; - if !use_emulator { - required_vars.push("AZURE_STORAGE_ACCOUNT"); - required_vars.push("AZURE_STORAGE_ACCESS_KEY"); - } - let unset_vars: Vec<_> = required_vars - .iter() - .filter_map(|&name| match env::var(name) { - Ok(_) => None, - Err(_) => Some(name), - }) - .collect(); - let unset_var_names = unset_vars.join(", "); - - let force = std::env::var("TEST_INTEGRATION"); - - if force.is_ok() && !unset_var_names.is_empty() { - panic!( - "TEST_INTEGRATION is set, \ - but variable(s) {} need to be set", - unset_var_names - ) - } else if force.is_err() { - eprintln!( - "skipping Azure integration test - set {}TEST_INTEGRATION to run", - if unset_var_names.is_empty() { - String::new() - } else { - format!("{} and ", unset_var_names) - } - ); - return; - } else { - let builder = MicrosoftAzureBuilder::new() - .with_container_name( - env::var("OBJECT_STORE_BUCKET") - .expect("already checked OBJECT_STORE_BUCKET"), - ) - .with_use_emulator(use_emulator); - if !use_emulator { - builder - .with_account( - env::var("AZURE_STORAGE_ACCOUNT").unwrap_or_default(), - ) - .with_access_key( - env::var("AZURE_STORAGE_ACCESS_KEY").unwrap_or_default(), - ) - } else { - builder - } - } - }}; - } #[tokio::test] async fn azure_blob_test() { - let integration = maybe_skip_integration!().build().unwrap(); - put_get_delete_list_opts(&integration, false).await; - get_opts(&integration).await; - list_uses_directories_correctly(&integration).await; - list_with_delimiter(&integration).await; - rename_and_copy(&integration).await; - copy_if_not_exists(&integration).await; - stream_get(&integration).await; - } - - // test for running integration test against actual blob service with service principal - // credentials. To run make sure all environment variables are set and remove the ignore - #[tokio::test] - #[ignore] - async fn azure_blob_test_sp() { - dotenv::dotenv().ok(); - let builder = MicrosoftAzureBuilder::new() - .with_account( - env::var("AZURE_STORAGE_ACCOUNT") - .expect("must be set AZURE_STORAGE_ACCOUNT"), - ) - .with_container_name( - env::var("OBJECT_STORE_BUCKET").expect("must be set OBJECT_STORE_BUCKET"), - ) - .with_access_key( - env::var("AZURE_STORAGE_ACCESS_KEY") - .expect("must be set AZURE_STORAGE_CLIENT_ID"), - ); - let integration = builder.build().unwrap(); + crate::test_util::maybe_skip_integration!(); + let integration = MicrosoftAzureBuilder::from_env().build().unwrap(); - put_get_delete_list(&integration).await; + put_get_delete_list_opts(&integration, false).await; get_opts(&integration).await; list_uses_directories_correctly(&integration).await; list_with_delimiter(&integration).await; @@ -1143,6 +1155,15 @@ mod tests { .unwrap(); assert_eq!(builder.account_name, Some("account".to_string())); assert_eq!(builder.container_name, Some("file_system".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("abfss://file_system@account.dfs.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, Some("file_system".to_string())); + assert!(builder.use_fabric_endpoint.get().unwrap()); let mut builder = MicrosoftAzureBuilder::new(); builder.parse_url("abfs://container/path").unwrap(); @@ -1161,12 +1182,46 @@ mod tests { .parse_url("https://account.dfs.core.windows.net/") .unwrap(); assert_eq!(builder.account_name, Some("account".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); let mut builder = MicrosoftAzureBuilder::new(); builder .parse_url("https://account.blob.core.windows.net/") .unwrap(); assert_eq!(builder.account_name, Some("account".to_string())); + assert!(!builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, None); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.dfs.fabric.microsoft.com/container") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name.as_deref(), Some("container")); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.fabric.microsoft.com/") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name, None); + assert!(builder.use_fabric_endpoint.get().unwrap()); + + let mut builder = MicrosoftAzureBuilder::new(); + builder + .parse_url("https://account.blob.fabric.microsoft.com/container") + .unwrap(); + assert_eq!(builder.account_name, Some("account".to_string())); + assert_eq!(builder.container_name.as_deref(), Some("container")); + assert!(builder.use_fabric_endpoint.get().unwrap()); let err_cases = [ "mailto://account.blob.core.windows.net/", diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs new file mode 100644 index 000000000000..bdc3f4c772b9 --- /dev/null +++ b/object_store/src/buffered.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for performing tokio-style buffered IO + +use crate::path::Path; +use crate::{ObjectMeta, ObjectStore}; +use bytes::Bytes; +use futures::future::{BoxFuture, FutureExt}; +use futures::ready; +use std::cmp::Ordering; +use std::io::{Error, ErrorKind, SeekFrom}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, ReadBuf}; + +/// The default buffer size used by [`BufReader`] +pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024; + +/// An async-buffered reader compatible with the tokio IO traits +/// +/// Internally this maintains a buffer of the requested size, and uses [`ObjectStore::get_range`] +/// to populate its internal buffer once depleted. This buffer is cleared on seek. +/// +/// Whilst simple, this interface will typically be outperformed by the native [`ObjectStore`] +/// methods that better map to the network APIs. This is because most object stores have +/// very [high first-byte latencies], on the order of 100-200ms, and so avoiding unnecessary +/// round-trips is critical to throughput. +/// +/// Systems looking to sequentially scan a file should instead consider using [`ObjectStore::get`], +/// or [`ObjectStore::get_opts`], or [`ObjectStore::get_range`] to read a particular range. +/// +/// Systems looking to read multiple ranges of a file should instead consider using +/// [`ObjectStore::get_ranges`], which will optimise the vectored IO. +/// +/// [high first-byte latencies]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance.html +pub struct BufReader { + /// The object store to fetch data from + store: Arc, + /// The size of the object + size: u64, + /// The path to the object + path: Path, + /// The current position in the object + cursor: u64, + /// The number of bytes to read in a single request + capacity: usize, + /// The buffered data if any + buffer: Buffer, +} + +impl std::fmt::Debug for BufReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufReader") + .field("path", &self.path) + .field("size", &self.size) + .field("capacity", &self.capacity) + .finish() + } +} + +enum Buffer { + Empty, + Pending(BoxFuture<'static, std::io::Result>), + Ready(Bytes), +} + +impl BufReader { + /// Create a new [`BufReader`] from the provided [`ObjectMeta`] and [`ObjectStore`] + pub fn new(store: Arc, meta: &ObjectMeta) -> Self { + Self::with_capacity(store, meta, DEFAULT_BUFFER_SIZE) + } + + /// Create a new [`BufReader`] from the provided [`ObjectMeta`], [`ObjectStore`], and `capacity` + pub fn with_capacity( + store: Arc, + meta: &ObjectMeta, + capacity: usize, + ) -> Self { + Self { + path: meta.location.clone(), + size: meta.size as _, + store, + capacity, + cursor: 0, + buffer: Buffer::Empty, + } + } + + fn poll_fill_buf_impl( + &mut self, + cx: &mut Context<'_>, + amnt: usize, + ) -> Poll> { + let buf = &mut self.buffer; + loop { + match buf { + Buffer::Empty => { + let store = Arc::clone(&self.store); + let path = self.path.clone(); + let start = self.cursor.min(self.size) as _; + let end = self.cursor.saturating_add(amnt as u64).min(self.size) as _; + + if start == end { + return Poll::Ready(Ok(&[])); + } + + *buf = Buffer::Pending(Box::pin(async move { + Ok(store.get_range(&path, start..end).await?) + })) + } + Buffer::Pending(fut) => match ready!(fut.poll_unpin(cx)) { + Ok(b) => *buf = Buffer::Ready(b), + Err(e) => return Poll::Ready(Err(e)), + }, + Buffer::Ready(r) => return Poll::Ready(Ok(r)), + } + } + } +} + +impl AsyncSeek for BufReader { + fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> { + self.cursor = match position { + SeekFrom::Start(offset) => offset, + SeekFrom::End(offset) => { + checked_add_signed(self.size,offset).ok_or_else(|| Error::new(ErrorKind::InvalidInput, format!("Seeking {offset} from end of {} byte file would result in overflow", self.size)))? + } + SeekFrom::Current(offset) => { + checked_add_signed(self.cursor, offset).ok_or_else(|| Error::new(ErrorKind::InvalidInput, format!("Seeking {offset} from current offset of {} would result in overflow", self.cursor)))? + } + }; + self.buffer = Buffer::Empty; + Ok(()) + } + + fn poll_complete( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.cursor)) + } +} + +impl AsyncRead for BufReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + out: &mut ReadBuf<'_>, + ) -> Poll> { + // Read the maximum of the internal buffer and `out` + let to_read = out.remaining().max(self.capacity); + let r = match ready!(self.poll_fill_buf_impl(cx, to_read)) { + Ok(buf) => { + let to_consume = out.remaining().min(buf.len()); + out.put_slice(&buf[..to_consume]); + self.consume(to_consume); + Ok(()) + } + Err(e) => Err(e), + }; + Poll::Ready(r) + } +} + +impl AsyncBufRead for BufReader { + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let capacity = self.capacity; + self.get_mut().poll_fill_buf_impl(cx, capacity) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + match &mut self.buffer { + Buffer::Empty => assert_eq!(amt, 0, "cannot consume from empty buffer"), + Buffer::Ready(b) => match b.len().cmp(&amt) { + Ordering::Less => panic!("{amt} exceeds buffer sized of {}", b.len()), + Ordering::Greater => *b = b.slice(amt..), + Ordering::Equal => self.buffer = Buffer::Empty, + }, + Buffer::Pending(_) => panic!("cannot consume from pending buffer"), + } + self.cursor += amt as u64; + } +} + +/// Port of standardised function as requires Rust 1.66 +/// +/// +#[inline] +fn checked_add_signed(a: u64, rhs: i64) -> Option { + let (res, overflowed) = a.overflowing_add(rhs as _); + let overflow = overflowed ^ (rhs < 0); + (!overflow).then_some(res) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::InMemory; + use crate::path::Path; + use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncSeekExt}; + + #[tokio::test] + async fn test_buf_reader() { + let store = Arc::new(InMemory::new()) as Arc; + + let existent = Path::from("exists.txt"); + const BYTES: usize = 4096; + + let data: Bytes = b"12345678".iter().cycle().copied().take(BYTES).collect(); + store.put(&existent, data.clone()).await.unwrap(); + + let meta = store.head(&existent).await.unwrap(); + + let mut reader = BufReader::new(Arc::clone(&store), &meta); + let mut out = Vec::with_capacity(BYTES); + let read = reader.read_to_end(&mut out).await.unwrap(); + + assert_eq!(read, BYTES); + assert_eq!(&out, &data); + + let err = reader.seek(SeekFrom::Current(i64::MIN)).await.unwrap_err(); + assert_eq!(err.to_string(), "Seeking -9223372036854775808 from current offset of 4096 would result in overflow"); + + reader.rewind().await.unwrap(); + + let err = reader.seek(SeekFrom::Current(-1)).await.unwrap_err(); + assert_eq!( + err.to_string(), + "Seeking -1 from current offset of 0 would result in overflow" + ); + + // Seeking beyond the bounds of the file is permitted but should return no data + reader.seek(SeekFrom::Start(u64::MAX)).await.unwrap(); + let buf = reader.fill_buf().await.unwrap(); + assert!(buf.is_empty()); + + let err = reader.seek(SeekFrom::Current(1)).await.unwrap_err(); + assert_eq!(err.to_string(), "Seeking 1 from current offset of 18446744073709551615 would result in overflow"); + + for capacity in [200, 1024, 4096, DEFAULT_BUFFER_SIZE] { + let store = Arc::clone(&store); + let mut reader = BufReader::with_capacity(store, &meta, capacity); + + let mut bytes_read = 0; + loop { + let buf = reader.fill_buf().await.unwrap(); + if buf.is_empty() { + assert_eq!(bytes_read, BYTES); + break; + } + assert!(buf.starts_with(b"12345678")); + bytes_read += 8; + reader.consume(8); + } + + let mut buf = Vec::with_capacity(76); + reader.seek(SeekFrom::Current(-76)).await.unwrap(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(&buf, &data[BYTES - 76..]); + + reader.rewind().await.unwrap(); + let buffer = reader.fill_buf().await.unwrap(); + assert_eq!(buffer, &data[..capacity.min(BYTES)]); + + reader.seek(SeekFrom::Start(325)).await.unwrap(); + let buffer = reader.fill_buf().await.unwrap(); + assert_eq!(buffer, &data[325..(325 + capacity).min(BYTES)]); + + reader.seek(SeekFrom::End(0)).await.unwrap(); + let buffer = reader.fill_buf().await.unwrap(); + assert!(buffer.is_empty()); + } + } +} diff --git a/object_store/src/chunked.rs b/object_store/src/chunked.rs index c639d7e89812..d3e02b412725 100644 --- a/object_store/src/chunked.rs +++ b/object_store/src/chunked.rs @@ -18,7 +18,6 @@ //! A [`ChunkedStore`] that can be used to test streaming behaviour use std::fmt::{Debug, Display, Formatter}; -use std::io::{BufReader, Read}; use std::ops::Range; use std::sync::Arc; @@ -29,8 +28,9 @@ use futures::StreamExt; use tokio::io::AsyncWrite; use crate::path::Path; -use crate::util::maybe_spawn_blocking; -use crate::{GetOptions, GetResult, ListResult, ObjectMeta, ObjectStore}; +use crate::{ + GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, +}; use crate::{MultipartId, Result}; /// Wraps a [`ObjectStore`] and makes its get response return chunks @@ -82,77 +82,57 @@ impl ObjectStore for ChunkedStore { } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - match self.inner.get_opts(location, options).await? { - GetResult::File(std_file, ..) => { - let reader = BufReader::new(std_file); - let chunk_size = self.chunk_size; - Ok(GetResult::Stream( - futures::stream::try_unfold(reader, move |mut reader| async move { - let (r, out, reader) = maybe_spawn_blocking(move || { - let mut out = Vec::with_capacity(chunk_size); - let r = (&mut reader) - .take(chunk_size as u64) - .read_to_end(&mut out) - .map_err(|err| crate::Error::Generic { - store: "ChunkedStore", - source: Box::new(err), - })?; - Ok((r, out, reader)) - }) - .await?; - - match r { - 0 => Ok(None), - _ => Ok(Some((out.into(), reader))), - } - }) - .boxed(), - )) + let r = self.inner.get_opts(location, options).await?; + let stream = match r.payload { + GetResultPayload::File(file, path) => { + crate::local::chunked_stream(file, path, r.range.clone(), self.chunk_size) } - GetResult::Stream(stream) => { + GetResultPayload::Stream(stream) => { let buffer = BytesMut::new(); - Ok(GetResult::Stream( - futures::stream::unfold( - (stream, buffer, false, self.chunk_size), - |(mut stream, mut buffer, mut exhausted, chunk_size)| async move { - // Keep accumulating bytes until we reach capacity as long as - // the stream can provide them: - if exhausted { - return None; - } - while buffer.len() < chunk_size { - match stream.next().await { - None => { - exhausted = true; - let slice = buffer.split_off(0).freeze(); - return Some(( - Ok(slice), - (stream, buffer, exhausted, chunk_size), - )); - } - Some(Ok(bytes)) => { - buffer.put(bytes); - } - Some(Err(e)) => { - return Some(( - Err(crate::Error::Generic { - store: "ChunkedStore", - source: Box::new(e), - }), - (stream, buffer, exhausted, chunk_size), - )) - } - }; - } - // Return the chunked values as the next value in the stream - let slice = buffer.split_to(chunk_size).freeze(); - Some((Ok(slice), (stream, buffer, exhausted, chunk_size))) - }, - ) - .boxed(), - )) + futures::stream::unfold( + (stream, buffer, false, self.chunk_size), + |(mut stream, mut buffer, mut exhausted, chunk_size)| async move { + // Keep accumulating bytes until we reach capacity as long as + // the stream can provide them: + if exhausted { + return None; + } + while buffer.len() < chunk_size { + match stream.next().await { + None => { + exhausted = true; + let slice = buffer.split_off(0).freeze(); + return Some(( + Ok(slice), + (stream, buffer, exhausted, chunk_size), + )); + } + Some(Ok(bytes)) => { + buffer.put(bytes); + } + Some(Err(e)) => { + return Some(( + Err(crate::Error::Generic { + store: "ChunkedStore", + source: Box::new(e), + }), + (stream, buffer, exhausted, chunk_size), + )) + } + }; + } + // Return the chunked values as the next value in the stream + let slice = buffer.split_to(chunk_size).freeze(); + Some((Ok(slice), (stream, buffer, exhausted, chunk_size))) + }, + ) + .boxed() } - } + }; + Ok(GetResult { + payload: GetResultPayload::Stream(stream), + ..r + }) } async fn get_range(&self, location: &Path, range: Range) -> Result { @@ -167,19 +147,16 @@ impl ObjectStore for ChunkedStore { self.inner.delete(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.inner.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.inner.list(prefix) } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - self.inner.list_with_offset(prefix, offset).await + ) -> BoxStream<'_, Result> { + self.inner.list_with_offset(prefix, offset) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -217,8 +194,8 @@ mod tests { for chunk_size in [10, 20, 31] { let store = ChunkedStore::new(Arc::clone(&store), chunk_size); - let mut s = match store.get(&location).await.unwrap() { - GetResult::Stream(s) => s, + let mut s = match store.get(&location).await.unwrap().payload { + GetResultPayload::Stream(s) => s, _ => unreachable!(), }; diff --git a/object_store/src/client/get.rs b/object_store/src/client/get.rs index 3c66a72d82ed..7f68b6d1225f 100644 --- a/object_store/src/client/get.rs +++ b/object_store/src/client/get.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::client::header::header_meta; +use crate::client::header::{header_meta, HeaderConfig}; use crate::path::Path; -use crate::Result; -use crate::{Error, GetOptions, GetResult, ObjectMeta}; +use crate::{Error, GetOptions, GetResult}; +use crate::{GetResultPayload, Result}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; use reqwest::Response; @@ -28,26 +28,34 @@ use reqwest::Response; pub trait GetClient: Send + Sync + 'static { const STORE: &'static str; - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result; + /// Configure the [`HeaderConfig`] for this client + const HEADER_CONFIG: HeaderConfig = HeaderConfig { + etag_required: true, + last_modified_required: true, + }; + + async fn get_request(&self, path: &Path, options: GetOptions) -> Result; } /// Extension trait for [`GetClient`] that adds common retrieval functionality #[async_trait] pub trait GetClientExt { async fn get_opts(&self, location: &Path, options: GetOptions) -> Result; - - async fn head(&self, location: &Path) -> Result; } #[async_trait] impl GetClientExt for T { async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - let response = self.get_request(location, options, false).await?; + let range = options.range.clone(); + let response = self.get_request(location, options).await?; + let meta = + header_meta(location, response.headers(), T::HEADER_CONFIG).map_err(|e| { + Error::Generic { + store: T::STORE, + source: Box::new(e), + } + })?; + let stream = response .bytes_stream() .map_err(|source| Error::Generic { @@ -56,15 +64,10 @@ impl GetClientExt for T { }) .boxed(); - Ok(GetResult::Stream(stream)) - } - - async fn head(&self, location: &Path) -> Result { - let options = GetOptions::default(); - let response = self.get_request(location, options, true).await?; - header_meta(location, response.headers()).map_err(|e| Error::Generic { - store: T::STORE, - source: Box::new(e), + Ok(GetResult { + range: range.unwrap_or(0..meta.size), + payload: GetResultPayload::Stream(stream), + meta, }) } } diff --git a/object_store/src/client/header.rs b/object_store/src/client/header.rs index cc4f16eaa599..6499eff5aebe 100644 --- a/object_store/src/client/header.rs +++ b/object_store/src/client/header.rs @@ -19,11 +19,24 @@ use crate::path::Path; use crate::ObjectMeta; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, TimeZone, Utc}; use hyper::header::{CONTENT_LENGTH, ETAG, LAST_MODIFIED}; use hyper::HeaderMap; use snafu::{OptionExt, ResultExt, Snafu}; +#[derive(Debug, Copy, Clone)] +/// Configuration for header extraction +pub struct HeaderConfig { + /// Whether to require an ETag header when extracting [`ObjectMeta`] from headers. + /// + /// Defaults to `true` + pub etag_required: bool, + /// Whether to require a Last-Modified header when extracting [`ObjectMeta`] from headers. + /// + /// Defaults to `true` + pub last_modified_required: bool, +} + #[derive(Debug, Snafu)] pub enum Error { #[snafu(display("ETag Header missing from response"))] @@ -52,32 +65,44 @@ pub enum Error { } /// Extracts [`ObjectMeta`] from the provided [`HeaderMap`] -pub fn header_meta(location: &Path, headers: &HeaderMap) -> Result { - let last_modified = headers - .get(LAST_MODIFIED) - .context(MissingLastModifiedSnafu)?; +pub fn header_meta( + location: &Path, + headers: &HeaderMap, + cfg: HeaderConfig, +) -> Result { + let last_modified = match headers.get(LAST_MODIFIED) { + Some(last_modified) => { + let last_modified = last_modified.to_str().context(BadHeaderSnafu)?; + DateTime::parse_from_rfc2822(last_modified) + .context(InvalidLastModifiedSnafu { last_modified })? + .with_timezone(&Utc) + } + None if cfg.last_modified_required => return Err(Error::MissingLastModified), + None => Utc.timestamp_nanos(0), + }; + + let e_tag = match headers.get(ETAG) { + Some(e_tag) => { + let e_tag = e_tag.to_str().context(BadHeaderSnafu)?; + Some(e_tag.to_string()) + } + None if cfg.etag_required => return Err(Error::MissingEtag), + None => None, + }; let content_length = headers .get(CONTENT_LENGTH) .context(MissingContentLengthSnafu)?; - let last_modified = last_modified.to_str().context(BadHeaderSnafu)?; - let last_modified = DateTime::parse_from_rfc2822(last_modified) - .context(InvalidLastModifiedSnafu { last_modified })? - .with_timezone(&Utc); - let content_length = content_length.to_str().context(BadHeaderSnafu)?; let content_length = content_length .parse() .context(InvalidContentLengthSnafu { content_length })?; - let e_tag = headers.get(ETAG).context(MissingEtagSnafu)?; - let e_tag = e_tag.to_str().context(BadHeaderSnafu)?; - Ok(ObjectMeta { location: location.clone(), last_modified, size: content_length, - e_tag: Some(e_tag.to_string()), + e_tag, }) } diff --git a/object_store/src/client/list.rs b/object_store/src/client/list.rs index b2dbee27f14d..371894dfeb71 100644 --- a/object_store/src/client/list.rs +++ b/object_store/src/client/list.rs @@ -46,16 +46,13 @@ pub trait ListClientExt { offset: Option<&Path>, ) -> BoxStream<'_, Result>; - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>>; + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result>; - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>>; + ) -> BoxStream<'_, Result>; async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result; } @@ -90,31 +87,22 @@ impl ListClientExt for T { .boxed() } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - let stream = self - .list_paginated(prefix, false, None) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.list_paginated(prefix, false, None) .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) .try_flatten() - .boxed(); - - Ok(stream) + .boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - let stream = self - .list_paginated(prefix, false, Some(offset)) + ) -> BoxStream<'_, Result> { + self.list_paginated(prefix, false, Some(offset)) .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) .try_flatten() - .boxed(); - - Ok(stream) + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index 5f3a042be46a..137da2b37594 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -18,6 +18,7 @@ //! Generic utilities reqwest based ObjectStore implementations pub mod backoff; + #[cfg(test)] pub mod mock_server; @@ -26,7 +27,6 @@ pub mod retry; #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] pub mod pagination; -#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] pub mod get; #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] @@ -35,7 +35,6 @@ pub mod list; #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] pub mod token; -#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] pub mod header; #[cfg(any(feature = "aws", feature = "gcp"))] @@ -48,7 +47,7 @@ use std::sync::Arc; use std::time::Duration; use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest::{Client, ClientBuilder, Proxy, RequestBuilder}; +use reqwest::{Client, ClientBuilder, NoProxy, Proxy, RequestBuilder}; use serde::{Deserialize, Serialize}; use crate::config::{fmt_duration, ConfigValue}; @@ -103,6 +102,10 @@ pub enum ClientConfigKey { PoolMaxIdlePerHost, /// HTTP proxy to use for requests ProxyUrl, + /// PEM-formatted CA certificate for proxy connections + ProxyCaCertificate, + /// List of hosts that bypass proxy + ProxyExcludes, /// Request timeout /// /// The timeout is applied from when the request starts connecting until the @@ -127,6 +130,8 @@ impl AsRef for ClientConfigKey { Self::PoolIdleTimeout => "pool_idle_timeout", Self::PoolMaxIdlePerHost => "pool_max_idle_per_host", Self::ProxyUrl => "proxy_url", + Self::ProxyCaCertificate => "proxy_ca_certificate", + Self::ProxyExcludes => "proxy_excludes", Self::Timeout => "timeout", Self::UserAgent => "user_agent", } @@ -161,13 +166,15 @@ impl FromStr for ClientConfigKey { } /// HTTP client configuration for remote object stores -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct ClientOptions { user_agent: Option>, content_type_map: HashMap, default_content_type: Option, default_headers: Option, proxy_url: Option, + proxy_ca_certificate: Option, + proxy_excludes: Option, allow_http: ConfigValue, allow_insecure: ConfigValue, timeout: Option>, @@ -181,6 +188,35 @@ pub struct ClientOptions { http2_only: ConfigValue, } +impl Default for ClientOptions { + fn default() -> Self { + // Defaults based on + // + // + // Which recommend a connection timeout of 3.1s and a request timeout of 2s + Self { + user_agent: None, + content_type_map: Default::default(), + default_content_type: None, + default_headers: None, + proxy_url: None, + proxy_ca_certificate: None, + proxy_excludes: None, + allow_http: Default::default(), + allow_insecure: Default::default(), + timeout: Some(Duration::from_secs(5).into()), + connect_timeout: Some(Duration::from_secs(5).into()), + pool_idle_timeout: None, + pool_max_idle_per_host: None, + http2_keep_alive_interval: None, + http2_keep_alive_timeout: None, + http2_keep_alive_while_idle: Default::default(), + http1_only: Default::default(), + http2_only: Default::default(), + } + } +} + impl ClientOptions { /// Create a new [`ClientOptions`] with default values pub fn new() -> Self { @@ -216,6 +252,10 @@ impl ClientOptions { self.pool_max_idle_per_host = Some(ConfigValue::Deferred(value.into())) } ClientConfigKey::ProxyUrl => self.proxy_url = Some(value.into()), + ClientConfigKey::ProxyCaCertificate => { + self.proxy_ca_certificate = Some(value.into()) + } + ClientConfigKey::ProxyExcludes => self.proxy_excludes = Some(value.into()), ClientConfigKey::Timeout => { self.timeout = Some(ConfigValue::Deferred(value.into())) } @@ -255,6 +295,8 @@ impl ClientOptions { self.pool_max_idle_per_host.as_ref().map(|v| v.to_string()) } ClientConfigKey::ProxyUrl => self.proxy_url.clone(), + ClientConfigKey::ProxyCaCertificate => self.proxy_ca_certificate.clone(), + ClientConfigKey::ProxyExcludes => self.proxy_excludes.clone(), ClientConfigKey::Timeout => self.timeout.as_ref().map(fmt_duration), ClientConfigKey::UserAgent => self .user_agent @@ -329,27 +371,62 @@ impl ClientOptions { self } - /// Set an HTTP proxy to use for requests + /// Set a proxy URL to use for requests pub fn with_proxy_url(mut self, proxy_url: impl Into) -> Self { self.proxy_url = Some(proxy_url.into()); self } + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate( + mut self, + proxy_ca_certificate: impl Into, + ) -> Self { + self.proxy_ca_certificate = Some(proxy_ca_certificate.into()); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.proxy_excludes = Some(proxy_excludes.into()); + self + } + /// Set a request timeout /// /// The timeout is applied from when the request starts connecting until the /// response body has finished + /// + /// Default is 5 seconds pub fn with_timeout(mut self, timeout: Duration) -> Self { self.timeout = Some(ConfigValue::Parsed(timeout)); self } + /// Disables the request timeout + /// + /// See [`Self::with_timeout`] + pub fn with_timeout_disabled(mut self) -> Self { + self.timeout = None; + self + } + /// Set a timeout for only the connect phase of a Client + /// + /// Default is 5 seconds pub fn with_connect_timeout(mut self, timeout: Duration) -> Self { self.connect_timeout = Some(ConfigValue::Parsed(timeout)); self } + /// Disables the connection timeout + /// + /// See [`Self::with_connect_timeout`] + pub fn with_connect_timeout_disabled(mut self) -> Self { + self.timeout = None; + self + } + /// Set the pool max idle timeout /// /// This is the length of time an idle connection will be kept alive @@ -416,7 +493,20 @@ impl ClientOptions { } } - pub(crate) fn client(&self) -> super::Result { + /// Create a [`Client`] with overrides optimised for metadata endpoint access + /// + /// In particular: + /// * Allows HTTP as metadata endpoints do not use TLS + /// * Configures a low connection timeout to provide quick feedback if not present + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + pub(crate) fn metadata_client(&self) -> Result { + self.clone() + .with_allow_http(true) + .with_connect_timeout(Duration::from_secs(1)) + .client() + } + + pub(crate) fn client(&self) -> Result { let mut builder = ClientBuilder::new(); match &self.user_agent { @@ -429,7 +519,22 @@ impl ClientOptions { } if let Some(proxy) = &self.proxy_url { - let proxy = Proxy::all(proxy).map_err(map_client_error)?; + let mut proxy = Proxy::all(proxy).map_err(map_client_error)?; + + if let Some(certificate) = &self.proxy_ca_certificate { + let certificate = + reqwest::tls::Certificate::from_pem(certificate.as_bytes()) + .map_err(map_client_error)?; + + builder = builder.add_root_certificate(certificate); + } + + if let Some(proxy_excludes) = &self.proxy_excludes { + let no_proxy = NoProxy::from_string(proxy_excludes); + + proxy = proxy.no_proxy(no_proxy); + } + builder = builder.proxy(proxy); } @@ -531,6 +636,7 @@ pub struct StaticCredentialProvider { } impl StaticCredentialProvider { + /// A [`CredentialProvider`] for a static credential of type `T` pub fn new(credential: T) -> Self { Self { credential: Arc::new(credential), diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index 39a913142e09..e4d246c87a2a 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -23,46 +23,50 @@ use futures::FutureExt; use reqwest::header::LOCATION; use reqwest::{Response, StatusCode}; use snafu::Error as SnafuError; +use snafu::Snafu; use std::time::{Duration, Instant}; use tracing::info; /// Retry request error -#[derive(Debug)] -pub struct Error { - retries: usize, - message: String, - source: Option, - status: Option, -} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "response error \"{}\", after {} retries", - self.message, self.retries - )?; - if let Some(source) = &self.source { - write!(f, ": {source}")?; - } - Ok(()) - } -} - -impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source.as_ref().map(|e| e as _) - } +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Received redirect without LOCATION, this normally indicates an incorrectly configured region"))] + BareRedirect, + + #[snafu(display("Client error with status {status}: {}", body.as_deref().unwrap_or("No Body")))] + Client { + status: StatusCode, + body: Option, + }, + + #[snafu(display("Error after {retries} retries: {source}"))] + Reqwest { + retries: usize, + source: reqwest::Error, + }, } impl Error { /// Returns the status code associated with this error if any pub fn status(&self) -> Option { - self.status + match self { + Self::BareRedirect => None, + Self::Client { status, .. } => Some(*status), + Self::Reqwest { source, .. } => source.status(), + } + } + + /// Returns the error body if any + pub fn body(&self) -> Option<&str> { + match self { + Self::Client { body, .. } => body.as_deref(), + Self::BareRedirect => None, + Self::Reqwest { .. } => None, + } } pub fn error(self, store: &'static str, path: String) -> crate::Error { - match self.status { + match self.status() { Some(StatusCode::NOT_FOUND) => crate::Error::NotFound { path, source: Box::new(self), @@ -86,16 +90,19 @@ impl Error { impl From for std::io::Error { fn from(err: Error) -> Self { use std::io::ErrorKind; - match (&err.source, err.status()) { - (Some(source), _) if source.is_builder() || source.is_request() => { - Self::new(ErrorKind::InvalidInput, err) - } - (_, Some(StatusCode::NOT_FOUND)) => Self::new(ErrorKind::NotFound, err), - (_, Some(StatusCode::BAD_REQUEST)) => Self::new(ErrorKind::InvalidInput, err), - (Some(source), None) if source.is_timeout() => { + match &err { + Error::Client { + status: StatusCode::NOT_FOUND, + .. + } => Self::new(ErrorKind::NotFound, err), + Error::Client { + status: StatusCode::BAD_REQUEST, + .. + } => Self::new(ErrorKind::InvalidInput, err), + Error::Reqwest { source, .. } if source.is_timeout() => { Self::new(ErrorKind::TimedOut, err) } - (Some(source), None) if source.is_connect() => { + Error::Reqwest { source, .. } if source.is_connect() => { Self::new(ErrorKind::NotConnected, err) } _ => Self::new(ErrorKind::Other, err), @@ -169,27 +176,21 @@ impl RetryExt for reqwest::RequestBuilder { Ok(r) => match r.error_for_status_ref() { Ok(_) if r.status().is_success() => return Ok(r), Ok(r) if r.status() == StatusCode::NOT_MODIFIED => { - return Err(Error{ - message: "not modified".to_string(), - retries, - status: Some(r.status()), - source: None, + return Err(Error::Client { + body: None, + status: StatusCode::NOT_MODIFIED, }) } Ok(r) => { let is_bare_redirect = r.status().is_redirection() && !r.headers().contains_key(LOCATION); - let message = match is_bare_redirect { - true => "Received redirect without LOCATION, this normally indicates an incorrectly configured region".to_string(), + return match is_bare_redirect { + true => Err(Error::BareRedirect), // Not actually sure if this is reachable, but here for completeness - false => format!("request unsuccessful: {}", r.status()), - }; - - return Err(Error{ - message, - retries, - status: Some(r.status()), - source: None, - }) + false => Err(Error::Client { + body: None, + status: r.status(), + }) + } } Err(e) => { let status = r.status(); @@ -198,23 +199,26 @@ impl RetryExt for reqwest::RequestBuilder { || now.elapsed() > retry_timeout || !status.is_server_error() { - // Get the response message if returned a client error - let message = match status.is_client_error() { + return Err(match status.is_client_error() { true => match r.text().await { - Ok(message) if !message.is_empty() => message, - Ok(_) => "No Body".to_string(), - Err(e) => format!("error getting response body: {e}") + Ok(body) => { + Error::Client { + body: Some(body).filter(|b| !b.is_empty()), + status, + } + } + Err(e) => { + Error::Reqwest { + retries, + source: e, + } + } } - false => status.to_string(), - }; - - return Err(Error{ - message, - retries, - status: Some(status), - source: Some(e), - }) - + false => Error::Reqwest { + retries, + source: e, + } + }); } let sleep = backoff.next(); @@ -238,16 +242,14 @@ impl RetryExt for reqwest::RequestBuilder { || now.elapsed() > retry_timeout || !do_retry { - return Err(Error{ + return Err(Error::Reqwest { retries, - message: "request error".to_string(), - status: e.status(), - source: Some(e), + source: e, }) } let sleep = backoff.next(); retries += 1; - info!("Encountered request error ({}) backing off for {} seconds, retry {} of {}", e, sleep.as_secs_f32(), retries, max_retries); + info!("Encountered transport error ({}) backing off for {} seconds, retry {} of {}", e, sleep.as_secs_f32(), retries, max_retries); tokio::time::sleep(sleep).await; } } @@ -260,7 +262,7 @@ impl RetryExt for reqwest::RequestBuilder { #[cfg(test)] mod tests { use crate::client::mock_server::MockServer; - use crate::client::retry::RetryExt; + use crate::client::retry::{Error, RetryExt}; use crate::RetryConfig; use hyper::header::LOCATION; use hyper::{Body, Response}; @@ -294,8 +296,11 @@ mod tests { let e = do_request().await.unwrap_err(); assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "cupcakes"); + assert_eq!(e.body(), Some("cupcakes")); + assert_eq!( + e.to_string(), + "Client error with status 400 Bad Request: cupcakes" + ); // Handles client errors with no payload mock.push( @@ -307,8 +312,11 @@ mod tests { let e = do_request().await.unwrap_err(); assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "No Body"); + assert_eq!(e.body(), None); + assert_eq!( + e.to_string(), + "Client error with status 400 Bad Request: No Body" + ); // Should retry server error request mock.push( @@ -381,7 +389,8 @@ mod tests { ); let e = do_request().await.unwrap_err(); - assert_eq!(e.message, "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); + assert!(matches!(e, Error::BareRedirect)); + assert_eq!(e.to_string(), "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); // Gives up after the retrying the specified number of times for _ in 0..=retry.max_retries { @@ -393,22 +402,23 @@ mod tests { ); } - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "502 Bad Gateway"); + let e = do_request().await.unwrap_err().to_string(); + assert!(e.starts_with("Error after 2 retries: HTTP status server error (502 Bad Gateway) for url"), "{e}"); // Panic results in an incomplete message error in the client mock.push_fn(|_| panic!()); let r = do_request().await.unwrap(); assert_eq!(r.status(), StatusCode::OK); - // Gives up after retrying mulitiple panics + // Gives up after retrying multiple panics for _ in 0..=retry.max_retries { mock.push_fn(|_| panic!()); } - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "request error"); + let e = do_request().await.unwrap_err().to_string(); + assert!( + e.starts_with("Error after 2 retries: error sending request for url"), + "{e}" + ); // Shutdown mock.shutdown().await diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs index 205b805947cc..87f8e244f21c 100644 --- a/object_store/src/gcp/credential.rs +++ b/object_store/src/gcp/credential.rs @@ -17,10 +17,8 @@ use crate::client::retry::RetryExt; use crate::client::token::TemporaryToken; -use crate::client::{TokenCredentialProvider, TokenProvider}; -use crate::gcp::credential::Error::UnsupportedCredentialsType; -use crate::gcp::{GcpCredentialProvider, STORE}; -use crate::ClientOptions; +use crate::client::TokenProvider; +use crate::gcp::STORE; use crate::RetryConfig; use async_trait::async_trait; use base64::prelude::BASE64_URL_SAFE_NO_PAD; @@ -28,6 +26,7 @@ use base64::Engine; use futures::TryFutureExt; use reqwest::{Client, Method}; use ring::signature::RsaKeyPair; +use serde::Deserialize; use snafu::{ResultExt, Snafu}; use std::env; use std::fs::File; @@ -37,6 +36,10 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::info; +pub const DEFAULT_SCOPE: &str = "https://www.googleapis.com/auth/devstorage.full_control"; + +pub const DEFAULT_GCS_BASE_URL: &str = "https://storage.googleapis.com"; + #[derive(Debug, Snafu)] pub enum Error { #[snafu(display("Unable to open service account file from {}: {}", path.display(), source))] @@ -68,9 +71,6 @@ pub enum Error { #[snafu(display("Error getting token response body: {}", source))] TokenResponseBody { source: reqwest::Error }, - - #[snafu(display("Unsupported ApplicationCredentials type: {}", type_))] - UnsupportedCredentialsType { type_: String }, } impl From for crate::Error { @@ -92,48 +92,48 @@ pub struct GcpCredential { pub type Result = std::result::Result; #[derive(Debug, Default, serde::Serialize)] -pub struct JwtHeader { +pub struct JwtHeader<'a> { /// The type of JWS: it can only be "JWT" here /// /// Defined in [RFC7515#4.1.9](https://tools.ietf.org/html/rfc7515#section-4.1.9). #[serde(skip_serializing_if = "Option::is_none")] - pub typ: Option, + pub typ: Option<&'a str>, /// The algorithm used /// /// Defined in [RFC7515#4.1.1](https://tools.ietf.org/html/rfc7515#section-4.1.1). - pub alg: String, + pub alg: &'a str, /// Content type /// /// Defined in [RFC7519#5.2](https://tools.ietf.org/html/rfc7519#section-5.2). #[serde(skip_serializing_if = "Option::is_none")] - pub cty: Option, + pub cty: Option<&'a str>, /// JSON Key URL /// /// Defined in [RFC7515#4.1.2](https://tools.ietf.org/html/rfc7515#section-4.1.2). #[serde(skip_serializing_if = "Option::is_none")] - pub jku: Option, + pub jku: Option<&'a str>, /// Key ID /// /// Defined in [RFC7515#4.1.4](https://tools.ietf.org/html/rfc7515#section-4.1.4). #[serde(skip_serializing_if = "Option::is_none")] - pub kid: Option, + pub kid: Option<&'a str>, /// X.509 URL /// /// Defined in [RFC7515#4.1.5](https://tools.ietf.org/html/rfc7515#section-4.1.5). #[serde(skip_serializing_if = "Option::is_none")] - pub x5u: Option, + pub x5u: Option<&'a str>, /// X.509 certificate thumbprint /// /// Defined in [RFC7515#4.1.7](https://tools.ietf.org/html/rfc7515#section-4.1.7). #[serde(skip_serializing_if = "Option::is_none")] - pub x5t: Option, + pub x5t: Option<&'a str>, } #[derive(serde::Serialize)] struct TokenClaims<'a> { iss: &'a str, + sub: &'a str, scope: &'a str, - aud: &'a str, exp: u64, iat: u64, } @@ -144,28 +144,32 @@ struct TokenResponse { expires_in: u64, } -/// Encapsulates the logic to perform an OAuth token challenge +/// Self-signed JWT (JSON Web Token). +/// +/// # References +/// - #[derive(Debug)] -pub struct OAuthProvider { +pub struct SelfSignedJwt { issuer: String, scope: String, - audience: String, key_pair: RsaKeyPair, jwt_header: String, random: ring::rand::SystemRandom, } -impl OAuthProvider { - /// Create a new [`OAuthProvider`] +impl SelfSignedJwt { + /// Create a new [`SelfSignedJwt`] pub fn new( + key_id: String, issuer: String, private_key_pem: String, scope: String, - audience: String, ) -> Result { let key_pair = decode_first_rsa_key(private_key_pem)?; let jwt_header = b64_encode_obj(&JwtHeader { - alg: "RS256".to_string(), + alg: "RS256", + typ: Some("JWT"), + kid: Some(&key_id), ..Default::default() })?; @@ -173,7 +177,6 @@ impl OAuthProvider { issuer, key_pair, scope, - audience, jwt_header, random: ring::rand::SystemRandom::new(), }) @@ -181,29 +184,29 @@ impl OAuthProvider { } #[async_trait] -impl TokenProvider for OAuthProvider { +impl TokenProvider for SelfSignedJwt { type Credential = GcpCredential; /// Fetch a fresh token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, + _client: &Client, + _retry: &RetryConfig, ) -> crate::Result>> { let now = seconds_since_epoch(); let exp = now + 3600; let claims = TokenClaims { iss: &self.issuer, + sub: &self.issuer, scope: &self.scope, - aud: &self.audience, - exp, iat: now, + exp, }; let claim_str = b64_encode_obj(&claims)?; let message = [self.jwt_header.as_ref(), claim_str.as_ref()].join("."); - let mut sig_bytes = vec![0; self.key_pair.public_modulus_len()]; + let mut sig_bytes = vec![0; self.key_pair.public().modulus_len()]; self.key_pair .sign( &ring::signature::RSA_PKCS1_SHA256, @@ -214,28 +217,11 @@ impl TokenProvider for OAuthProvider { .context(SignSnafu)?; let signature = BASE64_URL_SAFE_NO_PAD.encode(sig_bytes); - let jwt = [message, signature].join("."); - - let body = [ - ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), - ("assertion", &jwt), - ]; - - let response: TokenResponse = client - .request(Method::POST, &self.audience) - .form(&body) - .send_retry(retry) - .await - .context(TokenRequestSnafu)? - .json() - .await - .context(TokenResponseBodySnafu)?; + let bearer = [message, signature].join("."); Ok(TemporaryToken { - token: Arc::new(GcpCredential { - bearer: response.access_token, - }), - expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), + token: Arc::new(GcpCredential { bearer }), + expiry: Some(Instant::now() + Duration::from_secs(3600)), }) } } @@ -259,29 +245,24 @@ pub struct ServiceAccountCredentials { /// The private key in RSA format. pub private_key: String, + /// The private key ID + pub private_key_id: String, + /// The email address associated with the service account. pub client_email: String, /// Base URL for GCS - #[serde(default = "default_gcs_base_url")] - pub gcs_base_url: String, + #[serde(default)] + pub gcs_base_url: Option, /// Disable oauth and use empty tokens. - #[serde(default = "default_disable_oauth")] + #[serde(default)] pub disable_oauth: bool, } -pub fn default_gcs_base_url() -> String { - "https://storage.googleapis.com".to_owned() -} - -pub fn default_disable_oauth() -> bool { - false -} - impl ServiceAccountCredentials { /// Create a new [`ServiceAccountCredentials`] from a file. - pub fn from_file>(path: P) -> Result { + pub fn from_file>(path: P) -> Result { read_credentials_file(path) } @@ -290,17 +271,20 @@ impl ServiceAccountCredentials { serde_json::from_str(key).context(DecodeCredentialsSnafu) } - /// Create an [`OAuthProvider`] from this credentials struct. - pub fn oauth_provider( - self, - scope: &str, - audience: &str, - ) -> crate::Result { - Ok(OAuthProvider::new( + /// Create a [`SelfSignedJwt`] from this credentials struct. + /// + /// We use a scope of [`DEFAULT_SCOPE`] as opposed to an audience + /// as GCS appears to not support audience + /// + /// # References + /// - + /// - + pub fn token_provider(self) -> crate::Result { + Ok(SelfSignedJwt::new( + self.private_key_id, self.client_email, self.private_key, - scope.to_string(), - audience.to_string(), + DEFAULT_SCOPE.to_string(), )?) } } @@ -337,25 +321,13 @@ fn b64_encode_obj(obj: &T) -> Result { /// /// #[derive(Debug, Default)] -pub struct InstanceCredentialProvider { - audience: String, -} - -impl InstanceCredentialProvider { - /// Create a new [`InstanceCredentialProvider`], we need to control the client in order to enable http access so save the options. - pub fn new>(audience: T) -> Self { - Self { - audience: audience.into(), - } - } -} +pub struct InstanceCredentialProvider {} /// Make a request to the metadata server to fetch a token, using a a given hostname. async fn make_metadata_request( client: &Client, hostname: &str, retry: &RetryConfig, - audience: &str, ) -> crate::Result { let url = format!( "http://{hostname}/computeMetadata/v1/instance/service-accounts/default/token" @@ -363,7 +335,7 @@ async fn make_metadata_request( let response: TokenResponse = client .request(Method::GET, url) .header("Metadata-Flavor", "Google") - .query(&[("audience", audience)]) + .query(&[("audience", "https://www.googleapis.com/oauth2/v4/token")]) .send_retry(retry) .await .context(TokenRequestSnafu)? @@ -388,12 +360,9 @@ impl TokenProvider for InstanceCredentialProvider { const METADATA_HOST: &str = "metadata"; info!("fetching token from metadata server"); - let response = - make_metadata_request(client, METADATA_HOST, retry, &self.audience) - .or_else(|_| { - make_metadata_request(client, METADATA_IP, retry, &self.audience) - }) - .await?; + let response = make_metadata_request(client, METADATA_HOST, retry) + .or_else(|_| make_metadata_request(client, METADATA_IP, retry)) + .await?; let token = TemporaryToken { token: Arc::new(GcpCredential { bearer: response.access_token, @@ -404,62 +373,36 @@ impl TokenProvider for InstanceCredentialProvider { } } -/// ApplicationDefaultCredentials -/// -pub fn application_default_credentials( - path: Option<&str>, - client: &ClientOptions, - retry: &RetryConfig, -) -> crate::Result> { - let file = match ApplicationDefaultCredentialsFile::read(path)? { - Some(x) => x, - None => return Ok(None), - }; - - match file.type_.as_str() { - // - "authorized_user" => { - let token = AuthorizedUserCredentials { - client_id: file.client_id, - client_secret: file.client_secret, - refresh_token: file.refresh_token, - }; - - Ok(Some(Arc::new(TokenCredentialProvider::new( - token, - client.client()?, - retry.clone(), - )))) - } - type_ => Err(UnsupportedCredentialsType { - type_: type_.to_string(), - } - .into()), - } -} - /// A deserialized `application_default_credentials.json`-file. -/// +/// +/// # References +/// - +/// - #[derive(serde::Deserialize)] -struct ApplicationDefaultCredentialsFile { - #[serde(default)] - client_id: String, - #[serde(default)] - client_secret: String, - #[serde(default)] - refresh_token: String, - #[serde(rename = "type")] - type_: String, +#[serde(tag = "type")] +pub enum ApplicationDefaultCredentials { + /// Service Account. + /// + /// # References + /// - + #[serde(rename = "service_account")] + ServiceAccount(ServiceAccountCredentials), + /// Authorized user via "gcloud CLI Integration". + /// + /// # References + /// - + #[serde(rename = "authorized_user")] + AuthorizedUser(AuthorizedUserCredentials), } -impl ApplicationDefaultCredentialsFile { +impl ApplicationDefaultCredentials { const CREDENTIALS_PATH: &'static str = ".config/gcloud/application_default_credentials.json"; // Create a new application default credential in the following situations: // 1. a file is passed in and the type matches. // 2. without argument if the well-known configuration file is present. - fn read(path: Option<&str>) -> Result, Error> { + pub fn read(path: Option<&str>) -> Result, Error> { if let Some(path) = path { return read_credentials_file::(path).map(Some); } @@ -478,8 +421,8 @@ impl ApplicationDefaultCredentialsFile { const DEFAULT_TOKEN_GCP_URI: &str = "https://accounts.google.com/o/oauth2/token"; /// -#[derive(Debug)] -struct AuthorizedUserCredentials { +#[derive(Debug, Deserialize)] +pub struct AuthorizedUserCredentials { client_id: String, client_secret: String, refresh_token: String, diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index d4d370373d0d..513e396cbae6 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -29,7 +29,6 @@ //! to abort the upload and drop those unneeded parts. In addition, you may wish to //! consider implementing automatic clean up of unused parts that are older than one //! week. -use std::io; use std::str::FromStr; use std::sync::Arc; @@ -52,16 +51,13 @@ use crate::client::{ TokenCredentialProvider, }; use crate::{ - multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, + multipart::{PartId, PutPart, WriteMultiPart}, path::{Path, DELIMITER}, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, RetryConfig, }; -use credential::{ - application_default_credentials, default_gcs_base_url, InstanceCredentialProvider, - ServiceAccountCredentials, -}; +use credential::{InstanceCredentialProvider, ServiceAccountCredentials}; mod credential; @@ -69,6 +65,7 @@ const STORE: &str = "GCS"; /// [`CredentialProvider`] for [`GoogleCloudStorage`] pub type GcpCredentialProvider = Arc>; +use crate::gcp::credential::{ApplicationDefaultCredentials, DEFAULT_GCS_BASE_URL}; pub use credential::GcpCredential; #[derive(Debug, Snafu)] @@ -117,6 +114,15 @@ enum Error { #[snafu(display("Error getting put response body: {}", source))] PutResponseBody { source: reqwest::Error }, + #[snafu(display("Got invalid put response: {}", source))] + InvalidPutResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Error performing post request {}: {}", path, source))] + PostRequest { + source: crate::client::retry::Error, + path: String, + }, + #[snafu(display("Error decoding object size: {}", source))] InvalidSize { source: std::num::ParseIntError }, @@ -148,6 +154,12 @@ enum Error { #[snafu(display("Configuration key: '{}' is not known.", key))] UnknownConfigurationKey { key: String }, + + #[snafu(display("ETag Header missing from response"))] + MissingEtag, + + #[snafu(display("Received header containing non-ASCII data"))] + BadHeader { source: header::ToStrError }, } impl From for super::Error { @@ -283,14 +295,9 @@ impl GoogleCloudStorageClient { })?; let data = response.bytes().await.context(PutResponseBodySnafu)?; - let result: InitiateMultipartUploadResult = quick_xml::de::from_reader( - data.as_ref().reader(), - ) - .context(InvalidXMLResponseSnafu { - method: "POST".to_string(), - url, - data, - })?; + let result: InitiateMultipartUploadResult = + quick_xml::de::from_reader(data.as_ref().reader()) + .context(InvalidPutResponseSnafu)?; Ok(result.upload_id) } @@ -380,16 +387,11 @@ impl GetClient for GoogleCloudStorageClient { const STORE: &'static str = STORE; /// Perform a get request - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.object_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -472,24 +474,16 @@ struct GCSMultipartUpload { } #[async_trait] -impl CloudMultiPartUploadImpl for GCSMultipartUpload { +impl PutPart for GCSMultipartUpload { /// Upload an object part - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result { + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result { let upload_id = self.multipart_id.clone(); let url = format!( "{}/{}/{}", self.client.base_url, self.client.bucket_name_encoded, self.encoded_path ); - let credential = self - .client - .get_credential() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let credential = self.client.get_credential().await?; let response = self .client @@ -504,26 +498,24 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { .header(header::CONTENT_LENGTH, format!("{}", buf.len())) .body(buf) .send_retry(&self.client.retry_config) - .await?; + .await + .context(PutRequestSnafu { + path: &self.encoded_path, + })?; let content_id = response .headers() .get("ETag") - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "response headers missing ETag", - ) - })? + .context(MissingEtagSnafu)? .to_str() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + .context(BadHeaderSnafu)? .to_string(); - Ok(UploadPart { content_id }) + Ok(PartId { content_id }) } /// Complete a multipart upload - async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { + async fn complete(&self, completed_parts: Vec) -> Result<()> { let upload_id = self.multipart_id.clone(); let url = format!( "{}/{}/{}", @@ -539,16 +531,11 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { }) .collect(); - let credential = self - .client - .get_credential() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - + let credential = self.client.get_credential().await?; let upload_info = CompleteMultipartUpload { parts }; let data = quick_xml::se::to_string(&upload_info) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? + .context(InvalidPutResponseSnafu)? // We cannot disable the escaping that transforms "/" to ""e;" :( // https://github.com/tafia/quick-xml/issues/362 // https://github.com/tafia/quick-xml/issues/350 @@ -561,7 +548,10 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload { .query(&[("uploadId", upload_id)]) .body(data) .send_retry(&self.client.retry_config) - .await?; + .await + .context(PostRequestSnafu { + path: &self.encoded_path, + })?; Ok(()) } @@ -588,7 +578,7 @@ impl ObjectStore for GoogleCloudStorage { multipart_id: upload_id.clone(), }; - Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8)))) + Ok((upload_id, Box::new(WriteMultiPart::new(inner, 8)))) } async fn abort_multipart( @@ -607,19 +597,12 @@ impl ObjectStore for GoogleCloudStorage { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.client.list(prefix).await + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.client.list(prefix) } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -995,6 +978,23 @@ impl GoogleCloudStorageBuilder { self } + /// Set a trusted proxy CA certificate + pub fn with_proxy_ca_certificate( + mut self, + proxy_ca_certificate: impl Into, + ) -> Self { + self.client_options = self + .client_options + .with_proxy_ca_certificate(proxy_ca_certificate); + self + } + + /// Set a list of hosts to exclude from proxy connections + pub fn with_proxy_excludes(mut self, proxy_excludes: impl Into) -> Self { + self.client_options = self.client_options.with_proxy_excludes(proxy_excludes); + self + } + /// Sets the client options, overriding any already set pub fn with_client_options(mut self, options: ClientOptions) -> Self { self.client_options = options; @@ -1029,10 +1029,8 @@ impl GoogleCloudStorageBuilder { }; // Then try to initialize from the application credentials file, or the environment. - let application_default_credentials = application_default_credentials( + let application_default_credentials = ApplicationDefaultCredentials::read( self.application_credentials_path.as_deref(), - &self.client_options, - &self.retry_config, )?; let disable_oauth = service_account_credentials @@ -1040,14 +1038,10 @@ impl GoogleCloudStorageBuilder { .map(|c| c.disable_oauth) .unwrap_or(false); - let gcs_base_url = service_account_credentials + let gcs_base_url: String = service_account_credentials .as_ref() - .map(|c| c.gcs_base_url.clone()) - .unwrap_or_else(default_gcs_base_url); - - // TODO: https://cloud.google.com/storage/docs/authentication#oauth-scopes - let scope = "https://www.googleapis.com/auth/devstorage.full_control"; - let audience = "https://www.googleapis.com/oauth2/v4/token"; + .and_then(|c| c.gcs_base_url.clone()) + .unwrap_or_else(|| DEFAULT_GCS_BASE_URL.to_string()); let credentials = if let Some(credentials) = self.credentials { credentials @@ -1057,16 +1051,31 @@ impl GoogleCloudStorageBuilder { })) as _ } else if let Some(credentials) = service_account_credentials { Arc::new(TokenCredentialProvider::new( - credentials.oauth_provider(scope, audience)?, + credentials.token_provider()?, self.client_options.client()?, self.retry_config.clone(), )) as _ } else if let Some(credentials) = application_default_credentials { - credentials + match credentials { + ApplicationDefaultCredentials::AuthorizedUser(token) => { + Arc::new(TokenCredentialProvider::new( + token, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } + ApplicationDefaultCredentials::ServiceAccount(token) => { + Arc::new(TokenCredentialProvider::new( + token.token_provider()?, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ + } + } } else { Arc::new(TokenCredentialProvider::new( - InstanceCredentialProvider::new(audience), - self.client_options.clone().with_allow_http(true).client()?, + InstanceCredentialProvider::default(), + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; @@ -1093,7 +1102,6 @@ impl GoogleCloudStorageBuilder { mod test { use bytes::Bytes; use std::collections::HashMap; - use std::env; use std::io::Write; use tempfile::NamedTempFile; @@ -1101,65 +1109,19 @@ mod test { use super::*; - const FAKE_KEY: &str = r#"{"private_key": "private_key", "client_email":"client_email", "disable_oauth":true}"#; + const FAKE_KEY: &str = r#"{"private_key": "private_key", "private_key_id": "private_key_id", "client_email":"client_email", "disable_oauth":true}"#; const NON_EXISTENT_NAME: &str = "nonexistentname"; - // Helper macro to skip tests if TEST_INTEGRATION and the GCP environment variables are not set. - macro_rules! maybe_skip_integration { - () => {{ - dotenv::dotenv().ok(); - - let required_vars = ["OBJECT_STORE_BUCKET", "GOOGLE_SERVICE_ACCOUNT"]; - let unset_vars: Vec<_> = required_vars - .iter() - .filter_map(|&name| match env::var(name) { - Ok(_) => None, - Err(_) => Some(name), - }) - .collect(); - let unset_var_names = unset_vars.join(", "); - - let force = std::env::var("TEST_INTEGRATION"); - - if force.is_ok() && !unset_var_names.is_empty() { - panic!( - "TEST_INTEGRATION is set, \ - but variable(s) {} need to be set", - unset_var_names - ) - } else if force.is_err() { - eprintln!( - "skipping Google Cloud integration test - set {}TEST_INTEGRATION to run", - if unset_var_names.is_empty() { - String::new() - } else { - format!("{} and ", unset_var_names) - } - ); - return; - } else { - GoogleCloudStorageBuilder::new() - .with_bucket_name( - env::var("OBJECT_STORE_BUCKET") - .expect("already checked OBJECT_STORE_BUCKET") - ) - .with_service_account_path( - env::var("GOOGLE_SERVICE_ACCOUNT") - .expect("already checked GOOGLE_SERVICE_ACCOUNT") - ) - } - }}; - } - #[tokio::test] async fn gcs_test() { - let integration = maybe_skip_integration!().build().unwrap(); + crate::test_util::maybe_skip_integration!(); + let integration = GoogleCloudStorageBuilder::from_env().build().unwrap(); put_get_delete_list(&integration).await; list_uses_directories_correctly(&integration).await; list_with_delimiter(&integration).await; rename_and_copy(&integration).await; - if integration.client.base_url == default_gcs_base_url() { + if integration.client.base_url == DEFAULT_GCS_BASE_URL { // Fake GCS server doesn't currently honor ifGenerationMatch // https://github.com/fsouza/fake-gcs-server/issues/994 copy_if_not_exists(&integration).await; @@ -1173,7 +1135,8 @@ mod test { #[tokio::test] async fn gcs_test_get_nonexistent_location() { - let integration = maybe_skip_integration!().build().unwrap(); + crate::test_util::maybe_skip_integration!(); + let integration = GoogleCloudStorageBuilder::from_env().build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1187,10 +1150,9 @@ mod test { #[tokio::test] async fn gcs_test_get_nonexistent_bucket() { - let integration = maybe_skip_integration!() - .with_bucket_name(NON_EXISTENT_NAME) - .build() - .unwrap(); + crate::test_util::maybe_skip_integration!(); + let config = GoogleCloudStorageBuilder::from_env(); + let integration = config.with_bucket_name(NON_EXISTENT_NAME).build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1206,7 +1168,8 @@ mod test { #[tokio::test] async fn gcs_test_delete_nonexistent_location() { - let integration = maybe_skip_integration!().build().unwrap(); + crate::test_util::maybe_skip_integration!(); + let integration = GoogleCloudStorageBuilder::from_env().build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1219,10 +1182,9 @@ mod test { #[tokio::test] async fn gcs_test_delete_nonexistent_bucket() { - let integration = maybe_skip_integration!() - .with_bucket_name(NON_EXISTENT_NAME) - .build() - .unwrap(); + crate::test_util::maybe_skip_integration!(); + let config = GoogleCloudStorageBuilder::from_env(); + let integration = config.with_bucket_name(NON_EXISTENT_NAME).build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1235,10 +1197,9 @@ mod test { #[tokio::test] async fn gcs_test_put_nonexistent_bucket() { - let integration = maybe_skip_integration!() - .with_bucket_name(NON_EXISTENT_NAME) - .build() - .unwrap(); + crate::test_util::maybe_skip_integration!(); + let config = GoogleCloudStorageBuilder::from_env(); + let integration = config.with_bucket_name(NON_EXISTENT_NAME).build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); let data = Bytes::from("arbitrary data"); @@ -1249,7 +1210,7 @@ mod test { .unwrap_err() .to_string(); assert!( - err.contains("HTTP status client error (404 Not Found)"), + err.contains("Client error with status 404 Not Found"), "{}", err ) diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index 1d3df34db9d1..b2a6ac0aa34a 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::client::get::GetClient; +use crate::client::header::HeaderConfig; use crate::client::retry::{self, RetryConfig, RetryExt}; use crate::client::GetOptionsExt; use crate::path::{Path, DELIMITER}; use crate::util::deserialize_rfc1123; use crate::{ClientOptions, GetOptions, ObjectMeta, Result}; +use async_trait::async_trait; use bytes::{Buf, Bytes}; use chrono::{DateTime, Utc}; use percent_encoding::percent_decode_str; @@ -37,6 +40,9 @@ enum Error { #[snafu(display("Request error: {}", source))] Reqwest { source: reqwest::Error }, + #[snafu(display("Range request not supported by {}", href))] + RangeNotSupported { href: String }, + #[snafu(display("Error decoding PROPFIND response: {}", source))] InvalidPropFind { source: quick_xml::de::DeError }, @@ -235,11 +241,63 @@ impl Client { Ok(()) } - pub async fn get(&self, location: &Path, options: GetOptions) -> Result { - let url = self.path_url(location); - let builder = self.client.get(url); + pub async fn copy(&self, from: &Path, to: &Path, overwrite: bool) -> Result<()> { + let mut retry = false; + loop { + let method = Method::from_bytes(b"COPY").unwrap(); + + let mut builder = self + .client + .request(method, self.path_url(from)) + .header("Destination", self.path_url(to).as_str()); + + if !overwrite { + builder = builder.header("Overwrite", "F"); + } - builder + return match builder.send_retry(&self.retry_config).await { + Ok(_) => Ok(()), + Err(source) => Err(match source.status() { + Some(StatusCode::PRECONDITION_FAILED) if !overwrite => { + crate::Error::AlreadyExists { + path: to.to_string(), + source: Box::new(source), + } + } + // Some implementations return 404 instead of 409 + Some(StatusCode::CONFLICT | StatusCode::NOT_FOUND) if !retry => { + retry = true; + self.create_parent_directories(to).await?; + continue; + } + _ => Error::Request { source }.into(), + }), + }; + } + } +} + +#[async_trait] +impl GetClient for Client { + const STORE: &'static str = "HTTP"; + + /// Override the [`HeaderConfig`] to be less strict to support a + /// broader range of HTTP servers (#4831) + const HEADER_CONFIG: HeaderConfig = HeaderConfig { + etag_required: false, + last_modified_required: false, + }; + + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + let url = self.path_url(path); + let method = match options.head { + true => Method::HEAD, + false => Method::GET, + }; + let has_range = options.range.is_some(); + let builder = self.client.request(method, url); + + let res = builder .with_get_options(options) .send_retry(&self.retry_config) .await @@ -248,40 +306,23 @@ impl Client { Some(StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED) => { crate::Error::NotFound { source: Box::new(source), - path: location.to_string(), + path: path.to_string(), } } _ => Error::Request { source }.into(), - }) - } - - pub async fn copy(&self, from: &Path, to: &Path, overwrite: bool) -> Result<()> { - let from = self.path_url(from); - let to = self.path_url(to); - let method = Method::from_bytes(b"COPY").unwrap(); - - let mut builder = self - .client - .request(method, from) - .header("Destination", to.as_str()); + })?; - if !overwrite { - builder = builder.header("Overwrite", "F"); + // We expect a 206 Partial Content response if a range was requested + // a 200 OK response would indicate the server did not fulfill the request + if has_range && res.status() != StatusCode::PARTIAL_CONTENT { + return Err(crate::Error::NotSupported { + source: Box::new(Error::RangeNotSupported { + href: path.to_string(), + }), + }); } - match builder.send_retry(&self.retry_config).await { - Ok(_) => Ok(()), - Err(e) - if !overwrite - && matches!(e.status(), Some(StatusCode::PRECONDITION_FAILED)) => - { - Err(crate::Error::AlreadyExists { - path: to.to_string(), - source: Box::new(e), - }) - } - Err(source) => Err(Error::Request { source }.into()), - } + Ok(res) } } diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index 124b7da2f7e7..2fd7850b6bbf 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -17,7 +17,7 @@ //! An object store implementation for generic HTTP servers //! -//! This follows [rfc2518] commonly known called [WebDAV] +//! This follows [rfc2518] commonly known as [WebDAV] //! //! Basic get support will work out of the box with most HTTP servers, //! even those that don't explicitly support [rfc2518] @@ -40,11 +40,12 @@ use snafu::{OptionExt, ResultExt, Snafu}; use tokio::io::AsyncWrite; use url::Url; +use crate::client::get::GetClientExt; use crate::http::client::Client; use crate::path::Path; use crate::{ - ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, Result, RetryConfig, + ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, + ObjectMeta, ObjectStore, Result, RetryConfig, }; mod client; @@ -60,6 +61,11 @@ enum Error { url: String, }, + #[snafu(display("Unable to extract metadata from headers: {}", source))] + Metadata { + source: crate::client::header::Error, + }, + #[snafu(display("Request error: {}", source))] Reqwest { source: reqwest::Error }, } @@ -109,48 +115,20 @@ impl ObjectStore for HttpStore { } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - let response = self.client.get(location, options).await?; - let stream = response - .bytes_stream() - .map_err(|source| Error::Reqwest { source }.into()) - .boxed(); - - Ok(GetResult::Stream(stream)) - } - - async fn head(&self, location: &Path) -> Result { - let status = self.client.list(Some(location), "0").await?; - match status.response.len() { - 1 => { - let response = status.response.into_iter().next().unwrap(); - response.check_ok()?; - match response.is_dir() { - true => Err(crate::Error::NotFound { - path: location.to_string(), - source: "Is directory".to_string().into(), - }), - false => response.object_meta(self.client.base_url()), - } - } - x => Err(crate::Error::NotFound { - path: location.to_string(), - source: format!("Expected 1 result, got {x}").into(), - }), - } + self.client.get_opts(location, options).await } async fn delete(&self, location: &Path) -> Result<()> { self.client.delete(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let prefix_len = prefix.map(|p| p.as_ref().len()).unwrap_or_default(); - let status = self.client.list(prefix, "infinity").await?; - Ok(futures::stream::iter( - status + let prefix = prefix.cloned(); + futures::stream::once(async move { + let status = self.client.list(prefix.as_ref(), "infinity").await?; + + let iter = status .response .into_iter() .filter(|r| !r.is_dir()) @@ -159,9 +137,12 @@ impl ObjectStore for HttpStore { response.object_meta(self.client.base_url()) }) // Filter out exact prefix matches - .filter_ok(move |r| r.location.as_ref().len() > prefix_len), - ) - .boxed()) + .filter_ok(move |r| r.location.as_ref().len() > prefix_len); + + Ok::<_, crate::Error>(futures::stream::iter(iter)) + }) + .try_flatten() + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -231,6 +212,12 @@ impl HttpBuilder { self } + /// Set individual client configuration without overriding the entire config + pub fn with_config(mut self, key: ClientConfigKey, value: impl Into) -> Self { + self.client_options = self.client_options.with_config(key, value); + self + } + /// Sets the client options, overriding any already set pub fn with_client_options(mut self, options: ClientOptions) -> Self { self.client_options = options; @@ -256,12 +243,7 @@ mod tests { #[tokio::test] async fn http_test() { - dotenv::dotenv().ok(); - let force = std::env::var("TEST_INTEGRATION"); - if force.is_err() { - eprintln!("skipping HTTP integration test - set TEST_INTEGRATION to run"); - return; - } + crate::test_util::maybe_skip_integration!(); let url = std::env::var("HTTP_URL").expect("HTTP_URL must be set"); let options = ClientOptions::new().with_allow_http(true); let integration = HttpBuilder::new() diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 864cabc4a8c0..c5f014b9ea56 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -28,29 +28,56 @@ //! # object_store //! -//! This crate provides a uniform API for interacting with object storage services and -//! local files via the the [`ObjectStore`] trait. +//! This crate provides a uniform API for interacting with object +//! storage services and local files via the [`ObjectStore`] +//! trait. //! -//! # Create an [`ObjectStore`] implementation: +//! Using this crate, the same binary and code can run in multiple +//! clouds and local test environments, via a simple runtime +//! configuration change. +//! +//! # Highlights +//! +//! 1. A focused, easy to use, idiomatic, well documented, high +//! performance, `async` API. +//! +//! 2. Production quality, leading this crate to be used in large +//! scale production systems, such as [crates.io] and [InfluxDB IOx]. +//! +//! 3. Stable and predictable governance via the [Apache Arrow] project. +//! +//! Originally developed for [InfluxDB IOx] and subsequently donated +//! to [Apache Arrow]. +//! +//! [Apache Arrow]: https://arrow.apache.org/ +//! [InfluxDB IOx]: https://github.com/influxdata/influxdb_iox/ +//! [crates.io]: https://github.com/rust-lang/crates.io +//! +//! # Available [`ObjectStore`] Implementations +//! +//! By default, this crate provides the following implementations: +//! +//! * Memory: [`InMemory`](memory::InMemory) +//! * Local filesystem: [`LocalFileSystem`](local::LocalFileSystem) +//! +//! Feature flags are used to enable support for other implementations: //! #![cfg_attr( feature = "gcp", - doc = "* [Google Cloud Storage](https://cloud.google.com/storage/): [`GoogleCloudStorageBuilder`](gcp::GoogleCloudStorageBuilder)" + doc = "* `gcp`: [Google Cloud Storage](https://cloud.google.com/storage/) support. See [`GoogleCloudStorageBuilder`](gcp::GoogleCloudStorageBuilder)" )] #![cfg_attr( feature = "aws", - doc = "* [Amazon S3](https://aws.amazon.com/s3/): [`AmazonS3Builder`](aws::AmazonS3Builder)" + doc = "* `aws`: [Amazon S3](https://aws.amazon.com/s3/). See [`AmazonS3Builder`](aws::AmazonS3Builder)" )] #![cfg_attr( feature = "azure", - doc = "* [Azure Blob Storage](https://azure.microsoft.com/en-gb/services/storage/blobs/): [`MicrosoftAzureBuilder`](azure::MicrosoftAzureBuilder)" + doc = "* `azure`: [Azure Blob Storage](https://azure.microsoft.com/en-gb/services/storage/blobs/). See [`MicrosoftAzureBuilder`](azure::MicrosoftAzureBuilder)" )] #![cfg_attr( feature = "http", - doc = "* [HTTP Storage](https://datatracker.ietf.org/doc/html/rfc2518): [`HttpBuilder`](http::HttpBuilder)" + doc = "* `http`: [HTTP/WebDAV Storage](https://datatracker.ietf.org/doc/html/rfc2518). See [`HttpBuilder`](http::HttpBuilder)" )] -//! * In Memory: [`InMemory`](memory::InMemory) -//! * Local filesystem: [`LocalFileSystem`](local::LocalFileSystem) //! //! # Adapters //! @@ -68,18 +95,18 @@ //! //! ``` //! # use object_store::local::LocalFileSystem; +//! # use std::sync::Arc; +//! # use object_store::{path::Path, ObjectStore}; +//! # use futures::stream::StreamExt; //! # // use LocalFileSystem for example -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } -//! +//! # //! # async fn example() { -//! use std::sync::Arc; -//! use object_store::{path::Path, ObjectStore}; -//! use futures::stream::StreamExt; -//! +//! # //! // create an ObjectStore -//! let object_store: Arc = Arc::new(get_object_store()); +//! let object_store: Arc = get_object_store(); //! //! // Recursively list all files below the 'data' path. //! // 1. On AWS S3 this would be the 'data/' prefix @@ -87,21 +114,12 @@ //! let prefix: Path = "data".try_into().unwrap(); //! //! // Get an `async` stream of Metadata objects: -//! let list_stream = object_store -//! .list(Some(&prefix)) -//! .await -//! .expect("Error listing files"); +//! let mut list_stream = object_store.list(Some(&prefix)); //! -//! // Print a line about each object based on its metadata -//! // using for_each from `StreamExt` trait. -//! list_stream -//! .for_each(move |meta| { -//! async { -//! let meta = meta.expect("Error listing"); -//! println!("Name: {}, size: {}", meta.location, meta.size); -//! } -//! }) -//! .await; +//! // Print a line about each object +//! while let Some(meta) = list_stream.next().await.transpose().unwrap() { +//! println!("Name: {}, size: {}", meta.location, meta.size); +//! } //! # } //! ``` //! @@ -120,19 +138,18 @@ //! from remote storage or files in the local filesystem as a stream. //! //! ``` +//! # use futures::TryStreamExt; //! # use object_store::local::LocalFileSystem; -//! # // use LocalFileSystem for example -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # use std::sync::Arc; +//! # use object_store::{path::Path, ObjectStore}; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } -//! +//! # //! # async fn example() { -//! use std::sync::Arc; -//! use object_store::{path::Path, ObjectStore}; -//! use futures::stream::StreamExt; -//! +//! # //! // create an ObjectStore -//! let object_store: Arc = Arc::new(get_object_store()); +//! let object_store: Arc = get_object_store(); //! //! // Retrieve a specific file //! let path: Path = "data/file01.parquet".try_into().unwrap(); @@ -144,16 +161,11 @@ //! .unwrap() //! .into_stream(); //! -//! // Count the '0's using `map` from `StreamExt` trait +//! // Count the '0's using `try_fold` from `TryStreamExt` trait //! let num_zeros = stream -//! .map(|bytes| { -//! let bytes = bytes.unwrap(); -//! bytes.iter().filter(|b| **b == 0).count() -//! }) -//! .collect::>() -//! .await -//! .into_iter() -//! .sum::(); +//! .try_fold(0, |acc, bytes| async move { +//! Ok(acc + bytes.iter().filter(|b| **b == 0).count()) +//! }).await.unwrap(); //! //! println!("Num zeros in {} is {}", path, num_zeros); //! # } @@ -169,22 +181,19 @@ //! //! ``` //! # use object_store::local::LocalFileSystem; -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # use object_store::ObjectStore; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } //! # async fn put() { -//! use object_store::ObjectStore; -//! use std::sync::Arc; -//! use bytes::Bytes; -//! use object_store::path::Path; -//! -//! let object_store: Arc = Arc::new(get_object_store()); +//! # +//! let object_store: Arc = get_object_store(); //! let path: Path = "data/file1".try_into().unwrap(); //! let bytes = Bytes::from_static(b"hello"); -//! object_store -//! .put(&path, bytes) -//! .await -//! .unwrap(); +//! object_store.put(&path, bytes).await.unwrap(); //! # } //! ``` //! @@ -193,22 +202,20 @@ //! //! ``` //! # use object_store::local::LocalFileSystem; -//! # fn get_object_store() -> LocalFileSystem { -//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # use object_store::ObjectStore; +//! # use std::sync::Arc; +//! # use bytes::Bytes; +//! # use tokio::io::AsyncWriteExt; +//! # use object_store::path::Path; +//! # fn get_object_store() -> Arc { +//! # Arc::new(LocalFileSystem::new()) //! # } //! # async fn multi_upload() { -//! use object_store::ObjectStore; -//! use std::sync::Arc; -//! use bytes::Bytes; -//! use tokio::io::AsyncWriteExt; -//! use object_store::path::Path; -//! -//! let object_store: Arc = Arc::new(get_object_store()); +//! # +//! let object_store: Arc = get_object_store(); //! let path: Path = "data/large_file".try_into().unwrap(); -//! let (_id, mut writer) = object_store -//! .put_multipart(&path) -//! .await -//! .unwrap(); +//! let (_id, mut writer) = object_store.put_multipart(&path).await.unwrap(); +//! //! let bytes = Bytes::from_static(b"hello"); //! writer.write_all(&bytes).await.unwrap(); //! writer.flush().await.unwrap(); @@ -218,14 +225,15 @@ #[cfg(all( target_arch = "wasm32", - any(feature = "gcp", feature = "aws", feature = "azure",) + any(feature = "gcp", feature = "aws", feature = "azure", feature = "http") ))] -compile_error!("Features 'gcp', 'aws', 'azure' are not supported on wasm."); +compile_error!("Features 'gcp', 'aws', 'azure', 'http' are not supported on wasm."); #[cfg(feature = "aws")] pub mod aws; #[cfg(feature = "azure")] pub mod azure; +pub mod buffered; #[cfg(not(target_arch = "wasm32"))] pub mod chunked; pub mod delimited; @@ -239,19 +247,26 @@ pub mod local; pub mod memory; pub mod path; pub mod prefix; +#[cfg(feature = "cloud")] +pub mod signer; pub mod throttle; -#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = "http"))] +#[cfg(feature = "cloud")] mod client; #[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = "http"))] -pub use client::{backoff::BackoffConfig, retry::RetryConfig, CredentialProvider}; +pub use client::{ + backoff::BackoffConfig, retry::RetryConfig, ClientConfigKey, ClientOptions, + CredentialProvider, StaticCredentialProvider, +}; -#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = "http"))] +use std::collections::HashMap; + +#[cfg(feature = "cloud")] mod config; -#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] -mod multipart; +#[cfg(feature = "cloud")] +pub mod multipart; mod parse; mod util; @@ -260,7 +275,7 @@ pub use parse::{parse_url, parse_url_opts}; use crate::path::Path; #[cfg(not(target_arch = "wasm32"))] use crate::util::maybe_spawn_blocking; -use crate::util::{coalesce_ranges, collect_bytes, OBJECT_STORE_COALESCE_DEFAULT}; +pub use crate::util::{coalesce_ranges, collect_bytes, OBJECT_STORE_COALESCE_DEFAULT}; use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, Utc}; @@ -270,11 +285,9 @@ use std::fmt::{Debug, Formatter}; #[cfg(not(target_arch = "wasm32"))] use std::io::{Read, Seek, SeekFrom}; use std::ops::Range; +use std::sync::Arc; use tokio::io::AsyncWrite; -#[cfg(any(feature = "azure", feature = "aws", feature = "gcp", feature = "http"))] -pub use client::ClientOptions; - /// An alias for a dynamically dispatched object store implementation. pub type DynObjectStore = dyn ObjectStore; @@ -291,6 +304,25 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// should be able to observe a partially written object async fn put(&self, location: &Path, bytes: Bytes) -> Result<()>; + /// Save the provided bytes to the specified location + /// + /// The operation is guaranteed to be atomic, it will either successfully + /// write the entirety of `bytes` to `location`, or fail. No clients + /// should be able to observe a partially written object + /// + /// If the specified `options` include key-value metadata, this will be stored + /// along with the object depending on the capabilities of the underlying implementation. + /// + /// For example, when using an AWS S3 `ObjectStore` the `tags` will be saved as object tags in S3 + async fn put_opts( + &self, + location: &Path, + bytes: Bytes, + _options: PutOptions, + ) -> Result<()> { + self.put(location, bytes).await + } + /// Get a multi-part upload that allows writing data in chunks /// /// Most cloud-based uploads will buffer and upload parts in parallel. @@ -351,8 +383,6 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { } /// Perform a get request with options - /// - /// Note: options.range will be ignored if [`GetResult::File`] async fn get_opts(&self, location: &Path, options: GetOptions) -> Result; /// Return the bytes that are stored at the specified location @@ -362,17 +392,7 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { range: Some(range.clone()), ..Default::default() }; - // Temporary until GetResult::File supports range (#4352) - match self.get_opts(location, options).await? { - GetResult::Stream(s) => collect_bytes(s, None).await, - #[cfg(not(target_arch = "wasm32"))] - GetResult::File(mut file, path) => { - maybe_spawn_blocking(move || local::read_range(&mut file, &path, range)) - .await - } - #[cfg(target_arch = "wasm32")] - _ => unimplemented!("File IO not implemented on wasm32."), - } + self.get_opts(location, options).await?.bytes().await } /// Return the bytes that are stored at the specified location @@ -391,7 +411,13 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { } /// Return the metadata for the specified location - async fn head(&self, location: &Path) -> Result; + async fn head(&self, location: &Path) -> Result { + let options = GetOptions { + head: true, + ..Default::default() + }; + Ok(self.get_opts(location, options).await?.meta) + } /// Delete the object at the specified location. async fn delete(&self, location: &Path) -> Result<()>; @@ -414,23 +440,22 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// return Ok. If it is an error, it will be [`Error::NotFound`]. /// /// ``` + /// # use futures::{StreamExt, TryStreamExt}; /// # use object_store::local::LocalFileSystem; /// # async fn example() -> Result<(), Box> { /// # let root = tempfile::TempDir::new().unwrap(); /// # let store = LocalFileSystem::new_with_prefix(root.path()).unwrap(); - /// use object_store::{ObjectStore, ObjectMeta}; - /// use object_store::path::Path; - /// use futures::{StreamExt, TryStreamExt}; - /// use bytes::Bytes; - /// + /// # use object_store::{ObjectStore, ObjectMeta}; + /// # use object_store::path::Path; + /// # use futures::{StreamExt, TryStreamExt}; + /// # use bytes::Bytes; + /// # /// // Create two objects /// store.put(&Path::from("foo"), Bytes::from("foo")).await?; /// store.put(&Path::from("bar"), Bytes::from("bar")).await?; /// /// // List object - /// let locations = store.list(None).await? - /// .map(|meta: Result| meta.map(|m| m.location)) - /// .boxed(); + /// let locations = store.list(None).map_ok(|m| m.location).boxed(); /// /// // Delete them /// store.delete_stream(locations).try_collect::>().await?; @@ -459,10 +484,7 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// `foo/bar_baz/x`. /// /// Note: the order of returned [`ObjectMeta`] is not guaranteed - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>>; + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result>; /// List all the objects with the given prefix and a location greater than `offset` /// @@ -470,18 +492,15 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// the number of network requests required /// /// Note: the order of returned [`ObjectMeta`] is not guaranteed - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { + ) -> BoxStream<'_, Result> { let offset = offset.clone(); - let stream = self - .list(prefix) - .await? + self.list(prefix) .try_filter(move |f| futures::future::ready(f.location > offset)) - .boxed(); - Ok(stream) + .boxed() } /// List objects with the given prefix and an implementation specific @@ -526,105 +545,120 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { } } -#[async_trait] -impl ObjectStore for Box { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.as_ref().put(location, bytes).await - } +macro_rules! as_ref_impl { + ($type:ty) => { + #[async_trait] + impl ObjectStore for $type { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + self.as_ref().put(location, bytes).await + } - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - self.as_ref().put_multipart(location).await - } + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + self.as_ref().put_multipart(location).await + } - async fn abort_multipart( - &self, - location: &Path, - multipart_id: &MultipartId, - ) -> Result<()> { - self.as_ref().abort_multipart(location, multipart_id).await - } + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + self.as_ref().abort_multipart(location, multipart_id).await + } - async fn append( - &self, - location: &Path, - ) -> Result> { - self.as_ref().append(location).await - } + async fn append( + &self, + location: &Path, + ) -> Result> { + self.as_ref().append(location).await + } - async fn get(&self, location: &Path) -> Result { - self.as_ref().get(location).await - } + async fn get(&self, location: &Path) -> Result { + self.as_ref().get(location).await + } - async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - self.as_ref().get_opts(location, options).await - } + async fn get_opts( + &self, + location: &Path, + options: GetOptions, + ) -> Result { + self.as_ref().get_opts(location, options).await + } - async fn get_range(&self, location: &Path, range: Range) -> Result { - self.as_ref().get_range(location, range).await - } + async fn get_range( + &self, + location: &Path, + range: Range, + ) -> Result { + self.as_ref().get_range(location, range).await + } - async fn get_ranges( - &self, - location: &Path, - ranges: &[Range], - ) -> Result> { - self.as_ref().get_ranges(location, ranges).await - } + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> Result> { + self.as_ref().get_ranges(location, ranges).await + } - async fn head(&self, location: &Path) -> Result { - self.as_ref().head(location).await - } + async fn head(&self, location: &Path) -> Result { + self.as_ref().head(location).await + } - async fn delete(&self, location: &Path) -> Result<()> { - self.as_ref().delete(location).await - } + async fn delete(&self, location: &Path) -> Result<()> { + self.as_ref().delete(location).await + } - fn delete_stream<'a>( - &'a self, - locations: BoxStream<'a, Result>, - ) -> BoxStream<'a, Result> { - self.as_ref().delete_stream(locations) - } + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, Result>, + ) -> BoxStream<'a, Result> { + self.as_ref().delete_stream(locations) + } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.as_ref().list(prefix).await - } + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + self.as_ref().list(prefix) + } - async fn list_with_offset( - &self, - prefix: Option<&Path>, - offset: &Path, - ) -> Result>> { - self.as_ref().list_with_offset(prefix, offset).await - } + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'_, Result> { + self.as_ref().list_with_offset(prefix, offset) + } - async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { - self.as_ref().list_with_delimiter(prefix).await - } + async fn list_with_delimiter( + &self, + prefix: Option<&Path>, + ) -> Result { + self.as_ref().list_with_delimiter(prefix).await + } - async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - self.as_ref().copy(from, to).await - } + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().copy(from, to).await + } - async fn rename(&self, from: &Path, to: &Path) -> Result<()> { - self.as_ref().rename(from, to).await - } + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().rename(from, to).await + } - async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - self.as_ref().copy_if_not_exists(from, to).await - } + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().copy_if_not_exists(from, to).await + } - async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - self.as_ref().rename_if_not_exists(from, to).await - } + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.as_ref().rename_if_not_exists(from, to).await + } + } + }; } +as_ref_impl!(Arc); +as_ref_impl!(Box); + /// Result of a list call that includes objects, prefixes (directories) and a /// token for the next set of results. Individual result sets may be limited to /// 1,000 objects based on the underlying object storage's limitations. @@ -655,12 +689,28 @@ pub struct GetOptions { /// Request will succeed if the `ObjectMeta::e_tag` matches /// otherwise returning [`Error::Precondition`] /// - /// + /// See + /// + /// Examples: + /// + /// ```text + /// If-Match: "xyzzy" + /// If-Match: "xyzzy", "r2d2xxxx", "c3piozzzz" + /// If-Match: * + /// ``` pub if_match: Option, /// Request will succeed if the `ObjectMeta::e_tag` does not match /// otherwise returning [`Error::NotModified`] /// - /// + /// See + /// + /// Examples: + /// + /// ```text + /// If-None-Match: "xyzzy" + /// If-None-Match: "xyzzy", "r2d2xxxx", "c3piozzzz" + /// If-None-Match: * + /// ``` pub if_none_match: Option, /// Request will succeed if the object has been modified since /// @@ -679,29 +729,56 @@ pub struct GetOptions { /// /// pub range: Option>, + /// Request transfer of no content + /// + /// + pub head: bool, +} + +/// Options for a put request, such as tags +#[derive(Debug, Default)] +pub struct PutOptions { + /// Key/Value metadata associated with the object + pub tags: HashMap, } impl GetOptions { /// Returns an error if the modification conditions on this request are not satisfied - fn check_modified( - &self, - location: &Path, - last_modified: DateTime, - ) -> Result<()> { - if let Some(date) = self.if_modified_since { - if last_modified <= date { - return Err(Error::NotModified { - path: location.to_string(), - source: format!("{} >= {}", date, last_modified).into(), + /// + /// + fn check_preconditions(&self, meta: &ObjectMeta) -> Result<()> { + // The use of the invalid etag "*" means no ETag is equivalent to never matching + let etag = meta.e_tag.as_deref().unwrap_or("*"); + let last_modified = meta.last_modified; + + if let Some(m) = &self.if_match { + if m != "*" && m.split(',').map(str::trim).all(|x| x != etag) { + return Err(Error::Precondition { + path: meta.location.to_string(), + source: format!("{etag} does not match {m}").into(), }); } - } - - if let Some(date) = self.if_unmodified_since { + } else if let Some(date) = self.if_unmodified_since { if last_modified > date { return Err(Error::Precondition { - path: location.to_string(), - source: format!("{} < {}", date, last_modified).into(), + path: meta.location.to_string(), + source: format!("{date} < {last_modified}").into(), + }); + } + } + + if let Some(m) = &self.if_none_match { + if m == "*" || m.split(',').map(str::trim).any(|x| x == etag) { + return Err(Error::NotModified { + path: meta.location.to_string(), + source: format!("{etag} matches {m}").into(), + }); + } + } else if let Some(date) = self.if_modified_since { + if last_modified <= date { + return Err(Error::NotModified { + path: meta.location.to_string(), + source: format!("{date} >= {last_modified}").into(), }); } } @@ -710,21 +787,32 @@ impl GetOptions { } /// Result for a get request +#[derive(Debug)] +pub struct GetResult { + /// The [`GetResultPayload`] + pub payload: GetResultPayload, + /// The [`ObjectMeta`] for this object + pub meta: ObjectMeta, + /// The range of bytes returned by this request + pub range: Range, +} + +/// The kind of a [`GetResult`] /// /// This special cases the case of a local file, as some systems may /// be able to optimise the case of a file already present on local disk -pub enum GetResult { - /// A file and its path on the local filesystem +pub enum GetResultPayload { + /// The file, path File(std::fs::File, std::path::PathBuf), - /// An asynchronous stream + /// An opaque stream of bytes Stream(BoxStream<'static, Result>), } -impl Debug for GetResult { +impl Debug for GetResultPayload { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::File(_, _) => write!(f, "GetResult(File)"), - Self::Stream(_) => write!(f, "GetResult(Stream)"), + Self::File(_, _) => write!(f, "GetResultPayload(File)"), + Self::Stream(_) => write!(f, "GetResultPayload(Stream)"), } } } @@ -732,32 +820,31 @@ impl Debug for GetResult { impl GetResult { /// Collects the data into a [`Bytes`] pub async fn bytes(self) -> Result { - match self { + let len = self.range.end - self.range.start; + match self.payload { #[cfg(not(target_arch = "wasm32"))] - Self::File(mut file, path) => { + GetResultPayload::File(mut file, path) => { maybe_spawn_blocking(move || { - let len = file.seek(SeekFrom::End(0)).map_err(|source| { - local::Error::Seek { + file.seek(SeekFrom::Start(self.range.start as _)).map_err( + |source| local::Error::Seek { source, path: path.clone(), - } - })?; - - file.rewind().map_err(|source| local::Error::Seek { - source, - path: path.clone(), - })?; + }, + )?; - let mut buffer = Vec::with_capacity(len as usize); - file.read_to_end(&mut buffer).map_err(|source| { - local::Error::UnableToReadBytes { source, path } - })?; + let mut buffer = Vec::with_capacity(len); + file.take(len as _) + .read_to_end(&mut buffer) + .map_err(|source| local::Error::UnableToReadBytes { + source, + path, + })?; Ok(buffer.into()) }) .await } - Self::Stream(s) => collect_bytes(s, None).await, + GetResultPayload::Stream(s) => collect_bytes(s, Some(len)).await, #[cfg(target_arch = "wasm32")] _ => unimplemented!("File IO not implemented on wasm32."), } @@ -765,8 +852,8 @@ impl GetResult { /// Converts this into a byte stream /// - /// If the result is [`Self::File`] will perform chunked reads of the file, otherwise - /// will return the [`Self::Stream`]. + /// If the `self.kind` is [`GetResultPayload::File`] will perform chunked reads of the file, + /// otherwise will return the [`GetResultPayload::Stream`]. /// /// # Tokio Compatibility /// @@ -778,36 +865,13 @@ impl GetResult { /// If not called from a tokio context, this will perform IO on the current thread with /// no additional complexity or overheads pub fn into_stream(self) -> BoxStream<'static, Result> { - match self { + match self.payload { #[cfg(not(target_arch = "wasm32"))] - Self::File(file, path) => { + GetResultPayload::File(file, path) => { const CHUNK_SIZE: usize = 8 * 1024; - - futures::stream::try_unfold( - (file, path, false), - |(mut file, path, finished)| { - maybe_spawn_blocking(move || { - if finished { - return Ok(None); - } - - let mut buffer = Vec::with_capacity(CHUNK_SIZE); - let read = file - .by_ref() - .take(CHUNK_SIZE as u64) - .read_to_end(&mut buffer) - .map_err(|e| local::Error::UnableToReadBytes { - source: e, - path: path.clone(), - })?; - - Ok(Some((buffer.into(), (file, path, read != CHUNK_SIZE)))) - }) - }, - ) - .boxed() + local::chunked_stream(file, path, self.range, CHUNK_SIZE) } - Self::Stream(s) => s, + GetResultPayload::Stream(s) => s, #[cfg(target_arch = "wasm32")] _ => unimplemented!("File IO not implemented on wasm32."), } @@ -891,13 +955,22 @@ mod test_util { use super::*; use futures::TryStreamExt; + macro_rules! maybe_skip_integration { + () => { + if std::env::var("TEST_INTEGRATION").is_err() { + eprintln!("Skipping integration test - set TEST_INTEGRATION"); + return; + } + }; + } + pub(crate) use maybe_skip_integration; + pub async fn flatten_list_stream( storage: &DynObjectStore, prefix: Option<&Path>, ) -> Result> { storage .list(prefix) - .await? .map_ok(|meta| meta.location) .try_collect::>() .await @@ -908,8 +981,8 @@ mod test_util { mod tests { use super::*; use crate::test_util::flatten_list_stream; - use bytes::{BufMut, BytesMut}; - use itertools::Itertools; + use chrono::TimeZone; + use rand::{thread_rng, Rng}; use tokio::io::AsyncWriteExt; pub(crate) async fn put_get_delete_list(storage: &DynObjectStore) { @@ -1080,8 +1153,24 @@ mod tests { files.sort_unstable(); assert_eq!(files, vec![emoji_file.clone(), dst.clone()]); + let dst2 = Path::from("new/nested/foo.parquet"); + storage.copy(&emoji_file, &dst2).await.unwrap(); + let mut files = flatten_list_stream(storage, None).await.unwrap(); + files.sort_unstable(); + assert_eq!(files, vec![emoji_file.clone(), dst.clone(), dst2.clone()]); + + let dst3 = Path::from("new/nested2/bar.parquet"); + storage.rename(&dst, &dst3).await.unwrap(); + let mut files = flatten_list_stream(storage, None).await.unwrap(); + files.sort_unstable(); + assert_eq!(files, vec![emoji_file.clone(), dst2.clone(), dst3.clone()]); + + let err = storage.head(&dst).await.unwrap_err(); + assert!(matches!(err, Error::NotFound { .. })); + storage.delete(&emoji_file).await.unwrap(); - storage.delete(&dst).await.unwrap(); + storage.delete(&dst3).await.unwrap(); + storage.delete(&dst2).await.unwrap(); let files = flatten_list_stream(storage, Some(&emoji_prefix)) .await .unwrap(); @@ -1172,11 +1261,7 @@ mod tests { ]; for (prefix, offset) in cases { - let s = storage - .list_with_offset(prefix.as_ref(), &offset) - .await - .unwrap(); - + let s = storage.list_with_offset(prefix.as_ref(), &offset); let mut actual: Vec<_> = s.map_ok(|x| x.location).try_collect().await.unwrap(); @@ -1239,6 +1324,15 @@ mod tests { } delete_fixtures(storage).await; + + let path = Path::from("empty"); + storage.put(&path, Bytes::new()).await.unwrap(); + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.size, 0); + let data = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(data.len(), 0); + + storage.delete(&path).await.unwrap(); } pub(crate) async fn get_opts(storage: &dyn ObjectStore) { @@ -1291,56 +1385,55 @@ mod tests { Err(e) => panic!("{e}"), } - if let Some(tag) = meta.e_tag { - let options = GetOptions { - if_match: Some(tag.clone()), - ..GetOptions::default() - }; - storage.get_opts(&path, options).await.unwrap(); - - let options = GetOptions { - if_match: Some("invalid".to_string()), - ..GetOptions::default() - }; - let err = storage.get_opts(&path, options).await.unwrap_err(); - assert!(matches!(err, Error::Precondition { .. }), "{err}"); - - let options = GetOptions { - if_none_match: Some(tag.clone()), - ..GetOptions::default() - }; - let err = storage.get_opts(&path, options).await.unwrap_err(); - assert!(matches!(err, Error::NotModified { .. }), "{err}"); - - let options = GetOptions { - if_none_match: Some("invalid".to_string()), - ..GetOptions::default() - }; - storage.get_opts(&path, options).await.unwrap(); - } + let tag = meta.e_tag.unwrap(); + let options = GetOptions { + if_match: Some(tag.clone()), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); + + let options = GetOptions { + if_match: Some("invalid".to_string()), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + let options = GetOptions { + if_none_match: Some(tag.clone()), + ..GetOptions::default() + }; + let err = storage.get_opts(&path, options).await.unwrap_err(); + assert!(matches!(err, Error::NotModified { .. }), "{err}"); + + let options = GetOptions { + if_none_match: Some("invalid".to_string()), + ..GetOptions::default() + }; + storage.get_opts(&path, options).await.unwrap(); } - fn get_random_bytes(len: usize) -> Bytes { - use rand::Rng; - let mut rng = rand::thread_rng(); - let mut bytes = BytesMut::with_capacity(len); - for _ in 0..len { - bytes.put_u8(rng.gen()); + /// Returns a chunk of length `chunk_length` + fn get_chunk(chunk_length: usize) -> Bytes { + let mut data = vec![0_u8; chunk_length]; + let mut rng = thread_rng(); + // Set a random selection of bytes + for _ in 0..1000 { + data[rng.gen_range(0..chunk_length)] = rng.gen(); } - bytes.freeze() + data.into() } - fn get_vec_of_bytes(chunk_length: usize, num_chunks: usize) -> Vec { - std::iter::repeat(get_random_bytes(chunk_length)) - .take(num_chunks) - .collect() + /// Returns `num_chunks` of length `chunks` + fn get_chunks(chunk_length: usize, num_chunks: usize) -> Vec { + (0..num_chunks).map(|_| get_chunk(chunk_length)).collect() } pub(crate) async fn stream_get(storage: &DynObjectStore) { let location = Path::from("test_dir/test_upload_file.txt"); // Can write to storage - let data = get_vec_of_bytes(5_000, 10); + let data = get_chunks(5_000, 10); let bytes_expected = data.concat(); let (_, mut writer) = storage.put_multipart(&location).await.unwrap(); for chunk in &data { @@ -1367,7 +1460,7 @@ mod tests { // Can overwrite some storage // Sizes chosen to ensure we write three parts - let data = (0..7).map(|_| get_random_bytes(3_200_000)).collect_vec(); + let data = get_chunks(3_200_000, 7); let bytes_expected = data.concat(); let (_, mut writer) = storage.put_multipart(&location).await.unwrap(); for chunk in &data { @@ -1571,7 +1664,7 @@ mod tests { pub(crate) async fn copy_if_not_exists(storage: &DynObjectStore) { // Create two objects let path1 = Path::from("test1"); - let path2 = Path::from("test2"); + let path2 = Path::from("not_exists_nested/test2"); let contents1 = Bytes::from("cats"); let contents2 = Bytes::from("dogs"); @@ -1600,12 +1693,7 @@ mod tests { } async fn delete_fixtures(storage: &DynObjectStore) { - let paths = storage - .list(None) - .await - .unwrap() - .map_ok(|meta| meta.location) - .boxed(); + let paths = storage.list(None).map_ok(|meta| meta.location).boxed(); storage .delete_stream(paths) .try_collect::>() @@ -1614,23 +1702,101 @@ mod tests { } /// Test that the returned stream does not borrow the lifetime of Path - async fn list_store<'a, 'b>( + fn list_store<'a>( store: &'a dyn ObjectStore, - path_str: &'b str, - ) -> super::Result>> { + path_str: &str, + ) -> BoxStream<'a, Result> { let path = Path::from(path_str); - store.list(Some(&path)).await + store.list(Some(&path)) } #[tokio::test] async fn test_list_lifetimes() { let store = memory::InMemory::new(); - let mut stream = list_store(&store, "path").await.unwrap(); + let mut stream = list_store(&store, "path"); assert!(stream.next().await.is_none()); } - // Tests TODO: - // GET nonexisting location (in_memory/file) - // DELETE nonexisting location - // PUT overwriting + #[test] + fn test_preconditions() { + let mut meta = ObjectMeta { + location: Path::from("test"), + last_modified: Utc.timestamp_nanos(100), + size: 100, + e_tag: Some("123".to_string()), + }; + + let mut options = GetOptions::default(); + options.check_preconditions(&meta).unwrap(); + + options.if_modified_since = Some(Utc.timestamp_nanos(50)); + options.check_preconditions(&meta).unwrap(); + + options.if_modified_since = Some(Utc.timestamp_nanos(100)); + options.check_preconditions(&meta).unwrap_err(); + + options.if_modified_since = Some(Utc.timestamp_nanos(101)); + options.check_preconditions(&meta).unwrap_err(); + + options = GetOptions::default(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(50)); + options.check_preconditions(&meta).unwrap_err(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(100)); + options.check_preconditions(&meta).unwrap(); + + options.if_unmodified_since = Some(Utc.timestamp_nanos(101)); + options.check_preconditions(&meta).unwrap(); + + options = GetOptions::default(); + + options.if_match = Some("123".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("123,354".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("354, 123,".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_match = Some("354".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_match = Some("*".to_string()); + options.check_preconditions(&meta).unwrap(); + + // If-Match takes precedence + options.if_unmodified_since = Some(Utc.timestamp_nanos(200)); + options.check_preconditions(&meta).unwrap(); + + options = GetOptions::default(); + + options.if_none_match = Some("123".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_none_match = Some("*".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + options.if_none_match = Some("1232".to_string()); + options.check_preconditions(&meta).unwrap(); + + options.if_none_match = Some("23, 123".to_string()); + options.check_preconditions(&meta).unwrap_err(); + + // If-None-Match takes precedence + options.if_modified_since = Some(Utc.timestamp_nanos(10)); + options.check_preconditions(&meta).unwrap_err(); + + // Check missing ETag + meta.e_tag = None; + options = GetOptions::default(); + + options.if_none_match = Some("*".to_string()); // Fails if any file exists + options.check_preconditions(&meta).unwrap_err(); + + options = GetOptions::default(); + options.if_match = Some("*".to_string()); // Passes if file exists + options.check_preconditions(&meta).unwrap(); + } } diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs index 630fd145b72c..06cb2f25c7e6 100644 --- a/object_store/src/limit.rs +++ b/object_store/src/limit.rs @@ -18,12 +18,12 @@ //! An object store that limits the maximum concurrency of the wrapped implementation use crate::{ - BoxStream, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, - Path, Result, StreamExt, + BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, + ObjectMeta, ObjectStore, Path, Result, StreamExt, PutOptions, }; use async_trait::async_trait; use bytes::Bytes; -use futures::Stream; +use futures::{FutureExt, Stream}; use std::io::{Error, IoSlice}; use std::ops::Range; use std::pin::Pin; @@ -77,6 +77,16 @@ impl ObjectStore for LimitStore { self.inner.put(location, bytes).await } + async fn put_opts( + &self, + location: &Path, + bytes: Bytes, + options: PutOptions, + ) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.put_opts(location, bytes, options).await + } + async fn put_multipart( &self, location: &Path, @@ -106,22 +116,14 @@ impl ObjectStore for LimitStore { async fn get(&self, location: &Path) -> Result { let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - match self.inner.get(location).await? { - r @ GetResult::File(_, _) => Ok(r), - GetResult::Stream(s) => { - Ok(GetResult::Stream(PermitWrapper::new(s, permit).boxed())) - } - } + let r = self.inner.get(location).await?; + Ok(permit_get_result(r, permit)) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - match self.inner.get_opts(location, options).await? { - r @ GetResult::File(_, _) => Ok(r), - GetResult::Stream(s) => { - Ok(GetResult::Stream(PermitWrapper::new(s, permit).boxed())) - } - } + let r = self.inner.get_opts(location, options).await?; + Ok(permit_get_result(r, permit)) } async fn get_range(&self, location: &Path, range: Range) -> Result { @@ -155,23 +157,31 @@ impl ObjectStore for LimitStore { self.inner.delete_stream(locations) } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - let s = self.inner.list(prefix).await?; - Ok(PermitWrapper::new(s, permit).boxed()) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + let prefix = prefix.cloned(); + let fut = Arc::clone(&self.semaphore) + .acquire_owned() + .map(move |permit| { + let s = self.inner.list(prefix.as_ref()); + PermitWrapper::new(s, permit.unwrap()) + }); + fut.into_stream().flatten().boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); - let s = self.inner.list_with_offset(prefix, offset).await?; - Ok(PermitWrapper::new(s, permit).boxed()) + ) -> BoxStream<'_, Result> { + let prefix = prefix.cloned(); + let offset = offset.clone(); + let fut = Arc::clone(&self.semaphore) + .acquire_owned() + .map(move |permit| { + let s = self.inner.list_with_offset(prefix.as_ref(), &offset); + PermitWrapper::new(s, permit.unwrap()) + }); + fut.into_stream().flatten().boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -200,6 +210,16 @@ impl ObjectStore for LimitStore { } } +fn permit_get_result(r: GetResult, permit: OwnedSemaphorePermit) -> GetResult { + let payload = match r.payload { + v @ GetResultPayload::File(_, _) => v, + GetResultPayload::Stream(s) => { + GetResultPayload::Stream(PermitWrapper::new(s, permit).boxed()) + } + }; + GetResult { payload, ..r } +} + /// Combines an [`OwnedSemaphorePermit`] with some other type struct PermitWrapper { inner: T, @@ -270,6 +290,8 @@ mod tests { use crate::memory::InMemory; use crate::tests::*; use crate::ObjectStore; + use futures::stream::StreamExt; + use std::pin::Pin; use std::time::Duration; use tokio::time::timeout; @@ -288,19 +310,21 @@ mod tests { let mut streams = Vec::with_capacity(max_requests); for _ in 0..max_requests { - let stream = integration.list(None).await.unwrap(); + let mut stream = integration.list(None).peekable(); + Pin::new(&mut stream).peek().await; // Ensure semaphore is acquired streams.push(stream); } let t = Duration::from_millis(20); // Expect to not be able to make another request - assert!(timeout(t, integration.list(None)).await.is_err()); + let fut = integration.list(None).collect::>(); + assert!(timeout(t, fut).await.is_err()); // Drop one of the streams streams.pop(); // Can now make another request - integration.list(None).await.unwrap(); + integration.list(None).collect::>().await; } } diff --git a/object_store/src/local.rs b/object_store/src/local.rs index ffff6a5739d5..38467c3a9e7c 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -19,16 +19,18 @@ use crate::{ maybe_spawn_blocking, path::{absolute_path_to_url, Path}, - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, + GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, ObjectMeta, + ObjectStore, Result, }; use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, Utc}; use futures::future::BoxFuture; -use futures::FutureExt; +use futures::ready; use futures::{stream::BoxStream, StreamExt}; -use snafu::{ensure, OptionExt, ResultExt, Snafu}; -use std::fs::{metadata, symlink_metadata, File, OpenOptions}; +use futures::{FutureExt, TryStreamExt}; +use snafu::{ensure, ResultExt, Snafu}; +use std::fs::{metadata, symlink_metadata, File, Metadata, OpenOptions}; use std::io::{ErrorKind, Read, Seek, SeekFrom, Write}; use std::ops::Range; use std::pin::Pin; @@ -77,10 +79,10 @@ pub(crate) enum Error { path: PathBuf, }, - #[snafu(display("Unable to create file {}: {}", path.display(), err))] + #[snafu(display("Unable to create file {}: {}", path.display(), source))] UnableToCreateFile { + source: io::Error, path: PathBuf, - err: io::Error, }, #[snafu(display("Unable to delete file {}: {}", path.display(), source))] @@ -273,13 +275,15 @@ impl ObjectStore for LocalFileSystem { maybe_spawn_blocking(move || { let (mut file, suffix) = new_staged_upload(&path)?; let staging_path = staged_upload_path(&path, &suffix); - file.write_all(&bytes) - .context(UnableToCopyDataToFileSnafu)?; - - std::fs::rename(staging_path, path).context(UnableToRenameFileSnafu)?; - - Ok(()) + .context(UnableToCopyDataToFileSnafu) + .and_then(|_| { + std::fs::rename(&staging_path, &path).context(UnableToRenameFileSnafu) + }) + .map_err(|e| { + let _ = std::fs::remove_file(&staging_path); // Attempt to cleanup + e.into() + }) }) .await } @@ -303,12 +307,14 @@ impl ObjectStore for LocalFileSystem { multipart_id: &MultipartId, ) -> Result<()> { let dest = self.config.path_to_filesystem(location)?; - let staging_path: PathBuf = staged_upload_path(&dest, multipart_id); + let path: PathBuf = staged_upload_path(&dest, multipart_id); - maybe_spawn_blocking(move || { - std::fs::remove_file(&staging_path) - .context(UnableToDeleteFileSnafu { path: staging_path })?; - Ok(()) + maybe_spawn_blocking(move || match std::fs::remove_file(&path) { + Ok(_) => Ok(()), + Err(source) => match source.kind() { + ErrorKind::NotFound => Ok(()), // Already deleted + _ => Err(Error::UnableToDeleteFile { path, source }.into()), + }, }) .await } @@ -317,7 +323,6 @@ impl ObjectStore for LocalFileSystem { &self, location: &Path, ) -> Result> { - #[cfg(not(target_arch = "wasm32"))] // Get the path to the file from the configuration. let path = self.config.path_to_filesystem(location)?; loop { @@ -335,12 +340,13 @@ impl ObjectStore for LocalFileSystem { // If the file was successfully opened, return it wrapped in a boxed `AsyncWrite` trait object. Ok(file) => return Ok(Box::new(file)), // If the error is that the file was not found, attempt to create the file and any necessary parent directories. - Err(err) if err.kind() == ErrorKind::NotFound => { + Err(source) if source.kind() == ErrorKind::NotFound => { // Get the path to the parent directory of the file. - let parent = path - .parent() - // If the parent directory does not exist, return a `UnableToCreateFileSnafu` error. - .context(UnableToCreateFileSnafu { path: &path, err })?; + let parent = + path.parent().ok_or_else(|| Error::UnableToCreateFile { + path: path.to_path_buf(), + source, + })?; // Create the parent directory and any necessary ancestors. tokio::fs::create_dir_all(parent) @@ -356,32 +362,21 @@ impl ObjectStore for LocalFileSystem { } } } - #[cfg(target_arch = "wasm32")] - Err(super::Error::NotImplemented) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - if options.if_match.is_some() || options.if_none_match.is_some() { - return Err(super::Error::NotSupported { - source: "ETags not supported by LocalFileSystem".to_string().into(), - }); - } - let location = location.clone(); let path = self.config.path_to_filesystem(&location)?; maybe_spawn_blocking(move || { - let file = open_file(&path)?; - if options.if_unmodified_since.is_some() - || options.if_modified_since.is_some() - { - let metadata = file.metadata().map_err(|e| Error::Metadata { - source: e.into(), - path: location.to_string(), - })?; - options.check_modified(&location, last_modified(&metadata))?; - } - - Ok(GetResult::File(file, path)) + let (file, metadata) = open_file(&path)?; + let meta = convert_metadata(metadata, location)?; + options.check_preconditions(&meta)?; + + Ok(GetResult { + payload: GetResultPayload::File(file, path), + range: options.range.unwrap_or(0..meta.size), + meta, + }) }) .await } @@ -389,7 +384,7 @@ impl ObjectStore for LocalFileSystem { async fn get_range(&self, location: &Path, range: Range) -> Result { let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || { - let mut file = open_file(&path)?; + let (mut file, _) = open_file(&path)?; read_range(&mut file, &path, range) }) .await @@ -404,7 +399,7 @@ impl ObjectStore for LocalFileSystem { let ranges = ranges.to_vec(); maybe_spawn_blocking(move || { // Vectored IO might be faster - let mut file = open_file(&path)?; + let (mut file, _) = open_file(&path)?; ranges .into_iter() .map(|r| read_range(&mut file, &path, r)) @@ -413,35 +408,6 @@ impl ObjectStore for LocalFileSystem { .await } - async fn head(&self, location: &Path) -> Result { - let path = self.config.path_to_filesystem(location)?; - let location = location.clone(); - - maybe_spawn_blocking(move || { - let metadata = match metadata(&path) { - Err(e) => Err(match e.kind() { - ErrorKind::NotFound => Error::NotFound { - path: path.clone(), - source: e, - }, - _ => Error::Metadata { - source: e.into(), - path: location.to_string(), - }, - }), - Ok(m) => match !m.is_dir() { - true => Ok(m), - false => Err(Error::NotFound { - path, - source: io::Error::new(ErrorKind::NotFound, "is directory"), - }), - }, - }?; - convert_metadata(metadata, location) - }) - .await - } - async fn delete(&self, location: &Path) -> Result<()> { let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || match std::fs::remove_file(&path) { @@ -454,14 +420,14 @@ impl ObjectStore for LocalFileSystem { .await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let config = Arc::clone(&self.config); let root_path = match prefix { - Some(prefix) => config.path_to_filesystem(prefix)?, + Some(prefix) => match config.path_to_filesystem(prefix) { + Ok(path) => path, + Err(e) => return futures::future::ready(Err(e)).into_stream().boxed(), + }, None => self.config.root.to_file_path().unwrap(), }; @@ -491,36 +457,34 @@ impl ObjectStore for LocalFileSystem { // If no tokio context, return iterator directly as no // need to perform chunked spawn_blocking reads if tokio::runtime::Handle::try_current().is_err() { - return Ok(futures::stream::iter(s).boxed()); + return futures::stream::iter(s).boxed(); } // Otherwise list in batches of CHUNK_SIZE const CHUNK_SIZE: usize = 1024; let buffer = VecDeque::with_capacity(CHUNK_SIZE); - let stream = - futures::stream::try_unfold((s, buffer), |(mut s, mut buffer)| async move { - if buffer.is_empty() { - (s, buffer) = tokio::task::spawn_blocking(move || { - for _ in 0..CHUNK_SIZE { - match s.next() { - Some(r) => buffer.push_back(r), - None => break, - } + futures::stream::try_unfold((s, buffer), |(mut s, mut buffer)| async move { + if buffer.is_empty() { + (s, buffer) = tokio::task::spawn_blocking(move || { + for _ in 0..CHUNK_SIZE { + match s.next() { + Some(r) => buffer.push_back(r), + None => break, } - (s, buffer) - }) - .await?; - } - - match buffer.pop_front() { - Some(Err(e)) => Err(e), - Some(Ok(meta)) => Ok(Some((meta, (s, buffer)))), - None => Ok(None), - } - }); + } + (s, buffer) + }) + .await?; + } - Ok(stream.boxed()) + match buffer.pop_front() { + Some(Err(e)) => Err(e), + Some(Ok(meta)) => Ok(Some((meta, (s, buffer)))), + None => Ok(None), + } + }) + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -581,10 +545,28 @@ impl ObjectStore for LocalFileSystem { async fn copy(&self, from: &Path, to: &Path) -> Result<()> { let from = self.config.path_to_filesystem(from)?; let to = self.config.path_to_filesystem(to)?; - - maybe_spawn_blocking(move || { - std::fs::copy(&from, &to).context(UnableToCopyFileSnafu { from, to })?; - Ok(()) + let mut id = 0; + // In order to make this atomic we: + // + // - hard link to a hidden temporary file + // - atomically rename this temporary file into place + // + // This is necessary because hard_link returns an error if the destination already exists + maybe_spawn_blocking(move || loop { + let staged = staged_upload_path(&to, &id.to_string()); + match std::fs::hard_link(&from, &staged) { + Ok(_) => { + return std::fs::rename(&staged, &to).map_err(|source| { + let _ = std::fs::remove_file(&staged); // Attempt to clean up + Error::UnableToCopyFile { from, to, source }.into() + }); + } + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => id += 1, + ErrorKind::NotFound => create_parent_dirs(&to, source)?, + _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), + }, + } }) .await } @@ -592,9 +574,14 @@ impl ObjectStore for LocalFileSystem { async fn rename(&self, from: &Path, to: &Path) -> Result<()> { let from = self.config.path_to_filesystem(from)?; let to = self.config.path_to_filesystem(to)?; - maybe_spawn_blocking(move || { - std::fs::rename(&from, &to).context(UnableToCopyFileSnafu { from, to })?; - Ok(()) + maybe_spawn_blocking(move || loop { + match std::fs::rename(&from, &to) { + Ok(_) => return Ok(()), + Err(source) => match source.kind() { + ErrorKind::NotFound => create_parent_dirs(&to, source)?, + _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), + }, + } }) .await } @@ -603,25 +590,37 @@ impl ObjectStore for LocalFileSystem { let from = self.config.path_to_filesystem(from)?; let to = self.config.path_to_filesystem(to)?; - maybe_spawn_blocking(move || { - std::fs::hard_link(&from, &to).map_err(|err| match err.kind() { - io::ErrorKind::AlreadyExists => Error::AlreadyExists { - path: to.to_str().unwrap().to_string(), - source: err, - } - .into(), - _ => Error::UnableToCopyFile { - from, - to, - source: err, - } - .into(), - }) + maybe_spawn_blocking(move || loop { + match std::fs::hard_link(&from, &to) { + Ok(_) => return Ok(()), + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => { + return Err(Error::AlreadyExists { + path: to.to_str().unwrap().to_string(), + source, + } + .into()) + } + ErrorKind::NotFound => create_parent_dirs(&to, source)?, + _ => return Err(Error::UnableToCopyFile { from, to, source }.into()), + }, + } }) .await } } +/// Creates the parent directories of `path` or returns an error based on `source` if no parent +fn create_parent_dirs(path: &std::path::Path, source: io::Error) -> Result<()> { + let parent = path.parent().ok_or_else(|| Error::UnableToCreateFile { + path: path.to_path_buf(), + source, + })?; + + std::fs::create_dir_all(parent).context(UnableToCreateDirSnafu { path: parent })?; + Ok(()) +} + /// Generates a unique file path `{base}#{suffix}`, returning the opened `File` and `suffix` /// /// Creates any directories if necessary @@ -633,20 +632,11 @@ fn new_staged_upload(base: &std::path::Path) -> Result<(File, String)> { let mut options = OpenOptions::new(); match options.read(true).write(true).create_new(true).open(&path) { Ok(f) => return Ok((f, suffix)), - Err(e) if e.kind() == ErrorKind::AlreadyExists => { - multipart_id += 1; - } - Err(err) if err.kind() == ErrorKind::NotFound => { - let parent = path - .parent() - .context(UnableToCreateFileSnafu { path: &path, err })?; - - std::fs::create_dir_all(parent) - .context(UnableToCreateDirSnafu { path: parent })?; - - continue; - } - Err(source) => return Err(Error::UnableToOpenFile { source, path }.into()), + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => multipart_id += 1, + ErrorKind::NotFound => create_parent_dirs(&path, source)?, + _ => return Err(Error::UnableToOpenFile { source, path }.into()), + }, } } } @@ -661,12 +651,9 @@ fn staged_upload_path(dest: &std::path::Path, suffix: &str) -> PathBuf { enum LocalUploadState { /// Upload is ready to send new data - Idle(Arc), + Idle(Arc), /// In the middle of a write - Writing( - Arc, - BoxFuture<'static, Result>, - ), + Writing(Arc, BoxFuture<'static, Result>), /// In the middle of syncing data and closing file. /// /// Future will contain last reference to file, so it will call drop on completion. @@ -684,11 +671,7 @@ struct LocalUpload { } impl LocalUpload { - pub fn new( - dest: PathBuf, - multipart_id: MultipartId, - file: Arc, - ) -> Self { + pub fn new(dest: PathBuf, multipart_id: MultipartId, file: Arc) -> Self { Self { inner_state: LocalUploadState::Idle(file), dest, @@ -702,14 +685,13 @@ impl AsyncWrite for LocalUpload { mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> std::task::Poll> { - let invalid_state = - |condition: &str| -> std::task::Poll> { - Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("Tried to write to file {condition}."), - ))) - }; + ) -> Poll> { + let invalid_state = |condition: &str| -> Poll> { + Poll::Ready(Err(io::Error::new( + ErrorKind::InvalidInput, + format!("Tried to write to file {condition}."), + ))) + }; if let Ok(runtime) = tokio::runtime::Handle::try_current() { let mut data: Vec = buf.to_vec(); @@ -728,7 +710,7 @@ impl AsyncWrite for LocalUpload { .spawn_blocking(move || (&*file2).write_all(&data)) .map(move |res| match res { Err(err) => { - Err(io::Error::new(io::ErrorKind::Other, err)) + Err(io::Error::new(ErrorKind::Other, err)) } Ok(res) => res.map(move |_| data_len), }), @@ -736,16 +718,9 @@ impl AsyncWrite for LocalUpload { ); } LocalUploadState::Writing(file, inner_write) => { - match inner_write.poll_unpin(cx) { - Poll::Ready(res) => { - self.inner_state = - LocalUploadState::Idle(Arc::clone(file)); - return Poll::Ready(res); - } - Poll::Pending => { - return Poll::Pending; - } - } + let res = ready!(inner_write.poll_unpin(cx)); + self.inner_state = LocalUploadState::Idle(Arc::clone(file)); + return Poll::Ready(res); } LocalUploadState::ShuttingDown(_) => { return invalid_state("when writer is shutting down"); @@ -771,14 +746,14 @@ impl AsyncWrite for LocalUpload { fn poll_flush( self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { if let Ok(runtime) = tokio::runtime::Handle::try_current() { loop { match &mut self.inner_state { @@ -825,13 +800,11 @@ impl AsyncWrite for LocalUpload { "Tried to commit a file where a write is in progress.", ))); } - LocalUploadState::Committing(fut) => match fut.poll_unpin(cx) { - Poll::Ready(res) => { - self.inner_state = LocalUploadState::Complete; - return Poll::Ready(res); - } - Poll::Pending => return Poll::Pending, - }, + LocalUploadState::Committing(fut) => { + let res = ready!(fut.poll_unpin(cx)); + self.inner_state = LocalUploadState::Complete; + return Poll::Ready(res); + } LocalUploadState::Complete => { return Poll::Ready(Err(io::Error::new( io::ErrorKind::Other, @@ -847,23 +820,86 @@ impl AsyncWrite for LocalUpload { let file = Arc::clone(file); self.inner_state = LocalUploadState::Complete; file.sync_all()?; - std::mem::drop(file); + drop(file); std::fs::rename(staging_path, &self.dest)?; Poll::Ready(Ok(())) } _ => { // If we are running on this thread, then only possible states are Idle and Complete. - Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Already complete", - ))) + Poll::Ready(Err(io::Error::new(ErrorKind::Other, "Already complete"))) } } } } } -pub(crate) fn read_range(file: &mut File, path: &PathBuf, range: Range) -> Result { +impl Drop for LocalUpload { + fn drop(&mut self) { + match self.inner_state { + LocalUploadState::Complete => (), + _ => { + self.inner_state = LocalUploadState::Complete; + let path = staged_upload_path(&self.dest, &self.multipart_id); + // Try to cleanup intermediate file ignoring any error + match tokio::runtime::Handle::try_current() { + Ok(r) => drop(r.spawn_blocking(move || std::fs::remove_file(path))), + Err(_) => drop(std::fs::remove_file(path)), + }; + } + } + } +} + +pub(crate) fn chunked_stream( + mut file: File, + path: PathBuf, + range: Range, + chunk_size: usize, +) -> BoxStream<'static, Result> { + futures::stream::once(async move { + let (file, path) = maybe_spawn_blocking(move || { + file.seek(SeekFrom::Start(range.start as _)) + .map_err(|source| Error::Seek { + source, + path: path.clone(), + })?; + Ok((file, path)) + }) + .await?; + + let stream = futures::stream::try_unfold( + (file, path, range.end - range.start), + move |(mut file, path, remaining)| { + maybe_spawn_blocking(move || { + if remaining == 0 { + return Ok(None); + } + + let to_read = remaining.min(chunk_size); + let mut buffer = Vec::with_capacity(to_read); + let read = (&mut file) + .take(to_read as u64) + .read_to_end(&mut buffer) + .map_err(|e| Error::UnableToReadBytes { + source: e, + path: path.clone(), + })?; + + Ok(Some((buffer.into(), (file, path, remaining - read)))) + }) + }, + ); + Ok::<_, super::Error>(stream) + }) + .try_flatten() + .boxed() +} + +pub(crate) fn read_range( + file: &mut File, + path: &PathBuf, + range: Range, +) -> Result { let to_read = range.end - range.start; file.seek(SeekFrom::Start(range.start as u64)) .context(SeekSnafu { path })?; @@ -885,8 +921,8 @@ pub(crate) fn read_range(file: &mut File, path: &PathBuf, range: Range) - Ok(buf.into()) } -fn open_file(path: &PathBuf) -> Result { - let file = match File::open(path).and_then(|f| Ok((f.metadata()?, f))) { +fn open_file(path: &PathBuf) -> Result<(File, Metadata)> { + let ret = match File::open(path).and_then(|f| Ok((f.metadata()?, f))) { Err(e) => Err(match e.kind() { ErrorKind::NotFound => Error::NotFound { path: path.clone(), @@ -898,14 +934,14 @@ fn open_file(path: &PathBuf) -> Result { }, }), Ok((metadata, file)) => match !metadata.is_dir() { - true => Ok(file), + true => Ok((file, metadata)), false => Err(Error::NotFound { path: path.clone(), source: io::Error::new(ErrorKind::NotFound, "is directory"), }), }, }?; - Ok(file) + Ok(ret) } fn convert_entry(entry: DirEntry, location: Path) -> Result { @@ -916,32 +952,52 @@ fn convert_entry(entry: DirEntry, location: Path) -> Result { convert_metadata(metadata, location) } -fn last_modified(metadata: &std::fs::Metadata) -> DateTime { +fn last_modified(metadata: &Metadata) -> DateTime { metadata .modified() .expect("Modified file time should be supported on this platform") .into() } -fn convert_metadata(metadata: std::fs::Metadata, location: Path) -> Result { +fn convert_metadata(metadata: Metadata, location: Path) -> Result { let last_modified = last_modified(&metadata); let size = usize::try_from(metadata.len()).context(FileSizeOverflowedUsizeSnafu { path: location.as_ref(), })?; + let inode = get_inode(&metadata); + let mtime = last_modified.timestamp_micros(); + + // Use an ETag scheme based on that used by many popular HTTP servers + // + // + let etag = format!("{inode:x}-{mtime:x}-{size:x}"); Ok(ObjectMeta { location, last_modified, size, - e_tag: None, + e_tag: Some(etag), }) } +#[cfg(unix)] +/// We include the inode when available to yield an ETag more resistant to collisions +/// and as used by popular web servers such as [Apache](https://httpd.apache.org/docs/2.2/mod/core.html#fileetag) +fn get_inode(metadata: &Metadata) -> u64 { + std::os::unix::fs::MetadataExt::ino(metadata) +} + +#[cfg(not(unix))] +/// On platforms where an inode isn't available, fallback to just relying on size and mtime +fn get_inode(metadata: &Metadata) -> u64 { + 0 +} + /// Convert walkdir results and converts not-found errors into `None`. /// Convert broken symlinks to `None`. fn convert_walkdir_result( - res: std::result::Result, -) -> Result> { + res: std::result::Result, +) -> Result> { match res { Ok(entry) => { // To check for broken symlink: call symlink_metadata() - it does not traverse symlinks); @@ -970,7 +1026,7 @@ fn convert_walkdir_result( Err(walkdir_err) => match walkdir_err.io_error() { Some(io_err) => match io_err.kind() { - io::ErrorKind::NotFound => Ok(None), + ErrorKind::NotFound => Ok(None), _ => Err(Error::UnableToWalkDir { source: walkdir_err, } @@ -1080,21 +1136,14 @@ mod tests { let store = LocalFileSystem::new_with_prefix(root.path()).unwrap(); - // `list` must fail - match store.list(None).await { - Err(_) => { - // ok, error found - } - Ok(mut stream) => { - let mut any_err = false; - while let Some(res) = stream.next().await { - if res.is_err() { - any_err = true; - } - } - assert!(any_err); + let mut stream = store.list(None); + let mut any_err = false; + while let Some(res) = stream.next().await { + if res.is_err() { + any_err = true; } } + assert!(any_err); // `list_with_delimiter assert!(store.list_with_delimiter(None).await.is_err()); @@ -1168,13 +1217,7 @@ mod tests { prefix: Option<&Path>, expected: &[&str], ) { - let result: Vec<_> = integration - .list(prefix) - .await - .unwrap() - .try_collect() - .await - .unwrap(); + let result: Vec<_> = integration.list(prefix).try_collect().await.unwrap(); let mut strings: Vec<_> = result.iter().map(|x| x.location.as_ref()).collect(); strings.sort_unstable(); @@ -1370,8 +1413,7 @@ mod tests { std::fs::write(temp_dir.path().join(filename), "foo").unwrap(); - let list_stream = integration.list(None).await.unwrap(); - let res: Vec<_> = list_stream.try_collect().await.unwrap(); + let res: Vec<_> = integration.list(None).try_collect().await.unwrap(); assert_eq!(res.len(), 1); assert_eq!(res[0].location.as_ref(), filename); @@ -1398,6 +1440,7 @@ mod not_wasm_tests { use crate::local::LocalFileSystem; use crate::{ObjectStore, Path}; use bytes::Bytes; + use std::time::Duration; use tempfile::TempDir; use tokio::io::AsyncWriteExt; @@ -1415,6 +1458,8 @@ mod not_wasm_tests { writer.write_all(data.as_ref()).await.unwrap(); + writer.flush().await.unwrap(); + let read_data = integration .get(&location) .await @@ -1466,11 +1511,13 @@ mod not_wasm_tests { for d in &data { writer.write_all(d).await.unwrap(); } + writer.flush().await.unwrap(); let mut writer = integration.append(&location).await.unwrap(); for d in &data { writer.write_all(d).await.unwrap(); } + writer.flush().await.unwrap(); let read_data = integration .get(&location) @@ -1482,6 +1529,25 @@ mod not_wasm_tests { let expected_data = Bytes::from("arbitrarydatagnzarbitrarydatagnz"); assert_eq!(&*read_data, expected_data); } + + #[tokio::test] + async fn test_cleanup_intermediate_files() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("some_file"); + let (_, mut writer) = integration.put_multipart(&location).await.unwrap(); + writer.write_all(b"hello").await.unwrap(); + + let file_count = std::fs::read_dir(root.path()).unwrap().count(); + assert_eq!(file_count, 1); + drop(writer); + + tokio::time::sleep(Duration::from_millis(1)).await; + + let file_count = std::fs::read_dir(root.path()).unwrap().count(); + assert_eq!(file_count, 0); + } } #[cfg(target_family = "unix")] @@ -1502,15 +1568,15 @@ mod unix_test { let path = root.path().join(filename); unistd::mkfifo(&path, stat::Mode::S_IRWXU).unwrap(); - let location = Path::from(filename); - integration.head(&location).await.unwrap(); - // Need to open read and write side in parallel let spawned = tokio::task::spawn_blocking(|| { - OpenOptions::new().write(true).open(path).unwrap(); + OpenOptions::new().write(true).open(path).unwrap() }); + let location = Path::from(filename); + integration.head(&location).await.unwrap(); integration.get(&location).await.unwrap(); + spawned.await.unwrap(); } } diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index 82d485997e88..00b330b5eb94 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -16,7 +16,9 @@ // under the License. //! An in-memory object store implementation -use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore, Result}; +use crate::{ + path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, Result, +}; use crate::{GetOptions, MultipartId}; use async_trait::async_trait; use bytes::Bytes; @@ -33,9 +35,6 @@ use std::sync::Arc; use std::task::Poll; use tokio::io::AsyncWrite; -type Entry = (Bytes, DateTime); -type StorageType = Arc>>; - /// A specialized `Error` for in-memory object store-related errors #[derive(Debug, Snafu)] #[allow(missing_docs)] @@ -43,11 +42,13 @@ enum Error { #[snafu(display("No data in memory found. Location: {path}"))] NoDataInMemory { path: String }, - #[snafu(display("Out of range"))] - OutOfRange, + #[snafu(display( + "Requested range {}..{} is out of bounds for object with length {}", range.start, range.end, len + ))] + OutOfRange { range: Range, len: usize }, - #[snafu(display("Bad range"))] - BadRange, + #[snafu(display("Invalid range: {}..{}", range.start, range.end))] + BadRange { range: Range }, #[snafu(display("Object already exists at that location: {path}"))] AlreadyExists { path: String }, @@ -76,7 +77,41 @@ impl From for super::Error { /// storage provider. #[derive(Debug, Default)] pub struct InMemory { - storage: StorageType, + storage: SharedStorage, +} + +#[derive(Debug, Clone)] +struct Entry { + data: Bytes, + last_modified: DateTime, + e_tag: usize, +} + +impl Entry { + fn new(data: Bytes, last_modified: DateTime, e_tag: usize) -> Self { + Self { + data, + last_modified, + e_tag, + } + } +} + +#[derive(Debug, Default, Clone)] +struct Storage { + next_etag: usize, + map: BTreeMap, +} + +type SharedStorage = Arc>; + +impl Storage { + fn insert(&mut self, location: &Path, bytes: Bytes) { + let etag = self.next_etag; + self.next_etag += 1; + let entry = Entry::new(bytes, Utc::now(), etag); + self.map.insert(location.clone(), entry); + } } impl std::fmt::Display for InMemory { @@ -88,9 +123,7 @@ impl std::fmt::Display for InMemory { #[async_trait] impl ObjectStore for InMemory { async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.storage - .write() - .insert(location.clone(), (bytes, Utc::now())); + self.storage.write().insert(location, bytes); Ok(()) } @@ -124,29 +157,38 @@ impl ObjectStore for InMemory { Ok(Box::new(InMemoryAppend { location: location.clone(), data: Vec::::new(), - storage: StorageType::clone(&self.storage), + storage: SharedStorage::clone(&self.storage), })) } async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - if options.if_match.is_some() || options.if_none_match.is_some() { - return Err(super::Error::NotSupported { - source: "ETags not supported by InMemory".to_string().into(), - }); - } - let (data, last_modified) = self.entry(location).await?; - options.check_modified(location, last_modified)?; + let entry = self.entry(location).await?; + let e_tag = entry.e_tag.to_string(); + let meta = ObjectMeta { + location: location.clone(), + last_modified: entry.last_modified, + size: entry.data.len(), + e_tag: Some(e_tag), + }; + options.check_preconditions(&meta)?; + + let (range, data) = match options.range { + Some(range) => { + let len = entry.data.len(); + ensure!(range.end <= len, OutOfRangeSnafu { range, len }); + ensure!(range.start <= range.end, BadRangeSnafu { range }); + (range.clone(), entry.data.slice(range)) + } + None => (0..entry.data.len(), entry.data), + }; let stream = futures::stream::once(futures::future::ready(Ok(data))); - Ok(GetResult::Stream(stream.boxed())) - } - - async fn get_range(&self, location: &Path, range: Range) -> Result { - let data = self.entry(location).await?; - ensure!(range.end <= data.0.len(), OutOfRangeSnafu); - ensure!(range.start <= range.end, BadRangeSnafu); - Ok(data.0.slice(range)) + Ok(GetResult { + payload: GetResultPayload::Stream(stream.boxed()), + meta, + range, + }) } async fn get_ranges( @@ -154,13 +196,18 @@ impl ObjectStore for InMemory { location: &Path, ranges: &[Range], ) -> Result> { - let data = self.entry(location).await?; + let entry = self.entry(location).await?; ranges .iter() .map(|range| { - ensure!(range.end <= data.0.len(), OutOfRangeSnafu); - ensure!(range.start <= range.end, BadRangeSnafu); - Ok(data.0.slice(range.clone())) + let range = range.clone(); + let len = entry.data.len(); + ensure!( + range.end <= entry.data.len(), + OutOfRangeSnafu { range, len } + ); + ensure!(range.start <= range.end, BadRangeSnafu { range }); + Ok(entry.data.slice(range)) }) .collect() } @@ -170,26 +217,24 @@ impl ObjectStore for InMemory { Ok(ObjectMeta { location: location.clone(), - last_modified: entry.1, - size: entry.0.len(), - e_tag: None, + last_modified: entry.last_modified, + size: entry.data.len(), + e_tag: Some(entry.e_tag.to_string()), }) } async fn delete(&self, location: &Path) -> Result<()> { - self.storage.write().remove(location); + self.storage.write().map.remove(location); Ok(()) } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let root = Path::default(); let prefix = prefix.unwrap_or(&root); let storage = self.storage.read(); let values: Vec<_> = storage + .map .range((prefix)..) .take_while(|(key, _)| key.as_ref().starts_with(prefix.as_ref())) .filter(|(key, _)| { @@ -201,14 +246,14 @@ impl ObjectStore for InMemory { .map(|(key, value)| { Ok(ObjectMeta { location: key.clone(), - last_modified: value.1, - size: value.0.len(), - e_tag: None, + last_modified: value.last_modified, + size: value.data.len(), + e_tag: Some(value.e_tag.to_string()), }) }) .collect(); - Ok(futures::stream::iter(values).boxed()) + futures::stream::iter(values).boxed() } /// The memory implementation returns all results, as opposed to the cloud @@ -223,7 +268,7 @@ impl ObjectStore for InMemory { // Only objects in this base level should be returned in the // response. Otherwise, we just collect the common prefixes. let mut objects = vec![]; - for (k, v) in self.storage.read().range((prefix)..) { + for (k, v) in self.storage.read().map.range((prefix)..) { if !k.as_ref().starts_with(prefix.as_ref()) { break; } @@ -245,9 +290,9 @@ impl ObjectStore for InMemory { } else { let object = ObjectMeta { location: k.clone(), - last_modified: v.1, - size: v.0.len(), - e_tag: None, + last_modified: v.last_modified, + size: v.data.len(), + e_tag: Some(v.e_tag.to_string()), }; objects.push(object); } @@ -260,23 +305,21 @@ impl ObjectStore for InMemory { } async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - let data = self.entry(from).await?; - self.storage - .write() - .insert(to.clone(), (data.0, Utc::now())); + let entry = self.entry(from).await?; + self.storage.write().insert(to, entry.data); Ok(()) } async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - let data = self.entry(from).await?; + let entry = self.entry(from).await?; let mut storage = self.storage.write(); - if storage.contains_key(to) { + if storage.map.contains_key(to) { return Err(Error::AlreadyExists { path: to.to_string(), } .into()); } - storage.insert(to.clone(), (data.0, Utc::now())); + storage.insert(to, entry.data); Ok(()) } } @@ -287,19 +330,24 @@ impl InMemory { Self::default() } - /// Creates a clone of the store - pub async fn clone(&self) -> Self { + /// Creates a fork of the store, with the current content copied into the + /// new store. + pub fn fork(&self) -> Self { let storage = self.storage.read(); - let storage = storage.clone(); + let storage = Arc::new(RwLock::new(storage.clone())); + Self { storage } + } - Self { - storage: Arc::new(RwLock::new(storage)), - } + /// Creates a clone of the store + #[deprecated(note = "Use fork() instead")] + pub async fn clone(&self) -> Self { + self.fork() } - async fn entry(&self, location: &Path) -> Result<(Bytes, DateTime)> { + async fn entry(&self, location: &Path) -> Result { let storage = self.storage.read(); let value = storage + .map .get(location) .cloned() .context(NoDataInMemorySnafu { @@ -313,7 +361,7 @@ impl InMemory { struct InMemoryUpload { location: Path, data: Vec, - storage: StorageType, + storage: Arc>, } impl AsyncWrite for InMemoryUpload { @@ -321,7 +369,7 @@ impl AsyncWrite for InMemoryUpload { mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> std::task::Poll> { + ) -> Poll> { self.data.extend_from_slice(buf); Poll::Ready(Ok(buf.len())) } @@ -329,18 +377,16 @@ impl AsyncWrite for InMemoryUpload { fn poll_flush( self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { let data = Bytes::from(std::mem::take(&mut self.data)); - self.storage - .write() - .insert(self.location.clone(), (data, Utc::now())); + self.storage.write().insert(&self.location, data); Poll::Ready(Ok(())) } } @@ -348,7 +394,7 @@ impl AsyncWrite for InMemoryUpload { struct InMemoryAppend { location: Path, data: Vec, - storage: StorageType, + storage: Arc>, } impl AsyncWrite for InMemoryAppend { @@ -356,7 +402,7 @@ impl AsyncWrite for InMemoryAppend { mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8], - ) -> std::task::Poll> { + ) -> Poll> { self.data.extend_from_slice(buf); Poll::Ready(Ok(buf.len())) } @@ -364,20 +410,18 @@ impl AsyncWrite for InMemoryAppend { fn poll_flush( mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let storage = StorageType::clone(&self.storage); + ) -> Poll> { + let storage = Arc::clone(&self.storage); let mut writer = storage.write(); - if let Some((bytes, _)) = writer.remove(&self.location) { + if let Some(entry) = writer.map.remove(&self.location) { let buf = std::mem::take(&mut self.data); - let concat = Bytes::from_iter(bytes.into_iter().chain(buf.into_iter())); - writer.insert(self.location.clone(), (concat, Utc::now())); + let concat = Bytes::from_iter(entry.data.into_iter().chain(buf)); + writer.insert(&self.location, concat); } else { - writer.insert( - self.location.clone(), - (Bytes::from(std::mem::take(&mut self.data)), Utc::now()), - ); + let data = Bytes::from(std::mem::take(&mut self.data)); + writer.insert(&self.location, data); }; Poll::Ready(Ok(())) } @@ -411,6 +455,32 @@ mod tests { stream_get(&integration).await; } + #[tokio::test] + async fn box_test() { + let integration: Box = Box::new(InMemory::new()); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + + #[tokio::test] + async fn arc_test() { + let integration: Arc = Arc::new(InMemory::new()); + + put_get_delete_list(&integration).await; + get_opts(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + #[tokio::test] async fn unknown_length() { let integration = InMemory::new(); diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs index 26580307053e..d4c911fceab4 100644 --- a/object_store/src/multipart.rs +++ b/object_store/src/multipart.rs @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! Cloud Multipart Upload +//! +//! This crate provides an asynchronous interface for multipart file uploads to cloud storage services. +//! It's designed to offer efficient, non-blocking operations, +//! especially useful when dealing with large files or high-throughput systems. + use async_trait::async_trait; use futures::{stream::FuturesUnordered, Future, StreamExt}; use std::{io, pin::Pin, sync::Arc, task::Poll}; @@ -25,37 +31,33 @@ use crate::Result; type BoxedTryFuture = Pin> + Send>>; /// A trait that can be implemented by cloud-based object stores -/// and used in combination with [`CloudMultiPartUpload`] to provide +/// and used in combination with [`WriteMultiPart`] to provide /// multipart upload support #[async_trait] -pub(crate) trait CloudMultiPartUploadImpl: 'static { +pub trait PutPart: Send + Sync + 'static { /// Upload a single part - async fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> Result; + async fn put_part(&self, buf: Vec, part_idx: usize) -> Result; /// Complete the upload with the provided parts /// /// `completed_parts` is in order of part number - async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error>; + async fn complete(&self, completed_parts: Vec) -> Result<()>; } +/// Represents a part of a file that has been successfully uploaded in a multipart upload process. #[derive(Debug, Clone)] -pub(crate) struct UploadPart { +pub struct PartId { + /// Id of this part pub content_id: String, } -pub(crate) struct CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl, -{ +/// Wrapper around a [`PutPart`] that implements [`AsyncWrite`] +pub struct WriteMultiPart { inner: Arc, /// A list of completed parts, in sequential order. - completed_parts: Vec>, + completed_parts: Vec>, /// Part upload tasks currently running - tasks: FuturesUnordered>, + tasks: FuturesUnordered>, /// Maximum number of upload tasks to run concurrently max_concurrency: usize, /// Buffer that will be sent in next upload. @@ -71,10 +73,8 @@ where completion_task: Option>, } -impl CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl, -{ +impl WriteMultiPart { + /// Create a new multipart upload with the implementation and the given maximum concurrency pub fn new(inner: T, max_concurrency: usize) -> Self { Self { inner: Arc::new(inner), @@ -103,7 +103,8 @@ where to_copy } - pub fn poll_tasks( + /// Poll current tasks + fn poll_tasks( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Result<(), io::Error> { @@ -119,12 +120,7 @@ where } Ok(()) } -} -impl CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl + Send + Sync, -{ // The `poll_flush` function will only flush the in-progress tasks. // The `final_flush` method called during `poll_shutdown` will flush // the `current_buffer` along with in-progress tasks. @@ -142,7 +138,7 @@ where let inner = Arc::clone(&self.inner); let part_idx = self.current_part_idx; self.tasks.push(Box::pin(async move { - let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + let upload_part = inner.put_part(out_buffer, part_idx).await?; Ok((part_idx, upload_part)) })); } @@ -158,10 +154,7 @@ where } } -impl AsyncWrite for CloudMultiPartUpload -where - T: CloudMultiPartUploadImpl + Send + Sync, -{ +impl AsyncWrite for WriteMultiPart { fn poll_write( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -188,7 +181,7 @@ where let inner = Arc::clone(&self.inner); let part_idx = self.current_part_idx; self.tasks.push(Box::pin(async move { - let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + let upload_part = inner.put_part(out_buffer, part_idx).await?; Ok((part_idx, upload_part)) })); self.current_part_idx += 1; @@ -257,3 +250,16 @@ where Pin::new(completion_task).poll(cx) } } + +impl std::fmt::Debug for WriteMultiPart { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WriteMultiPart") + .field("completed_parts", &self.completed_parts) + .field("tasks", &self.tasks) + .field("max_concurrency", &self.max_concurrency) + .field("current_buffer", &self.current_buffer) + .field("part_size", &self.part_size) + .field("current_part_idx", &self.current_part_idx) + .finish() + } +} diff --git a/object_store/src/parse.rs b/object_store/src/parse.rs index 7b89e58e10e7..2e72a710ac75 100644 --- a/object_store/src/parse.rs +++ b/object_store/src/parse.rs @@ -47,12 +47,12 @@ impl From for super::Error { } } -/// Recognises various URL formats, identifying the relevant [`ObjectStore`](crate::ObjectStore) +/// Recognises various URL formats, identifying the relevant [`ObjectStore`] #[derive(Debug, Eq, PartialEq)] enum ObjectStoreScheme { - /// Url corresponding to [`LocalFileSystem`](crate::local::LocalFileSystem) + /// Url corresponding to [`LocalFileSystem`] Local, - /// Url corresponding to [`InMemory`](crate::memory::InMemory) + /// Url corresponding to [`InMemory`] Memory, /// Url corresponding to [`AmazonS3`](crate::aws::AmazonS3) AmazonS3, @@ -104,7 +104,7 @@ impl ObjectStoreScheme { } } -#[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))] +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] macro_rules! builder_opts { ($builder:ty, $url:expr, $options:expr) => {{ let builder = $options.into_iter().fold( diff --git a/object_store/src/prefix.rs b/object_store/src/prefix.rs index 39585f73b692..5e7ea324677b 100644 --- a/object_store/src/prefix.rs +++ b/object_store/src/prefix.rs @@ -23,7 +23,8 @@ use tokio::io::AsyncWrite; use crate::path::Path; use crate::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, + Result, }; #[doc(hidden)] @@ -84,6 +85,16 @@ impl ObjectStore for PrefixStore { self.inner.put(&full_path, bytes).await } + async fn put_opts( + &self, + location: &Path, + bytes: Bytes, + options: PutOptions, + ) -> Result<()> { + let full_path = self.full_path(location); + self.inner.put_opts(&full_path, bytes, options).await + } + async fn put_multipart( &self, location: &Path, @@ -144,24 +155,21 @@ impl ObjectStore for PrefixStore { self.inner.delete(&full_path).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { let prefix = self.full_path(prefix.unwrap_or(&Path::default())); - let s = self.inner.list(Some(&prefix)).await?; - Ok(s.map_ok(|meta| self.strip_meta(meta)).boxed()) + let s = self.inner.list(Some(&prefix)); + s.map_ok(|meta| self.strip_meta(meta)).boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { + ) -> BoxStream<'_, Result> { let offset = self.full_path(offset); let prefix = self.full_path(prefix.unwrap_or(&Path::default())); - let s = self.inner.list_with_offset(Some(&prefix), &offset).await?; - Ok(s.map_ok(|meta| self.strip_meta(meta)).boxed()) + let s = self.inner.list_with_offset(Some(&prefix), &offset); + s.map_ok(|meta| self.strip_meta(meta)).boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { diff --git a/object_store/src/signer.rs b/object_store/src/signer.rs new file mode 100644 index 000000000000..f1f35debe053 --- /dev/null +++ b/object_store/src/signer.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Abstraction of signed URL generation for those object store implementations that support it + +use crate::{path::Path, Result}; +use async_trait::async_trait; +use reqwest::Method; +use std::{fmt, time::Duration}; +use url::Url; + +/// Universal API to presigned URLs generated from multiple object store services. Not supported by +/// all object store services. +#[async_trait] +pub trait Signer: Send + Sync + fmt::Debug + 'static { + /// Given the intended [`Method`] and [`Path`] to use and the desired length of time for which + /// the URL should be valid, return a signed [`Url`] created with the object store + /// implementation's credentials such that the URL can be handed to something that doesn't have + /// access to the object store's credentials, to allow limited access to the object store. + async fn signed_url( + &self, + method: Method, + path: &Path, + expires_in: Duration, + ) -> Result; +} diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index fb90afcec9fb..f716a11f8a05 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -20,7 +20,9 @@ use parking_lot::Mutex; use std::ops::Range; use std::{convert::TryInto, sync::Arc}; -use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore, Result}; +use crate::{ + path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, Result, +}; use crate::{GetOptions, MultipartId}; use async_trait::async_trait; use bytes::Bytes; @@ -231,29 +233,30 @@ impl ObjectStore for ThrottledStore { self.inner.delete(location).await } - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - sleep(self.config().wait_list_per_call).await; - - // need to copy to avoid moving / referencing `self` - let wait_list_per_entry = self.config().wait_list_per_entry; - let stream = self.inner.list(prefix).await?; - Ok(throttle_stream(stream, move |_| wait_list_per_entry)) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result> { + let stream = self.inner.list(prefix); + futures::stream::once(async move { + let wait_list_per_entry = self.config().wait_list_per_entry; + sleep(self.config().wait_list_per_call).await; + throttle_stream(stream, move |_| wait_list_per_entry) + }) + .flatten() + .boxed() } - async fn list_with_offset( + fn list_with_offset( &self, prefix: Option<&Path>, offset: &Path, - ) -> Result>> { - sleep(self.config().wait_list_per_call).await; - - // need to copy to avoid moving / referencing `self` - let wait_list_per_entry = self.config().wait_list_per_entry; - let stream = self.inner.list_with_offset(prefix, offset).await?; - Ok(throttle_stream(stream, move |_| wait_list_per_entry)) + ) -> BoxStream<'_, Result> { + let stream = self.inner.list_with_offset(prefix, offset); + futures::stream::once(async move { + let wait_list_per_entry = self.config().wait_list_per_entry; + sleep(self.config().wait_list_per_call).await; + throttle_stream(stream, move |_| wait_list_per_entry) + }) + .flatten() + .boxed() } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { @@ -301,15 +304,20 @@ fn usize_to_u32_saturate(x: usize) -> u32 { } fn throttle_get(result: GetResult, wait_get_per_byte: Duration) -> GetResult { - let s = match result { - GetResult::Stream(s) => s, - GetResult::File(_, _) => unimplemented!(), + let s = match result.payload { + GetResultPayload::Stream(s) => s, + GetResultPayload::File(_, _) => unimplemented!(), }; - GetResult::Stream(throttle_stream(s, move |bytes| { + let stream = throttle_stream(s, move |bytes| { let bytes_len: u32 = usize_to_u32_saturate(bytes.len()); wait_get_per_byte * bytes_len - })) + }); + + GetResult { + payload: GetResultPayload::Stream(stream), + ..result + } } fn throttle_stream( @@ -330,7 +338,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{memory::InMemory, tests::*}; + use crate::{memory::InMemory, tests::*, GetResultPayload}; use bytes::Bytes; use futures::TryStreamExt; use tokio::time::Duration; @@ -504,13 +512,7 @@ mod tests { let prefix = Path::from("foo"); // clean up store - let entries: Vec<_> = store - .list(Some(&prefix)) - .await - .unwrap() - .try_collect() - .await - .unwrap(); + let entries: Vec<_> = store.list(Some(&prefix)).try_collect().await.unwrap(); for entry in entries { store.delete(&entry.location).await.unwrap(); @@ -550,9 +552,9 @@ mod tests { let res = store.get(&path).await; if n_bytes.is_some() { // need to consume bytes to provoke sleep times - let s = match res.unwrap() { - GetResult::Stream(s) => s, - GetResult::File(_, _) => unimplemented!(), + let s = match res.unwrap().payload { + GetResultPayload::Stream(s) => s, + GetResultPayload::File(_, _) => unimplemented!(), }; s.map_ok(|b| bytes::BytesMut::from(&b[..])) @@ -576,8 +578,6 @@ mod tests { let t0 = Instant::now(); store .list(Some(&prefix)) - .await - .unwrap() .try_collect::>() .await .unwrap(); diff --git a/object_store/src/util.rs b/object_store/src/util.rs index 79ca4bb7a834..764582a67f95 100644 --- a/object_store/src/util.rs +++ b/object_store/src/util.rs @@ -32,8 +32,9 @@ where D: serde::Deserializer<'de>, { let s: String = serde::Deserialize::deserialize(deserializer)?; - chrono::TimeZone::datetime_from_str(&chrono::Utc, &s, RFC1123_FMT) - .map_err(serde::de::Error::custom) + let naive = chrono::NaiveDateTime::parse_from_str(&s, RFC1123_FMT) + .map_err(serde::de::Error::custom)?; + Ok(chrono::TimeZone::from_utc_datetime(&chrono::Utc, &naive)) } #[cfg(any(feature = "aws", feature = "azure"))] @@ -46,9 +47,13 @@ pub(crate) fn hmac_sha256( } /// Collect a stream into [`Bytes`] avoiding copying in the event of a single chunk -pub async fn collect_bytes(mut stream: S, size_hint: Option) -> Result +pub async fn collect_bytes( + mut stream: S, + size_hint: Option, +) -> Result where - S: Stream> + Send + Unpin, + E: Send, + S: Stream> + Send + Unpin, { let first = stream.next().await.transpose()?.unwrap_or_default(); @@ -98,14 +103,15 @@ pub const OBJECT_STORE_COALESCE_PARALLEL: usize = 10; /// * Combine ranges less than `coalesce` bytes apart into a single call to `fetch` /// * Make multiple `fetch` requests in parallel (up to maximum of 10) /// -pub async fn coalesce_ranges( +pub async fn coalesce_ranges( ranges: &[std::ops::Range], fetch: F, coalesce: usize, -) -> Result> +) -> Result, E> where F: Send + FnMut(std::ops::Range) -> Fut, - Fut: std::future::Future> + Send, + E: Send, + Fut: std::future::Future> + Send, { let fetch_ranges = merge_ranges(ranges, coalesce); @@ -172,6 +178,8 @@ fn merge_ranges( #[cfg(test)] mod tests { + use crate::Error; + use super::*; use rand::{thread_rng, Rng}; use std::ops::Range; @@ -184,7 +192,7 @@ mod tests { let src: Vec<_> = (0..max).map(|x| x as u8).collect(); let mut fetches = vec![]; - let coalesced = coalesce_ranges( + let coalesced = coalesce_ranges::<_, Error, _>( &ranges, |range| { fetches.push(range.clone()); @@ -207,7 +215,7 @@ mod tests { let fetches = do_fetch(vec![], 0).await; assert!(fetches.is_empty()); - let fetches = do_fetch(vec![0..3], 0).await; + let fetches = do_fetch(vec![0..3; 1], 0).await; assert_eq!(fetches, vec![0..3]); let fetches = do_fetch(vec![0..2, 3..5], 0).await; diff --git a/object_store/tests/get_range_file.rs b/object_store/tests/get_range_file.rs index f926e3b07f2a..25c469260675 100644 --- a/object_store/tests/get_range_file.rs +++ b/object_store/tests/get_range_file.rs @@ -75,10 +75,7 @@ impl ObjectStore for MyStore { todo!() } - async fn list( - &self, - _: Option<&Path>, - ) -> object_store::Result>> { + fn list(&self, _: Option<&Path>) -> BoxStream<'_, object_store::Result> { todo!() } diff --git a/parquet/CONTRIBUTING.md b/parquet/CONTRIBUTING.md index 903126d9f4f8..5670eef08101 100644 --- a/parquet/CONTRIBUTING.md +++ b/parquet/CONTRIBUTING.md @@ -62,10 +62,6 @@ To compile and view in the browser, run `cargo doc --no-deps --open`. ## Update Parquet Format -To generate the parquet format (thrift definitions) code run from the repository root run - -``` -$ docker run -v $(pwd):/thrift/src -it archlinux pacman -Sy --noconfirm thrift && wget https://raw.githubusercontent.com/apache/parquet-format/apache-parquet-format-2.9.0/src/main/thrift/parquet.thrift -O /tmp/parquet.thrift && thrift --gen rs /tmp/parquet.thrift && sed -i '/use thrift::server::TProcessor;/d' parquet.rs && mv parquet.rs parquet/src/format.rs -``` +To generate the parquet format (thrift definitions) code run [`./regen.sh`](./regen.sh). You may need to manually patch up doc comments that contain unescaped `[]` diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 52b0f049752c..659e2c0ee3a7 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -26,7 +26,7 @@ authors = { workspace = true } keywords = ["arrow", "parquet", "hadoop"] readme = "README.md" edition = { workspace = true } -rust-version = { workspace = true } +rust-version = "1.70.0" [target.'cfg(target_arch = "wasm32")'.dependencies] ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] } @@ -44,16 +44,16 @@ arrow-schema = { workspace = true, optional = true } arrow-select = { workspace = true, optional = true } arrow-ipc = { workspace = true, optional = true } # Intentionally not a path dependency as object_store is released separately -object_store = { version = "0.6", default-features = false, optional = true } +object_store = { version = "0.7", default-features = false, optional = true } bytes = { version = "1.1", default-features = false, features = ["std"] } thrift = { version = "0.17", default-features = false } snap = { version = "1.0", default-features = false, optional = true } brotli = { version = "3.3", default-features = false, features = ["std"], optional = true } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } -lz4 = { version = "1.23", default-features = false, optional = true } -zstd = { version = "0.12.0", optional = true, default-features = false } -chrono = { version = "0.4.23", default-features = false, features = ["alloc"] } +lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } +zstd = { version = "0.13.0", optional = true, default-features = false } +chrono = { workspace = true } num = { version = "0.4", default-features = false } num-bigint = { version = "0.4", default-features = false } base64 = { version = "0.21", default-features = false, features = ["std", ], optional = true } @@ -74,8 +74,8 @@ snap = { version = "1.0", default-features = false } tempfile = { version = "3.0", default-features = false } brotli = { version = "3.3", default-features = false, features = ["std"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] } -lz4 = { version = "1.23", default-features = false } -zstd = { version = "0.12", default-features = false } +lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"] } +zstd = { version = "0.13", default-features = false } serde_json = { version = "1.0", features = ["std"], default-features = false } arrow = { workspace = true, features = ["ipc", "test_utils", "prettyprint", "json"] } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "io-util", "fs"] } @@ -86,6 +86,8 @@ all-features = true [features] default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64"] +# Enable lz4 +lz4 = ["lz4_flex"] # Enable arrow reader/writer APIs arrow = ["base64", "arrow-array", "arrow-buffer", "arrow-cast", "arrow-data", "arrow-schema", "arrow-select", "arrow-ipc"] # Enable CLI tools @@ -166,5 +168,15 @@ name = "arrow_reader" required-features = ["arrow", "test_common", "experimental"] harness = false +[[bench]] +name = "compression" +required-features = ["experimental", "default"] +harness = false + + +[[bench]] +name = "metadata" +harness = false + [lib] bench = false diff --git a/parquet/benches/compression.rs b/parquet/benches/compression.rs new file mode 100644 index 000000000000..ce4f9aead751 --- /dev/null +++ b/parquet/benches/compression.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use criterion::*; +use parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; +use parquet::compression::create_codec; +use rand::distributions::Alphanumeric; +use rand::prelude::*; + +fn do_bench(c: &mut Criterion, name: &str, uncompressed: &[u8]) { + let codecs = [ + Compression::BROTLI(BrotliLevel::default()), + Compression::GZIP(GzipLevel::default()), + Compression::LZ4, + Compression::LZ4_RAW, + Compression::SNAPPY, + Compression::GZIP(GzipLevel::default()), + Compression::ZSTD(ZstdLevel::default()), + ]; + + for compression in codecs { + let mut codec = create_codec(compression, &Default::default()) + .unwrap() + .unwrap(); + + c.bench_function(&format!("compress {compression} - {name}"), |b| { + b.iter(|| { + let mut out = Vec::new(); + codec.compress(uncompressed, &mut out).unwrap(); + out + }); + }); + + let mut compressed = Vec::new(); + codec.compress(uncompressed, &mut compressed).unwrap(); + println!( + "{compression} compressed {} bytes of {name} to {} bytes", + uncompressed.len(), + compressed.len() + ); + + c.bench_function(&format!("decompress {compression} - {name}"), |b| { + b.iter(|| { + let mut out = Vec::new(); + codec + .decompress( + black_box(&compressed), + &mut out, + Some(uncompressed.len()), + ) + .unwrap(); + out + }); + }); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(42); + let rng = &mut rng; + const DATA_SIZE: usize = 1024 * 1024; + + let uncompressed: Vec<_> = rng.sample_iter(&Alphanumeric).take(DATA_SIZE).collect(); + do_bench(c, "alphanumeric", &uncompressed); + + // Create a collection of 64 words + let words: Vec> = (0..64) + .map(|_| { + let len = rng.gen_range(1..12); + rng.sample_iter(&Alphanumeric).take(len).collect() + }) + .collect(); + + // Build data by concatenating these words randomly together + let mut uncompressed = Vec::with_capacity(DATA_SIZE); + while uncompressed.len() < DATA_SIZE { + let word = &words[rng.gen_range(0..words.len())]; + uncompressed + .extend_from_slice(&word[..word.len().min(DATA_SIZE - uncompressed.len())]) + } + assert_eq!(uncompressed.len(), DATA_SIZE); + + do_bench(c, "words", &uncompressed); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/parquet/benches/metadata.rs b/parquet/benches/metadata.rs new file mode 100644 index 000000000000..c817385f6ba9 --- /dev/null +++ b/parquet/benches/metadata.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use bytes::Bytes; +use criterion::*; +use parquet::file::reader::SerializedFileReader; +use parquet::file::serialized_reader::ReadOptionsBuilder; + +fn criterion_benchmark(c: &mut Criterion) { + // Read file into memory to isolate filesystem performance + let file = "../parquet-testing/data/alltypes_tiny_pages.parquet"; + let data = std::fs::read(file).unwrap(); + let data = Bytes::from(data); + + c.bench_function("open(default)", |b| { + b.iter(|| SerializedFileReader::new(data.clone()).unwrap()) + }); + + c.bench_function("open(page index)", |b| { + b.iter(|| { + let options = ReadOptionsBuilder::new().with_page_index().build(); + SerializedFileReader::new_with_options(data.clone(), options).unwrap() + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/parquet/examples/async_read_parquet.rs b/parquet/examples/async_read_parquet.rs index f600cd0d11e3..e59cad8055cb 100644 --- a/parquet/examples/async_read_parquet.rs +++ b/parquet/examples/async_read_parquet.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. +use arrow::compute::kernels::cmp::eq; use arrow::util::pretty::print_batches; +use arrow_array::{Int32Array, Scalar}; use futures::TryStreamExt; use parquet::arrow::arrow_reader::{ArrowPredicateFn, RowFilter}; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; @@ -44,9 +46,10 @@ async fn main() -> Result<()> { // Highlight: set `RowFilter`, it'll push down filter predicates to skip IO and decode. // For more specific usage: please refer to https://github.com/apache/arrow-datafusion/blob/master/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs. + let scalar = Int32Array::from(vec![1]); let filter = ArrowPredicateFn::new( ProjectionMask::roots(file_metadata.schema_descr(), [0]), - |record_batch| arrow::compute::eq_dyn_scalar(record_batch.column(0), 1), + move |record_batch| eq(record_batch.column(0), &Scalar::new(&scalar)), ); let row_filter = RowFilter::new(vec![Box::new(filter)]); builder = builder.with_row_filter(row_filter); diff --git a/parquet/regen.sh b/parquet/regen.sh new file mode 100755 index 000000000000..b8c3549e2324 --- /dev/null +++ b/parquet/regen.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +REVISION=aeae80660c1d0c97314e9da837de1abdebd49c37 + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" + +docker run -v $SOURCE_DIR:/thrift/src -it archlinux pacman -Sy --noconfirm thrift && \ + wget https://raw.githubusercontent.com/apache/parquet-format/$REVISION/src/main/thrift/parquet.thrift -O /tmp/parquet.thrift && \ + thrift --gen rs /tmp/parquet.thrift && \ + echo "Removing TProcessor" && \ + sed -i '/use thrift::server::TProcessor;/d' parquet.rs && \ + echo "Replacing TSerializable" && \ + sed -i 's/impl TSerializable for/impl crate::thrift::TSerializable for/g' parquet.rs && \ + echo "Rewriting write_to_out_protocol" && \ + sed -i 's/fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol)/fn write_to_out_protocol(\&self, o_prot: \&mut T)/g' parquet.rs && \ + echo "Rewriting read_from_in_protocol" && \ + sed -i 's/fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol)/fn read_from_in_protocol(i_prot: \&mut T)/g' parquet.rs && \ + mv parquet.rs src/format.rs diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs index 43db658d9324..4612f816146a 100644 --- a/parquet/src/arrow/array_reader/byte_array.rs +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -636,7 +636,7 @@ mod tests { assert_eq!(decoder.read(&mut output, 4..8).unwrap(), 0); - let valid = vec![false, false, true, true, false, true, true, false, false]; + let valid = [false, false, true, true, false, true, true, false, false]; let valid_buffer = Buffer::from_iter(valid.iter().cloned()); output.pad_nulls(0, 4, valid.len(), valid_buffer.as_slice()); @@ -690,7 +690,7 @@ mod tests { assert_eq!(decoder.read(&mut output, 4..8).unwrap(), 0); - let valid = vec![false, false, true, true, false, false]; + let valid = [false, false, true, true, false, false]; let valid_buffer = Buffer::from_iter(valid.iter().cloned()); output.pad_nulls(0, 2, valid.len(), valid_buffer.as_slice()); diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index 763a6ccee2c3..841f5a95fd4e 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -510,7 +510,7 @@ mod tests { assert_eq!(decoder.read(&mut output, 4..5).unwrap(), 1); assert_eq!(decoder.skip_values(4).unwrap(), 0); - let valid = vec![true, true, true, true, true]; + let valid = [true, true, true, true, true]; let valid_buffer = Buffer::from_iter(valid.iter().cloned()); output.pad_nulls(0, 5, 5, valid_buffer.as_slice()); diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs index 1e781fb73ce5..a4ee5040590e 100644 --- a/parquet/src/arrow/array_reader/mod.rs +++ b/parquet/src/arrow/array_reader/mod.rs @@ -118,49 +118,6 @@ impl RowGroups for Arc { } } -pub(crate) struct FileReaderRowGroups { - /// The underling file reader - reader: Arc, - /// Optional list of row group indices to scan - row_groups: Option>, -} - -impl FileReaderRowGroups { - /// Creates a new [`RowGroups`] from a `FileReader` and an optional - /// list of row group indexes to scan - pub fn new(reader: Arc, row_groups: Option>) -> Self { - Self { reader, row_groups } - } -} - -impl RowGroups for FileReaderRowGroups { - fn num_rows(&self) -> usize { - match &self.row_groups { - None => self.reader.metadata().file_metadata().num_rows() as usize, - Some(row_groups) => { - let meta = self.reader.metadata().row_groups(); - row_groups - .iter() - .map(|x| meta[*x].num_rows() as usize) - .sum() - } - } - } - - fn column_chunks(&self, i: usize) -> Result> { - let iterator = match &self.row_groups { - Some(row_groups) => FilePageIterator::with_row_groups( - i, - Box::new(row_groups.clone().into_iter()), - Arc::clone(&self.reader), - )?, - None => FilePageIterator::new(i, Arc::clone(&self.reader))?, - }; - - Ok(Box::new(iterator)) - } -} - /// Uses `record_reader` to read up to `batch_size` records from `pages` /// /// Returns the number of records read, which can be less than `batch_size` if @@ -195,7 +152,7 @@ where Ok(records_read) } -/// Uses `record_reader` to skip up to `batch_size` records from`pages` +/// Uses `record_reader` to skip up to `batch_size` records from `pages` /// /// Returns the number of records skipped, which can be less than `batch_size` if /// pages is exhausted diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 988738dac6ac..2acc0faf130f 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -26,19 +26,21 @@ use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_schema::{ArrowError, DataType as ArrowType, Schema, SchemaRef}; use arrow_select::filter::prep_null_mask_filter; -use crate::arrow::array_reader::{build_array_reader, ArrayReader, FileReaderRowGroups}; +use crate::arrow::array_reader::{build_array_reader, ArrayReader}; use crate::arrow::schema::{parquet_to_arrow_schema_and_fields, ParquetField}; use crate::arrow::{FieldLevels, ProjectionMask}; use crate::errors::{ParquetError, Result}; use crate::file::metadata::ParquetMetaData; -use crate::file::reader::{ChunkReader, SerializedFileReader}; -use crate::file::serialized_reader::ReadOptionsBuilder; +use crate::file::reader::{ChunkReader, SerializedPageReader}; use crate::schema::types::SchemaDescriptor; mod filter; mod selection; pub use crate::arrow::array_reader::RowGroups; +use crate::column::page::{PageIterator, PageReader}; +use crate::file::footer; +use crate::file::page_index::index_reader; pub use filter::{ArrowPredicate, ArrowPredicateFn, RowFilter}; pub use selection::{RowSelection, RowSelector}; @@ -57,7 +59,7 @@ pub struct ArrowReaderBuilder { pub(crate) schema: SchemaRef, - pub(crate) fields: Option, + pub(crate) fields: Option>, pub(crate) batch_size: usize, @@ -75,27 +77,12 @@ pub struct ArrowReaderBuilder { } impl ArrowReaderBuilder { - pub(crate) fn new_builder( - input: T, - metadata: Arc, - options: ArrowReaderOptions, - ) -> Result { - let kv_metadata = match options.skip_arrow_metadata { - true => None, - false => metadata.file_metadata().key_value_metadata(), - }; - - let (schema, fields) = parquet_to_arrow_schema_and_fields( - metadata.file_metadata().schema_descr(), - ProjectionMask::all(), - kv_metadata, - )?; - - Ok(Self { + pub(crate) fn new_builder(input: T, metadata: ArrowReaderMetadata) -> Self { + Self { input, - metadata, - schema: Arc::new(schema), - fields, + metadata: metadata.metadata, + schema: metadata.schema, + fields: metadata.fields, batch_size: 1024, row_groups: None, projection: ProjectionMask::all(), @@ -103,7 +90,7 @@ impl ArrowReaderBuilder { selection: None, limit: None, offset: None, - }) + } } /// Returns a reference to the [`ParquetMetaData`] for this parquet file @@ -234,48 +221,186 @@ impl ArrowReaderOptions { } } +/// The cheaply clone-able metadata necessary to construct a [`ArrowReaderBuilder`] +/// +/// This allows loading the metadata for a file once and then using this to construct +/// multiple separate readers, for example, to distribute readers across multiple threads +#[derive(Debug, Clone)] +pub struct ArrowReaderMetadata { + pub(crate) metadata: Arc, + + pub(crate) schema: SchemaRef, + + pub(crate) fields: Option>, +} + +impl ArrowReaderMetadata { + /// Loads [`ArrowReaderMetadata`] from the provided [`ChunkReader`] + /// + /// See [`ParquetRecordBatchReaderBuilder::new_with_metadata`] for how this can be used + pub fn load(reader: &T, options: ArrowReaderOptions) -> Result { + let mut metadata = footer::parse_metadata(reader)?; + if options.page_index { + let column_index = metadata + .row_groups() + .iter() + .map(|rg| index_reader::read_columns_indexes(reader, rg.columns())) + .collect::>>()?; + metadata.set_column_index(Some(column_index)); + + let offset_index = metadata + .row_groups() + .iter() + .map(|rg| index_reader::read_pages_locations(reader, rg.columns())) + .collect::>>()?; + + metadata.set_offset_index(Some(offset_index)) + } + Self::try_new(Arc::new(metadata), options) + } + + pub(crate) fn try_new( + metadata: Arc, + options: ArrowReaderOptions, + ) -> Result { + let kv_metadata = match options.skip_arrow_metadata { + true => None, + false => metadata.file_metadata().key_value_metadata(), + }; + + let (schema, fields) = parquet_to_arrow_schema_and_fields( + metadata.file_metadata().schema_descr(), + ProjectionMask::all(), + kv_metadata, + )?; + + Ok(Self { + metadata, + schema: Arc::new(schema), + fields: fields.map(Arc::new), + }) + } + + /// Returns a reference to the [`ParquetMetaData`] for this parquet file + pub fn metadata(&self) -> &Arc { + &self.metadata + } + + /// Returns the parquet [`SchemaDescriptor`] for this parquet file + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.metadata.file_metadata().schema_descr() + } + + /// Returns the arrow [`SchemaRef`] for this parquet file + pub fn schema(&self) -> &SchemaRef { + &self.schema + } +} + #[doc(hidden)] /// A newtype used within [`ReaderOptionsBuilder`] to distinguish sync readers from async -pub struct SyncReader(SerializedFileReader); +pub struct SyncReader(T); /// A synchronous builder used to construct [`ParquetRecordBatchReader`] for a file /// /// For an async API see [`crate::arrow::async_reader::ParquetRecordBatchStreamBuilder`] +/// +/// See [`ArrowReaderBuilder`] for additional member functions pub type ParquetRecordBatchReaderBuilder = ArrowReaderBuilder>; -impl ArrowReaderBuilder> { +impl ParquetRecordBatchReaderBuilder { /// Create a new [`ParquetRecordBatchReaderBuilder`] + /// + /// ``` + /// # use std::sync::Arc; + /// # use bytes::Bytes; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; + /// # use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; + /// # use parquet::arrow::ArrowWriter; + /// # let mut file: Vec = Vec::with_capacity(1024); + /// # let schema = Arc::new(Schema::new(vec![Field::new("i32", DataType::Int32, false)])); + /// # let mut writer = ArrowWriter::try_new(&mut file, schema.clone(), None).unwrap(); + /// # let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap(); + /// # writer.write(&batch).unwrap(); + /// # writer.close().unwrap(); + /// # let file = Bytes::from(file); + /// # + /// let mut builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + /// + /// // Inspect metadata + /// assert_eq!(builder.metadata().num_row_groups(), 1); + /// + /// // Construct reader + /// let mut reader: ParquetRecordBatchReader = builder.with_row_groups(vec![0]).build().unwrap(); + /// + /// // Read data + /// let _batch = reader.next().unwrap().unwrap(); + /// ``` pub fn try_new(reader: T) -> Result { Self::try_new_with_options(reader, Default::default()) } /// Create a new [`ParquetRecordBatchReaderBuilder`] with [`ArrowReaderOptions`] pub fn try_new_with_options(reader: T, options: ArrowReaderOptions) -> Result { - let reader = match options.page_index { - true => { - let read_options = ReadOptionsBuilder::new().with_page_index().build(); - SerializedFileReader::new_with_options(reader, read_options)? - } - false => SerializedFileReader::new(reader)?, - }; + let metadata = ArrowReaderMetadata::load(&reader, options)?; + Ok(Self::new_with_metadata(reader, metadata)) + } - let metadata = Arc::clone(reader.metadata_ref()); - Self::new_builder(SyncReader(reader), metadata, options) + /// Create a [`ParquetRecordBatchReaderBuilder`] from the provided [`ArrowReaderMetadata`] + /// + /// This allows loading metadata once and using it to create multiple builders with + /// potentially different settings + /// + /// ``` + /// # use std::fs::metadata; + /// # use std::sync::Arc; + /// # use bytes::Bytes; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; + /// # use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; + /// # use parquet::arrow::ArrowWriter; + /// # let mut file: Vec = Vec::with_capacity(1024); + /// # let schema = Arc::new(Schema::new(vec![Field::new("i32", DataType::Int32, false)])); + /// # let mut writer = ArrowWriter::try_new(&mut file, schema.clone(), None).unwrap(); + /// # let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap(); + /// # writer.write(&batch).unwrap(); + /// # writer.close().unwrap(); + /// # let file = Bytes::from(file); + /// # + /// let metadata = ArrowReaderMetadata::load(&file, Default::default()).unwrap(); + /// let mut a = ParquetRecordBatchReaderBuilder::new_with_metadata(file.clone(), metadata.clone()).build().unwrap(); + /// let mut b = ParquetRecordBatchReaderBuilder::new_with_metadata(file, metadata).build().unwrap(); + /// + /// // Should be able to read from both in parallel + /// assert_eq!(a.next().unwrap().unwrap(), b.next().unwrap().unwrap()); + /// ``` + pub fn new_with_metadata(input: T, metadata: ArrowReaderMetadata) -> Self { + Self::new_builder(SyncReader(input), metadata) } /// Build a [`ParquetRecordBatchReader`] /// /// Note: this will eagerly evaluate any `RowFilter` before returning pub fn build(self) -> Result { - let reader = FileReaderRowGroups::new(Arc::new(self.input.0), self.row_groups); - - let mut filter = self.filter; - let mut selection = self.selection; - // Try to avoid allocate large buffer let batch_size = self .batch_size .min(self.metadata.file_metadata().num_rows() as usize); + + let row_groups = self + .row_groups + .unwrap_or_else(|| (0..self.metadata.num_row_groups()).collect()); + + let reader = ReaderRowGroups { + reader: Arc::new(self.input.0), + metadata: self.metadata, + row_groups, + }; + + let mut filter = self.filter; + let mut selection = self.selection; + if let Some(filter) = filter.as_mut() { for predicate in filter.predicates.iter_mut() { if !selects_any(selection.as_ref()) { @@ -283,7 +408,7 @@ impl ArrowReaderBuilder> { } let array_reader = build_array_reader( - self.fields.as_ref(), + self.fields.as_deref(), predicate.projection(), &reader, )?; @@ -298,7 +423,7 @@ impl ArrowReaderBuilder> { } let array_reader = - build_array_reader(self.fields.as_ref(), &self.projection, &reader)?; + build_array_reader(self.fields.as_deref(), &self.projection, &reader)?; // If selection is empty, truncate if !selects_any(selection.as_ref()) { @@ -313,6 +438,63 @@ impl ArrowReaderBuilder> { } } +struct ReaderRowGroups { + reader: Arc, + + metadata: Arc, + /// Optional list of row group indices to scan + row_groups: Vec, +} + +impl RowGroups for ReaderRowGroups { + fn num_rows(&self) -> usize { + let meta = self.metadata.row_groups(); + self.row_groups + .iter() + .map(|x| meta[*x].num_rows() as usize) + .sum() + } + + fn column_chunks(&self, i: usize) -> Result> { + Ok(Box::new(ReaderPageIterator { + column_idx: i, + reader: self.reader.clone(), + metadata: self.metadata.clone(), + row_groups: self.row_groups.clone().into_iter(), + })) + } +} + +struct ReaderPageIterator { + reader: Arc, + column_idx: usize, + row_groups: std::vec::IntoIter, + metadata: Arc, +} + +impl Iterator for ReaderPageIterator { + type Item = Result>; + + fn next(&mut self) -> Option { + let rg_idx = self.row_groups.next()?; + let rg = self.metadata.row_group(rg_idx); + let meta = rg.column(self.column_idx); + let offset_index = self.metadata.offset_index(); + // `offset_index` may not exist and `i[rg_idx]` will be empty. + // To avoid `i[rg_idx][self.oolumn_idx`] panic, we need to filter out empty `i[rg_idx]`. + let page_locations = offset_index + .filter(|i| !i[rg_idx].is_empty()) + .map(|i| i[rg_idx][self.column_idx].clone()); + let total_rows = rg.num_rows() as usize; + let reader = self.reader.clone(); + + let ret = SerializedPageReader::new(reader, meta, total_rows, page_locations); + Some(ret.map(|x| Box::new(x) as _)) + } +} + +impl PageIterator for ReaderPageIterator {} + /// An `Iterator>` that yields [`RecordBatch`] /// read from a parquet data source pub struct ParquetRecordBatchReader { @@ -1493,7 +1675,7 @@ mod tests { _ => -1, }; - let mut fields = vec![Arc::new( + let fields = vec![Arc::new( Type::primitive_type_builder("leaf", T::get_physical_type()) .with_repetition(repetition) .with_converted_type(converted_type) @@ -1504,7 +1686,7 @@ mod tests { let schema = Arc::new( Type::group_type_builder("test_schema") - .with_fields(&mut fields) + .with_fields(fields) .build() .unwrap(), ); @@ -1850,7 +2032,7 @@ mod tests { #[test] fn test_dictionary_preservation() { - let mut fields = vec![Arc::new( + let fields = vec![Arc::new( Type::primitive_type_builder("leaf", PhysicalType::BYTE_ARRAY) .with_repetition(Repetition::OPTIONAL) .with_converted_type(ConvertedType::UTF8) @@ -1860,7 +2042,7 @@ mod tests { let schema = Arc::new( Type::group_type_builder("test_schema") - .with_fields(&mut fields) + .with_fields(fields) .build() .unwrap(), ); @@ -2303,6 +2485,43 @@ mod tests { assert_eq!(reader.batch_size, num_rows as usize); } + #[test] + fn test_read_with_page_index_enabled() { + let testdata = arrow::util::test_util::parquet_test_data(); + + { + // `alltypes_tiny_pages.parquet` has page index + let path = format!("{testdata}/alltypes_tiny_pages.parquet"); + let test_file = File::open(path).unwrap(); + let builder = ParquetRecordBatchReaderBuilder::try_new_with_options( + test_file, + ArrowReaderOptions::new().with_page_index(true), + ) + .unwrap(); + assert!(!builder.metadata().offset_index().unwrap()[0].is_empty()); + let reader = builder.build().unwrap(); + let batches = reader.collect::, _>>().unwrap(); + assert_eq!(batches.len(), 8); + } + + { + // `alltypes_plain.parquet` doesn't have page index + let path = format!("{testdata}/alltypes_plain.parquet"); + let test_file = File::open(path).unwrap(); + let builder = ParquetRecordBatchReaderBuilder::try_new_with_options( + test_file, + ArrowReaderOptions::new().with_page_index(true), + ) + .unwrap(); + // Although `Vec>` of each row group is empty, + // we should read the file successfully. + assert!(builder.metadata().offset_index().unwrap()[0].is_empty()); + let reader = builder.build().unwrap(); + let batches = reader.collect::, _>>().unwrap(); + assert_eq!(batches.len(), 1); + } + } + #[test] fn test_raw_repetition() { const MESSAGE_TYPE: &str = " diff --git a/parquet/src/arrow/arrow_writer/levels.rs b/parquet/src/arrow/arrow_writer/levels.rs index 47b01890301e..4a0bd551e1f9 100644 --- a/parquet/src/arrow/arrow_writer/levels.rs +++ b/parquet/src/arrow/arrow_writer/levels.rs @@ -42,19 +42,20 @@ use crate::errors::{ParquetError, Result}; use arrow_array::cast::AsArray; -use arrow_array::{Array, ArrayRef, FixedSizeListArray, OffsetSizeTrait, StructArray}; -use arrow_buffer::NullBuffer; +use arrow_array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow_buffer::{NullBuffer, OffsetBuffer}; use arrow_schema::{DataType, Field}; use std::ops::Range; +use std::sync::Arc; -/// Performs a depth-first scan of the children of `array`, constructing [`LevelInfo`] +/// Performs a depth-first scan of the children of `array`, constructing [`ArrayLevels`] /// for each leaf column encountered pub(crate) fn calculate_array_levels( array: &ArrayRef, field: &Field, -) -> Result> { - let mut builder = LevelInfoBuilder::try_new(field, Default::default())?; - builder.write(array, 0..array.len()); +) -> Result> { + let mut builder = LevelInfoBuilder::try_new(field, Default::default(), array)?; + builder.write(0..array.len()); Ok(builder.finish()) } @@ -102,31 +103,57 @@ struct LevelContext { def_level: i16, } -/// A helper to construct [`LevelInfo`] from a potentially nested [`Field`] +/// A helper to construct [`ArrayLevels`] from a potentially nested [`Field`] enum LevelInfoBuilder { /// A primitive, leaf array - Primitive(LevelInfo), - /// A list array, contains the [`LevelInfoBuilder`] of the child and - /// the [`LevelContext`] of this list - List(Box, LevelContext), - /// A list array, contains the [`LevelInfoBuilder`] of its children and - /// the [`LevelContext`] of this struct array - Struct(Vec, LevelContext), + Primitive(ArrayLevels), + /// A list array + List( + Box, // Child Values + LevelContext, // Context + OffsetBuffer, // Offsets + Option, // Nulls + ), + /// A large list array + LargeList( + Box, // Child Values + LevelContext, // Context + OffsetBuffer, // Offsets + Option, // Nulls + ), + /// A fixed size list array + FixedSizeList( + Box, // Values + LevelContext, // Context + usize, // List Size + Option, // Nulls + ), + /// A struct array + Struct(Vec, LevelContext, Option), } impl LevelInfoBuilder { /// Create a new [`LevelInfoBuilder`] for the given [`Field`] and parent [`LevelContext`] - fn try_new(field: &Field, parent_ctx: LevelContext) -> Result { - match field.data_type() { - d if is_leaf(d) => Ok(Self::Primitive(LevelInfo::new( - parent_ctx, - field.is_nullable(), - ))), - DataType::Dictionary(_, v) if is_leaf(v.as_ref()) => Ok(Self::Primitive( - LevelInfo::new(parent_ctx, field.is_nullable()), - )), + fn try_new( + field: &Field, + parent_ctx: LevelContext, + array: &ArrayRef, + ) -> Result { + assert_eq!(field.data_type(), array.data_type()); + let is_nullable = field.is_nullable(); + + match array.data_type() { + d if is_leaf(d) => { + let levels = ArrayLevels::new(parent_ctx, is_nullable, array.clone()); + Ok(Self::Primitive(levels)) + } + DataType::Dictionary(_, v) if is_leaf(v.as_ref()) => { + let levels = ArrayLevels::new(parent_ctx, is_nullable, array.clone()); + Ok(Self::Primitive(levels)) + } DataType::Struct(children) => { - let def_level = match field.is_nullable() { + let array = array.as_struct(); + let def_level = match is_nullable { true => parent_ctx.def_level + 1, false => parent_ctx.def_level, }; @@ -138,16 +165,17 @@ impl LevelInfoBuilder { let children = children .iter() - .map(|f| Self::try_new(f, ctx)) + .zip(array.columns()) + .map(|(f, a)| Self::try_new(f, ctx, a)) .collect::>()?; - Ok(Self::Struct(children, ctx)) + Ok(Self::Struct(children, ctx, array.nulls().cloned())) } DataType::List(child) | DataType::LargeList(child) | DataType::Map(child, _) | DataType::FixedSizeList(child, _) => { - let def_level = match field.is_nullable() { + let def_level = match is_nullable { true => parent_ctx.def_level + 2, false => parent_ctx.def_level + 1, }; @@ -157,79 +185,70 @@ impl LevelInfoBuilder { def_level, }; - let child = Self::try_new(child.as_ref(), ctx)?; - Ok(Self::List(Box::new(child), ctx)) + Ok(match field.data_type() { + DataType::List(_) => { + let list = array.as_list(); + let child = Self::try_new(child.as_ref(), ctx, list.values())?; + let offsets = list.offsets().clone(); + Self::List(Box::new(child), ctx, offsets, list.nulls().cloned()) + } + DataType::LargeList(_) => { + let list = array.as_list(); + let child = Self::try_new(child.as_ref(), ctx, list.values())?; + let offsets = list.offsets().clone(); + let nulls = list.nulls().cloned(); + Self::LargeList(Box::new(child), ctx, offsets, nulls) + } + DataType::Map(_, _) => { + let map = array.as_map(); + let entries = Arc::new(map.entries().clone()) as ArrayRef; + let child = Self::try_new(child.as_ref(), ctx, &entries)?; + let offsets = map.offsets().clone(); + Self::List(Box::new(child), ctx, offsets, map.nulls().cloned()) + } + DataType::FixedSizeList(_, size) => { + let list = array.as_fixed_size_list(); + let child = Self::try_new(child.as_ref(), ctx, list.values())?; + let nulls = list.nulls().cloned(); + Self::FixedSizeList(Box::new(child), ctx, *size as _, nulls) + } + _ => unreachable!(), + }) } d => Err(nyi_err!("Datatype {} is not yet supported", d)), } } - /// Finish this [`LevelInfoBuilder`] returning the [`LevelInfo`] for the leaf columns + /// Finish this [`LevelInfoBuilder`] returning the [`ArrayLevels`] for the leaf columns /// as enumerated by a depth-first search - fn finish(self) -> Vec { + fn finish(self) -> Vec { match self { LevelInfoBuilder::Primitive(v) => vec![v], - LevelInfoBuilder::List(v, _) => v.finish(), - LevelInfoBuilder::Struct(v, _) => { + LevelInfoBuilder::List(v, _, _, _) + | LevelInfoBuilder::LargeList(v, _, _, _) + | LevelInfoBuilder::FixedSizeList(v, _, _, _) => v.finish(), + LevelInfoBuilder::Struct(v, _, _) => { v.into_iter().flat_map(|l| l.finish()).collect() } } } /// Given an `array`, write the level data for the elements in `range` - fn write(&mut self, array: &dyn Array, range: Range) { - match array.data_type() { - d if is_leaf(d) => self.write_leaf(array, range), - DataType::Dictionary(_, v) if is_leaf(v.as_ref()) => { - self.write_leaf(array, range) - } - DataType::Struct(_) => { - let array = array.as_struct(); - self.write_struct(array, range) - } - DataType::List(_) => { - let array = array.as_list::(); - self.write_list( - array.value_offsets(), - array.nulls(), - array.values(), - range, - ) + fn write(&mut self, range: Range) { + match self { + LevelInfoBuilder::Primitive(info) => Self::write_leaf(info, range), + LevelInfoBuilder::List(child, ctx, offsets, nulls) => { + Self::write_list(child, ctx, offsets, nulls.as_ref(), range) } - DataType::LargeList(_) => { - let array = array.as_list::(); - self.write_list( - array.value_offsets(), - array.nulls(), - array.values(), - range, - ) + LevelInfoBuilder::LargeList(child, ctx, offsets, nulls) => { + Self::write_list(child, ctx, offsets, nulls.as_ref(), range) } - DataType::Map(_, _) => { - let array = array.as_map(); - // A Map is just as ListArray with a StructArray child, we therefore - // treat it as such to avoid code duplication - self.write_list( - array.value_offsets(), - array.nulls(), - array.entries(), - range, - ) + LevelInfoBuilder::FixedSizeList(child, ctx, size, nulls) => { + Self::write_fixed_size_list(child, ctx, *size, nulls.as_ref(), range) } - &DataType::FixedSizeList(_, size) => { - let array = array - .as_any() - .downcast_ref::() - .expect("unable to get fixed-size list array"); - - self.write_fixed_size_list( - size as usize, - array.nulls(), - array.values(), - range, - ) + LevelInfoBuilder::Struct(children, ctx, nulls) => { + Self::write_struct(children, ctx, nulls.as_ref(), range) } - _ => unreachable!(), } } @@ -237,22 +256,17 @@ impl LevelInfoBuilder { /// /// Note: MapArrays are `ListArray` under the hood and so are dispatched to this method fn write_list( - &mut self, + child: &mut LevelInfoBuilder, + ctx: &LevelContext, offsets: &[O], nulls: Option<&NullBuffer>, - values: &dyn Array, range: Range, ) { - let (child, ctx) = match self { - Self::List(child, ctx) => (child, ctx), - _ => unreachable!(), - }; - let offsets = &offsets[range.start..range.end + 1]; let write_non_null_slice = |child: &mut LevelInfoBuilder, start_idx: usize, end_idx: usize| { - child.write(values, start_idx..end_idx); + child.write(start_idx..end_idx); child.visit_leaves(|leaf| { let rep_levels = leaf.rep_levels.as_mut().unwrap(); let mut rev = rep_levels.iter_mut().rev(); @@ -324,12 +338,12 @@ impl LevelInfoBuilder { } /// Write `range` elements from StructArray `array` - fn write_struct(&mut self, array: &StructArray, range: Range) { - let (children, ctx) = match self { - Self::Struct(children, ctx) => (children, ctx), - _ => unreachable!(), - }; - + fn write_struct( + children: &mut [LevelInfoBuilder], + ctx: &LevelContext, + nulls: Option<&NullBuffer>, + range: Range, + ) { let write_null = |children: &mut [LevelInfoBuilder], range: Range| { for child in children { child.visit_leaves(|info| { @@ -346,12 +360,12 @@ impl LevelInfoBuilder { }; let write_non_null = |children: &mut [LevelInfoBuilder], range: Range| { - for (child_array, child) in array.columns().iter().zip(children) { - child.write(child_array, range.clone()) + for child in children { + child.write(range.clone()) } }; - match array.nulls() { + match nulls { Some(validity) => { let mut last_non_null_idx = None; let mut last_null_idx = None; @@ -388,22 +402,17 @@ impl LevelInfoBuilder { /// Write `range` elements from FixedSizeListArray with child data `values` and null bitmap `nulls`. fn write_fixed_size_list( - &mut self, + child: &mut LevelInfoBuilder, + ctx: &LevelContext, fixed_size: usize, nulls: Option<&NullBuffer>, - values: &dyn Array, range: Range, ) { - let (child, ctx) = match self { - Self::List(child, ctx) => (child, ctx), - _ => unreachable!(), - }; - let write_non_null = |child: &mut LevelInfoBuilder, start_idx: usize, end_idx: usize| { let values_start = start_idx * fixed_size; let values_end = end_idx * fixed_size; - child.write(values, values_start..values_end); + child.write(values_start..values_end); child.visit_leaves(|leaf| { let rep_levels = leaf.rep_levels.as_mut().unwrap(); @@ -481,12 +490,7 @@ impl LevelInfoBuilder { } /// Write a primitive array, as defined by [`is_leaf`] - fn write_leaf(&mut self, array: &dyn Array, range: Range) { - let info = match self { - Self::Primitive(info) => info, - _ => unreachable!(), - }; - + fn write_leaf(info: &mut ArrayLevels, range: Range) { let len = range.end - range.start; match &mut info.def_levels { @@ -494,7 +498,7 @@ impl LevelInfoBuilder { def_levels.reserve(len); info.non_null_indices.reserve(len); - match array.nulls() { + match info.array.logical_nulls() { Some(nulls) => { // TODO: Faster bitmask iteration (#1757) for i in range { @@ -523,11 +527,13 @@ impl LevelInfoBuilder { } /// Visits all children of this node in depth first order - fn visit_leaves(&mut self, visit: impl Fn(&mut LevelInfo) + Copy) { + fn visit_leaves(&mut self, visit: impl Fn(&mut ArrayLevels) + Copy) { match self { LevelInfoBuilder::Primitive(info) => visit(info), - LevelInfoBuilder::List(c, _) => c.visit_leaves(visit), - LevelInfoBuilder::Struct(children, _) => { + LevelInfoBuilder::List(c, _, _, _) + | LevelInfoBuilder::LargeList(c, _, _, _) + | LevelInfoBuilder::FixedSizeList(c, _, _, _) => c.visit_leaves(visit), + LevelInfoBuilder::Struct(children, _, _) => { for c in children { c.visit_leaves(visit) } @@ -537,8 +543,8 @@ impl LevelInfoBuilder { } /// The data necessary to write a primitive Arrow array to parquet, taking into account /// any non-primitive parents it may have in the arrow representation -#[derive(Debug, Eq, PartialEq, Clone)] -pub(crate) struct LevelInfo { +#[derive(Debug, Clone)] +pub(crate) struct ArrayLevels { /// Array's definition levels /// /// Present if `max_def_level != 0` @@ -558,10 +564,25 @@ pub(crate) struct LevelInfo { /// The maximum repetition for this leaf column max_rep_level: i16, + + /// The arrow array + array: ArrayRef, } -impl LevelInfo { - fn new(ctx: LevelContext, is_nullable: bool) -> Self { +impl PartialEq for ArrayLevels { + fn eq(&self, other: &Self) -> bool { + self.def_levels == other.def_levels + && self.rep_levels == other.rep_levels + && self.non_null_indices == other.non_null_indices + && self.max_def_level == other.max_def_level + && self.max_rep_level == other.max_rep_level + && self.array.as_ref() == other.array.as_ref() + } +} +impl Eq for ArrayLevels {} + +impl ArrayLevels { + fn new(ctx: LevelContext, is_nullable: bool, array: ArrayRef) -> Self { let max_rep_level = ctx.rep_level; let max_def_level = match is_nullable { true => ctx.def_level + 1, @@ -574,9 +595,14 @@ impl LevelInfo { non_null_indices: vec![], max_def_level, max_rep_level, + array, } } + pub fn array(&self) -> &ArrayRef { + &self.array + } + pub fn def_levels(&self) -> Option<&[i16]> { self.def_levels.as_deref() } @@ -597,6 +623,7 @@ mod tests { use std::sync::Arc; use arrow_array::builder::*; + use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; use arrow_array::*; use arrow_buffer::{Buffer, ToByteSlice}; @@ -622,7 +649,7 @@ mod tests { let inner_list = ArrayDataBuilder::new(inner_type) .len(4) .add_buffer(offsets) - .add_child_data(primitives.into_data()) + .add_child_data(primitives.to_data()) .build() .unwrap(); @@ -638,12 +665,13 @@ mod tests { let levels = calculate_array_levels(&outer_list, &outer_field).unwrap(); assert_eq!(levels.len(), 1); - let expected = LevelInfo { + let expected = ArrayLevels { def_levels: Some(vec![2; 10]), rep_levels: Some(vec![0, 2, 2, 1, 2, 2, 2, 0, 1, 2]), non_null_indices: vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], max_def_level: 2, max_rep_level: 2, + array: Arc::new(primitives), }; assert_eq!(&levels[0], &expected); } @@ -657,12 +685,13 @@ mod tests { let levels = calculate_array_levels(&array, &field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: None, rep_levels: None, non_null_indices: (0..10).collect(), max_def_level: 0, max_rep_level: 0, + array, }; assert_eq!(&levels[0], &expected_levels); } @@ -682,12 +711,13 @@ mod tests { let levels = calculate_array_levels(&array, &field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![1, 0, 1, 1, 0]), rep_levels: None, non_null_indices: vec![0, 2, 3], max_def_level: 1, max_rep_level: 0, + array, }; assert_eq!(&levels[0], &expected_levels); } @@ -706,7 +736,7 @@ mod tests { let list = ArrayDataBuilder::new(list_type.clone()) .len(5) .add_buffer(offsets) - .add_child_data(leaf_array.into_data()) + .add_child_data(leaf_array.to_data()) .build() .unwrap(); let list = make_array(list); @@ -715,12 +745,13 @@ mod tests { let levels = calculate_array_levels(&list, &list_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![1; 5]), rep_levels: Some(vec![0; 5]), non_null_indices: (0..5).collect(), max_def_level: 1, max_rep_level: 1, + array: Arc::new(leaf_array), }; assert_eq!(&levels[0], &expected_levels); @@ -737,7 +768,7 @@ mod tests { let list = ArrayDataBuilder::new(list_type.clone()) .len(5) .add_buffer(offsets) - .add_child_data(leaf_array.into_data()) + .add_child_data(leaf_array.to_data()) .null_bit_buffer(Some(Buffer::from([0b00011101]))) .build() .unwrap(); @@ -747,12 +778,13 @@ mod tests { let levels = calculate_array_levels(&list, &list_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2]), rep_levels: Some(vec![0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1]), non_null_indices: (0..11).collect(), max_def_level: 2, max_rep_level: 1, + array: Arc::new(leaf_array), }; assert_eq!(&levels[0], &expected_levels); } @@ -778,7 +810,7 @@ mod tests { let list_type = DataType::List(Arc::new(leaf_field)); let list = ArrayData::builder(list_type.clone()) .len(5) - .add_child_data(leaf.into_data()) + .add_child_data(leaf.to_data()) .add_buffer(Buffer::from_iter([0_i32, 2, 2, 4, 8, 11])) .build() .unwrap(); @@ -795,12 +827,13 @@ mod tests { let levels = calculate_array_levels(&array, &struct_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![0, 2, 0, 3, 3, 3, 3, 3, 3, 3]), rep_levels: Some(vec![0, 0, 0, 0, 1, 1, 1, 0, 1, 1]), non_null_indices: (4..11).collect(), max_def_level: 3, max_rep_level: 1, + array: Arc::new(leaf), }; assert_eq!(&levels[0], &expected_levels); @@ -820,7 +853,7 @@ mod tests { let offsets = Buffer::from_iter([0_i32, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22]); let l1 = ArrayData::builder(l1_type.clone()) .len(11) - .add_child_data(leaf.into_data()) + .add_child_data(leaf.to_data()) .add_buffer(offsets) .build() .unwrap(); @@ -840,7 +873,7 @@ mod tests { let levels = calculate_array_levels(&l2, &l2_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![ 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ]), @@ -850,6 +883,7 @@ mod tests { non_null_indices: (0..22).collect(), max_def_level: 5, max_rep_level: 2, + array: Arc::new(leaf), }; assert_eq!(&levels[0], &expected_levels); @@ -871,7 +905,7 @@ mod tests { let list = ArrayData::builder(list_type.clone()) .len(4) .add_buffer(Buffer::from_iter(0_i32..5)) - .add_child_data(leaf.into_data()) + .add_child_data(leaf.to_data()) .build() .unwrap(); let list = make_array(list); @@ -880,12 +914,13 @@ mod tests { let levels = calculate_array_levels(&list, &list_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![1; 4]), rep_levels: Some(vec![0; 4]), non_null_indices: (0..4).collect(), max_def_level: 1, max_rep_level: 1, + array: Arc::new(leaf), }; assert_eq!(&levels[0], &expected_levels); @@ -898,7 +933,7 @@ mod tests { .len(4) .add_buffer(Buffer::from_iter([0_i32, 0, 3, 5, 7])) .null_bit_buffer(Some(Buffer::from([0b00001110]))) - .add_child_data(leaf.into_data()) + .add_child_data(leaf.to_data()) .build() .unwrap(); let list = make_array(list); @@ -911,12 +946,13 @@ mod tests { let levels = calculate_array_levels(&array, &struct_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![1, 3, 3, 3, 3, 3, 3, 3]), rep_levels: Some(vec![0, 0, 1, 1, 0, 1, 0, 1]), non_null_indices: (0..7).collect(), max_def_level: 3, max_rep_level: 1, + array: Arc::new(leaf), }; assert_eq!(&levels[0], &expected_levels); @@ -933,7 +969,7 @@ mod tests { let list_1 = ArrayData::builder(list_1_type.clone()) .len(7) .add_buffer(Buffer::from_iter([0_i32, 1, 3, 3, 6, 10, 10, 15])) - .add_child_data(leaf.into_data()) + .add_child_data(leaf.to_data()) .build() .unwrap(); @@ -958,12 +994,13 @@ mod tests { let levels = calculate_array_levels(&array, &struct_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![1, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5, 4, 5, 5, 5, 5, 5]), rep_levels: Some(vec![0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 2, 0, 1, 2, 2, 2, 2]), non_null_indices: (0..15).collect(), max_def_level: 5, max_rep_level: 2, + array: Arc::new(leaf), }; assert_eq!(&levels[0], &expected_levels); } @@ -980,9 +1017,10 @@ mod tests { // - {a: {b: {c: 6}}} let c = Int32Array::from_iter([Some(1), None, Some(3), None, Some(5), Some(6)]); + let leaf = Arc::new(c) as ArrayRef; let c_field = Arc::new(Field::new("c", DataType::Int32, true)); let b = StructArray::from(( - (vec![(c_field, Arc::new(c) as ArrayRef)]), + (vec![(c_field, leaf.clone())]), Buffer::from([0b00110111]), )); @@ -998,12 +1036,13 @@ mod tests { let levels = calculate_array_levels(&a_array, &a_field).unwrap(); assert_eq!(levels.len(), 1); - let expected_levels = LevelInfo { + let expected_levels = ArrayLevels { def_levels: Some(vec![3, 2, 3, 1, 0, 3]), rep_levels: None, non_null_indices: vec![0, 2, 5], max_def_level: 3, max_rep_level: 0, + array: leaf, }; assert_eq!(&levels[0], &expected_levels); } @@ -1020,7 +1059,7 @@ mod tests { .len(5) .add_buffer(a_value_offsets) .null_bit_buffer(Some(Buffer::from(vec![0b00011011]))) - .add_child_data(a_values.into_data()) + .add_child_data(a_values.to_data()) .build() .unwrap(); @@ -1029,21 +1068,21 @@ mod tests { let a = ListArray::from(a_list_data); let item_field = Field::new("item", a_list_type, true); - let mut builder = - LevelInfoBuilder::try_new(&item_field, Default::default()).unwrap(); - builder.write(&a, 2..4); + let mut builder = levels(&item_field, a); + builder.write(2..4); let levels = builder.finish(); assert_eq!(levels.len(), 1); let list_level = levels.get(0).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![0, 3, 3, 3]), rep_levels: Some(vec![0, 0, 1, 1]), non_null_indices: vec![3, 4, 5], max_def_level: 3, max_rep_level: 1, + array: Arc::new(a_values), }; assert_eq!(list_level, &expected_level); } @@ -1100,19 +1139,19 @@ mod tests { let g = ListArray::from(g_list_data); let e = StructArray::from(vec![ - (struct_field_f, Arc::new(f) as ArrayRef), + (struct_field_f, Arc::new(f.clone()) as ArrayRef), (struct_field_g, Arc::new(g) as ArrayRef), ]); let c = StructArray::from(vec![ - (struct_field_d, Arc::new(d) as ArrayRef), + (struct_field_d, Arc::new(d.clone()) as ArrayRef), (struct_field_e, Arc::new(e) as ArrayRef), ]); // build a record batch let batch = RecordBatch::try_new( Arc::new(schema), - vec![Arc::new(a), Arc::new(b), Arc::new(c)], + vec![Arc::new(a.clone()), Arc::new(b.clone()), Arc::new(c)], ) .unwrap(); @@ -1132,48 +1171,52 @@ mod tests { // test "a" levels let list_level = levels.get(0).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: None, rep_levels: None, non_null_indices: vec![0, 1, 2, 3, 4], max_def_level: 0, max_rep_level: 0, + array: Arc::new(a), }; assert_eq!(list_level, &expected_level); // test "b" levels let list_level = levels.get(1).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![1, 0, 0, 1, 1]), rep_levels: None, non_null_indices: vec![0, 3, 4], max_def_level: 1, max_rep_level: 0, + array: Arc::new(b), }; assert_eq!(list_level, &expected_level); // test "d" levels let list_level = levels.get(2).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![1, 1, 1, 2, 1]), rep_levels: None, non_null_indices: vec![3], max_def_level: 2, max_rep_level: 0, + array: Arc::new(d), }; assert_eq!(list_level, &expected_level); // test "f" levels let list_level = levels.get(3).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![3, 2, 3, 2, 3]), rep_levels: None, non_null_indices: vec![0, 2, 4], max_def_level: 3, max_rep_level: 0, + array: Arc::new(f), }; assert_eq!(list_level, &expected_level); } @@ -1270,27 +1313,31 @@ mod tests { }); assert_eq!(levels.len(), 2); + let map = batch.column(0).as_map(); + // test key levels let list_level = levels.get(0).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![1; 7]), rep_levels: Some(vec![0, 1, 0, 1, 0, 1, 1]), non_null_indices: vec![0, 1, 2, 3, 4, 5, 6], max_def_level: 1, max_rep_level: 1, + array: map.keys().clone(), }; assert_eq!(list_level, &expected_level); // test values levels let list_level = levels.get(1).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![2, 2, 2, 1, 2, 1, 2]), rep_levels: Some(vec![0, 1, 0, 1, 0, 1, 1]), non_null_indices: vec![0, 1, 2, 4, 6], max_def_level: 2, max_rep_level: 1, + array: map.values().clone(), }; assert_eq!(list_level, &expected_level); } @@ -1358,7 +1405,8 @@ mod tests { let array = Arc::new(list_builder.finish()); - let values_len = array.values().len(); + let values = array.values().as_struct().column(0).clone(); + let values_len = values.len(); assert_eq!(values_len, 5); let schema = Arc::new(Schema::new(vec![list_field])); @@ -1368,12 +1416,13 @@ mod tests { let levels = calculate_array_levels(rb.column(0), rb.schema().field(0)).unwrap(); let list_level = &levels[0]; - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![4, 1, 0, 2, 2, 3, 4]), rep_levels: Some(vec![0, 0, 0, 0, 1, 0, 0]), non_null_indices: vec![0, 4], max_def_level: 4, max_rep_level: 1, + array: values, }; assert_eq!(list_level, &expected_level); @@ -1391,6 +1440,7 @@ mod tests { None, // Masked by struct array None, ]); + let values = inner.values().clone(); // This test assumes that nulls don't take up space assert_eq!(inner.values().len(), 7); @@ -1406,12 +1456,13 @@ mod tests { assert_eq!(levels.len(), 1); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![4, 4, 3, 2, 0, 4, 4, 0, 1]), rep_levels: Some(vec![0, 1, 0, 0, 0, 0, 1, 0, 0]), non_null_indices: vec![0, 1, 5, 6], max_def_level: 4, max_rep_level: 1, + array: values, }; assert_eq!(&levels[0], &expected_level); @@ -1422,14 +1473,16 @@ mod tests { // Test the null mask of a struct array and the null mask of a list array // masking out non-null elements of their children - let a1 = Arc::new(ListArray::from_iter_primitive::(vec![ + let a1 = ListArray::from_iter_primitive::(vec![ Some(vec![None]), // Masked by list array Some(vec![]), // Masked by list array Some(vec![Some(3), None]), Some(vec![Some(4), Some(5), None, Some(6)]), // Masked by struct array None, None, - ])) as ArrayRef; + ]); + let a1_values = a1.values().clone(); + let a1 = Arc::new(a1) as ArrayRef; let a2 = Arc::new(Int32Array::from_iter(vec![ Some(1), // Masked by list array @@ -1439,6 +1492,7 @@ mod tests { Some(5), None, ])) as ArrayRef; + let a2_values = a2.clone(); let field_a1 = Arc::new(Field::new("list", a1.data_type().clone(), true)); let field_a2 = Arc::new(Field::new("integers", a2.data_type().clone(), true)); @@ -1486,22 +1540,24 @@ mod tests { assert_eq!(levels.len(), 2); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![0, 0, 1, 6, 5, 2, 3, 1]), rep_levels: Some(vec![0, 0, 0, 0, 2, 0, 1, 0]), non_null_indices: vec![1], max_def_level: 6, max_rep_level: 2, + array: a1_values, }; assert_eq!(&levels[0], &expected_level); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![0, 0, 1, 3, 2, 4, 1]), rep_levels: Some(vec![0, 0, 0, 0, 0, 1, 0]), non_null_indices: vec![4], max_def_level: 4, max_rep_level: 1, + array: a2_values, }; assert_eq!(&levels[1], &expected_level); @@ -1522,23 +1578,24 @@ mod tests { builder.values().append_slice(&[9, 10]); builder.append(false); let a = builder.finish(); + let values = a.values().clone(); let item_field = Field::new("item", a.data_type().clone(), true); - let mut builder = - LevelInfoBuilder::try_new(&item_field, Default::default()).unwrap(); - builder.write(&a, 1..4); + let mut builder = levels(&item_field, a); + builder.write(1..4); let levels = builder.finish(); assert_eq!(levels.len(), 1); let list_level = levels.get(0).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![0, 0, 3, 3]), rep_levels: Some(vec![0, 0, 0, 1]), non_null_indices: vec![6, 7], max_def_level: 3, max_rep_level: 1, + array: values, }; assert_eq!(list_level, &expected_level); } @@ -1670,6 +1727,10 @@ mod tests { assert_eq!(array.values().len(), 8); assert_eq!(array.len(), 4); + let struct_values = array.values().as_struct(); + let values_a = struct_values.column(0).clone(); + let values_b = struct_values.column(1).clone(); + let schema = Arc::new(Schema::new(vec![list_field])); let rb = RecordBatch::try_new(schema, vec![array]).unwrap(); @@ -1678,20 +1739,22 @@ mod tests { let b_levels = &levels[1]; // [[{a: 1}, null], null, [null, null], [{a: null}, {a: 2}]] - let expected_a = LevelInfo { + let expected_a = ArrayLevels { def_levels: Some(vec![4, 2, 0, 2, 2, 3, 4]), rep_levels: Some(vec![0, 1, 0, 0, 1, 0, 1]), non_null_indices: vec![0, 7], max_def_level: 4, max_rep_level: 1, + array: values_a, }; // [[{b: 2}, null], null, [null, null], [{b: 3}, {b: 4}]] - let expected_b = LevelInfo { + let expected_b = ArrayLevels { def_levels: Some(vec![3, 2, 0, 2, 2, 3, 3]), rep_levels: Some(vec![0, 1, 0, 0, 1, 0, 1]), non_null_indices: vec![0, 6, 7], max_def_level: 3, max_rep_level: 1, + array: values_b, }; assert_eq!(a_levels, &expected_a); @@ -1704,24 +1767,25 @@ mod tests { builder.append(true); builder.append(false); builder.append(true); - let a = builder.finish(); + let array = builder.finish(); + let values = array.values().clone(); - let item_field = Field::new("item", a.data_type().clone(), true); - let mut builder = - LevelInfoBuilder::try_new(&item_field, Default::default()).unwrap(); - builder.write(&a, 0..3); + let item_field = Field::new("item", array.data_type().clone(), true); + let mut builder = levels(&item_field, array); + builder.write(0..3); let levels = builder.finish(); assert_eq!(levels.len(), 1); let list_level = levels.get(0).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![1, 0, 1]), rep_levels: Some(vec![0, 0, 0]), non_null_indices: vec![], max_def_level: 3, max_rep_level: 1, + array: values, }; assert_eq!(list_level, &expected_level); } @@ -1744,22 +1808,56 @@ mod tests { builder.values().append_null(); builder.append(false); let a = builder.finish(); + let values = a.values().as_list::().values().clone(); let item_field = Field::new("item", a.data_type().clone(), true); - let mut builder = - LevelInfoBuilder::try_new(&item_field, Default::default()).unwrap(); - builder.write(&a, 0..4); + let mut builder = levels(&item_field, a); + builder.write(0..4); let levels = builder.finish(); - let list_level = levels.get(0).unwrap(); - let expected_level = LevelInfo { + let expected_level = ArrayLevels { def_levels: Some(vec![5, 4, 5, 2, 5, 3, 5, 5, 4, 4, 0]), rep_levels: Some(vec![0, 2, 2, 1, 0, 1, 0, 2, 1, 2, 0]), non_null_indices: vec![0, 2, 3, 4, 5], max_def_level: 5, max_rep_level: 2, + array: values, }; - assert_eq!(list_level, &expected_level); + assert_eq!(levels[0], expected_level); + } + + #[test] + fn test_null_dictionary_values() { + let values = Int32Array::new( + vec![1, 2, 3, 4].into(), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + let keys = Int32Array::new( + vec![1, 54, 2, 0].into(), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + // [NULL, NULL, 3, 0] + let dict = DictionaryArray::new(keys, Arc::new(values)); + + let item_field = Field::new("item", dict.data_type().clone(), true); + + let mut builder = levels(&item_field, dict.clone()); + builder.write(0..4); + let levels = builder.finish(); + let expected_level = ArrayLevels { + def_levels: Some(vec![0, 0, 1, 1]), + rep_levels: None, + non_null_indices: vec![2, 3], + max_def_level: 1, + max_rep_level: 0, + array: Arc::new(dict), + }; + assert_eq!(levels[0], expected_level); + } + + fn levels(field: &Field, array: T) -> LevelInfoBuilder { + let v = Arc::new(array) as ArrayRef; + LevelInfoBuilder::try_new(field, Default::default(), &v).unwrap() } } diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index ccec4ffb20c0..752eff86c5e9 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -18,18 +18,19 @@ //! Contains writer which writes arrow data into parquet data. use bytes::Bytes; -use std::fmt::Debug; use std::io::{Read, Write}; use std::iter::Peekable; use std::slice::Iter; use std::sync::{Arc, Mutex}; use std::vec::IntoIter; -use thrift::protocol::{TCompactOutputProtocol, TSerializable}; +use thrift::protocol::TCompactOutputProtocol; use arrow_array::cast::AsArray; use arrow_array::types::*; -use arrow_array::{Array, FixedSizeListArray, RecordBatch, RecordBatchWriter}; -use arrow_schema::{ArrowError, DataType as ArrowDataType, IntervalUnit, SchemaRef}; +use arrow_array::{ArrayRef, RecordBatch, RecordBatchWriter}; +use arrow_schema::{ + ArrowError, DataType as ArrowDataType, Field, IntervalUnit, SchemaRef, +}; use super::schema::{ add_encoded_arrow_schema_to_metadata, arrow_to_parquet_schema, @@ -47,14 +48,15 @@ use crate::errors::{ParquetError, Result}; use crate::file::metadata::{ColumnChunkMetaData, KeyValue, RowGroupMetaDataPtr}; use crate::file::properties::{WriterProperties, WriterPropertiesPtr}; use crate::file::reader::{ChunkReader, Length}; -use crate::file::writer::SerializedFileWriter; +use crate::file::writer::{SerializedFileWriter, SerializedRowGroupWriter}; use crate::schema::types::{ColumnDescPtr, SchemaDescriptor}; -use levels::{calculate_array_levels, LevelInfo}; +use crate::thrift::TSerializable; +use levels::{calculate_array_levels, ArrayLevels}; mod byte_array; mod levels; -/// Arrow writer +/// Encodes [`RecordBatch`] to parquet /// /// Writes Arrow `RecordBatch`es to a Parquet writer. Multiple [`RecordBatch`] will be encoded /// to the same row group, up to `max_row_group_size` rows. Any remaining rows will be @@ -97,7 +99,7 @@ pub struct ArrowWriter { max_row_group_size: usize, } -impl Debug for ArrowWriter { +impl std::fmt::Debug for ArrowWriter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let buffered_memory = self.in_progress_size(); f.debug_struct("ArrowWriter") @@ -150,7 +152,7 @@ impl ArrowWriter { Some(in_progress) => in_progress .writers .iter() - .map(|(_, x)| x.get_estimated_total_bytes() as usize) + .map(|x| x.get_estimated_total_bytes()) .sum(), None => 0, } @@ -208,8 +210,8 @@ impl ArrowWriter { }; let mut row_group_writer = self.writer.next_row_group()?; - for (chunk, close) in in_progress.close()? { - row_group_writer.append_column(&chunk, close)?; + for chunk in in_progress.close()? { + chunk.append_to_row_group(&mut row_group_writer)?; } row_group_writer.close()?; Ok(()) @@ -246,25 +248,27 @@ impl RecordBatchWriter for ArrowWriter { } } -/// A list of [`Bytes`] comprising a single column chunk +/// A single column chunk produced by [`ArrowColumnWriter`] #[derive(Default)] -struct ArrowColumnChunk { +struct ArrowColumnChunkData { length: usize, data: Vec, } -impl Length for ArrowColumnChunk { +impl Length for ArrowColumnChunkData { fn len(&self) -> u64 { self.length as _ } } -impl ChunkReader for ArrowColumnChunk { - type T = ChainReader; +impl ChunkReader for ArrowColumnChunkData { + type T = ArrowColumnChunkReader; fn get_read(&self, start: u64) -> Result { assert_eq!(start, 0); // Assume append_column writes all data in one-shot - Ok(ChainReader(self.data.clone().into_iter().peekable())) + Ok(ArrowColumnChunkReader( + self.data.clone().into_iter().peekable(), + )) } fn get_bytes(&self, _start: u64, _length: usize) -> Result { @@ -272,10 +276,10 @@ impl ChunkReader for ArrowColumnChunk { } } -/// A [`Read`] for an iterator of [`Bytes`] -struct ChainReader(Peekable>); +/// A [`Read`] for [`ArrowColumnChunkData`] +struct ArrowColumnChunkReader(Peekable>); -impl Read for ChainReader { +impl Read for ArrowColumnChunkReader { fn read(&mut self, out: &mut [u8]) -> std::io::Result { let buffer = loop { match self.0.peek_mut() { @@ -295,11 +299,11 @@ impl Read for ChainReader { } } -/// A shared [`ArrowColumnChunk`] +/// A shared [`ArrowColumnChunkData`] /// /// This allows it to be owned by [`ArrowPageWriter`] whilst allowing access via /// [`ArrowRowGroupWriter`] on flush, without requiring self-referential borrows -type SharedColumnChunk = Arc>; +type SharedColumnChunk = Arc>; #[derive(Default)] struct ArrowPageWriter { @@ -345,25 +349,169 @@ impl PageWriter for ArrowPageWriter { } } -/// Encodes a leaf column to [`ArrowPageWriter`] -enum ArrowColumnWriter { +/// A leaf column that can be encoded by [`ArrowColumnWriter`] +#[derive(Debug)] +pub struct ArrowLeafColumn(ArrayLevels); + +/// Computes the [`ArrowLeafColumn`] for a potentially nested [`ArrayRef`] +pub fn compute_leaves(field: &Field, array: &ArrayRef) -> Result> { + let levels = calculate_array_levels(array, field)?; + Ok(levels.into_iter().map(ArrowLeafColumn).collect()) +} + +/// The data for a single column chunk, see [`ArrowColumnWriter`] +pub struct ArrowColumnChunk { + data: ArrowColumnChunkData, + close: ColumnCloseResult, +} + +impl std::fmt::Debug for ArrowColumnChunk { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ArrowColumnChunk") + .field("length", &self.data.length) + .finish_non_exhaustive() + } +} + +impl ArrowColumnChunk { + /// Calls [`SerializedRowGroupWriter::append_column`] with this column's data + pub fn append_to_row_group( + self, + writer: &mut SerializedRowGroupWriter<'_, W>, + ) -> Result<()> { + writer.append_column(&self.data, self.close) + } +} + +/// Encodes [`ArrowLeafColumn`] to [`ArrowColumnChunk`] +/// +/// Note: This is a low-level interface for applications that require fine-grained control +/// of encoding, see [`ArrowWriter`] for a higher-level interface +/// +/// ``` +/// // The arrow schema +/// # use std::sync::Arc; +/// # use arrow_array::*; +/// # use arrow_schema::*; +/// # use parquet::arrow::arrow_to_parquet_schema; +/// # use parquet::arrow::arrow_writer::{ArrowLeafColumn, compute_leaves, get_column_writers}; +/// # use parquet::file::properties::WriterProperties; +/// # use parquet::file::writer::SerializedFileWriter; +/// # +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("i32", DataType::Int32, false), +/// Field::new("f32", DataType::Float32, false), +/// ])); +/// +/// // Compute the parquet schema +/// let parquet_schema = arrow_to_parquet_schema(schema.as_ref()).unwrap(); +/// let props = Arc::new(WriterProperties::default()); +/// +/// // Create writers for each of the leaf columns +/// let col_writers = get_column_writers(&parquet_schema, &props, &schema).unwrap(); +/// +/// // Spawn a worker thread for each column +/// // This is for demonstration purposes, a thread-pool e.g. rayon or tokio, would be better +/// let mut workers: Vec<_> = col_writers +/// .into_iter() +/// .map(|mut col_writer| { +/// let (send, recv) = std::sync::mpsc::channel::(); +/// let handle = std::thread::spawn(move || { +/// for col in recv { +/// col_writer.write(&col)?; +/// } +/// col_writer.close() +/// }); +/// (handle, send) +/// }) +/// .collect(); +/// +/// // Create parquet writer +/// let root_schema = parquet_schema.root_schema_ptr(); +/// let mut out = Vec::with_capacity(1024); // This could be a File +/// let mut writer = SerializedFileWriter::new(&mut out, root_schema, props.clone()).unwrap(); +/// +/// // Start row group +/// let mut row_group = writer.next_row_group().unwrap(); +/// +/// // Columns to encode +/// let to_write = vec![ +/// Arc::new(Int32Array::from_iter_values([1, 2, 3])) as _, +/// Arc::new(Float32Array::from_iter_values([1., 45., -1.])) as _, +/// ]; +/// +/// // Spawn work to encode columns +/// let mut worker_iter = workers.iter_mut(); +/// for (arr, field) in to_write.iter().zip(&schema.fields) { +/// for leaves in compute_leaves(field, arr).unwrap() { +/// worker_iter.next().unwrap().1.send(leaves).unwrap(); +/// } +/// } +/// +/// // Finish up parallel column encoding +/// for (handle, send) in workers { +/// drop(send); // Drop send side to signal termination +/// let chunk = handle.join().unwrap().unwrap(); +/// chunk.append_to_row_group(&mut row_group).unwrap(); +/// } +/// row_group.close().unwrap(); +/// +/// let metadata = writer.close().unwrap(); +/// assert_eq!(metadata.num_rows, 3); +/// ``` +pub struct ArrowColumnWriter { + writer: ArrowColumnWriterImpl, + chunk: SharedColumnChunk, +} + +impl std::fmt::Debug for ArrowColumnWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ArrowColumnWriter").finish_non_exhaustive() + } +} + +enum ArrowColumnWriterImpl { ByteArray(GenericColumnWriter<'static, ByteArrayEncoder>), Column(ColumnWriter<'static>), } impl ArrowColumnWriter { + /// Write an [`ArrowLeafColumn`] + pub fn write(&mut self, col: &ArrowLeafColumn) -> Result<()> { + match &mut self.writer { + ArrowColumnWriterImpl::Column(c) => { + write_leaf(c, &col.0)?; + } + ArrowColumnWriterImpl::ByteArray(c) => { + write_primitive(c, col.0.array().as_ref(), &col.0)?; + } + } + Ok(()) + } + + /// Close this column returning the written [`ArrowColumnChunk`] + pub fn close(self) -> Result { + let close = match self.writer { + ArrowColumnWriterImpl::ByteArray(c) => c.close()?, + ArrowColumnWriterImpl::Column(c) => c.close()?, + }; + let chunk = Arc::try_unwrap(self.chunk).ok().unwrap(); + let data = chunk.into_inner().unwrap(); + Ok(ArrowColumnChunk { data, close }) + } + /// Returns the estimated total bytes for this column writer - fn get_estimated_total_bytes(&self) -> u64 { - match self { - ArrowColumnWriter::ByteArray(c) => c.get_estimated_total_bytes(), - ArrowColumnWriter::Column(c) => c.get_estimated_total_bytes(), + pub fn get_estimated_total_bytes(&self) -> usize { + match &self.writer { + ArrowColumnWriterImpl::ByteArray(c) => c.get_estimated_total_bytes() as _, + ArrowColumnWriterImpl::Column(c) => c.get_estimated_total_bytes() as _, } } } /// Encodes [`RecordBatch`] to a parquet row group struct ArrowRowGroupWriter { - writers: Vec<(SharedColumnChunk, ArrowColumnWriter)>, + writers: Vec, schema: SchemaRef, buffered_rows: usize, } @@ -374,11 +522,7 @@ impl ArrowRowGroupWriter { props: &WriterPropertiesPtr, arrow: &SchemaRef, ) -> Result { - let mut writers = Vec::with_capacity(arrow.fields.len()); - let mut leaves = parquet.columns().iter(); - for field in &arrow.fields { - get_arrow_column_writer(field.data_type(), props, &mut leaves, &mut writers)?; - } + let writers = get_column_writers(parquet, props, arrow)?; Ok(Self { writers, schema: arrow.clone(), @@ -388,49 +532,62 @@ impl ArrowRowGroupWriter { fn write(&mut self, batch: &RecordBatch) -> Result<()> { self.buffered_rows += batch.num_rows(); - let mut writers = self.writers.iter_mut().map(|(_, x)| x); - for (array, field) in batch.columns().iter().zip(&self.schema.fields) { - let mut levels = calculate_array_levels(array, field)?.into_iter(); - write_leaves(&mut writers, &mut levels, array.as_ref())?; + let mut writers = self.writers.iter_mut(); + for (field, column) in self.schema.fields().iter().zip(batch.columns()) { + for leaf in compute_leaves(field.as_ref(), column)? { + writers.next().unwrap().write(&leaf)? + } } Ok(()) } - fn close(self) -> Result> { + fn close(self) -> Result> { self.writers .into_iter() - .map(|(chunk, writer)| { - let close_result = match writer { - ArrowColumnWriter::ByteArray(c) => c.close()?, - ArrowColumnWriter::Column(c) => c.close()?, - }; - - let chunk = Arc::try_unwrap(chunk).ok().unwrap().into_inner().unwrap(); - Ok((chunk, close_result)) - }) + .map(|writer| writer.close()) .collect() } } -/// Get an [`ArrowColumnWriter`] along with a reference to its [`SharedColumnChunk`] +/// Returns the [`ArrowColumnWriter`] for a given schema +pub fn get_column_writers( + parquet: &SchemaDescriptor, + props: &WriterPropertiesPtr, + arrow: &SchemaRef, +) -> Result> { + let mut writers = Vec::with_capacity(arrow.fields.len()); + let mut leaves = parquet.columns().iter(); + for field in &arrow.fields { + get_arrow_column_writer(field.data_type(), props, &mut leaves, &mut writers)?; + } + Ok(writers) +} + +/// Gets the [`ArrowColumnWriter`] for the given `data_type` fn get_arrow_column_writer( data_type: &ArrowDataType, props: &WriterPropertiesPtr, leaves: &mut Iter<'_, ColumnDescPtr>, - out: &mut Vec<(SharedColumnChunk, ArrowColumnWriter)>, + out: &mut Vec, ) -> Result<()> { let col = |desc: &ColumnDescPtr| { let page_writer = Box::::default(); let chunk = page_writer.buffer.clone(); let writer = get_column_writer(desc.clone(), props.clone(), page_writer); - (chunk, ArrowColumnWriter::Column(writer)) + ArrowColumnWriter { + chunk, + writer: ArrowColumnWriterImpl::Column(writer), + } }; let bytes = |desc: &ColumnDescPtr| { let page_writer = Box::::default(); let chunk = page_writer.buffer.clone(); let writer = GenericColumnWriter::new(desc.clone(), props.clone(), page_writer); - (chunk, ArrowColumnWriter::ByteArray(writer)) + ArrowColumnWriter { + chunk, + writer: ArrowColumnWriterImpl::ByteArray(writer), + } }; match data_type { @@ -476,52 +633,8 @@ fn get_arrow_column_writer( Ok(()) } -/// Write the leaves of `array` in depth-first order to `writers` with `levels` -fn write_leaves<'a, W>( - writers: &mut W, - levels: &mut IntoIter, - array: &(dyn Array + 'static), -) -> Result<()> -where - W: Iterator, -{ - match array.data_type() { - ArrowDataType::List(_) => { - write_leaves(writers, levels, array.as_list::().values().as_ref())? - } - ArrowDataType::LargeList(_) => { - write_leaves(writers, levels, array.as_list::().values().as_ref())? - } - ArrowDataType::FixedSizeList(_, _) => { - let array = array.as_any().downcast_ref::().unwrap(); - write_leaves(writers, levels, array.values().as_ref())? - } - ArrowDataType::Struct(_) => { - for column in array.as_struct().columns() { - write_leaves(writers, levels, column.as_ref())? - } - } - ArrowDataType::Map(_, _) => { - let map = array.as_map(); - write_leaves(writers, levels, map.keys().as_ref())?; - write_leaves(writers, levels, map.values().as_ref())? - } - _ => { - let levels = levels.next().unwrap(); - match writers.next().unwrap() { - ArrowColumnWriter::Column(c) => write_leaf(c, array, levels)?, - ArrowColumnWriter::ByteArray(c) => write_primitive(c, array, levels)?, - }; - } - } - Ok(()) -} - -fn write_leaf( - writer: &mut ColumnWriter<'_>, - column: &dyn Array, - levels: LevelInfo, -) -> Result { +fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result { + let column = levels.array().as_ref(); let indices = levels.non_null_indices(); match writer { ColumnWriter::Int32ColumnWriter(ref mut typed) => { @@ -676,7 +789,7 @@ fn write_leaf( fn write_primitive( writer: &mut GenericColumnWriter, values: &E::Values, - levels: LevelInfo, + levels: &ArrayLevels, ) -> Result { writer.write_batch_internal( values, @@ -859,7 +972,7 @@ mod tests { let expected_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap(); - for buffer in vec![ + for buffer in [ get_bytes_after_close(schema.clone(), &expected_batch), get_bytes_by_into_inner(schema, &expected_batch), ] { @@ -1650,6 +1763,27 @@ mod tests { writer.close().unwrap(); } + #[test] + fn check_page_offset_index_with_nan() { + let values = Arc::new(Float64Array::from(vec![f64::NAN; 10])); + let schema = Schema::new(vec![Field::new("col", DataType::Float64, true)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); + + let mut out = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None) + .expect("Unable to write file"); + writer.write(&batch).unwrap(); + let file_meta_data = writer.close().unwrap(); + for row_group in file_meta_data.row_groups { + for column in row_group.columns { + assert!(column.offset_index_offset.is_some()); + assert!(column.offset_index_length.is_some()); + assert!(column.column_index_offset.is_none()); + assert!(column.column_index_length.is_none()); + } + } + } + #[test] fn i8_single_column() { required_and_optional::(0..SMALL_SIZE as i8); @@ -1944,7 +2078,7 @@ mod tests { assert_eq!(a.value(0).len(), 0); assert_eq!(a.value(2).len(), 2); - assert_eq!(a.value(2).null_count(), 2); + assert_eq!(a.value(2).logical_nulls().unwrap().null_count(), 2); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); roundtrip(batch, None); @@ -2137,7 +2271,7 @@ mod tests { #[test] fn u32_min_max() { // check values roundtrip through parquet - let src = vec![ + let src = [ u32::MIN, u32::MIN + 1, (i32::MAX as u32) - 1, @@ -2178,7 +2312,7 @@ mod tests { #[test] fn u64_min_max() { // check values roundtrip through parquet - let src = vec![ + let src = [ u64::MIN, u64::MIN + 1, (i64::MAX as u64) - 1, diff --git a/parquet/src/arrow/async_reader/metadata.rs b/parquet/src/arrow/async_reader/metadata.rs index 076ae5c54052..fe7b4427647c 100644 --- a/parquet/src/arrow/async_reader/metadata.rs +++ b/parquet/src/arrow/async_reader/metadata.rs @@ -17,7 +17,7 @@ use crate::arrow::async_reader::AsyncFileReader; use crate::errors::{ParquetError, Result}; -use crate::file::footer::{decode_footer, read_metadata}; +use crate::file::footer::{decode_footer, decode_metadata}; use crate::file::metadata::ParquetMetaData; use crate::file::page_index::index::Index; use crate::file::page_index::index_reader::{ @@ -27,7 +27,6 @@ use bytes::Bytes; use futures::future::BoxFuture; use futures::FutureExt; use std::future::Future; -use std::io::Read; use std::ops::Range; /// A data source that can be used with [`MetadataLoader`] to load [`ParquetMetaData`] @@ -95,16 +94,14 @@ impl MetadataLoader { // Did not fetch the entire file metadata in the initial read, need to make a second request let (metadata, remainder) = if length > suffix_len - 8 { let metadata_start = file_size - length - 8; - let remaining_metadata = fetch.fetch(metadata_start..footer_start).await?; - - let reader = remaining_metadata.as_ref().chain(&suffix[..suffix_len - 8]); - (read_metadata(reader)?, None) + let meta = fetch.fetch(metadata_start..file_size - 8).await?; + (decode_metadata(&meta)?, None) } else { let metadata_start = file_size - length - 8 - footer_start; let slice = &suffix[metadata_start..suffix_len - 8]; ( - read_metadata(slice)?, + decode_metadata(slice)?, Some((footer_start, suffix.slice(..metadata_start))), ) }; diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index f17fb0751d52..875fff4dac57 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -22,13 +22,13 @@ //! # #[tokio::main(flavor="current_thread")] //! # async fn main() { //! # -//! use arrow_array::RecordBatch; -//! use arrow::util::pretty::pretty_format_batches; -//! use futures::TryStreamExt; -//! use tokio::fs::File; -//! -//! use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; -//! +//! # use arrow_array::RecordBatch; +//! # use arrow::util::pretty::pretty_format_batches; +//! # use futures::TryStreamExt; +//! # use tokio::fs::File; +//! # +//! # use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +//! # //! # fn assert_batches_eq(batches: &[RecordBatch], expected_lines: &[&str]) { //! # let formatted = pretty_format_batches(batches).unwrap().to_string(); //! # let actual_lines: Vec<_> = formatted.trim().lines().collect(); @@ -38,7 +38,7 @@ //! # expected_lines, actual_lines //! # ); //! # } -//! +//! # //! let testdata = arrow::util::test_util::parquet_test_data(); //! let path = format!("{}/alltypes_plain.parquet", testdata); //! let file = File::open(path).await.unwrap(); @@ -77,7 +77,6 @@ use std::collections::VecDeque; use std::fmt::Formatter; - use std::io::SeekFrom; use std::ops::Range; use std::pin::Pin; @@ -88,7 +87,6 @@ use bytes::{Buf, Bytes}; use futures::future::{BoxFuture, FutureExt}; use futures::ready; use futures::stream::Stream; - use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; use arrow_array::RecordBatch; @@ -96,20 +94,24 @@ use arrow_schema::SchemaRef; use crate::arrow::array_reader::{build_array_reader, RowGroups}; use crate::arrow::arrow_reader::{ - apply_range, evaluate_predicate, selects_any, ArrowReaderBuilder, ArrowReaderOptions, - ParquetRecordBatchReader, RowFilter, RowSelection, + apply_range, evaluate_predicate, selects_any, ArrowReaderBuilder, + ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReader, RowFilter, + RowSelection, }; use crate::arrow::ProjectionMask; +use crate::bloom_filter::{ + chunk_read_bloom_filter_header_and_offset, Sbbf, SBBF_HEADER_SIZE_ESTIMATE, +}; use crate::column::page::{PageIterator, PageReader}; - use crate::errors::{ParquetError, Result}; use crate::file::footer::{decode_footer, decode_metadata}; use crate::file::metadata::{ParquetMetaData, RowGroupMetaData}; use crate::file::reader::{ChunkReader, Length, SerializedPageReader}; -use crate::format::PageLocation; - use crate::file::FOOTER_SIZE; +use crate::format::{ + BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHash, PageLocation, +}; mod metadata; pub use metadata::*; @@ -205,6 +207,29 @@ impl AsyncFileReader for T { } } +impl ArrowReaderMetadata { + /// Returns a new [`ArrowReaderMetadata`] for this builder + /// + /// See [`ParquetRecordBatchStreamBuilder::new_with_metadata`] for how this can be used + pub async fn load_async( + input: &mut T, + options: ArrowReaderOptions, + ) -> Result { + let mut metadata = input.get_metadata().await?; + + if options.page_index + && metadata.column_index().is_none() + && metadata.offset_index().is_none() + { + let m = Arc::try_unwrap(metadata).unwrap_or_else(|e| e.as_ref().clone()); + let mut loader = MetadataLoader::new(input, m); + loader.load_page_index(true, true).await?; + metadata = Arc::new(loader.finish()) + } + Self::try_new(metadata, options) + } +} + #[doc(hidden)] /// A newtype used within [`ReaderOptionsBuilder`] to distinguish sync readers from async /// @@ -218,32 +243,129 @@ pub struct AsyncReader(T); /// to use this information to select what specific columns, row groups, etc... /// they wish to be read by the resulting stream /// +/// See [`ArrowReaderBuilder`] for additional member functions pub type ParquetRecordBatchStreamBuilder = ArrowReaderBuilder>; -impl ArrowReaderBuilder> { +impl ParquetRecordBatchStreamBuilder { /// Create a new [`ParquetRecordBatchStreamBuilder`] with the provided parquet file - pub async fn new(mut input: T) -> Result { - let metadata = input.get_metadata().await?; - Self::new_builder(AsyncReader(input), metadata, Default::default()) + pub async fn new(input: T) -> Result { + Self::new_with_options(input, Default::default()).await } + /// Create a new [`ParquetRecordBatchStreamBuilder`] with the provided parquet file + /// and [`ArrowReaderOptions`] pub async fn new_with_options( mut input: T, options: ArrowReaderOptions, ) -> Result { - let mut metadata = input.get_metadata().await?; + let metadata = ArrowReaderMetadata::load_async(&mut input, options).await?; + Ok(Self::new_with_metadata(input, metadata)) + } - if options.page_index - && metadata.column_index().is_none() - && metadata.offset_index().is_none() - { - let m = Arc::try_unwrap(metadata).unwrap_or_else(|e| e.as_ref().clone()); - let mut loader = MetadataLoader::new(&mut input, m); - loader.load_page_index(true, true).await?; - metadata = Arc::new(loader.finish()) + /// Create a [`ParquetRecordBatchStreamBuilder`] from the provided [`ArrowReaderMetadata`] + /// + /// This allows loading metadata once and using it to create multiple builders with + /// potentially different settings + /// + /// ``` + /// # use std::fs::metadata; + /// # use std::sync::Arc; + /// # use bytes::Bytes; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; + /// # use parquet::arrow::arrow_reader::ArrowReaderMetadata; + /// # use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder}; + /// # use tempfile::tempfile; + /// # use futures::StreamExt; + /// # #[tokio::main(flavor="current_thread")] + /// # async fn main() { + /// # + /// let mut file = tempfile().unwrap(); + /// # let schema = Arc::new(Schema::new(vec![Field::new("i32", DataType::Int32, false)])); + /// # let mut writer = ArrowWriter::try_new(&mut file, schema.clone(), None).unwrap(); + /// # let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap(); + /// # writer.write(&batch).unwrap(); + /// # writer.close().unwrap(); + /// # + /// let mut file = tokio::fs::File::from_std(file); + /// let meta = ArrowReaderMetadata::load_async(&mut file, Default::default()).await.unwrap(); + /// let mut a = ParquetRecordBatchStreamBuilder::new_with_metadata( + /// file.try_clone().await.unwrap(), + /// meta.clone() + /// ).build().unwrap(); + /// let mut b = ParquetRecordBatchStreamBuilder::new_with_metadata(file, meta).build().unwrap(); + /// + /// // Should be able to read from both in parallel + /// assert_eq!(a.next().await.unwrap().unwrap(), b.next().await.unwrap().unwrap()); + /// # } + /// ``` + pub fn new_with_metadata(input: T, metadata: ArrowReaderMetadata) -> Self { + Self::new_builder(AsyncReader(input), metadata) + } + + /// Read bloom filter for a column in a row group + /// Returns `None` if the column does not have a bloom filter + /// + /// We should call this function after other forms pruning, such as projection and predicate pushdown. + pub async fn get_row_group_column_bloom_filter( + &mut self, + row_group_idx: usize, + column_idx: usize, + ) -> Result> { + let metadata = self.metadata.row_group(row_group_idx); + let column_metadata = metadata.column(column_idx); + + let offset: usize = if let Some(offset) = column_metadata.bloom_filter_offset() { + offset.try_into().map_err(|_| { + ParquetError::General("Bloom filter offset is invalid".to_string()) + })? + } else { + return Ok(None); + }; + + let buffer = match column_metadata.bloom_filter_length() { + Some(length) => self.input.0.get_bytes(offset..offset + length as usize), + None => self + .input + .0 + .get_bytes(offset..offset + SBBF_HEADER_SIZE_ESTIMATE), + } + .await?; + + let (header, bitset_offset) = + chunk_read_bloom_filter_header_and_offset(offset as u64, buffer.clone())?; + + match header.algorithm { + BloomFilterAlgorithm::BLOCK(_) => { + // this match exists to future proof the singleton algorithm enum + } + } + match header.compression { + BloomFilterCompression::UNCOMPRESSED(_) => { + // this match exists to future proof the singleton compression enum + } + } + match header.hash { + BloomFilterHash::XXHASH(_) => { + // this match exists to future proof the singleton hash enum + } } - Self::new_builder(AsyncReader(input), metadata, options) + let bitset = match column_metadata.bloom_filter_length() { + Some(_) => buffer.slice((bitset_offset as usize - offset)..), + None => { + let bitset_length: usize = header.num_bytes.try_into().map_err(|_| { + ParquetError::General("Bloom filter length is invalid".to_string()) + })?; + self.input + .0 + .get_bytes( + bitset_offset as usize..bitset_offset as usize + bitset_length, + ) + .await? + } + }; + Ok(Some(Sbbf::new(&bitset))) } /// Build a new [`ParquetRecordBatchStream`] @@ -297,7 +419,7 @@ type ReadResult = Result<(ReaderFactory, Option) struct ReaderFactory { metadata: Arc, - fields: Option, + fields: Option>, input: T, @@ -350,7 +472,7 @@ where .await?; let array_reader = build_array_reader( - self.fields.as_ref(), + self.fields.as_deref(), predicate_projection, &row_group, )?; @@ -403,7 +525,7 @@ where let reader = ParquetRecordBatchReader::new( batch_size, - build_array_reader(self.fields.as_ref(), &projection, &row_group)?, + build_array_reader(self.fields.as_deref(), &projection, &row_group)?, selection, ); @@ -569,27 +691,27 @@ impl<'a> InMemoryRowGroup<'a> { .iter() .zip(self.metadata.columns()) .enumerate() - .filter_map(|(idx, (chunk, chunk_meta))| { - (chunk.is_none() && projection.leaf_included(idx)).then(|| { - // If the first page does not start at the beginning of the column, - // then we need to also fetch a dictionary page. - let mut ranges = vec![]; - let (start, _len) = chunk_meta.byte_range(); - match page_locations[idx].first() { - Some(first) if first.offset as u64 != start => { - ranges.push(start as usize..first.offset as usize); - } - _ => (), + .filter(|&(idx, (chunk, _chunk_meta))| { + chunk.is_none() && projection.leaf_included(idx) + }) + .flat_map(|(idx, (_chunk, chunk_meta))| { + // If the first page does not start at the beginning of the column, + // then we need to also fetch a dictionary page. + let mut ranges = vec![]; + let (start, _len) = chunk_meta.byte_range(); + match page_locations[idx].first() { + Some(first) if first.offset as u64 != start => { + ranges.push(start as usize..first.offset as usize); } + _ => (), + } - ranges.extend(selection.scan_ranges(&page_locations[idx])); - page_start_offsets - .push(ranges.iter().map(|range| range.start).collect()); + ranges.extend(selection.scan_ranges(&page_locations[idx])); + page_start_offsets + .push(ranges.iter().map(|range| range.start).collect()); - ranges - }) + ranges }) - .flatten() .collect(); let mut chunk_data = input.get_byte_ranges(fetch_ranges).await?.into_iter(); @@ -617,12 +739,11 @@ impl<'a> InMemoryRowGroup<'a> { .column_chunks .iter() .enumerate() - .filter_map(|(idx, chunk)| { - (chunk.is_none() && projection.leaf_included(idx)).then(|| { - let column = self.metadata.column(idx); - let (start, length) = column.byte_range(); - start as usize..(start + length) as usize - }) + .filter(|&(idx, chunk)| chunk.is_none() && projection.leaf_included(idx)) + .map(|(idx, _chunk)| { + let column = self.metadata.column(idx); + let (start, length) = column.byte_range(); + start as usize..(start + length) as usize }) .collect(); @@ -755,13 +876,19 @@ mod tests { use crate::file::footer::parse_metadata; use crate::file::page_index::index_reader; use crate::file::properties::WriterProperties; + use arrow::compute::kernels::cmp::eq; use arrow::error::Result as ArrowResult; + use arrow_array::builder::{ListBuilder, StringBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; - use arrow_array::{Array, ArrayRef, Int32Array, StringArray}; - use futures::TryStreamExt; + use arrow_array::{ + Array, ArrayRef, Int32Array, Int8Array, Scalar, StringArray, UInt64Array, + }; + use arrow_schema::{DataType, Field, Schema}; + use futures::{StreamExt, TryStreamExt}; use rand::{thread_rng, Rng}; use std::sync::Mutex; + use tempfile::tempfile; #[derive(Clone)] struct TestReader { @@ -1167,14 +1294,16 @@ mod tests { }; let requests = test.requests.clone(); + let a_scalar = StringArray::from_iter_values(["b"]); let a_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![0]), - |batch| arrow::compute::eq_dyn_utf8_scalar(batch.column(0), "b"), + move |batch| eq(batch.column(0), &Scalar::new(&a_scalar)), ); + let b_scalar = StringArray::from_iter_values(["4"]); let b_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![1]), - |batch| arrow::compute::eq_dyn_utf8_scalar(batch.column(0), "4"), + move |batch| eq(batch.column(0), &Scalar::new(&b_scalar)), ); let filter = RowFilter::new(vec![Box::new(a_filter), Box::new(b_filter)]); @@ -1332,12 +1461,13 @@ mod tests { let a_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![1]), - |batch| arrow::compute::eq_dyn_bool_scalar(batch.column(0), true), + |batch| Ok(batch.column(0).as_boolean().clone()), ); + let b_scalar = Int8Array::from(vec![2]); let b_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![2]), - |batch| arrow::compute::eq_dyn_scalar(batch.column(0), 2_i32), + move |batch| eq(batch.column(0), &Scalar::new(&b_scalar)), ); let filter = RowFilter::new(vec![Box::new(a_filter), Box::new(b_filter)]); @@ -1409,7 +1539,7 @@ mod tests { let reader_factory = ReaderFactory { metadata, - fields, + fields: fields.map(Arc::new), input: async_reader, filter: None, limit: None, @@ -1481,4 +1611,162 @@ mod tests { assert_ne!(1024, file_rows); assert_eq!(stream.batch_size, file_rows); } + + #[tokio::test] + async fn test_get_row_group_column_bloom_filter_without_length() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/data_index_bloom_encoding_stats.parquet"); + let data = Bytes::from(std::fs::read(path).unwrap()); + test_get_row_group_column_bloom_filter(data, false).await; + } + + #[tokio::test] + async fn test_get_row_group_column_bloom_filter_with_length() { + // convert to new parquet file with bloom_filter_length + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/data_index_bloom_encoding_stats.parquet"); + let data = Bytes::from(std::fs::read(path).unwrap()); + let metadata = parse_metadata(&data).unwrap(); + let metadata = Arc::new(metadata); + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + let builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + let schema = builder.schema().clone(); + let stream = builder.build().unwrap(); + let batches = stream.try_collect::>().await.unwrap(); + + let mut parquet_data = Vec::new(); + let props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .build(); + let mut writer = + ArrowWriter::try_new(&mut parquet_data, schema, Some(props)).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + } + writer.close().unwrap(); + + // test the new parquet file + test_get_row_group_column_bloom_filter(parquet_data.into(), true).await; + } + + async fn test_get_row_group_column_bloom_filter(data: Bytes, with_length: bool) { + let metadata = parse_metadata(&data).unwrap(); + let metadata = Arc::new(metadata); + + assert_eq!(metadata.num_row_groups(), 1); + let row_group = metadata.row_group(0); + let column = row_group.column(0); + assert_eq!(column.bloom_filter_length().is_some(), with_length); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let mut builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + + let sbbf = builder + .get_row_group_column_bloom_filter(0, 0) + .await + .unwrap() + .unwrap(); + assert!(sbbf.check(&"Hello")); + assert!(!sbbf.check(&"Hello_Not_Exists")); + } + + #[tokio::test] + async fn test_nested_skip() { + let schema = Arc::new(Schema::new(vec![ + Field::new("col_1", DataType::UInt64, false), + Field::new_list("col_2", Field::new("item", DataType::Utf8, true), true), + ])); + + // Default writer properties + let props = WriterProperties::builder() + .set_data_page_row_count_limit(256) + .set_write_batch_size(256) + .set_max_row_group_size(1024); + + // Write data + let mut file = tempfile().unwrap(); + let mut writer = + ArrowWriter::try_new(&mut file, schema.clone(), Some(props.build())).unwrap(); + + let mut builder = ListBuilder::new(StringBuilder::new()); + for id in 0..1024 { + match id % 3 { + 0 => builder + .append_value([Some("val_1".to_string()), Some(format!("id_{id}"))]), + 1 => builder.append_value([Some(format!("id_{id}"))]), + _ => builder.append_null(), + } + } + let refs = vec![ + Arc::new(UInt64Array::from_iter_values(0..1024)) as ArrayRef, + Arc::new(builder.finish()) as ArrayRef, + ]; + + let batch = RecordBatch::try_new(schema.clone(), refs).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let selections = [ + RowSelection::from(vec![ + RowSelector::skip(313), + RowSelector::select(1), + RowSelector::skip(709), + RowSelector::select(1), + ]), + RowSelection::from(vec![ + RowSelector::skip(255), + RowSelector::select(1), + RowSelector::skip(767), + RowSelector::select(1), + ]), + RowSelection::from(vec![ + RowSelector::select(255), + RowSelector::skip(1), + RowSelector::select(767), + RowSelector::skip(1), + ]), + RowSelection::from(vec![ + RowSelector::skip(254), + RowSelector::select(1), + RowSelector::select(1), + RowSelector::skip(767), + RowSelector::select(1), + ]), + ]; + + for selection in selections { + let expected = selection.row_count(); + // Read data + let mut reader = ParquetRecordBatchStreamBuilder::new_with_options( + tokio::fs::File::from_std(file.try_clone().unwrap()), + ArrowReaderOptions::new().with_page_index(true), + ) + .await + .unwrap(); + + reader = reader.with_row_selection(selection); + + let mut stream = reader.build().unwrap(); + + let mut total_rows = 0; + while let Some(rb) = stream.next().await { + let rb = rb.unwrap(); + total_rows += rb.num_rows(); + } + assert_eq!(total_rows, expected); + } + } } diff --git a/parquet/src/arrow/async_writer/mod.rs b/parquet/src/arrow/async_writer/mod.rs index 339618364324..0957b58697d7 100644 --- a/parquet/src/arrow/async_writer/mod.rs +++ b/parquet/src/arrow/async_writer/mod.rs @@ -77,12 +77,16 @@ pub struct AsyncArrowWriter { /// The inner buffer shared by the `sync_writer` and the `async_writer` shared_buffer: SharedBuffer, + + /// Trigger forced flushing once buffer size reaches this value + buffer_size: usize, } impl AsyncArrowWriter { /// Try to create a new Async Arrow Writer. /// - /// `buffer_size` determines the initial size of the intermediate buffer. + /// `buffer_size` determines the number of bytes to buffer before flushing + /// to the underlying [`AsyncWrite`] /// /// The intermediate buffer will automatically be resized if necessary /// @@ -102,6 +106,7 @@ impl AsyncArrowWriter { sync_writer, async_writer: writer, shared_buffer, + buffer_size, }) } @@ -111,7 +116,12 @@ impl AsyncArrowWriter { /// checked and flush if at least half full pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> { self.sync_writer.write(batch)?; - Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, false).await + Self::try_flush( + &mut self.shared_buffer, + &mut self.async_writer, + self.buffer_size, + ) + .await } /// Append [`KeyValue`] metadata in addition to those in [`WriterProperties`] @@ -128,7 +138,7 @@ impl AsyncArrowWriter { let metadata = self.sync_writer.close()?; // Force to flush the remaining data. - Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, true).await?; + Self::try_flush(&mut self.shared_buffer, &mut self.async_writer, 0).await?; self.async_writer.shutdown().await?; Ok(metadata) @@ -139,16 +149,16 @@ impl AsyncArrowWriter { async fn try_flush( shared_buffer: &mut SharedBuffer, async_writer: &mut W, - force: bool, + buffer_size: usize, ) -> Result<()> { let mut buffer = shared_buffer.buffer.try_lock().unwrap(); - if !force && buffer.len() < buffer.capacity() / 2 { + if buffer.is_empty() || buffer.len() < buffer_size { // no need to flush return Ok(()); } async_writer - .write(buffer.as_slice()) + .write_all(buffer.as_slice()) .await .map_err(|e| ParquetError::External(Box::new(e)))?; @@ -197,7 +207,7 @@ impl Write for SharedBuffer { #[cfg(test)] mod tests { - use arrow_array::{ArrayRef, Int64Array, RecordBatchReader}; + use arrow_array::{ArrayRef, BinaryArray, Int64Array, RecordBatchReader}; use bytes::Bytes; use tokio::pin; @@ -364,4 +374,32 @@ mod tests { async_writer.close().await.unwrap(); } } + + #[tokio::test] + async fn test_async_writer_file() { + let col = Arc::new(Int64Array::from_iter_values([1, 2, 3])) as ArrayRef; + let col2 = Arc::new(BinaryArray::from_iter_values(vec![ + vec![0; 500000], + vec![0; 500000], + vec![0; 500000], + ])) as ArrayRef; + let to_write = + RecordBatch::try_from_iter([("col", col), ("col2", col2)]).unwrap(); + + let temp = tempfile::tempfile().unwrap(); + + let file = tokio::fs::File::from_std(temp.try_clone().unwrap()); + let mut writer = + AsyncArrowWriter::try_new(file, to_write.schema(), 0, None).unwrap(); + writer.write(&to_write).await.unwrap(); + writer.close().await.unwrap(); + + let mut reader = ParquetRecordBatchReaderBuilder::try_new(temp) + .unwrap() + .build() + .unwrap(); + let read = reader.next().unwrap().unwrap(); + + assert_eq!(to_write, read); + } } diff --git a/parquet/src/arrow/buffer/bit_util.rs b/parquet/src/arrow/buffer/bit_util.rs index 2781190331c5..b8e2e2f539d3 100644 --- a/parquet/src/arrow/buffer/bit_util.rs +++ b/parquet/src/arrow/buffer/bit_util.rs @@ -35,7 +35,7 @@ pub fn iter_set_bits_rev(bytes: &[u8]) -> impl Iterator + '_ { .prefix() .into_iter() .chain(unaligned.chunks().iter().cloned()) - .chain(unaligned.suffix().into_iter()); + .chain(unaligned.suffix()); iter.rev().flat_map(move |mut chunk| { let chunk_idx = chunk_end_idx - 64; @@ -84,7 +84,7 @@ mod tests { .iter() .enumerate() .rev() - .filter_map(|(x, y)| y.then(|| x)) + .filter_map(|(x, y)| y.then_some(x)) .collect(); assert_eq!(actual, expected); diff --git a/parquet/src/arrow/buffer/offset_buffer.rs b/parquet/src/arrow/buffer/offset_buffer.rs index c8732bc4ed13..07d78e8a3282 100644 --- a/parquet/src/arrow/buffer/offset_buffer.rs +++ b/parquet/src/arrow/buffer/offset_buffer.rs @@ -281,10 +281,10 @@ mod tests { buffer.try_push(v.as_bytes(), false).unwrap() } - let valid = vec![ + let valid = [ true, false, false, true, false, true, false, true, true, false, false, ]; - let valid_mask = Buffer::from_iter(valid.iter().cloned()); + let valid_mask = Buffer::from_iter(valid.iter().copied()); // Both trailing and leading nulls buffer.pad_nulls(1, values.len() - 1, valid.len() - 1, valid_mask.as_slice()); diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs index aad4925c7c70..0174db6b517f 100644 --- a/parquet/src/arrow/mod.rs +++ b/parquet/src/arrow/mod.rs @@ -25,12 +25,13 @@ //!# Example of writing Arrow record batch to Parquet file //! //!```rust -//! use arrow_array::{Int32Array, ArrayRef}; -//! use arrow_array::RecordBatch; -//! use parquet::arrow::arrow_writer::ArrowWriter; -//! use parquet::file::properties::WriterProperties; -//! use std::fs::File; -//! use std::sync::Arc; +//! # use arrow_array::{Int32Array, ArrayRef}; +//! # use arrow_array::RecordBatch; +//! # use parquet::arrow::arrow_writer::ArrowWriter; +//! # use parquet::file::properties::WriterProperties; +//! # use tempfile::tempfile; +//! # use std::sync::Arc; +//! # use parquet::basic::Compression; //! let ids = Int32Array::from(vec![1, 2, 3, 4]); //! let vals = Int32Array::from(vec![5, 6, 7, 8]); //! let batch = RecordBatch::try_from_iter(vec![ @@ -38,9 +39,14 @@ //! ("val", Arc::new(vals) as ArrayRef), //! ]).unwrap(); //! -//! let file = File::create("data.parquet").unwrap(); +//! let file = tempfile().unwrap(); +//! +//! // WriterProperties can be used to set Parquet file options +//! let props = WriterProperties::builder() +//! .set_compression(Compression::SNAPPY) +//! .build(); //! -//! let mut writer = ArrowWriter::try_new(file, batch.schema(), None).unwrap(); +//! let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props)).unwrap(); //! //! writer.write(&batch).expect("Writing batch"); //! @@ -48,24 +54,11 @@ //! writer.close().unwrap(); //! ``` //! -//! `WriterProperties` can be used to set Parquet file options -//! ```rust -//! use parquet::file::properties::WriterProperties; -//! use parquet::basic::{ Compression, Encoding }; -//! use parquet::file::properties::WriterVersion; -//! -//! // File compression -//! let props = WriterProperties::builder() -//! .set_compression(Compression::SNAPPY) -//! .build(); -//! ``` -//! //! # Example of reading parquet file into arrow record batch //! //! ```rust -//! use std::fs::File; -//! use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; -//! +//! # use std::fs::File; +//! # use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; //! # use std::sync::Arc; //! # use arrow_array::Int32Array; //! # use arrow::datatypes::{DataType, Field, Schema}; @@ -88,7 +81,7 @@ //! # writer.write(&batch).expect("Writing batch"); //! # } //! # writer.close().unwrap(); -//! +//! # //! let file = File::open("data.parquet").unwrap(); //! //! let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); @@ -130,6 +123,13 @@ pub use self::schema::{ /// Schema metadata key used to store serialized Arrow IPC schema pub const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; +/// The value of this metadata key, if present on [`Field::metadata`], will be used +/// to populate [`BasicTypeInfo::id`] +/// +/// [`Field::metadata`]: arrow_schema::Field::metadata +/// [`BasicTypeInfo::id`]: crate::schema::types::BasicTypeInfo::id +pub const PARQUET_FIELD_ID_META_KEY: &str = "PARQUET:field_id"; + /// A [`ProjectionMask`] identifies a set of columns within a potentially nested schema to project /// /// In particular, a [`ProjectionMask`] can be constructed from a list of leaf column indices diff --git a/parquet/src/arrow/schema/complex.rs b/parquet/src/arrow/schema/complex.rs index 0d19875d97de..9f85b2c284c6 100644 --- a/parquet/src/arrow/schema/complex.rs +++ b/parquet/src/arrow/schema/complex.rs @@ -19,7 +19,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::arrow::schema::primitive::convert_primitive; -use crate::arrow::ProjectionMask; +use crate::arrow::{ProjectionMask, PARQUET_FIELD_ID_META_KEY}; use crate::basic::{ConvertedType, Repetition}; use crate::errors::ParquetError; use crate::errors::Result; @@ -550,7 +550,16 @@ fn convert_field( field.with_metadata(hint.metadata().clone()) } - None => Field::new(name, data_type, nullable), + None => { + let mut ret = Field::new(name, data_type, nullable); + let basic_info = parquet_type.get_basic_info(); + if basic_info.has_id() { + let mut meta = HashMap::with_capacity(1); + meta.insert(PARQUET_FIELD_ID_META_KEY.to_string(), basic_info.id().to_string()); + ret.set_metadata(meta); + } + ret + }, } } diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index cd6e8046cc63..d56cc42d4313 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -37,7 +37,7 @@ use crate::basic::{ }; use crate::errors::{ParquetError, Result}; use crate::file::{metadata::KeyValue, properties::WriterProperties}; -use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr}; +use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type}; mod complex; mod primitive; @@ -45,6 +45,8 @@ mod primitive; use crate::arrow::ProjectionMask; pub(crate) use complex::{ParquetField, ParquetFieldType}; +use super::PARQUET_FIELD_ID_META_KEY; + /// Convert Parquet schema to Arrow schema including optional metadata /// /// Attempts to decode any existing Arrow schema metadata, falling back @@ -230,13 +232,13 @@ pub(crate) fn add_encoded_arrow_schema_to_metadata( /// Convert arrow schema to parquet schema pub fn arrow_to_parquet_schema(schema: &Schema) -> Result { - let fields: Result> = schema + let fields = schema .fields() .iter() .map(|field| arrow_to_parquet_type(field).map(Arc::new)) - .collect(); + .collect::>()?; let group = Type::group_type_builder("arrow_schema") - .with_fields(&mut fields?) + .with_fields(fields) .build()?; Ok(SchemaDescriptor::new(Arc::new(group))) } @@ -268,12 +270,20 @@ fn parse_key_value_metadata( /// Convert parquet column schema to arrow field. pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result { let field = complex::convert_type(&parquet_column.self_type_ptr())?; - - Ok(Field::new( + let mut ret = Field::new( parquet_column.name(), field.arrow_type, field.nullable, - )) + ); + + let basic_info = parquet_column.self_type().get_basic_info(); + if basic_info.has_id() { + let mut meta = HashMap::with_capacity(1); + meta.insert(PARQUET_FIELD_ID_META_KEY.to_string(), basic_info.id().to_string()); + ret.set_metadata(meta); + } + + Ok(ret) } pub fn decimal_length_from_precision(precision: u8) -> usize { @@ -295,14 +305,17 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } else { Repetition::REQUIRED }; + let id = field_id(field); // create type from field match field.data_type() { DataType::Null => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Unknown)) .with_repetition(repetition) + .with_id(id) .build(), DataType::Boolean => Type::primitive_type_builder(name, PhysicalType::BOOLEAN) .with_repetition(repetition) + .with_id(id) .build(), DataType::Int8 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Integer { @@ -310,6 +323,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_signed: true, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::Int16 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Integer { @@ -317,12 +331,15 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_signed: true, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::Int32 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_repetition(repetition) + .with_id(id) .build(), DataType::Int64 => Type::primitive_type_builder(name, PhysicalType::INT64) .with_repetition(repetition) + .with_id(id) .build(), DataType::UInt8 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Integer { @@ -330,6 +347,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_signed: false, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::UInt16 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Integer { @@ -337,6 +355,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_signed: false, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::UInt32 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Integer { @@ -344,6 +363,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_signed: false, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::UInt64 => Type::primitive_type_builder(name, PhysicalType::INT64) .with_logical_type(Some(LogicalType::Integer { @@ -351,18 +371,22 @@ fn arrow_to_parquet_type(field: &Field) -> Result { is_signed: false, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")), DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT) .with_repetition(repetition) + .with_id(id) .build(), DataType::Float64 => Type::primitive_type_builder(name, PhysicalType::DOUBLE) .with_repetition(repetition) + .with_id(id) .build(), DataType::Timestamp(TimeUnit::Second, _) => { // Cannot represent seconds in LogicalType Type::primitive_type_builder(name, PhysicalType::INT64) .with_repetition(repetition) + .with_id(id) .build() } DataType::Timestamp(time_unit, tz) => { @@ -384,21 +408,25 @@ fn arrow_to_parquet_type(field: &Field) -> Result { }, })) .with_repetition(repetition) + .with_id(id) .build() } DataType::Date32 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Date)) .with_repetition(repetition) + .with_id(id) .build(), // date64 is cast to date32 (#1666) DataType::Date64 => Type::primitive_type_builder(name, PhysicalType::INT32) .with_logical_type(Some(LogicalType::Date)) .with_repetition(repetition) + .with_id(id) .build(), DataType::Time32(TimeUnit::Second) => { // Cannot represent seconds in LogicalType Type::primitive_type_builder(name, PhysicalType::INT32) .with_repetition(repetition) + .with_id(id) .build() } DataType::Time32(unit) => Type::primitive_type_builder(name, PhysicalType::INT32) @@ -410,6 +438,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { }, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::Time64(unit) => Type::primitive_type_builder(name, PhysicalType::INT64) .with_logical_type(Some(LogicalType::Time { @@ -421,6 +450,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { }, })) .with_repetition(repetition) + .with_id(id) .build(), DataType::Duration(_) => { Err(arrow_err!("Converting Duration to parquet not supported",)) @@ -429,17 +459,20 @@ fn arrow_to_parquet_type(field: &Field) -> Result { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_converted_type(ConvertedType::INTERVAL) .with_repetition(repetition) + .with_id(id) .with_length(12) .build() } DataType::Binary | DataType::LargeBinary => { Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) .with_repetition(repetition) + .with_id(id) .build() } DataType::FixedSizeBinary(length) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_repetition(repetition) + .with_id(id) .with_length(*length) .build() } @@ -459,6 +492,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { }; Type::primitive_type_builder(name, physical_type) .with_repetition(repetition) + .with_id(id) .with_length(length) .with_logical_type(Some(LogicalType::Decimal { scale: *scale as i32, @@ -472,18 +506,20 @@ fn arrow_to_parquet_type(field: &Field) -> Result { Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) .with_logical_type(Some(LogicalType::String)) .with_repetition(repetition) + .with_id(id) .build() } DataType::List(f) | DataType::FixedSizeList(f, _) | DataType::LargeList(f) => { Type::group_type_builder(name) - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( Type::group_type_builder("list") - .with_fields(&mut vec![Arc::new(arrow_to_parquet_type(f)?)]) + .with_fields(vec![Arc::new(arrow_to_parquet_type(f)?)]) .with_repetition(Repetition::REPEATED) .build()?, )]) .with_logical_type(Some(LogicalType::List)) .with_repetition(repetition) + .with_id(id) .build() } DataType::Struct(fields) => { @@ -493,37 +529,31 @@ fn arrow_to_parquet_type(field: &Field) -> Result { ); } // recursively convert children to types/nodes - let fields: Result> = fields + let fields = fields .iter() .map(|f| arrow_to_parquet_type(f).map(Arc::new)) - .collect(); + .collect::>()?; Type::group_type_builder(name) - .with_fields(&mut fields?) + .with_fields(fields) .with_repetition(repetition) + .with_id(id) .build() } DataType::Map(field, _) => { if let DataType::Struct(struct_fields) = field.data_type() { Type::group_type_builder(name) - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( Type::group_type_builder(field.name()) - .with_fields(&mut vec![ - Arc::new(arrow_to_parquet_type(&Field::new( - struct_fields[0].name(), - struct_fields[0].data_type().clone(), - false, - ))?), - Arc::new(arrow_to_parquet_type(&Field::new( - struct_fields[1].name(), - struct_fields[1].data_type().clone(), - struct_fields[1].is_nullable(), - ))?), + .with_fields(vec![ + Arc::new(arrow_to_parquet_type(&struct_fields[0])?), + Arc::new(arrow_to_parquet_type(&struct_fields[1])?), ]) .with_repetition(Repetition::REPEATED) .build()?, )]) .with_logical_type(Some(LogicalType::Map)) .with_repetition(repetition) + .with_id(id) .build() } else { Err(arrow_err!( @@ -534,7 +564,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { DataType::Union(_, _) => unimplemented!("See ARROW-8817."), DataType::Dictionary(_, ref value) => { // Dictionary encoding not handled at the schema level - let dict_field = Field::new(name, *value.clone(), field.is_nullable()); + let dict_field = field.clone().with_data_type(value.as_ref().clone()); arrow_to_parquet_type(&dict_field) } DataType::RunEndEncoded(_, _) => Err(arrow_err!( @@ -543,6 +573,11 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } } +fn field_id(field: &Field) -> Option { + let value = field.metadata().get(super::PARQUET_FIELD_ID_META_KEY)?; + value.parse().ok() // Fail quietly if not a valid integer +} + #[cfg(test)] mod tests { use super::*; @@ -551,7 +586,9 @@ mod tests { use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; + use crate::arrow::PARQUET_FIELD_ID_META_KEY; use crate::file::metadata::KeyValue; + use crate::file::reader::FileReader; use crate::{ arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter}, schema::{parser::parse_message_type, types::SchemaDescriptor}, @@ -1555,17 +1592,18 @@ mod tests { #[test] fn test_arrow_schema_roundtrip() -> Result<()> { - // This tests the roundtrip of an Arrow schema - // Fields that are commented out fail roundtrip tests or are unsupported by the writer - let metadata: HashMap = - [("Key".to_string(), "Value".to_string())] - .iter() - .cloned() - .collect(); + let meta = |a: &[(&str, &str)]| -> HashMap { + a.iter() + .map(|(a, b)| (a.to_string(), b.to_string())) + .collect() + }; let schema = Schema::new_with_metadata( vec![ - Field::new("c1", DataType::Utf8, false), + Field::new("c1", DataType::Utf8, false).with_metadata(meta(&[ + ("Key", "Foo"), + (PARQUET_FIELD_ID_META_KEY, "2"), + ])), Field::new("c2", DataType::Binary, false), Field::new("c3", DataType::FixedSizeBinary(3), false), Field::new("c4", DataType::Boolean, false), @@ -1598,24 +1636,40 @@ mod tests { Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), Field::new_list( "c21", - Field::new("list", DataType::Boolean, true), + Field::new("item", DataType::Boolean, true).with_metadata(meta(&[ + ("Key", "Bar"), + (PARQUET_FIELD_ID_META_KEY, "5"), + ])), + false, + ) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "4")])), + Field::new( + "c22", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Boolean, true)), + 5, + ), + false, + ), + Field::new_list( + "c23", + Field::new_large_list( + "inner", + Field::new( + "item", + DataType::Struct( + vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Float64, false), + ] + .into(), + ), + false, + ), + true, + ), false, ), - // Field::new( - // "c22", - // DataType::FixedSizeList(Box::new(DataType::Boolean), 5), - // false, - // ), - // Field::new( - // "c23", - // DataType::List(Box::new(DataType::LargeList(Box::new( - // DataType::Struct(vec![ - // Field::new("a", DataType::Int16, true), - // Field::new("b", DataType::Float64, false), - // ]), - // )))), - // true, - // ), Field::new( "c24", DataType::Struct(Fields::from(vec![ @@ -1626,6 +1680,7 @@ mod tests { ), Field::new("c25", DataType::Interval(IntervalUnit::YearMonth), true), Field::new("c26", DataType::Interval(IntervalUnit::DayTime), true), + // Duration types not supported // Field::new("c27", DataType::Duration(TimeUnit::Second), false), // Field::new("c28", DataType::Duration(TimeUnit::Millisecond), false), // Field::new("c29", DataType::Duration(TimeUnit::Microsecond), false), @@ -1639,19 +1694,29 @@ mod tests { true, 123, true, - ), + ) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "6")])), Field::new("c32", DataType::LargeBinary, true), Field::new("c33", DataType::LargeUtf8, true), - // Field::new( - // "c34", - // DataType::LargeList(Box::new(DataType::List(Box::new( - // DataType::Struct(vec![ - // Field::new("a", DataType::Int16, true), - // Field::new("b", DataType::Float64, true), - // ]), - // )))), - // true, - // ), + Field::new_large_list( + "c34", + Field::new_list( + "inner", + Field::new( + "item", + DataType::Struct( + vec![ + Field::new("a", DataType::Int16, true), + Field::new("b", DataType::Float64, true), + ] + .into(), + ), + true, + ), + true, + ), + true, + ), Field::new("c35", DataType::Null, true), Field::new("c36", DataType::Decimal128(2, 1), false), Field::new("c37", DataType::Decimal256(50, 20), false), @@ -1671,29 +1736,34 @@ mod tests { Field::new_map( "c40", "my_entries", - Field::new("my_key", DataType::Utf8, false), + Field::new("my_key", DataType::Utf8, false) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "8")])), Field::new_list( "my_value", - Field::new("item", DataType::Utf8, true), + Field::new("item", DataType::Utf8, true) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "10")])), true, - ), + ) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "9")])), false, // fails to roundtrip keys_sorted true, - ), + ) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "7")])), Field::new_map( "c41", "my_entries", Field::new("my_key", DataType::Utf8, false), Field::new_list( "my_value", - Field::new("item", DataType::Utf8, true), + Field::new("item", DataType::Utf8, true) + .with_metadata(meta(&[(PARQUET_FIELD_ID_META_KEY, "11")])), true, ), false, // fails to roundtrip keys_sorted false, ), ], - metadata, + meta(&[("Key", "Value")]), ); // write to an empty parquet file so that schema is serialized @@ -1707,9 +1777,94 @@ mod tests { // read file back let arrow_reader = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + + // Check arrow schema let read_schema = arrow_reader.schema(); assert_eq!(&schema, read_schema.as_ref()); + // Walk schema finding field IDs + let mut stack = Vec::with_capacity(10); + let mut out = Vec::with_capacity(10); + + let root = arrow_reader.parquet_schema().root_schema_ptr(); + stack.push((root.name().to_string(), root)); + + while let Some((p, t)) = stack.pop() { + if t.is_group() { + for f in t.get_fields() { + stack.push((format!("{p}.{}", f.name()), f.clone())) + } + } + + let info = t.get_basic_info(); + if info.has_id() { + out.push(format!("{p} -> {}", info.id())) + } + } + out.sort_unstable(); + let out: Vec<_> = out.iter().map(|x| x.as_str()).collect(); + + assert_eq!( + &out, + &[ + "arrow_schema.c1 -> 2", + "arrow_schema.c21 -> 4", + "arrow_schema.c21.list.item -> 5", + "arrow_schema.c31 -> 6", + "arrow_schema.c40 -> 7", + "arrow_schema.c40.my_entries.my_key -> 8", + "arrow_schema.c40.my_entries.my_value -> 9", + "arrow_schema.c40.my_entries.my_value.list.item -> 10", + "arrow_schema.c41.my_entries.my_value.list.item -> 11", + ] + ); + + Ok(()) + } + + #[test] + fn test_read_parquet_field_ids_raw() -> Result<()> { + let meta = |a: &[(&str, &str)]| -> HashMap { + a.iter() + .map(|(a, b)| (a.to_string(), b.to_string())) + .collect() + }; + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Utf8, true).with_metadata(meta(&[ + (PARQUET_FIELD_ID_META_KEY, "1"), + ])), + Field::new("c2", DataType::Utf8, true).with_metadata(meta(&[ + (PARQUET_FIELD_ID_META_KEY, "2"), + ])), + ], + HashMap::new(), + ); + + let writer = ArrowWriter::try_new( + vec![], + Arc::new(schema.clone()), + None, + )?; + let parquet_bytes = writer.into_inner()?; + + let reader = crate::file::reader::SerializedFileReader::new( + bytes::Bytes::from(parquet_bytes), + )?; + let schema_descriptor = reader.metadata().file_metadata().schema_descr_ptr(); + + // don't pass metadata so field ids are read from Parquet and not from serialized Arrow schema + let arrow_schema = crate::arrow::parquet_to_arrow_schema( + &schema_descriptor, + None, + )?; + + let parq_schema_descr = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; + let parq_fields = parq_schema_descr.root_schema().get_fields(); + assert_eq!(parq_fields.len(), 2); + assert_eq!(parq_fields[0].get_basic_info().id(), 1); + assert_eq!(parq_fields[1].get_basic_info().id(), 2); + Ok(()) } diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs index 83d84b77ec06..7d8b6a04ee81 100644 --- a/parquet/src/arrow/schema/primitive.rs +++ b/parquet/src/arrow/schema/primitive.rs @@ -193,11 +193,11 @@ fn from_int64(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result Ok(DataType::Int64), ( Some(LogicalType::Integer { - bit_width, + bit_width: 64, is_signed, }), _, - ) if bit_width == 64 => match is_signed { + ) => match is_signed { true => Ok(DataType::Int64), false => Ok(DataType::UInt64), }, diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index cc8d033f42a4..cdad3597ffef 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -18,6 +18,7 @@ //! Contains Rust mappings for Thrift definition. //! Refer to [`parquet.thrift`](https://github.com/apache/parquet-format/blob/master/src/main/thrift/parquet.thrift) file to see raw definitions. +use std::str::FromStr; use std::{fmt, str}; pub use crate::compression::{BrotliLevel, GzipLevel, ZstdLevel}; @@ -278,6 +279,29 @@ pub enum Encoding { BYTE_STREAM_SPLIT, } +impl FromStr for Encoding { + type Err = ParquetError; + + fn from_str(s: &str) -> Result { + match s { + "PLAIN" | "plain" => Ok(Encoding::PLAIN), + "PLAIN_DICTIONARY" | "plain_dictionary" => Ok(Encoding::PLAIN_DICTIONARY), + "RLE" | "rle" => Ok(Encoding::RLE), + "BIT_PACKED" | "bit_packed" => Ok(Encoding::BIT_PACKED), + "DELTA_BINARY_PACKED" | "delta_binary_packed" => { + Ok(Encoding::DELTA_BINARY_PACKED) + } + "DELTA_LENGTH_BYTE_ARRAY" | "delta_length_byte_array" => { + Ok(Encoding::DELTA_LENGTH_BYTE_ARRAY) + } + "DELTA_BYTE_ARRAY" | "delta_byte_array" => Ok(Encoding::DELTA_BYTE_ARRAY), + "RLE_DICTIONARY" | "rle_dictionary" => Ok(Encoding::RLE_DICTIONARY), + "BYTE_STREAM_SPLIT" | "byte_stream_split" => Ok(Encoding::BYTE_STREAM_SPLIT), + _ => Err(general_err!("unknown encoding: {}", s)), + } + } +} + // ---------------------------------------------------------------------- // Mirrors `parquet::CompressionCodec` @@ -295,6 +319,90 @@ pub enum Compression { LZ4_RAW, } +fn split_compression_string( + str_setting: &str, +) -> Result<(&str, Option), ParquetError> { + let split_setting = str_setting.split_once('('); + + match split_setting { + Some((codec, level_str)) => { + let level = + &level_str[..level_str.len() - 1] + .parse::() + .map_err(|_| { + ParquetError::General(format!( + "invalid compression level: {}", + level_str + )) + })?; + Ok((codec, Some(*level))) + } + None => Ok((str_setting, None)), + } +} + +fn check_level_is_none(level: &Option) -> Result<(), ParquetError> { + if level.is_some() { + return Err(ParquetError::General("level is not support".to_string())); + } + + Ok(()) +} + +fn require_level(codec: &str, level: Option) -> Result { + level.ok_or(ParquetError::General(format!("{} require level", codec))) +} + +impl FromStr for Compression { + type Err = ParquetError; + + fn from_str(s: &str) -> std::result::Result { + let (codec, level) = split_compression_string(s)?; + + let c = match codec { + "UNCOMPRESSED" | "uncompressed" => { + check_level_is_none(&level)?; + Compression::UNCOMPRESSED + } + "SNAPPY" | "snappy" => { + check_level_is_none(&level)?; + Compression::SNAPPY + } + "GZIP" | "gzip" => { + let level = require_level(codec, level)?; + Compression::GZIP(GzipLevel::try_new(level)?) + } + "LZO" | "lzo" => { + check_level_is_none(&level)?; + Compression::LZO + } + "BROTLI" | "brotli" => { + let level = require_level(codec, level)?; + Compression::BROTLI(BrotliLevel::try_new(level)?) + } + "LZ4" | "lz4" => { + check_level_is_none(&level)?; + Compression::LZ4 + } + "ZSTD" | "zstd" => { + let level = require_level(codec, level)?; + Compression::ZSTD(ZstdLevel::try_new(level as i32)?) + } + "LZ4_RAW" | "lz4_raw" => { + check_level_is_none(&level)?; + Compression::LZ4_RAW + } + _ => { + return Err(ParquetError::General(format!( + "unsupport compression {codec}" + ))); + } + }; + + Ok(c) + } +} + // ---------------------------------------------------------------------- // Mirrors `parquet::PageType` @@ -2130,4 +2238,81 @@ mod tests { ); assert_eq!(ColumnOrder::UNDEFINED.sort_order(), SortOrder::SIGNED); } + + #[test] + fn test_parse_encoding() { + let mut encoding: Encoding = "PLAIN".parse().unwrap(); + assert_eq!(encoding, Encoding::PLAIN); + encoding = "PLAIN_DICTIONARY".parse().unwrap(); + assert_eq!(encoding, Encoding::PLAIN_DICTIONARY); + encoding = "RLE".parse().unwrap(); + assert_eq!(encoding, Encoding::RLE); + encoding = "BIT_PACKED".parse().unwrap(); + assert_eq!(encoding, Encoding::BIT_PACKED); + encoding = "DELTA_BINARY_PACKED".parse().unwrap(); + assert_eq!(encoding, Encoding::DELTA_BINARY_PACKED); + encoding = "DELTA_LENGTH_BYTE_ARRAY".parse().unwrap(); + assert_eq!(encoding, Encoding::DELTA_LENGTH_BYTE_ARRAY); + encoding = "DELTA_BYTE_ARRAY".parse().unwrap(); + assert_eq!(encoding, Encoding::DELTA_BYTE_ARRAY); + encoding = "RLE_DICTIONARY".parse().unwrap(); + assert_eq!(encoding, Encoding::RLE_DICTIONARY); + encoding = "BYTE_STREAM_SPLIT".parse().unwrap(); + assert_eq!(encoding, Encoding::BYTE_STREAM_SPLIT); + + // test lowercase + encoding = "byte_stream_split".parse().unwrap(); + assert_eq!(encoding, Encoding::BYTE_STREAM_SPLIT); + + // test unknown string + match "plain_xxx".parse::() { + Ok(e) => { + panic!("Should not be able to parse {:?}", e); + } + Err(e) => { + assert_eq!(e.to_string(), "Parquet error: unknown encoding: plain_xxx"); + } + } + } + + #[test] + fn test_parse_compression() { + let mut compress: Compression = "snappy".parse().unwrap(); + assert_eq!(compress, Compression::SNAPPY); + compress = "lzo".parse().unwrap(); + assert_eq!(compress, Compression::LZO); + compress = "zstd(3)".parse().unwrap(); + assert_eq!(compress, Compression::ZSTD(ZstdLevel::try_new(3).unwrap())); + compress = "LZ4_RAW".parse().unwrap(); + assert_eq!(compress, Compression::LZ4_RAW); + compress = "uncompressed".parse().unwrap(); + assert_eq!(compress, Compression::UNCOMPRESSED); + compress = "snappy".parse().unwrap(); + assert_eq!(compress, Compression::SNAPPY); + compress = "gzip(9)".parse().unwrap(); + assert_eq!(compress, Compression::GZIP(GzipLevel::try_new(9).unwrap())); + compress = "lzo".parse().unwrap(); + assert_eq!(compress, Compression::LZO); + compress = "brotli(3)".parse().unwrap(); + assert_eq!( + compress, + Compression::BROTLI(BrotliLevel::try_new(3).unwrap()) + ); + compress = "lz4".parse().unwrap(); + assert_eq!(compress, Compression::LZ4); + + // test unknown compression + let mut err = "plain_xxx".parse::().unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: unknown encoding: plain_xxx" + ); + + // test invalid compress level + err = "gzip(-10)".parse::().unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: unknown encoding: gzip(-10)" + ); + } } diff --git a/parquet/src/bin/parquet-fromcsv.rs b/parquet/src/bin/parquet-fromcsv.rs index 1ff6fecf5a81..1f5d0a62bbfa 100644 --- a/parquet/src/bin/parquet-fromcsv.rs +++ b/parquet/src/bin/parquet-fromcsv.rs @@ -321,7 +321,7 @@ fn configure_reader_builder(args: &Args, arrow_schema: Arc) -> ReaderBui let mut builder = ReaderBuilder::new(arrow_schema) .with_batch_size(args.batch_size) - .has_header(args.has_header) + .with_header(args.has_header) .with_delimiter(args.get_delimiter()); builder = configure_reader( @@ -386,9 +386,9 @@ fn convert_csv_to_parquet(args: &Args) -> Result<(), ParquetFromCsvError> { Compression::BROTLI(_) => { Box::new(brotli::Decompressor::new(input_file, 0)) as Box } - Compression::LZ4 => Box::new(lz4::Decoder::new(input_file).map_err(|e| { - ParquetFromCsvError::with_context(e, "Failed to create lz4::Decoder") - })?) as Box, + Compression::LZ4 => { + Box::new(lz4_flex::frame::FrameDecoder::new(input_file)) as Box + } Compression::ZSTD(_) => Box::new(zstd::Decoder::new(input_file).map_err(|e| { ParquetFromCsvError::with_context(e, "Failed to create zstd::Decoder") })?) as Box, @@ -606,7 +606,7 @@ mod tests { let reader_builder = configure_reader_builder(&args, arrow_schema); let builder_debug = format!("{reader_builder:?}"); - assert_debug_text(&builder_debug, "has_header", "false"); + assert_debug_text(&builder_debug, "header", "false"); assert_debug_text(&builder_debug, "delimiter", "Some(44)"); assert_debug_text(&builder_debug, "quote", "Some(34)"); assert_debug_text(&builder_debug, "terminator", "None"); @@ -641,7 +641,7 @@ mod tests { ])); let reader_builder = configure_reader_builder(&args, arrow_schema); let builder_debug = format!("{reader_builder:?}"); - assert_debug_text(&builder_debug, "has_header", "true"); + assert_debug_text(&builder_debug, "header", "true"); assert_debug_text(&builder_debug, "delimiter", "Some(9)"); assert_debug_text(&builder_debug, "quote", "None"); assert_debug_text(&builder_debug, "terminator", "Some(10)"); @@ -692,19 +692,9 @@ mod tests { encoder.into_inner() } Compression::LZ4 => { - let mut encoder = lz4::EncoderBuilder::new() - .build(input_file) - .map_err(|e| { - ParquetFromCsvError::with_context( - e, - "Failed to create lz4::Encoder", - ) - }) - .unwrap(); + let mut encoder = lz4_flex::frame::FrameEncoder::new(input_file); write_tmp_file(&mut encoder); - let (inner, err) = encoder.finish(); - err.unwrap(); - inner + encoder.finish().unwrap() } Compression::ZSTD(level) => { diff --git a/parquet/src/bin/parquet-layout.rs b/parquet/src/bin/parquet-layout.rs index d749bb8a4ba7..901ac9ea2309 100644 --- a/parquet/src/bin/parquet-layout.rs +++ b/parquet/src/bin/parquet-layout.rs @@ -38,12 +38,13 @@ use std::io::Read; use clap::Parser; use serde::Serialize; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; +use thrift::protocol::TCompactInputProtocol; use parquet::basic::{Compression, Encoding}; use parquet::errors::Result; use parquet::file::reader::ChunkReader; use parquet::format::PageHeader; +use parquet::thrift::TSerializable; #[derive(Serialize, Debug)] struct ParquetFile { diff --git a/parquet/src/bloom_filter/mod.rs b/parquet/src/bloom_filter/mod.rs index 4d2040b7f258..e98aee9fd213 100644 --- a/parquet/src/bloom_filter/mod.rs +++ b/parquet/src/bloom_filter/mod.rs @@ -26,13 +26,12 @@ use crate::format::{ BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHash, BloomFilterHeader, SplitBlockAlgorithm, Uncompressed, XxHash, }; -use bytes::{Buf, Bytes}; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; +use bytes::Bytes; use std::hash::Hasher; use std::io::Write; use std::sync::Arc; -use thrift::protocol::{ - TCompactInputProtocol, TCompactOutputProtocol, TOutputProtocol, TSerializable, -}; +use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; use twox_hash::XxHash64; /// Salt as defined in the [spec](https://github.com/apache/parquet-format/blob/master/BloomFilter.md#technical-approach). @@ -133,15 +132,14 @@ impl std::ops::IndexMut for Block { #[derive(Debug, Clone)] pub struct Sbbf(Vec); -const SBBF_HEADER_SIZE_ESTIMATE: usize = 20; +pub(crate) const SBBF_HEADER_SIZE_ESTIMATE: usize = 20; -/// given an initial offset, and a [ChunkReader], try to read out a bloom filter header and return +/// given an initial offset, and a byte buffer, try to read out a bloom filter header and return /// both the header and the offset after it (for bitset). -fn chunk_read_bloom_filter_header_and_offset( +pub(crate) fn chunk_read_bloom_filter_header_and_offset( offset: u64, - reader: Arc, + buffer: Bytes, ) -> Result<(BloomFilterHeader, u64), ParquetError> { - let buffer = reader.get_bytes(offset, SBBF_HEADER_SIZE_ESTIMATE)?; let (header, length) = read_bloom_filter_header_and_length(buffer)?; Ok((header, offset + length)) } @@ -149,19 +147,15 @@ fn chunk_read_bloom_filter_header_and_offset( /// given a [Bytes] buffer, try to read out a bloom filter header and return both the header and /// length of the header. #[inline] -fn read_bloom_filter_header_and_length( +pub(crate) fn read_bloom_filter_header_and_length( buffer: Bytes, ) -> Result<(BloomFilterHeader, u64), ParquetError> { let total_length = buffer.len(); - let mut buf_reader = buffer.reader(); - let mut prot = TCompactInputProtocol::new(&mut buf_reader); + let mut prot = TCompactSliceInputProtocol::new(buffer.as_ref()); let header = BloomFilterHeader::read_from_in_protocol(&mut prot).map_err(|e| { ParquetError::General(format!("Could not read bloom filter header: {e}")) })?; - Ok(( - header, - (total_length - buf_reader.into_inner().remaining()) as u64, - )) + Ok((header, (total_length - prot.as_slice().len()) as u64)) } pub(crate) const BITSET_MIN_LENGTH: usize = 32; @@ -205,7 +199,7 @@ impl Sbbf { Self::new(&bitset) } - fn new(bitset: &[u8]) -> Self { + pub(crate) fn new(bitset: &[u8]) -> Self { let data = bitset .chunks_exact(4 * 8) .map(|chunk| { @@ -271,8 +265,13 @@ impl Sbbf { return Ok(None); }; + let buffer = match column_metadata.bloom_filter_length() { + Some(length) => reader.get_bytes(offset, length as usize), + None => reader.get_bytes(offset, SBBF_HEADER_SIZE_ESTIMATE), + }?; + let (header, bitset_offset) = - chunk_read_bloom_filter_header_and_offset(offset, reader.clone())?; + chunk_read_bloom_filter_header_and_offset(offset, buffer.clone())?; match header.algorithm { BloomFilterAlgorithm::BLOCK(_) => { @@ -289,11 +288,17 @@ impl Sbbf { // this match exists to future proof the singleton hash enum } } - // length in bytes - let length: usize = header.num_bytes.try_into().map_err(|_| { - ParquetError::General("Bloom filter length is invalid".to_string()) - })?; - let bitset = reader.get_bytes(bitset_offset, length)?; + + let bitset = match column_metadata.bloom_filter_length() { + Some(_) => buffer.slice((bitset_offset - offset) as usize..), + None => { + let bitset_length: usize = header.num_bytes.try_into().map_err(|_| { + ParquetError::General("Bloom filter length is invalid".to_string()) + })?; + reader.get_bytes(bitset_offset, bitset_length)? + } + }; + Ok(Some(Self::new(&bitset))) } diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index 654cd0816039..933e42386272 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -58,7 +58,7 @@ pub enum Page { } impl Page { - /// Returns [`PageType`](crate::basic::PageType) for this page. + /// Returns [`PageType`] for this page. pub fn page_type(&self) -> PageType { match self { Page::DataPage { .. } => PageType::DATA_PAGE, @@ -85,7 +85,7 @@ impl Page { } } - /// Returns this page [`Encoding`](crate::basic::Encoding). + /// Returns this page [`Encoding`]. pub fn encoding(&self) -> Encoding { match self { Page::DataPage { encoding, .. } => *encoding, @@ -94,7 +94,7 @@ impl Page { } } - /// Returns optional [`Statistics`](crate::file::statistics::Statistics). + /// Returns optional [`Statistics`]. pub fn statistics(&self) -> Option<&Statistics> { match self { Page::DataPage { ref statistics, .. } => statistics.as_ref(), @@ -320,6 +320,20 @@ pub trait PageReader: Iterator> + Send { /// Skips reading the next page, returns an error if no /// column index information fn skip_next_page(&mut self) -> Result<()>; + + /// Returns `true` if the next page can be assumed to contain the start of a new record + /// + /// Prior to parquet V2 the specification was ambiguous as to whether a single record + /// could be split across multiple pages, and prior to [(#4327)] the Rust writer would do + /// this in certain situations. However, correctly interpreting the offset index relies on + /// this assumption holding [(#4943)], and so this mechanism is provided for a [`PageReader`] + /// to signal this to the calling context + /// + /// [(#4327)]: https://github.com/apache/arrow-rs/pull/4327 + /// [(#4943)]: https://github.com/apache/arrow-rs/pull/4943 + fn at_record_boundary(&mut self) -> Result { + Ok(self.peek_next_page()?.is_none()) + } } /// API for writing pages in a column chunk. diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index 88967e179271..52ad4d644c95 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -212,14 +212,17 @@ where Ok((values, levels)) } - /// Read up to `num_records` returning the number of complete records, non-null - /// values and levels decoded + /// Read up to `max_records` whole records, returning the number of complete + /// records, non-null values and levels decoded. All levels for a given record + /// will be read, i.e. the next repetition level, if any, will be 0 /// - /// If the max definition level is 0, `def_levels` will be ignored, otherwise it will be + /// If the max definition level is 0, `def_levels` will be ignored and the number of records, + /// non-null values and levels decoded will all be equal, otherwise `def_levels` will be /// populated with the number of levels read, with an error returned if it is `None`. /// - /// If the max repetition level is 0, `rep_levels` will be ignored, otherwise it will be - /// populated with the number of levels read, with an error returned if it is `None`. + /// If the max repetition level is 0, `rep_levels` will be ignored and the number of records + /// and levels decoded will both be equal, otherwise `rep_levels` will be populated with + /// the number of levels read, with an error returned if it is `None`. /// /// `values` will be contiguously populated with the non-null values. Note that if the column /// is not required, this may be less than either `max_records` or the number of levels read @@ -266,7 +269,7 @@ where // Reached end of page, which implies records_read < remaining_records // as otherwise would have stopped reading before reaching the end assert!(records_read < remaining_records); // Sanity check - records_read += 1; + records_read += reader.flush_partial() as usize; } (records_read, levels_read) } @@ -377,7 +380,7 @@ where // Reached end of page, which implies records_read < remaining_records // as otherwise would have stopped reading before reaching the end assert!(records_read < remaining_records); // Sanity check - records_read += 1; + records_read += decoder.flush_partial() as usize; } (records_read, levels_read) @@ -488,7 +491,7 @@ where offset += bytes_read; self.has_record_delimiter = - self.page_reader.peek_next_page()?.is_none(); + self.page_reader.at_record_boundary()?; self.rep_level_decoder .as_mut() @@ -545,7 +548,7 @@ where // across multiple pages, however, the parquet writer // used to do this so we preserve backwards compatibility self.has_record_delimiter = - self.page_reader.peek_next_page()?.is_none(); + self.page_reader.at_record_boundary()?; self.rep_level_decoder.as_mut().unwrap().set_data( Encoding::RLE, diff --git a/parquet/src/column/reader/decoder.rs b/parquet/src/column/reader/decoder.rs index 369b335dc98f..27ffb7637e18 100644 --- a/parquet/src/column/reader/decoder.rs +++ b/parquet/src/column/reader/decoder.rs @@ -102,6 +102,9 @@ pub trait RepetitionLevelDecoder: ColumnLevelDecoder { num_records: usize, num_levels: usize, ) -> Result<(usize, usize)>; + + /// Flush any partially read or skipped record + fn flush_partial(&mut self) -> bool; } pub trait DefinitionLevelDecoder: ColumnLevelDecoder { @@ -519,6 +522,10 @@ impl RepetitionLevelDecoder for RepetitionLevelDecoderImpl { } Ok((total_records_read, total_levels_read)) } + + fn flush_partial(&mut self) -> bool { + std::mem::take(&mut self.has_partial) + } } #[cfg(test)] diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 1cacfe793328..8c1c55409988 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -500,14 +500,11 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { let metadata = self.write_column_metadata()?; self.page_writer.close()?; - let (column_index, offset_index) = if self.column_index_builder.valid() { - // build the column and offset index - let column_index = self.column_index_builder.build_to_thrift(); - let offset_index = self.offset_index_builder.build_to_thrift(); - (Some(column_index), Some(offset_index)) - } else { - (None, None) - }; + let column_index = self + .column_index_builder + .valid() + .then(|| self.column_index_builder.build_to_thrift()); + let offset_index = Some(self.offset_index_builder.build_to_thrift()); Ok(ColumnCloseResult { bytes_written: self.column_metrics.total_bytes_written, @@ -2111,10 +2108,10 @@ mod tests { #[test] fn test_byte_array_statistics() { - let input = vec!["aawaa", "zz", "aaw", "m", "qrs"] + let input = ["aawaa", "zz", "aaw", "m", "qrs"] .iter() .map(|&s| s.into()) - .collect::>(); + .collect::>(); let stats = statistics_roundtrip::(&input); assert!(!stats.is_min_max_backwards_compatible()); @@ -2129,13 +2126,10 @@ mod tests { #[test] fn test_fixed_len_byte_array_statistics() { - let input = vec!["aawaa", "zz ", "aaw ", "m ", "qrs "] + let input = ["aawaa", "zz ", "aaw ", "m ", "qrs "] .iter() - .map(|&s| { - let b: ByteArray = s.into(); - b.into() - }) - .collect::>(); + .map(|&s| ByteArray::from(s).into()) + .collect::>(); let stats = statistics_roundtrip::(&input); assert!(stats.has_min_max_set()); diff --git a/parquet/src/compression.rs b/parquet/src/compression.rs index f1831ed48444..9e0eee0e3e04 100644 --- a/parquet/src/compression.rs +++ b/parquet/src/compression.rs @@ -388,7 +388,7 @@ mod lz4_codec { use std::io::{Read, Write}; use crate::compression::Codec; - use crate::errors::Result; + use crate::errors::{ParquetError, Result}; const LZ4_BUFFER_SIZE: usize = 4096; @@ -409,7 +409,7 @@ mod lz4_codec { output_buf: &mut Vec, _uncompress_size: Option, ) -> Result { - let mut decoder = lz4::Decoder::new(input_buf)?; + let mut decoder = lz4_flex::frame::FrameDecoder::new(input_buf); let mut buffer: [u8; LZ4_BUFFER_SIZE] = [0; LZ4_BUFFER_SIZE]; let mut total_len = 0; loop { @@ -424,7 +424,7 @@ mod lz4_codec { } fn compress(&mut self, input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { - let mut encoder = lz4::EncoderBuilder::new().build(output_buf)?; + let mut encoder = lz4_flex::frame::FrameEncoder::new(output_buf); let mut from = 0; loop { let to = std::cmp::min(from + LZ4_BUFFER_SIZE, input_buf.len()); @@ -434,7 +434,10 @@ mod lz4_codec { break; } } - encoder.finish().1.map_err(|e| e.into()) + match encoder.finish() { + Ok(_) => Ok(()), + Err(e) => Err(ParquetError::External(Box::new(e))), + } } } } @@ -551,11 +554,7 @@ mod lz4_raw_codec { } }; output_buf.resize(offset + required_len, 0); - match lz4::block::decompress_to_buffer( - input_buf, - Some(required_len.try_into().unwrap()), - &mut output_buf[offset..], - ) { + match lz4_flex::block::decompress_into(input_buf, &mut output_buf[offset..]) { Ok(n) => { if n != required_len { return Err(ParquetError::General( @@ -564,25 +563,20 @@ mod lz4_raw_codec { } Ok(n) } - Err(e) => Err(e.into()), + Err(e) => Err(ParquetError::External(Box::new(e))), } } fn compress(&mut self, input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { let offset = output_buf.len(); - let required_len = lz4::block::compress_bound(input_buf.len())?; + let required_len = lz4_flex::block::get_maximum_output_size(input_buf.len()); output_buf.resize(offset + required_len, 0); - match lz4::block::compress_to_buffer( - input_buf, - None, - false, - &mut output_buf[offset..], - ) { + match lz4_flex::block::compress_into(input_buf, &mut output_buf[offset..]) { Ok(n) => { output_buf.truncate(offset + n); Ok(()) } - Err(e) => Err(e.into()), + Err(e) => Err(ParquetError::External(Box::new(e))), } } } @@ -666,11 +660,11 @@ mod lz4_hadoop_codec { "Not enough bytes to hold advertised output", )); } - let decompressed_size = lz4::block::decompress_to_buffer( + let decompressed_size = lz4_flex::decompress_into( &input[..expected_compressed_size as usize], - Some(output_len as i32), output, - )?; + ) + .map_err(|e| ParquetError::External(Box::new(e)))?; if decompressed_size != expected_decompressed_size as usize { return Err(io::Error::new( io::ErrorKind::Other, diff --git a/parquet/src/encodings/decoding.rs b/parquet/src/encodings/decoding.rs index 8058335875c9..7aed6df419ee 100644 --- a/parquet/src/encodings/decoding.rs +++ b/parquet/src/encodings/decoding.rs @@ -1128,9 +1128,9 @@ mod tests { #[test] fn test_plain_decode_int32() { - let data = vec![42, 18, 52]; + let data = [42, 18, 52]; let data_bytes = Int32Type::to_byte_array(&data[..]); - let mut buffer = vec![0; 3]; + let mut buffer = [0; 3]; test_plain_decode::( ByteBufferPtr::new(data_bytes), 3, @@ -1142,7 +1142,7 @@ mod tests { #[test] fn test_plain_skip_int32() { - let data = vec![42, 18, 52]; + let data = [42, 18, 52]; let data_bytes = Int32Type::to_byte_array(&data[..]); test_plain_skip::( ByteBufferPtr::new(data_bytes), @@ -1155,7 +1155,7 @@ mod tests { #[test] fn test_plain_skip_all_int32() { - let data = vec![42, 18, 52]; + let data = [42, 18, 52]; let data_bytes = Int32Type::to_byte_array(&data[..]); test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 5, -1, &[]); } @@ -1165,7 +1165,7 @@ mod tests { let data = [42, 18, 52]; let expected_data = [0, 42, 0, 18, 0, 0, 52, 0]; let data_bytes = Int32Type::to_byte_array(&data[..]); - let mut buffer = vec![0; 8]; + let mut buffer = [0; 8]; let num_nulls = 5; let valid_bits = [0b01001010]; test_plain_decode_spaced::( @@ -1181,9 +1181,9 @@ mod tests { #[test] fn test_plain_decode_int64() { - let data = vec![42, 18, 52]; + let data = [42, 18, 52]; let data_bytes = Int64Type::to_byte_array(&data[..]); - let mut buffer = vec![0; 3]; + let mut buffer = [0; 3]; test_plain_decode::( ByteBufferPtr::new(data_bytes), 3, @@ -1195,7 +1195,7 @@ mod tests { #[test] fn test_plain_skip_int64() { - let data = vec![42, 18, 52]; + let data = [42, 18, 52]; let data_bytes = Int64Type::to_byte_array(&data[..]); test_plain_skip::( ByteBufferPtr::new(data_bytes), @@ -1208,16 +1208,16 @@ mod tests { #[test] fn test_plain_skip_all_int64() { - let data = vec![42, 18, 52]; + let data = [42, 18, 52]; let data_bytes = Int64Type::to_byte_array(&data[..]); test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 3, -1, &[]); } #[test] fn test_plain_decode_float() { - let data = vec![PI_f32, 2.414, 12.51]; + let data = [PI_f32, 2.414, 12.51]; let data_bytes = FloatType::to_byte_array(&data[..]); - let mut buffer = vec![0.0; 3]; + let mut buffer = [0.0; 3]; test_plain_decode::( ByteBufferPtr::new(data_bytes), 3, @@ -1229,7 +1229,7 @@ mod tests { #[test] fn test_plain_skip_float() { - let data = vec![PI_f32, 2.414, 12.51]; + let data = [PI_f32, 2.414, 12.51]; let data_bytes = FloatType::to_byte_array(&data[..]); test_plain_skip::( ByteBufferPtr::new(data_bytes), @@ -1242,14 +1242,14 @@ mod tests { #[test] fn test_plain_skip_all_float() { - let data = vec![PI_f32, 2.414, 12.51]; + let data = [PI_f32, 2.414, 12.51]; let data_bytes = FloatType::to_byte_array(&data[..]); test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 4, -1, &[]); } #[test] fn test_plain_skip_double() { - let data = vec![PI_f64, 2.414f64, 12.51f64]; + let data = [PI_f64, 2.414f64, 12.51f64]; let data_bytes = DoubleType::to_byte_array(&data[..]); test_plain_skip::( ByteBufferPtr::new(data_bytes), @@ -1262,16 +1262,16 @@ mod tests { #[test] fn test_plain_skip_all_double() { - let data = vec![PI_f64, 2.414f64, 12.51f64]; + let data = [PI_f64, 2.414f64, 12.51f64]; let data_bytes = DoubleType::to_byte_array(&data[..]); test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 5, -1, &[]); } #[test] fn test_plain_decode_double() { - let data = vec![PI_f64, 2.414f64, 12.51f64]; + let data = [PI_f64, 2.414f64, 12.51f64]; let data_bytes = DoubleType::to_byte_array(&data[..]); - let mut buffer = vec![0.0f64; 3]; + let mut buffer = [0.0f64; 3]; test_plain_decode::( ByteBufferPtr::new(data_bytes), 3, @@ -1283,13 +1283,13 @@ mod tests { #[test] fn test_plain_decode_int96() { - let mut data = vec![Int96::new(); 4]; + let mut data = [Int96::new(); 4]; data[0].set_data(11, 22, 33); data[1].set_data(44, 55, 66); data[2].set_data(10, 20, 30); data[3].set_data(40, 50, 60); let data_bytes = Int96Type::to_byte_array(&data[..]); - let mut buffer = vec![Int96::new(); 4]; + let mut buffer = [Int96::new(); 4]; test_plain_decode::( ByteBufferPtr::new(data_bytes), 4, @@ -1301,7 +1301,7 @@ mod tests { #[test] fn test_plain_skip_int96() { - let mut data = vec![Int96::new(); 4]; + let mut data = [Int96::new(); 4]; data[0].set_data(11, 22, 33); data[1].set_data(44, 55, 66); data[2].set_data(10, 20, 30); @@ -1318,7 +1318,7 @@ mod tests { #[test] fn test_plain_skip_all_int96() { - let mut data = vec![Int96::new(); 4]; + let mut data = [Int96::new(); 4]; data[0].set_data(11, 22, 33); data[1].set_data(44, 55, 66); data[2].set_data(10, 20, 30); @@ -1329,11 +1329,11 @@ mod tests { #[test] fn test_plain_decode_bool() { - let data = vec![ + let data = [ false, true, false, false, true, false, true, true, false, true, ]; let data_bytes = BoolType::to_byte_array(&data[..]); - let mut buffer = vec![false; 10]; + let mut buffer = [false; 10]; test_plain_decode::( ByteBufferPtr::new(data_bytes), 10, @@ -1345,7 +1345,7 @@ mod tests { #[test] fn test_plain_skip_bool() { - let data = vec![ + let data = [ false, true, false, false, true, false, true, true, false, true, ]; let data_bytes = BoolType::to_byte_array(&data[..]); @@ -1360,7 +1360,7 @@ mod tests { #[test] fn test_plain_skip_all_bool() { - let data = vec![ + let data = [ false, true, false, false, true, false, true, true, false, true, ]; let data_bytes = BoolType::to_byte_array(&data[..]); diff --git a/parquet/src/file/footer.rs b/parquet/src/file/footer.rs index fcd6a300c5fb..53496a66b572 100644 --- a/parquet/src/file/footer.rs +++ b/parquet/src/file/footer.rs @@ -18,7 +18,7 @@ use std::{io::Read, sync::Arc}; use crate::format::{ColumnOrder as TColumnOrder, FileMetaData as TFileMetaData}; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; use crate::basic::ColumnOrder; @@ -62,18 +62,13 @@ pub fn parse_metadata(chunk_reader: &R) -> Result Result { - read_metadata(metadata_read) -} - -/// Decodes [`ParquetMetaData`] from the provided [`Read`] -pub(crate) fn read_metadata(read: R) -> Result { +pub fn decode_metadata(buf: &[u8]) -> Result { // TODO: row group filtering - let mut prot = TCompactInputProtocol::new(read); + let mut prot = TCompactSliceInputProtocol::new(buf); let t_file_metadata: TFileMetaData = TFileMetaData::read_from_in_protocol(&mut prot) .map_err(|e| ParquetError::General(format!("Could not parse metadata: {e}")))?; let schema = types::from_thrift(&t_file_metadata.schema)?; @@ -103,13 +98,9 @@ pub fn decode_footer(slice: &[u8; FOOTER_SIZE]) -> Result { } // get the metadata length from the footer - let metadata_len = i32::from_le_bytes(slice[..4].try_into().unwrap()); - metadata_len.try_into().map_err(|_| { - general_err!( - "Invalid Parquet file. Metadata length is less than zero ({})", - metadata_len - ) - }) + let metadata_len = u32::from_le_bytes(slice[..4].try_into().unwrap()); + // u32 won't be larger than usize in most cases + Ok(metadata_len as usize) } /// Parses column orders from Thrift definition. @@ -175,16 +166,6 @@ mod tests { ); } - #[test] - fn test_parse_metadata_invalid_length() { - let test_file = Bytes::from(vec![0, 0, 0, 255, b'P', b'A', b'R', b'1']); - let reader_result = parse_metadata(&test_file); - assert_eq!( - reader_result.unwrap_err().to_string(), - "Parquet error: Invalid Parquet file. Metadata length is less than zero (-16777216)" - ); - } - #[test] fn test_parse_metadata_invalid_start() { let test_file = Bytes::from(vec![255, 0, 0, 0, b'P', b'A', b'R', b'1']); @@ -198,7 +179,7 @@ mod tests { #[test] fn test_metadata_column_orders_parse() { // Define simple schema, we do not need to provide logical types. - let mut fields = vec![ + let fields = vec![ Arc::new( SchemaType::primitive_type_builder("col1", Type::INT32) .build() @@ -211,7 +192,7 @@ mod tests { ), ]; let schema = SchemaType::group_type_builder("schema") - .with_fields(&mut fields) + .with_fields(fields) .build() .unwrap(); let schema_descr = SchemaDescriptor::new(Arc::new(schema)); diff --git a/parquet/src/file/metadata.rs b/parquet/src/file/metadata.rs index bb8346306cf9..1f46c8105ebc 100644 --- a/parquet/src/file/metadata.rs +++ b/parquet/src/file/metadata.rs @@ -155,13 +155,13 @@ impl ParquetMetaData { } /// Override the column index - #[allow(dead_code)] + #[cfg(feature = "arrow")] pub(crate) fn set_column_index(&mut self, index: Option) { self.column_index = index; } /// Override the offset index - #[allow(dead_code)] + #[cfg(feature = "arrow")] pub(crate) fn set_offset_index(&mut self, index: Option) { self.offset_index = index; } @@ -230,7 +230,9 @@ impl FileMetaData { self.key_value_metadata.as_ref() } - /// Returns Parquet ['Type`] that describes schema in this file. + /// Returns Parquet [`Type`] that describes schema in this file. + /// + /// [`Type`]: crate::schema::types::Type pub fn schema(&self) -> &SchemaType { self.schema_descr.root_schema() } @@ -277,6 +279,9 @@ pub struct RowGroupMetaData { sorting_columns: Option>, total_byte_size: i64, schema_descr: SchemaDescPtr, + // We can't infer from file offset of first column since there may empty columns in row group. + file_offset: Option, + ordinal: Option, } impl RowGroupMetaData { @@ -330,6 +335,18 @@ impl RowGroupMetaData { self.schema_descr.clone() } + /// Returns ordinal of this row group in file + #[inline(always)] + pub fn ordinal(&self) -> Option { + self.ordinal + } + + /// Returns file offset of this row group in file. + #[inline(always)] + pub fn file_offset(&self) -> Option { + self.file_offset + } + /// Method to convert from Thrift. pub fn from_thrift( schema_descr: SchemaDescPtr, @@ -350,6 +367,8 @@ impl RowGroupMetaData { sorting_columns, total_byte_size, schema_descr, + file_offset: rg.file_offset, + ordinal: rg.ordinal, }) } @@ -360,9 +379,9 @@ impl RowGroupMetaData { total_byte_size: self.total_byte_size, num_rows: self.num_rows, sorting_columns: self.sorting_columns().cloned(), - file_offset: None, - total_compressed_size: None, - ordinal: None, + file_offset: self.file_offset(), + total_compressed_size: Some(self.compressed_size()), + ordinal: self.ordinal, } } @@ -381,9 +400,11 @@ impl RowGroupMetaDataBuilder { Self(RowGroupMetaData { columns: Vec::with_capacity(schema_descr.num_columns()), schema_descr, + file_offset: None, num_rows: 0, sorting_columns: None, total_byte_size: 0, + ordinal: None, }) } @@ -411,6 +432,17 @@ impl RowGroupMetaDataBuilder { self } + /// Sets ordinal for this row group. + pub fn set_ordinal(mut self, value: i16) -> Self { + self.0.ordinal = Some(value); + self + } + + pub fn set_file_offset(mut self, value: i64) -> Self { + self.0.file_offset = Some(value); + self + } + /// Builds row group metadata. pub fn build(self) -> Result { if self.0.schema_descr.num_columns() != self.0.columns.len() { @@ -442,6 +474,7 @@ pub struct ColumnChunkMetaData { statistics: Option, encoding_stats: Option>, bloom_filter_offset: Option, + bloom_filter_length: Option, offset_index_offset: Option, offset_index_length: Option, column_index_offset: Option, @@ -559,6 +592,11 @@ impl ColumnChunkMetaData { self.bloom_filter_offset } + /// Returns the offset for the bloom filter. + pub fn bloom_filter_length(&self) -> Option { + self.bloom_filter_length + } + /// Returns the offset for the column index. pub fn column_index_offset(&self) -> Option { self.column_index_offset @@ -614,7 +652,7 @@ impl ColumnChunkMetaData { let data_page_offset = col_metadata.data_page_offset; let index_page_offset = col_metadata.index_page_offset; let dictionary_page_offset = col_metadata.dictionary_page_offset; - let statistics = statistics::from_thrift(column_type, col_metadata.statistics); + let statistics = statistics::from_thrift(column_type, col_metadata.statistics)?; let encoding_stats = col_metadata .encoding_stats .as_ref() @@ -625,6 +663,7 @@ impl ColumnChunkMetaData { }) .transpose()?; let bloom_filter_offset = col_metadata.bloom_filter_offset; + let bloom_filter_length = col_metadata.bloom_filter_length; let offset_index_offset = cc.offset_index_offset; let offset_index_length = cc.offset_index_length; let column_index_offset = cc.column_index_offset; @@ -645,6 +684,7 @@ impl ColumnChunkMetaData { statistics, encoding_stats, bloom_filter_offset, + bloom_filter_length, offset_index_offset, offset_index_length, column_index_offset, @@ -690,6 +730,7 @@ impl ColumnChunkMetaData { .as_ref() .map(|vec| vec.iter().map(page_encoding_stats::to_thrift).collect()), bloom_filter_offset: self.bloom_filter_offset, + bloom_filter_length: self.bloom_filter_length, } } @@ -720,6 +761,7 @@ impl ColumnChunkMetaDataBuilder { statistics: None, encoding_stats: None, bloom_filter_offset: None, + bloom_filter_length: None, offset_index_offset: None, offset_index_length: None, column_index_offset: None, @@ -805,6 +847,12 @@ impl ColumnChunkMetaDataBuilder { self } + /// Sets optional bloom filter length in bytes. + pub fn set_bloom_filter_length(mut self, value: Option) -> Self { + self.0.bloom_filter_length = value; + self + } + /// Sets optional offset index offset in bytes. pub fn set_offset_index_offset(mut self, value: Option) -> Self { self.0.offset_index_offset = value; @@ -966,6 +1014,7 @@ mod tests { .set_num_rows(1000) .set_total_byte_size(2000) .set_column_metadata(columns) + .set_ordinal(1) .build() .unwrap(); @@ -1020,6 +1069,7 @@ mod tests { }, ]) .set_bloom_filter_offset(Some(6000)) + .set_bloom_filter_length(Some(25)) .set_offset_index_offset(Some(7000)) .set_offset_index_length(Some(25)) .set_column_index_offset(Some(8000)) @@ -1079,7 +1129,7 @@ mod tests { /// Returns sample schema descriptor so we can create column metadata. fn get_test_schema_descr() -> SchemaDescPtr { let schema = SchemaType::group_type_builder("schema") - .with_fields(&mut vec![ + .with_fields(vec![ Arc::new( SchemaType::primitive_type_builder("a", Type::INT32) .build() diff --git a/parquet/src/file/page_index/index_reader.rs b/parquet/src/file/page_index/index_reader.rs index c36708a59aeb..ae3bf3699c1c 100644 --- a/parquet/src/file/page_index/index_reader.rs +++ b/parquet/src/file/page_index/index_reader.rs @@ -24,9 +24,8 @@ use crate::file::metadata::ColumnChunkMetaData; use crate::file::page_index::index::{Index, NativeIndex}; use crate::file::reader::ChunkReader; use crate::format::{ColumnIndex, OffsetIndex, PageLocation}; -use std::io::Cursor; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; use std::ops::Range; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; /// Computes the covering range of two optional ranges /// @@ -116,7 +115,7 @@ pub fn read_pages_locations( pub(crate) fn decode_offset_index( data: &[u8], ) -> Result, ParquetError> { - let mut prot = TCompactInputProtocol::new(data); + let mut prot = TCompactSliceInputProtocol::new(data); let offset = OffsetIndex::read_from_in_protocol(&mut prot)?; Ok(offset.page_locations) } @@ -125,8 +124,7 @@ pub(crate) fn decode_column_index( data: &[u8], column_type: Type, ) -> Result { - let mut d = Cursor::new(data); - let mut prot = TCompactInputProtocol::new(&mut d); + let mut prot = TCompactSliceInputProtocol::new(data); let index = ColumnIndex::read_from_in_protocol(&mut prot)?; diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index 3d6390c036ae..93b034cf4f60 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -16,6 +16,7 @@ // under the License. //! Configuration via [`WriterProperties`] and [`ReaderProperties`] +use std::str::FromStr; use std::{collections::HashMap, sync::Arc}; use crate::basic::{Compression, Encoding}; @@ -72,6 +73,18 @@ impl WriterVersion { } } +impl FromStr for WriterVersion { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "PARQUET_1_0" | "parquet_1_0" => Ok(WriterVersion::PARQUET_1_0), + "PARQUET_2_0" | "parquet_2_0" => Ok(WriterVersion::PARQUET_2_0), + _ => Err(format!("Invalid writer version: {}", s)), + } + } +} + /// Reference counted writer properties. pub type WriterPropertiesPtr = Arc; @@ -550,9 +563,7 @@ impl WriterPropertiesBuilder { /// Helper method to get existing or new mutable reference of column properties. #[inline] fn get_mut_props(&mut self, col: ColumnPath) -> &mut ColumnProperties { - self.column_properties - .entry(col) - .or_insert_with(Default::default) + self.column_properties.entry(col).or_default() } /// Sets encoding for a column. @@ -657,6 +668,19 @@ pub enum EnabledStatistics { Page, } +impl FromStr for EnabledStatistics { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "NONE" | "none" => Ok(EnabledStatistics::None), + "CHUNK" | "chunk" => Ok(EnabledStatistics::Chunk), + "PAGE" | "page" => Ok(EnabledStatistics::Page), + _ => Err(format!("Invalid statistics arg: {}", s)), + } + } +} + impl Default for EnabledStatistics { fn default() -> Self { DEFAULT_STATISTICS_ENABLED @@ -1184,4 +1208,46 @@ mod tests { assert_eq!(props.codec_options(), &codec_options); } + + #[test] + fn test_parse_writerversion() { + let mut writer_version = "PARQUET_1_0".parse::().unwrap(); + assert_eq!(writer_version, WriterVersion::PARQUET_1_0); + writer_version = "PARQUET_2_0".parse::().unwrap(); + assert_eq!(writer_version, WriterVersion::PARQUET_2_0); + + // test lowercase + writer_version = "parquet_1_0".parse::().unwrap(); + assert_eq!(writer_version, WriterVersion::PARQUET_1_0); + + // test invalid version + match "PARQUET_-1_0".parse::() { + Ok(_) => panic!("Should not be able to parse PARQUET_-1_0"), + Err(e) => { + assert_eq!(e, "Invalid writer version: PARQUET_-1_0"); + } + } + } + + #[test] + fn test_parse_enabledstatistics() { + let mut enabled_statistics = "NONE".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::None); + enabled_statistics = "CHUNK".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::Chunk); + enabled_statistics = "PAGE".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::Page); + + // test lowercase + enabled_statistics = "none".parse::().unwrap(); + assert_eq!(enabled_statistics, EnabledStatistics::None); + + //test invalid statistics + match "ChunkAndPage".parse::() { + Ok(_) => panic!("Should not be able to parse ChunkAndPage"), + Err(e) => { + assert_eq!(e, "Invalid statistics arg: ChunkAndPage"); + } + } + } } diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index f685f14bd92f..b60d30ffea23 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -19,7 +19,6 @@ //! Also contains implementations of the ChunkReader for files (with buffering) and byte arrays (RAM) use std::collections::VecDeque; -use std::io::Cursor; use std::iter; use std::{convert::TryFrom, fs::File, io::Read, path::Path, sync::Arc}; @@ -40,8 +39,9 @@ use crate::format::{PageHeader, PageLocation, PageType}; use crate::record::reader::RowIter; use crate::record::Row; use crate::schema::types::Type as SchemaType; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; use crate::util::memory::ByteBufferPtr; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; +use thrift::protocol::TCompactInputProtocol; impl TryFrom for SerializedFileReader { type Error = ParquetError; @@ -76,7 +76,7 @@ impl<'a> TryFrom<&'a str> for SerializedFileReader { } } -/// Conversion into a [`RowIter`](crate::record::reader::RowIter) +/// Conversion into a [`RowIter`] /// using the full file schema over all row groups. impl IntoIterator for SerializedFileReader { type Item = Result; @@ -242,11 +242,6 @@ impl SerializedFileReader { }) } } - - #[cfg(feature = "arrow")] - pub(crate) fn metadata_ref(&self) -> &Arc { - &self.metadata - } } /// Get midpoint offset for a row group @@ -442,8 +437,10 @@ pub(crate) fn decode_page( let result = match page_header.type_ { PageType::DICTIONARY_PAGE => { - assert!(page_header.dictionary_page_header.is_some()); - let dict_header = page_header.dictionary_page_header.as_ref().unwrap(); + let dict_header = + page_header.dictionary_page_header.as_ref().ok_or_else(|| { + ParquetError::General("Missing dictionary page header".to_string()) + })?; let is_sorted = dict_header.is_sorted.unwrap_or(false); Page::DictionaryPage { buf: buffer, @@ -453,20 +450,22 @@ pub(crate) fn decode_page( } } PageType::DATA_PAGE => { - assert!(page_header.data_page_header.is_some()); - let header = page_header.data_page_header.unwrap(); + let header = page_header.data_page_header.ok_or_else(|| { + ParquetError::General("Missing V1 data page header".to_string()) + })?; Page::DataPage { buf: buffer, num_values: header.num_values as u32, encoding: Encoding::try_from(header.encoding)?, def_level_encoding: Encoding::try_from(header.definition_level_encoding)?, rep_level_encoding: Encoding::try_from(header.repetition_level_encoding)?, - statistics: statistics::from_thrift(physical_type, header.statistics), + statistics: statistics::from_thrift(physical_type, header.statistics)?, } } PageType::DATA_PAGE_V2 => { - assert!(page_header.data_page_header_v2.is_some()); - let header = page_header.data_page_header_v2.unwrap(); + let header = page_header.data_page_header_v2.ok_or_else(|| { + ParquetError::General("Missing V2 data page header".to_string()) + })?; let is_compressed = header.is_compressed.unwrap_or(true); Page::DataPageV2 { buf: buffer, @@ -477,7 +476,7 @@ pub(crate) fn decode_page( def_levels_byte_len: header.definition_levels_byte_length as u32, rep_levels_byte_len: header.repetition_levels_byte_length as u32, is_compressed, - statistics: statistics::from_thrift(physical_type, header.statistics), + statistics: statistics::from_thrift(physical_type, header.statistics)?, } } _ => { @@ -662,11 +661,11 @@ impl PageReader for SerializedPageReader { let buffer = self.reader.get_bytes(front.offset as u64, page_len)?; - let mut cursor = Cursor::new(buffer.as_ref()); - let header = read_page_header(&mut cursor)?; - let offset = cursor.position(); + let mut prot = TCompactSliceInputProtocol::new(buffer.as_ref()); + let header = PageHeader::read_from_in_protocol(&mut prot)?; + let offset = buffer.len() - prot.as_slice().len(); - let bytes = buffer.slice(offset as usize..); + let bytes = buffer.slice(offset..); decode_page( header, bytes.into(), @@ -771,6 +770,15 @@ impl PageReader for SerializedPageReader { } } } + + fn at_record_boundary(&mut self) -> Result { + match &mut self.state { + SerializedPageReaderState::Values { .. } => { + Ok(self.peek_next_page()?.is_none()) + } + SerializedPageReaderState::Pages { .. } => Ok(true), + } + } } #[cfg(test)] @@ -852,38 +860,23 @@ mod tests { #[test] fn test_file_reader_into_iter() { let path = get_test_path("alltypes_plain.parquet"); - let vec = vec![path.clone(), path] - .iter() - .map(|p| SerializedFileReader::try_from(p.as_path()).unwrap()) - .flat_map(|r| r.into_iter()) - .flat_map(|r| r.unwrap().get_int(0)) - .collect::>(); - - // rows in the parquet file are not sorted by "id" - // each file contains [id:4, id:5, id:6, id:7, id:2, id:3, id:0, id:1] - assert_eq!(vec, vec![4, 5, 6, 7, 2, 3, 0, 1, 4, 5, 6, 7, 2, 3, 0, 1]); + let reader = SerializedFileReader::try_from(path.as_path()).unwrap(); + let iter = reader.into_iter(); + let values: Vec<_> = iter.flat_map(|x| x.unwrap().get_int(0)).collect(); + + assert_eq!(values, &[4, 5, 6, 7, 2, 3, 0, 1]); } #[test] fn test_file_reader_into_iter_project() { let path = get_test_path("alltypes_plain.parquet"); - let result = vec![path] - .iter() - .map(|p| SerializedFileReader::try_from(p.as_path()).unwrap()) - .flat_map(|r| { - let schema = "message schema { OPTIONAL INT32 id; }"; - let proj = parse_message_type(schema).ok(); - - r.into_iter().project(proj).unwrap() - }) - .map(|r| format!("{}", r.unwrap())) - .collect::>() - .join(","); + let reader = SerializedFileReader::try_from(path.as_path()).unwrap(); + let schema = "message schema { OPTIONAL INT32 id; }"; + let proj = parse_message_type(schema).ok(); + let iter = reader.into_iter().project(proj).unwrap(); + let values: Vec<_> = iter.flat_map(|x| x.unwrap().get_int(0)).collect(); - assert_eq!( - result, - "{id: 4},{id: 5},{id: 6},{id: 7},{id: 2},{id: 3},{id: 0},{id: 1}" - ); + assert_eq!(values, &[4, 5, 6, 7, 2, 3, 0, 1]); } #[test] diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs index 939ce037f968..b36e37a80c97 100644 --- a/parquet/src/file/statistics.rs +++ b/parquet/src/file/statistics.rs @@ -44,6 +44,7 @@ use crate::format::Statistics as TStatistics; use crate::basic::Type; use crate::data_type::private::ParquetValueType; use crate::data_type::*; +use crate::errors::{ParquetError, Result}; use crate::util::bit_util::from_le_slice; pub(crate) mod private { @@ -119,15 +120,18 @@ macro_rules! statistics_enum_func { pub fn from_thrift( physical_type: Type, thrift_stats: Option, -) -> Option { - match thrift_stats { +) -> Result> { + Ok(match thrift_stats { Some(stats) => { // Number of nulls recorded, when it is not available, we just mark it as 0. let null_count = stats.null_count.unwrap_or(0); - assert!( - null_count >= 0, - "Statistics null count is negative ({null_count})" - ); + + if null_count < 0 { + return Err(ParquetError::General(format!( + "Statistics null count is negative {}", + null_count + ))); + } // Generic null count. let null_count = null_count as u64; @@ -221,7 +225,7 @@ pub fn from_thrift( Some(res) } None => None, - } + }) } // Convert Statistics into Thrift definition. @@ -594,7 +598,7 @@ mod tests { } #[test] - #[should_panic(expected = "Statistics null count is negative (-10)")] + #[should_panic(expected = "General(\"Statistics null count is negative -10\")")] fn test_statistics_negative_null_count() { let thrift_stats = TStatistics { max: None, @@ -605,13 +609,13 @@ mod tests { min_value: None, }; - from_thrift(Type::INT32, Some(thrift_stats)); + from_thrift(Type::INT32, Some(thrift_stats)).unwrap(); } #[test] fn test_statistics_thrift_none() { - assert_eq!(from_thrift(Type::INT32, None), None); - assert_eq!(from_thrift(Type::BYTE_ARRAY, None), None); + assert_eq!(from_thrift(Type::INT32, None).unwrap(), None); + assert_eq!(from_thrift(Type::BYTE_ARRAY, None).unwrap(), None); } #[test] @@ -715,7 +719,7 @@ mod tests { fn check_stats(stats: Statistics) { let tpe = stats.physical_type(); let thrift_stats = to_thrift(Some(&stats)); - assert_eq!(from_thrift(tpe, thrift_stats), Some(stats)); + assert_eq!(from_thrift(tpe, thrift_stats).unwrap(), Some(stats)); } check_stats(Statistics::boolean(Some(false), Some(true), None, 7, true)); diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index bde350a1ea42..7796be6013df 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -21,10 +21,11 @@ use crate::bloom_filter::Sbbf; use crate::format as parquet; use crate::format::{ColumnIndex, OffsetIndex, RowGroup}; +use crate::thrift::TSerializable; use std::fmt::Debug; use std::io::{BufWriter, IoSlice, Read}; use std::{io::Write, sync::Arc}; -use thrift::protocol::{TCompactOutputProtocol, TSerializable}; +use thrift::protocol::TCompactOutputProtocol; use crate::column::writer::{ get_typed_column_writer_mut, ColumnCloseResult, ColumnWriterImpl, @@ -115,7 +116,8 @@ pub type OnCloseRowGroup<'a> = Box< Vec>, Vec>, ) -> Result<()> - + 'a, + + 'a + + Send, >; // ---------------------------------------------------------------------- @@ -183,6 +185,8 @@ impl SerializedFileWriter { /// previous row group must be finalised and closed using `RowGroupWriter::close` method. pub fn next_row_group(&mut self) -> Result> { self.assert_previous_writer_closed()?; + let ordinal = self.row_group_index; + self.row_group_index += 1; let row_groups = &mut self.row_groups; @@ -204,6 +208,7 @@ impl SerializedFileWriter { self.descr.clone(), self.props.clone(), &mut self.buf, + ordinal as i16, Some(Box::new(on_close)), ); Ok(row_group_writer) @@ -264,12 +269,15 @@ impl SerializedFileWriter { Some(bloom_filter) => { let start_offset = self.buf.bytes_written(); bloom_filter.write(&mut self.buf)?; + let end_offset = self.buf.bytes_written(); // set offset and index for bloom filter - column_chunk + let column_chunk_meta = column_chunk .meta_data .as_mut() - .expect("can't have bloom filter without column metadata") - .bloom_filter_offset = Some(start_offset as i64); + .expect("can't have bloom filter without column metadata"); + column_chunk_meta.bloom_filter_offset = Some(start_offset as i64); + column_chunk_meta.bloom_filter_length = + Some((end_offset - start_offset) as i32); } None => {} } @@ -347,7 +355,7 @@ impl SerializedFileWriter { let end_pos = self.buf.bytes_written(); // Write footer - let metadata_len = (end_pos - start_pos) as i32; + let metadata_len = (end_pos - start_pos) as u32; self.buf.write_all(&metadata_len.to_le_bytes())?; self.buf.write_all(&PARQUET_MAGIC)?; @@ -409,6 +417,8 @@ pub struct SerializedRowGroupWriter<'a, W: Write> { bloom_filters: Vec>, column_indexes: Vec>, offset_indexes: Vec>, + row_group_index: i16, + file_offset: i64, on_close: Option>, } @@ -418,16 +428,22 @@ impl<'a, W: Write + Send> SerializedRowGroupWriter<'a, W> { /// - `schema_descr` - the schema to write /// - `properties` - writer properties /// - `buf` - the buffer to write data to + /// - `row_group_index` - row group index in this parquet file. + /// - `file_offset` - file offset of this row group in this parquet file. /// - `on_close` - an optional callback that will invoked on [`Self::close`] pub fn new( schema_descr: SchemaDescPtr, properties: WriterPropertiesPtr, buf: &'a mut TrackedWrite, + row_group_index: i16, on_close: Option>, ) -> Self { let num_columns = schema_descr.num_columns(); + let file_offset = buf.bytes_written() as i64; Self { buf, + row_group_index, + file_offset, on_close, total_rows_written: None, descr: schema_descr, @@ -603,6 +619,8 @@ impl<'a, W: Write + Send> SerializedRowGroupWriter<'a, W> { .set_total_byte_size(self.total_uncompressed_bytes) .set_num_rows(self.total_rows_written.unwrap_or(0) as i64) .set_sorting_columns(self.props.sorting_columns().cloned()) + .set_ordinal(self.row_group_index) + .set_file_offset(self.file_offset) .build()?; let metadata = Arc::new(row_group_metadata); @@ -742,6 +760,8 @@ mod tests { use crate::column::reader::get_typed_column_reader; use crate::compression::{create_codec, Codec, CodecOptionsBuilder}; use crate::data_type::{BoolType, Int32Type}; + use crate::file::page_index::index::Index; + use crate::file::properties::EnabledStatistics; use crate::file::reader::ChunkReader; use crate::file::serialized_reader::ReadOptionsBuilder; use crate::file::{ @@ -760,7 +780,7 @@ mod tests { let file = tempfile::tempfile().unwrap(); let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( types::Type::primitive_type_builder("col1", Type::INT32) .build() .unwrap(), @@ -786,7 +806,7 @@ mod tests { let file = tempfile::tempfile().unwrap(); let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![ + .with_fields(vec![ Arc::new( types::Type::primitive_type_builder("col1", Type::INT32) .with_repetition(Repetition::REQUIRED) @@ -833,7 +853,7 @@ mod tests { let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( types::Type::primitive_type_builder("col1", Type::INT32) .build() .unwrap(), @@ -856,7 +876,7 @@ mod tests { let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( types::Type::primitive_type_builder("col1", Type::INT32) .build() .unwrap(), @@ -905,7 +925,7 @@ mod tests { ); let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![field.clone()]) + .with_fields(vec![field.clone()]) .build() .unwrap(), ); @@ -948,7 +968,7 @@ mod tests { let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![ + .with_fields(vec![ Arc::new( types::Type::primitive_type_builder("col1", Type::INT32) .build() @@ -1133,7 +1153,8 @@ mod tests { statistics: from_thrift( physical_type, to_thrift(statistics.as_ref()), - ), + ) + .unwrap(), } } Page::DataPageV2 { @@ -1166,7 +1187,8 @@ mod tests { statistics: from_thrift( physical_type, to_thrift(statistics.as_ref()), - ), + ) + .unwrap(), } } Page::DictionaryPage { @@ -1293,7 +1315,7 @@ mod tests { { let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( types::Type::primitive_type_builder("col1", D::get_physical_type()) .with_repetition(Repetition::REQUIRED) .build() @@ -1312,6 +1334,7 @@ mod tests { let mut rows: i64 = 0; for (idx, subset) in data.iter().enumerate() { + let row_group_file_offset = file_writer.buf.bytes_written(); let mut row_group_writer = file_writer.next_row_group().unwrap(); if let Some(mut writer) = row_group_writer.next_column().unwrap() { rows += writer @@ -1323,6 +1346,8 @@ mod tests { let last_group = row_group_writer.close().unwrap(); let flushed = file_writer.flushed_row_groups(); assert_eq!(flushed.len(), idx + 1); + assert_eq!(Some(idx as i16), last_group.ordinal()); + assert_eq!(Some(row_group_file_offset as i64), last_group.file_offset()); assert_eq!(flushed[idx].as_ref(), last_group.as_ref()); } let file_metadata = file_writer.close().unwrap(); @@ -1448,7 +1473,7 @@ mod tests { ) { let schema = Arc::new( types::Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( types::Type::primitive_type_builder("col1", Type::INT32) .with_repetition(Repetition::REQUIRED) .build() @@ -1648,4 +1673,62 @@ mod tests { let reader = SerializedFileReader::new_with_options(file, options).unwrap(); test_read(reader); } + + #[test] + fn test_disabled_statistics() { + let message_type = " + message test_schema { + REQUIRED INT32 a; + REQUIRED INT32 b; + } + "; + let schema = Arc::new(parse_message_type(message_type).unwrap()); + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::None) + .set_column_statistics_enabled("a".into(), EnabledStatistics::Page) + .build(); + let mut file = Vec::with_capacity(1024); + let mut file_writer = + SerializedFileWriter::new(&mut file, schema, Arc::new(props)).unwrap(); + + let mut row_group_writer = file_writer.next_row_group().unwrap(); + let mut a_writer = row_group_writer.next_column().unwrap().unwrap(); + let col_writer = a_writer.typed::(); + col_writer.write_batch(&[1, 2, 3], None, None).unwrap(); + a_writer.close().unwrap(); + + let mut b_writer = row_group_writer.next_column().unwrap().unwrap(); + let col_writer = b_writer.typed::(); + col_writer.write_batch(&[4, 5, 6], None, None).unwrap(); + b_writer.close().unwrap(); + row_group_writer.close().unwrap(); + + let metadata = file_writer.close().unwrap(); + assert_eq!(metadata.row_groups.len(), 1); + let row_group = &metadata.row_groups[0]; + assert_eq!(row_group.columns.len(), 2); + // Column "a" has both offset and column index, as requested + assert!(row_group.columns[0].offset_index_offset.is_some()); + assert!(row_group.columns[0].column_index_offset.is_some()); + // Column "b" should only have offset index + assert!(row_group.columns[1].offset_index_offset.is_some()); + assert!(row_group.columns[1].column_index_offset.is_none()); + + let options = ReadOptionsBuilder::new().with_page_index().build(); + let reader = + SerializedFileReader::new_with_options(Bytes::from(file), options).unwrap(); + + let offset_index = reader.metadata().offset_index().unwrap(); + assert_eq!(offset_index.len(), 1); // 1 row group + assert_eq!(offset_index[0].len(), 2); // 2 columns + + let column_index = reader.metadata().column_index().unwrap(); + assert_eq!(column_index.len(), 1); // 1 row group + assert_eq!(column_index[0].len(), 2); // 2 column + + let a_idx = &column_index[0][0]; + assert!(matches!(a_idx, Index::INT32(_)), "{a_idx:?}"); + let b_idx = &column_index[0][1]; + assert!(matches!(b_idx, Index::NONE), "{b_idx:?}"); + } } diff --git a/parquet/src/format.rs b/parquet/src/format.rs index 0851b2287fba..46adc39e6406 100644 --- a/parquet/src/format.rs +++ b/parquet/src/format.rs @@ -1,9 +1,10 @@ -// Autogenerated by Thrift Compiler (0.17.0) +// Autogenerated by Thrift Compiler (0.19.0) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +#![allow(dead_code)] #![allow(unused_imports)] #![allow(unused_extern_crates)] -#![allow(clippy::too_many_arguments, clippy::type_complexity, clippy::vec_box)] +#![allow(clippy::too_many_arguments, clippy::type_complexity, clippy::vec_box, clippy::wrong_self_convention)] #![cfg_attr(rustfmt, rustfmt_skip)] use std::cell::RefCell; @@ -52,12 +53,12 @@ impl Type { ]; } -impl TSerializable for Type { +impl crate::thrift::TSerializable for Type { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(Type::from(enum_value)) } @@ -99,7 +100,7 @@ impl From<&Type> for i32 { /// DEPRECATED: Common types used by frameworks(e.g. hive, pig) using parquet. /// ConvertedType is superseded by LogicalType. This enum should not be extended. -/// +/// /// See LogicalTypes.md for conversion between ConvertedType and LogicalType. #[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct ConvertedType(pub i32); @@ -117,12 +118,12 @@ impl ConvertedType { /// an enum is converted into a binary field pub const ENUM: ConvertedType = ConvertedType(4); /// A decimal value. - /// + /// /// This may be used to annotate binary or fixed primitive types. The /// underlying byte array stores the unscaled value encoded as two's /// complement using big-endian byte order (the most significant byte is the /// zeroth element). The value of the decimal is the value * 10^{-scale}. - /// + /// /// This must be accompanied by a (maximum) precision and a scale in the /// SchemaElement. The precision specifies the number of digits in the decimal /// and the scale stores the location of the decimal point. For example 1.23 @@ -130,62 +131,62 @@ impl ConvertedType { /// 2 digits over). pub const DECIMAL: ConvertedType = ConvertedType(5); /// A Date - /// + /// /// Stored as days since Unix epoch, encoded as the INT32 physical type. - /// + /// pub const DATE: ConvertedType = ConvertedType(6); /// A time - /// + /// /// The total number of milliseconds since midnight. The value is stored /// as an INT32 physical type. pub const TIME_MILLIS: ConvertedType = ConvertedType(7); /// A time. - /// + /// /// The total number of microseconds since midnight. The value is stored as /// an INT64 physical type. pub const TIME_MICROS: ConvertedType = ConvertedType(8); /// A date/time combination - /// + /// /// Date and time recorded as milliseconds since the Unix epoch. Recorded as /// a physical type of INT64. pub const TIMESTAMP_MILLIS: ConvertedType = ConvertedType(9); /// A date/time combination - /// + /// /// Date and time recorded as microseconds since the Unix epoch. The value is /// stored as an INT64 physical type. pub const TIMESTAMP_MICROS: ConvertedType = ConvertedType(10); /// An unsigned integer value. - /// + /// /// The number describes the maximum number of meaningful data bits in /// the stored value. 8, 16 and 32 bit values are stored using the /// INT32 physical type. 64 bit values are stored using the INT64 /// physical type. - /// + /// pub const UINT_8: ConvertedType = ConvertedType(11); pub const UINT_16: ConvertedType = ConvertedType(12); pub const UINT_32: ConvertedType = ConvertedType(13); pub const UINT_64: ConvertedType = ConvertedType(14); /// A signed integer value. - /// + /// /// The number describes the maximum number of meaningful data bits in /// the stored value. 8, 16 and 32 bit values are stored using the /// INT32 physical type. 64 bit values are stored using the INT64 /// physical type. - /// + /// pub const INT_8: ConvertedType = ConvertedType(15); pub const INT_16: ConvertedType = ConvertedType(16); pub const INT_32: ConvertedType = ConvertedType(17); pub const INT_64: ConvertedType = ConvertedType(18); /// An embedded JSON document - /// + /// /// A JSON document embedded within a single UTF8 column. pub const JSON: ConvertedType = ConvertedType(19); /// An embedded BSON document - /// + /// /// A BSON document embedded within a single BINARY column. pub const BSON: ConvertedType = ConvertedType(20); /// An interval of time - /// + /// /// This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12 /// This data is composed of three separate little endian unsigned /// integers. Each stores a component of a duration of time. The first @@ -221,12 +222,12 @@ impl ConvertedType { ]; } -impl TSerializable for ConvertedType { +impl crate::thrift::TSerializable for ConvertedType { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(ConvertedType::from(enum_value)) } @@ -298,12 +299,12 @@ impl FieldRepetitionType { ]; } -impl TSerializable for FieldRepetitionType { +impl crate::thrift::TSerializable for FieldRepetitionType { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(FieldRepetitionType::from(enum_value)) } @@ -396,12 +397,12 @@ impl Encoding { ]; } -impl TSerializable for Encoding { +impl crate::thrift::TSerializable for Encoding { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(Encoding::from(enum_value)) } @@ -443,11 +444,11 @@ impl From<&Encoding> for i32 { } /// Supported compression algorithms. -/// +/// /// Codecs added in format version X.Y can be read by readers based on X.Y and later. /// Codec support may vary between readers based on the format version and /// libraries available at runtime. -/// +/// /// See Compression.md for a detailed specification of these algorithms. #[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct CompressionCodec(pub i32); @@ -473,12 +474,12 @@ impl CompressionCodec { ]; } -impl TSerializable for CompressionCodec { +impl crate::thrift::TSerializable for CompressionCodec { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(CompressionCodec::from(enum_value)) } @@ -534,12 +535,12 @@ impl PageType { ]; } -impl TSerializable for PageType { +impl crate::thrift::TSerializable for PageType { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(PageType::from(enum_value)) } @@ -591,12 +592,12 @@ impl BoundaryOrder { ]; } -impl TSerializable for BoundaryOrder { +impl crate::thrift::TSerializable for BoundaryOrder { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(BoundaryOrder::from(enum_value)) } @@ -637,17 +638,17 @@ impl From<&BoundaryOrder> for i32 { /// Statistics per row group and per page /// All fields are optional. -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct Statistics { /// DEPRECATED: min and max value of the column. Use min_value and max_value. - /// + /// /// Values are encoded using PLAIN encoding, except that variable-length byte /// arrays do not include a length prefix. - /// + /// /// These fields encode min and max values determined by signed comparison /// only. New files should use the correct order for a column's logical type /// and store the values in the min_value and max_value fields. - /// + /// /// To support older readers, these may be set when the column order is /// signed. pub max: Option>, @@ -657,7 +658,7 @@ pub struct Statistics { /// count of distinct values occurring pub distinct_count: Option, /// Min and max values for the column, determined by its ColumnOrder. - /// + /// /// Values are encoded using PLAIN encoding, except that variable-length byte /// arrays do not include a length prefix. pub max_value: Option>, @@ -677,8 +678,8 @@ impl Statistics { } } -impl TSerializable for Statistics { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for Statistics { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -734,7 +735,7 @@ impl TSerializable for Statistics { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("Statistics"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.max { @@ -772,25 +773,12 @@ impl TSerializable for Statistics { } } -impl Default for Statistics { - fn default() -> Self { - Statistics{ - max: Some(Vec::new()), - min: Some(Vec::new()), - null_count: Some(0), - distinct_count: Some(0), - max_value: Some(Vec::new()), - min_value: Some(Vec::new()), - } - } -} - // // StringType // /// Empty structs to use as logical type annotations -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct StringType { } @@ -800,27 +788,22 @@ impl StringType { } } -impl TSerializable for StringType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for StringType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = StringType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("StringType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -828,17 +811,11 @@ impl TSerializable for StringType { } } -impl Default for StringType { - fn default() -> Self { - StringType{} - } -} - // // UUIDType // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct UUIDType { } @@ -848,27 +825,22 @@ impl UUIDType { } } -impl TSerializable for UUIDType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for UUIDType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = UUIDType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("UUIDType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -876,17 +848,11 @@ impl TSerializable for UUIDType { } } -impl Default for UUIDType { - fn default() -> Self { - UUIDType{} - } -} - // // MapType // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct MapType { } @@ -896,27 +862,22 @@ impl MapType { } } -impl TSerializable for MapType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for MapType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = MapType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("MapType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -924,17 +885,11 @@ impl TSerializable for MapType { } } -impl Default for MapType { - fn default() -> Self { - MapType{} - } -} - // // ListType // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct ListType { } @@ -944,27 +899,22 @@ impl ListType { } } -impl TSerializable for ListType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ListType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = ListType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ListType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -972,17 +922,11 @@ impl TSerializable for ListType { } } -impl Default for ListType { - fn default() -> Self { - ListType{} - } -} - // // EnumType // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct EnumType { } @@ -992,27 +936,22 @@ impl EnumType { } } -impl TSerializable for EnumType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EnumType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = EnumType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EnumType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1020,17 +959,11 @@ impl TSerializable for EnumType { } } -impl Default for EnumType { - fn default() -> Self { - EnumType{} - } -} - // // DateType // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct DateType { } @@ -1040,27 +973,22 @@ impl DateType { } } -impl TSerializable for DateType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DateType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = DateType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DateType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1068,22 +996,16 @@ impl TSerializable for DateType { } } -impl Default for DateType { - fn default() -> Self { - DateType{} - } -} - // // NullType // /// Logical type to annotate a column that is always null. -/// +/// /// Sometimes when discovering the schema of existing data, values are always /// null and the physical type can't be determined. This annotation signals /// the case where the physical type was guessed from all null values. -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct NullType { } @@ -1093,27 +1015,22 @@ impl NullType { } } -impl TSerializable for NullType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for NullType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = NullType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("NullType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1121,21 +1038,18 @@ impl TSerializable for NullType { } } -impl Default for NullType { - fn default() -> Self { - NullType{} - } -} - // // DecimalType // /// Decimal logical type annotation -/// +/// +/// Scale must be zero or a positive integer less than or equal to the precision. +/// Precision must be a non-zero positive integer. +/// /// To maintain forward-compatibility in v1, implementations using this logical /// type must also set scale and precision on the annotated SchemaElement. -/// +/// /// Allowed for physical types: INT32, INT64, FIXED, and BINARY #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct DecimalType { @@ -1152,8 +1066,8 @@ impl DecimalType { } } -impl TSerializable for DecimalType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DecimalType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1187,7 +1101,7 @@ impl TSerializable for DecimalType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DecimalType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("scale", TType::I32, 1))?; @@ -1206,7 +1120,7 @@ impl TSerializable for DecimalType { // /// Time units for logical types -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct MilliSeconds { } @@ -1216,27 +1130,22 @@ impl MilliSeconds { } } -impl TSerializable for MilliSeconds { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for MilliSeconds { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = MilliSeconds {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("MilliSeconds"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1244,17 +1153,11 @@ impl TSerializable for MilliSeconds { } } -impl Default for MilliSeconds { - fn default() -> Self { - MilliSeconds{} - } -} - // // MicroSeconds // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct MicroSeconds { } @@ -1264,27 +1167,22 @@ impl MicroSeconds { } } -impl TSerializable for MicroSeconds { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for MicroSeconds { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = MicroSeconds {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("MicroSeconds"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1292,17 +1190,11 @@ impl TSerializable for MicroSeconds { } } -impl Default for MicroSeconds { - fn default() -> Self { - MicroSeconds{} - } -} - // // NanoSeconds // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct NanoSeconds { } @@ -1312,27 +1204,22 @@ impl NanoSeconds { } } -impl TSerializable for NanoSeconds { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for NanoSeconds { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = NanoSeconds {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("NanoSeconds"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1340,12 +1227,6 @@ impl TSerializable for NanoSeconds { } } -impl Default for NanoSeconds { - fn default() -> Self { - NanoSeconds{} - } -} - // // TimeUnit // @@ -1357,8 +1238,8 @@ pub enum TimeUnit { NANOS(NanoSeconds), } -impl TSerializable for TimeUnit { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TimeUnit { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -1420,7 +1301,7 @@ impl TSerializable for TimeUnit { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TimeUnit"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -1450,7 +1331,7 @@ impl TSerializable for TimeUnit { // /// Timestamp logical type annotation -/// +/// /// Allowed for physical types: INT64 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct TimestampType { @@ -1467,8 +1348,8 @@ impl TimestampType { } } -impl TSerializable for TimestampType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TimestampType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1502,7 +1383,7 @@ impl TSerializable for TimestampType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TimestampType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("isAdjustedToUTC", TType::Bool, 1))?; @@ -1521,7 +1402,7 @@ impl TSerializable for TimestampType { // /// Time logical type annotation -/// +/// /// Allowed for physical types: INT32 (millis), INT64 (micros, nanos) #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct TimeType { @@ -1538,8 +1419,8 @@ impl TimeType { } } -impl TSerializable for TimeType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TimeType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1573,7 +1454,7 @@ impl TSerializable for TimeType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TimeType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("isAdjustedToUTC", TType::Bool, 1))?; @@ -1592,9 +1473,9 @@ impl TSerializable for TimeType { // /// Integer logical type annotation -/// +/// /// bitWidth must be 8, 16, 32, or 64. -/// +/// /// Allowed for physical types: INT32, INT64 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct IntType { @@ -1611,8 +1492,8 @@ impl IntType { } } -impl TSerializable for IntType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for IntType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1646,7 +1527,7 @@ impl TSerializable for IntType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("IntType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("bitWidth", TType::I08, 1))?; @@ -1665,9 +1546,9 @@ impl TSerializable for IntType { // /// Embedded JSON logical type annotation -/// +/// /// Allowed for physical types: BINARY -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct JsonType { } @@ -1677,27 +1558,22 @@ impl JsonType { } } -impl TSerializable for JsonType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for JsonType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = JsonType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("JsonType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1705,20 +1581,14 @@ impl TSerializable for JsonType { } } -impl Default for JsonType { - fn default() -> Self { - JsonType{} - } -} - // // BsonType // /// Embedded BSON logical type annotation -/// +/// /// Allowed for physical types: BINARY -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct BsonType { } @@ -1728,27 +1598,22 @@ impl BsonType { } } -impl TSerializable for BsonType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BsonType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = BsonType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BsonType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1756,12 +1621,6 @@ impl TSerializable for BsonType { } } -impl Default for BsonType { - fn default() -> Self { - BsonType{} - } -} - // // LogicalType // @@ -1783,8 +1642,8 @@ pub enum LogicalType { UUID(UUIDType), } -impl TSerializable for LogicalType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for LogicalType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -1916,7 +1775,7 @@ impl TSerializable for LogicalType { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("LogicalType"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -2003,7 +1862,7 @@ impl TSerializable for LogicalType { pub struct SchemaElement { /// Data type for this field. Not set if the current element is a non-leaf node pub type_: Option, - /// If type is FIXED_LEN_BYTE_ARRAY, this is the byte length of the vales. + /// If type is FIXED_LEN_BYTE_ARRAY, this is the byte length of the values. /// Otherwise, if specified, this is the maximum bit length to store any of the values. /// (e.g. a low cardinality INT col could have this set to 3). Note that this is /// in the schema, and therefore fixed for the entire file. @@ -2020,12 +1879,12 @@ pub struct SchemaElement { pub num_children: Option, /// DEPRECATED: When the schema is the result of a conversion from another model. /// Used to record the original type to help with cross conversion. - /// + /// /// This is superseded by logicalType. pub converted_type: Option, /// DEPRECATED: Used when this column contains decimal data. /// See the DECIMAL converted type for more details. - /// + /// /// This is superseded by using the DecimalType annotation in logicalType. pub scale: Option, pub precision: Option, @@ -2033,7 +1892,7 @@ pub struct SchemaElement { /// original field id in the parquet schema pub field_id: Option, /// The logical type of this SchemaElement - /// + /// /// LogicalType replaces ConvertedType, but ConvertedType is still required /// for some logical types to ensure forward-compatibility in format v1. pub logical_type: Option, @@ -2056,8 +1915,8 @@ impl SchemaElement { } } -impl TSerializable for SchemaElement { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for SchemaElement { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2138,7 +1997,7 @@ impl TSerializable for SchemaElement { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("SchemaElement"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.type_ { @@ -2225,8 +2084,8 @@ impl DataPageHeader { } } -impl TSerializable for DataPageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DataPageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2280,7 +2139,7 @@ impl TSerializable for DataPageHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DataPageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("num_values", TType::I32, 1))?; @@ -2309,7 +2168,7 @@ impl TSerializable for DataPageHeader { // IndexPageHeader // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct IndexPageHeader { } @@ -2319,27 +2178,22 @@ impl IndexPageHeader { } } -impl TSerializable for IndexPageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for IndexPageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = IndexPageHeader {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("IndexPageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2347,16 +2201,14 @@ impl TSerializable for IndexPageHeader { } } -impl Default for IndexPageHeader { - fn default() -> Self { - IndexPageHeader{} - } -} - // // DictionaryPageHeader // +/// The dictionary page must be placed at the first position of the column chunk +/// if it is partly or completely dictionary encoded. At most one dictionary page +/// can be placed in a column chunk. +/// #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct DictionaryPageHeader { /// Number of values in the dictionary * @@ -2377,8 +2229,8 @@ impl DictionaryPageHeader { } } -impl TSerializable for DictionaryPageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DictionaryPageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2418,7 +2270,7 @@ impl TSerializable for DictionaryPageHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DictionaryPageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("num_values", TType::I32, 1))?; @@ -2444,7 +2296,7 @@ impl TSerializable for DictionaryPageHeader { /// New page format allowing reading levels without decompressing the data /// Repetition and definition levels are uncompressed /// The remaining section containing the data is compressed if is_compressed is true -/// +/// #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct DataPageHeaderV2 { /// Number of values, including NULLs, in this data page. * @@ -2485,8 +2337,8 @@ impl DataPageHeaderV2 { } } -impl TSerializable for DataPageHeaderV2 { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DataPageHeaderV2 { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2560,7 +2412,7 @@ impl TSerializable for DataPageHeaderV2 { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DataPageHeaderV2"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("num_values", TType::I32, 1))?; @@ -2601,7 +2453,7 @@ impl TSerializable for DataPageHeaderV2 { // /// Block-based algorithm type annotation. * -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct SplitBlockAlgorithm { } @@ -2611,27 +2463,22 @@ impl SplitBlockAlgorithm { } } -impl TSerializable for SplitBlockAlgorithm { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for SplitBlockAlgorithm { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = SplitBlockAlgorithm {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("SplitBlockAlgorithm"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2639,12 +2486,6 @@ impl TSerializable for SplitBlockAlgorithm { } } -impl Default for SplitBlockAlgorithm { - fn default() -> Self { - SplitBlockAlgorithm{} - } -} - // // BloomFilterAlgorithm // @@ -2654,8 +2495,8 @@ pub enum BloomFilterAlgorithm { BLOCK(SplitBlockAlgorithm), } -impl TSerializable for BloomFilterAlgorithm { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterAlgorithm { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -2703,7 +2544,7 @@ impl TSerializable for BloomFilterAlgorithm { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterAlgorithm"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -2724,8 +2565,8 @@ impl TSerializable for BloomFilterAlgorithm { /// Hash strategy type annotation. xxHash is an extremely fast non-cryptographic hash /// algorithm. It uses 64 bits version of xxHash. -/// -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +/// +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct XxHash { } @@ -2735,27 +2576,22 @@ impl XxHash { } } -impl TSerializable for XxHash { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for XxHash { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = XxHash {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("XxHash"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2763,12 +2599,6 @@ impl TSerializable for XxHash { } } -impl Default for XxHash { - fn default() -> Self { - XxHash{} - } -} - // // BloomFilterHash // @@ -2778,8 +2608,8 @@ pub enum BloomFilterHash { XXHASH(XxHash), } -impl TSerializable for BloomFilterHash { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterHash { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -2827,7 +2657,7 @@ impl TSerializable for BloomFilterHash { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterHash"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -2847,8 +2677,8 @@ impl TSerializable for BloomFilterHash { // /// The compression used in the Bloom filter. -/// -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +/// +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct Uncompressed { } @@ -2858,27 +2688,22 @@ impl Uncompressed { } } -impl TSerializable for Uncompressed { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for Uncompressed { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = Uncompressed {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("Uncompressed"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2886,12 +2711,6 @@ impl TSerializable for Uncompressed { } } -impl Default for Uncompressed { - fn default() -> Self { - Uncompressed{} - } -} - // // BloomFilterCompression // @@ -2901,8 +2720,8 @@ pub enum BloomFilterCompression { UNCOMPRESSED(Uncompressed), } -impl TSerializable for BloomFilterCompression { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterCompression { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -2950,7 +2769,7 @@ impl TSerializable for BloomFilterCompression { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterCompression"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -2971,7 +2790,7 @@ impl TSerializable for BloomFilterCompression { /// Bloom filter header is stored at beginning of Bloom filter data of each column /// and followed by its bitset. -/// +/// #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct BloomFilterHeader { /// The size of bitset in bytes * @@ -2995,8 +2814,8 @@ impl BloomFilterHeader { } } -impl TSerializable for BloomFilterHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3044,7 +2863,7 @@ impl TSerializable for BloomFilterHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("numBytes", TType::I32, 1))?; @@ -3076,32 +2895,22 @@ pub struct PageHeader { pub uncompressed_page_size: i32, /// Compressed (and potentially encrypted) page size in bytes, not including this header * pub compressed_page_size: i32, - /// The 32bit CRC for the page, to be be calculated as follows: - /// - Using the standard CRC32 algorithm - /// - On the data only, i.e. this header should not be included. 'Data' - /// hereby refers to the concatenation of the repetition levels, the - /// definition levels and the column value, in this exact order. - /// - On the encoded versions of the repetition levels, definition levels and - /// column values - /// - On the compressed versions of the repetition levels, definition levels - /// and column values where possible; - /// - For v1 data pages, the repetition levels, definition levels and column - /// values are always compressed together. If a compression scheme is - /// specified, the CRC shall be calculated on the compressed version of - /// this concatenation. If no compression scheme is specified, the CRC - /// shall be calculated on the uncompressed version of this concatenation. - /// - For v2 data pages, the repetition levels and definition levels are - /// handled separately from the data and are never compressed (only - /// encoded). If a compression scheme is specified, the CRC shall be - /// calculated on the concatenation of the uncompressed repetition levels, - /// uncompressed definition levels and the compressed column values. - /// If no compression scheme is specified, the CRC shall be calculated on - /// the uncompressed concatenation. - /// - In encrypted columns, CRC is calculated after page encryption; the - /// encryption itself is performed after page compression (if compressed) + /// The 32-bit CRC checksum for the page, to be be calculated as follows: + /// + /// - The standard CRC32 algorithm is used (with polynomial 0x04C11DB7, + /// the same as in e.g. GZip). + /// - All page types can have a CRC (v1 and v2 data pages, dictionary pages, + /// etc.). + /// - The CRC is computed on the serialization binary representation of the page + /// (as written to disk), excluding the page header. For example, for v1 + /// data pages, the CRC is computed on the concatenation of repetition levels, + /// definition levels and column values (optionally compressed, optionally + /// encrypted). + /// - The CRC computation therefore takes place after any compression + /// and encryption steps, if any. + /// /// If enabled, this allows for disabling checksumming in HDFS if only a few /// pages need to be read. - /// pub crc: Option, pub data_page_header: Option, pub index_page_header: Option, @@ -3124,8 +2933,8 @@ impl PageHeader { } } -impl TSerializable for PageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for PageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3196,7 +3005,7 @@ impl TSerializable for PageHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("PageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("type", TType::I32, 1))?; @@ -3258,8 +3067,8 @@ impl KeyValue { } } -impl TSerializable for KeyValue { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for KeyValue { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3292,7 +3101,7 @@ impl TSerializable for KeyValue { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("KeyValue"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("key", TType::String, 1))?; @@ -3334,8 +3143,8 @@ impl SortingColumn { } } -impl TSerializable for SortingColumn { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for SortingColumn { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3376,7 +3185,7 @@ impl TSerializable for SortingColumn { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("SortingColumn"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("column_idx", TType::I32, 1))?; @@ -3418,8 +3227,8 @@ impl PageEncodingStats { } } -impl TSerializable for PageEncodingStats { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for PageEncodingStats { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3460,7 +3269,7 @@ impl TSerializable for PageEncodingStats { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("PageEncodingStats"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("page_type", TType::I32, 1))?; @@ -3516,10 +3325,16 @@ pub struct ColumnMetaData { pub encoding_stats: Option>, /// Byte offset from beginning of file to Bloom filter data. * pub bloom_filter_offset: Option, + /// Size of Bloom filter data including the serialized header, in bytes. + /// Added in 2.10 so readers may not read this field from old files and + /// it can be obtained after the BloomFilterHeader has been deserialized. + /// Writers should write this field so readers can read the bloom filter + /// in a single I/O. + pub bloom_filter_length: Option, } impl ColumnMetaData { - pub fn new(type_: Type, encodings: Vec, path_in_schema: Vec, codec: CompressionCodec, num_values: i64, total_uncompressed_size: i64, total_compressed_size: i64, key_value_metadata: F8, data_page_offset: i64, index_page_offset: F10, dictionary_page_offset: F11, statistics: F12, encoding_stats: F13, bloom_filter_offset: F14) -> ColumnMetaData where F8: Into>>, F10: Into>, F11: Into>, F12: Into>, F13: Into>>, F14: Into> { + pub fn new(type_: Type, encodings: Vec, path_in_schema: Vec, codec: CompressionCodec, num_values: i64, total_uncompressed_size: i64, total_compressed_size: i64, key_value_metadata: F8, data_page_offset: i64, index_page_offset: F10, dictionary_page_offset: F11, statistics: F12, encoding_stats: F13, bloom_filter_offset: F14, bloom_filter_length: F15) -> ColumnMetaData where F8: Into>>, F10: Into>, F11: Into>, F12: Into>, F13: Into>>, F14: Into>, F15: Into> { ColumnMetaData { type_, encodings, @@ -3535,12 +3350,13 @@ impl ColumnMetaData { statistics: statistics.into(), encoding_stats: encoding_stats.into(), bloom_filter_offset: bloom_filter_offset.into(), + bloom_filter_length: bloom_filter_length.into(), } } } -impl TSerializable for ColumnMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option> = None; @@ -3556,6 +3372,7 @@ impl TSerializable for ColumnMetaData { let mut f_12: Option = None; let mut f_13: Option> = None; let mut f_14: Option = None; + let mut f_15: Option = None; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { @@ -3643,6 +3460,10 @@ impl TSerializable for ColumnMetaData { let val = i_prot.read_i64()?; f_14 = Some(val); }, + 15 => { + let val = i_prot.read_i32()?; + f_15 = Some(val); + }, _ => { i_prot.skip(field_ident.field_type)?; }, @@ -3673,10 +3494,11 @@ impl TSerializable for ColumnMetaData { statistics: f_12, encoding_stats: f_13, bloom_filter_offset: f_14, + bloom_filter_length: f_15, }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnMetaData"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("type", TType::I32, 1))?; @@ -3749,6 +3571,11 @@ impl TSerializable for ColumnMetaData { o_prot.write_i64(fld_var)?; o_prot.write_field_end()? } + if let Some(fld_var) = self.bloom_filter_length { + o_prot.write_field_begin(&TFieldIdentifier::new("bloom_filter_length", TType::I32, 15))?; + o_prot.write_i32(fld_var)?; + o_prot.write_field_end()? + } o_prot.write_field_stop()?; o_prot.write_struct_end() } @@ -3758,7 +3585,7 @@ impl TSerializable for ColumnMetaData { // EncryptionWithFooterKey // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct EncryptionWithFooterKey { } @@ -3768,27 +3595,22 @@ impl EncryptionWithFooterKey { } } -impl TSerializable for EncryptionWithFooterKey { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EncryptionWithFooterKey { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = EncryptionWithFooterKey {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EncryptionWithFooterKey"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -3796,12 +3618,6 @@ impl TSerializable for EncryptionWithFooterKey { } } -impl Default for EncryptionWithFooterKey { - fn default() -> Self { - EncryptionWithFooterKey{} - } -} - // // EncryptionWithColumnKey // @@ -3823,8 +3639,8 @@ impl EncryptionWithColumnKey { } } -impl TSerializable for EncryptionWithColumnKey { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EncryptionWithColumnKey { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -3863,7 +3679,7 @@ impl TSerializable for EncryptionWithColumnKey { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EncryptionWithColumnKey"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("path_in_schema", TType::List, 1))?; @@ -3893,8 +3709,8 @@ pub enum ColumnCryptoMetaData { ENCRYPTIONWITHCOLUMNKEY(EncryptionWithColumnKey), } -impl TSerializable for ColumnCryptoMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnCryptoMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -3949,7 +3765,7 @@ impl TSerializable for ColumnCryptoMetaData { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnCryptoMetaData"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -3977,14 +3793,14 @@ impl TSerializable for ColumnCryptoMetaData { pub struct ColumnChunk { /// File where column data is stored. If not set, assumed to be same file as /// metadata. This path is relative to the current file. - /// + /// pub file_path: Option, /// Byte offset in file_path to the ColumnMetaData * pub file_offset: i64, /// Column metadata for this chunk. This is the same content as what is at /// file_path/file_offset. Having it here has it replicated in the file /// metadata. - /// + /// pub meta_data: Option, /// File offset of ColumnChunk's OffsetIndex * pub offset_index_offset: Option, @@ -4016,8 +3832,8 @@ impl ColumnChunk { } } -impl TSerializable for ColumnChunk { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnChunk { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -4092,7 +3908,7 @@ impl TSerializable for ColumnChunk { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnChunk"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.file_path { @@ -4151,7 +3967,7 @@ impl TSerializable for ColumnChunk { pub struct RowGroup { /// Metadata for each column chunk in this row group. /// This list must have the same order as the SchemaElement list in FileMetaData. - /// + /// pub columns: Vec, /// Total byte size of all the uncompressed column data in this row group * pub total_byte_size: i64, @@ -4184,8 +4000,8 @@ impl RowGroup { } } -impl TSerializable for RowGroup { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for RowGroup { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option = None; @@ -4262,7 +4078,7 @@ impl TSerializable for RowGroup { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("RowGroup"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("columns", TType::List, 1))?; @@ -4312,7 +4128,7 @@ impl TSerializable for RowGroup { // /// Empty struct to signal the order defined by the physical or logical type -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct TypeDefinedOrder { } @@ -4322,27 +4138,22 @@ impl TypeDefinedOrder { } } -impl TSerializable for TypeDefinedOrder { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TypeDefinedOrder { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; if field_ident.field_type == TType::Stop { break; } - let field_id = field_id(&field_ident)?; - match field_id { - _ => { - i_prot.skip(field_ident.field_type)?; - }, - }; + i_prot.skip(field_ident.field_type)?; i_prot.read_field_end()?; } i_prot.read_struct_end()?; let ret = TypeDefinedOrder {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TypeDefinedOrder"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -4350,12 +4161,6 @@ impl TSerializable for TypeDefinedOrder { } } -impl Default for TypeDefinedOrder { - fn default() -> Self { - TypeDefinedOrder{} - } -} - // // ColumnOrder // @@ -4365,8 +4170,8 @@ pub enum ColumnOrder { TYPEORDER(TypeDefinedOrder), } -impl TSerializable for ColumnOrder { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnOrder { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -4414,7 +4219,7 @@ impl TSerializable for ColumnOrder { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnOrder"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -4455,8 +4260,8 @@ impl PageLocation { } } -impl TSerializable for PageLocation { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for PageLocation { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -4497,7 +4302,7 @@ impl TSerializable for PageLocation { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("PageLocation"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("offset", TType::I64, 1))?; @@ -4533,8 +4338,8 @@ impl OffsetIndex { } } -impl TSerializable for OffsetIndex { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for OffsetIndex { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; loop { @@ -4567,7 +4372,7 @@ impl TSerializable for OffsetIndex { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("OffsetIndex"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("page_locations", TType::List, 1))?; @@ -4596,13 +4401,14 @@ pub struct ColumnIndex { /// byte\[0\], so that all lists have the same length. If false, the /// corresponding entries in min_values and max_values must be valid. pub null_pages: Vec, - /// Two lists containing lower and upper bounds for the values of each page. - /// These may be the actual minimum and maximum values found on a page, but - /// can also be (more compact) values that do not exist on a page. For - /// example, instead of storing ""Blart Versenwald III", a writer may set - /// min_values\[i\]="B", max_values\[i\]="C". Such more compact values must still - /// be valid values within the column's logical type. Readers must make sure - /// that list entries are populated before using them by inspecting null_pages. + /// Two lists containing lower and upper bounds for the values of each page + /// determined by the ColumnOrder of the column. These may be the actual + /// minimum and maximum values found on a page, but can also be (more compact) + /// values that do not exist on a page. For example, instead of storing ""Blart + /// Versenwald III", a writer may set min_values\[i\]="B", max_values\[i\]="C". + /// Such more compact values must still be valid values within the column's + /// logical type. Readers must make sure that list entries are populated before + /// using them by inspecting null_pages. pub min_values: Vec>, pub max_values: Vec>, /// Stores whether both min_values and max_values are ordered and if so, in @@ -4626,8 +4432,8 @@ impl ColumnIndex { } } -impl TSerializable for ColumnIndex { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnIndex { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option>> = None; @@ -4705,7 +4511,7 @@ impl TSerializable for ColumnIndex { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnIndex"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("null_pages", TType::List, 1))?; @@ -4750,7 +4556,7 @@ impl TSerializable for ColumnIndex { // AesGcmV1 // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct AesGcmV1 { /// AAD prefix * pub aad_prefix: Option>, @@ -4771,8 +4577,8 @@ impl AesGcmV1 { } } -impl TSerializable for AesGcmV1 { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for AesGcmV1 { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -4810,7 +4616,7 @@ impl TSerializable for AesGcmV1 { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("AesGcmV1"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.aad_prefix { @@ -4833,21 +4639,11 @@ impl TSerializable for AesGcmV1 { } } -impl Default for AesGcmV1 { - fn default() -> Self { - AesGcmV1{ - aad_prefix: Some(Vec::new()), - aad_file_unique: Some(Vec::new()), - supply_aad_prefix: Some(false), - } - } -} - // // AesGcmCtrV1 // -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct AesGcmCtrV1 { /// AAD prefix * pub aad_prefix: Option>, @@ -4868,8 +4664,8 @@ impl AesGcmCtrV1 { } } -impl TSerializable for AesGcmCtrV1 { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for AesGcmCtrV1 { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -4907,7 +4703,7 @@ impl TSerializable for AesGcmCtrV1 { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("AesGcmCtrV1"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.aad_prefix { @@ -4930,16 +4726,6 @@ impl TSerializable for AesGcmCtrV1 { } } -impl Default for AesGcmCtrV1 { - fn default() -> Self { - AesGcmCtrV1{ - aad_prefix: Some(Vec::new()), - aad_file_unique: Some(Vec::new()), - supply_aad_prefix: Some(false), - } - } -} - // // EncryptionAlgorithm // @@ -4950,8 +4736,8 @@ pub enum EncryptionAlgorithm { AESGCMCTRV1(AesGcmCtrV1), } -impl TSerializable for EncryptionAlgorithm { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EncryptionAlgorithm { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -5006,7 +4792,7 @@ impl TSerializable for EncryptionAlgorithm { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EncryptionAlgorithm"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -5051,19 +4837,22 @@ pub struct FileMetaData { /// String for application that wrote this file. This should be in the format /// `` version `` (build ``). /// e.g. impala version 1.0 (build 6cf94d29b2b7115df4de2c06e2ab4326d721eb55) - /// + /// pub created_by: Option, - /// Sort order used for the min_value and max_value fields of each column in - /// this file. Sort orders are listed in the order matching the columns in the - /// schema. The indexes are not necessary the same though, because only leaf - /// nodes of the schema are represented in the list of sort orders. - /// - /// Without column_orders, the meaning of the min_value and max_value fields is - /// undefined. To ensure well-defined behaviour, if min_value and max_value are - /// written to a Parquet file, column_orders must be written as well. - /// - /// The obsolete min and max fields are always sorted by signed comparison - /// regardless of column_orders. + /// Sort order used for the min_value and max_value fields in the Statistics + /// objects and the min_values and max_values fields in the ColumnIndex + /// objects of each column in this file. Sort orders are listed in the order + /// matching the columns in the schema. The indexes are not necessary the same + /// though, because only leaf nodes of the schema are represented in the list + /// of sort orders. + /// + /// Without column_orders, the meaning of the min_value and max_value fields + /// in the Statistics object and the ColumnIndex object is undefined. To ensure + /// well-defined behaviour, if these fields are written to a Parquet file, + /// column_orders must be written as well. + /// + /// The obsolete min and max fields in the Statistics object are always sorted + /// by signed comparison regardless of column_orders. pub column_orders: Option>, /// Encryption algorithm. This field is set only in encrypted files /// with plaintext footer. Files with encrypted footer store algorithm id @@ -5090,8 +4879,8 @@ impl FileMetaData { } } -impl TSerializable for FileMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for FileMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option> = None; @@ -5193,7 +4982,7 @@ impl TSerializable for FileMetaData { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("FileMetaData"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("version", TType::I32, 1))?; @@ -5279,8 +5068,8 @@ impl FileCryptoMetaData { } } -impl TSerializable for FileCryptoMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for FileCryptoMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option> = None; @@ -5313,7 +5102,7 @@ impl TSerializable for FileCryptoMetaData { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("FileCryptoMetaData"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("encryption_algorithm", TType::Struct, 1))?; diff --git a/parquet/src/lib.rs b/parquet/src/lib.rs index 2371f8837bb0..f1612c90cc2a 100644 --- a/parquet/src/lib.rs +++ b/parquet/src/lib.rs @@ -88,3 +88,5 @@ pub mod bloom_filter; pub mod file; pub mod record; pub mod schema; + +pub mod thrift; diff --git a/parquet/src/record/reader.rs b/parquet/src/record/reader.rs index 780e9822488d..2a9b6dbb0bed 100644 --- a/parquet/src/record/reader.rs +++ b/parquet/src/record/reader.rs @@ -16,7 +16,7 @@ // under the License. //! Contains implementation of record assembly and converting Parquet types into -//! [`Row`](crate::record::Row)s. +//! [`Row`]s. use std::{collections::HashMap, fmt, sync::Arc}; @@ -274,7 +274,7 @@ impl TreeBuilder { let required_field = Type::group_type_builder(field.name()) .with_repetition(Repetition::REQUIRED) .with_converted_type(field.get_basic_info().converted_type()) - .with_fields(&mut Vec::from(field.get_fields())) + .with_fields(field.get_fields().to_vec()) .build()?; path.pop(); @@ -618,7 +618,7 @@ impl fmt::Display for Reader { // Row iterators /// The enum Either with variants That represents a reference and a box of -/// [`FileReader`](crate::file::reader::FileReader). +/// [`FileReader`]. enum Either<'a> { Left(&'a dyn FileReader), Right(Box), @@ -633,7 +633,7 @@ impl<'a> Either<'a> { } } -/// Iterator of [`Row`](crate::record::Row)s. +/// Iterator of [`Row`]s. /// It is used either for a single row group to iterate over data in that row group, or /// an entire file with auto buffering of all row groups. pub struct RowIter<'a> { @@ -646,7 +646,7 @@ pub struct RowIter<'a> { } impl<'a> RowIter<'a> { - /// Creates a new iterator of [`Row`](crate::record::Row)s. + /// Creates a new iterator of [`Row`]s. fn new( file_reader: Option>, row_iter: Option, @@ -668,7 +668,7 @@ impl<'a> RowIter<'a> { } } - /// Creates iterator of [`Row`](crate::record::Row)s for all row groups in a + /// Creates iterator of [`Row`]s for all row groups in a /// file. pub fn from_file(proj: Option, reader: &'a dyn FileReader) -> Result { let either = Either::Left(reader); @@ -680,7 +680,7 @@ impl<'a> RowIter<'a> { Ok(Self::new(Some(either), None, descr)) } - /// Creates iterator of [`Row`](crate::record::Row)s for a specific row group. + /// Creates iterator of [`Row`]s for a specific row group. pub fn from_row_group( proj: Option, reader: &'a dyn RowGroupReader, @@ -694,8 +694,7 @@ impl<'a> RowIter<'a> { Ok(Self::new(None, Some(row_iter), descr)) } - /// Creates a iterator of [`Row`](crate::record::Row)s from a - /// [`FileReader`](crate::file::reader::FileReader) using the full file schema. + /// Creates a iterator of [`Row`]s from a [`FileReader`] using the full file schema. pub fn from_file_into(reader: Box) -> Self { let either = Either::Right(reader); let descr = either @@ -707,7 +706,7 @@ impl<'a> RowIter<'a> { Self::new(Some(either), None, descr) } - /// Tries to create a iterator of [`Row`](crate::record::Row)s using projections. + /// Tries to create a iterator of [`Row`]s using projections. /// Returns a error if a file reader is not the source of this iterator. /// /// The Projected schema can be a subset of or equal to the file schema, @@ -748,6 +747,12 @@ impl<'a> RowIter<'a> { } } + /// Sets batch size for this row iter. + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.tree_builder = self.tree_builder.with_batch_size(batch_size); + self + } + /// Returns common tree builder, so the same settings are applied to both iterators /// from file reader and row group. #[inline] @@ -793,7 +798,7 @@ impl<'a> Iterator for RowIter<'a> { } } -/// Internal iterator of [`Row`](crate::record::Row)s for a reader. +/// Internal iterator of [`Row`]s for a reader. pub struct ReaderIter { root_reader: Reader, records_left: usize, @@ -829,7 +834,7 @@ mod tests { use crate::errors::Result; use crate::file::reader::{FileReader, SerializedFileReader}; - use crate::record::api::{Field, Row, RowAccessor, RowFormatter}; + use crate::record::api::{Field, Row, RowAccessor}; use crate::schema::parser::parse_message_type; use crate::util::test_common::file_util::{get_test_file, get_test_path}; use std::convert::TryFrom; @@ -1501,33 +1506,26 @@ mod tests { #[test] fn test_file_reader_iter() { let path = get_test_path("alltypes_plain.parquet"); - let vec = vec![path] - .iter() - .map(|p| SerializedFileReader::try_from(p.as_path()).unwrap()) - .flat_map(|r| RowIter::from_file_into(Box::new(r))) - .flat_map(|r| r.unwrap().get_int(0)) - .collect::>(); - - assert_eq!(vec, vec![4, 5, 6, 7, 2, 3, 0, 1]); + let reader = SerializedFileReader::try_from(path.as_path()).unwrap(); + let iter = RowIter::from_file_into(Box::new(reader)); + + let values: Vec<_> = iter.flat_map(|r| r.unwrap().get_int(0)).collect(); + assert_eq!(values, &[4, 5, 6, 7, 2, 3, 0, 1]); } #[test] fn test_file_reader_iter_projection() { let path = get_test_path("alltypes_plain.parquet"); - let values = vec![path] - .iter() - .map(|p| SerializedFileReader::try_from(p.as_path()).unwrap()) - .flat_map(|r| { - let schema = "message schema { OPTIONAL INT32 id; }"; - let proj = parse_message_type(schema).ok(); - - RowIter::from_file_into(Box::new(r)).project(proj).unwrap() - }) - .map(|r| format!("id:{}", r.unwrap().fmt(0))) - .collect::>() - .join(", "); - - assert_eq!(values, "id:4, id:5, id:6, id:7, id:2, id:3, id:0, id:1"); + let reader = SerializedFileReader::try_from(path.as_path()).unwrap(); + let schema = "message schema { OPTIONAL INT32 id; }"; + let proj = parse_message_type(schema).ok(); + + let iter = RowIter::from_file_into(Box::new(reader)) + .project(proj) + .unwrap(); + let values: Vec<_> = iter.flat_map(|r| r.unwrap().get_int(0)).collect(); + + assert_eq!(values, &[4, 5, 6, 7, 2, 3, 0, 1]); } #[test] diff --git a/parquet/src/schema/mod.rs b/parquet/src/schema/mod.rs index 1ebee2e06e83..ead7f1d2c0f8 100644 --- a/parquet/src/schema/mod.rs +++ b/parquet/src/schema/mod.rs @@ -45,7 +45,7 @@ //! .unwrap(); //! //! let schema = Type::group_type_builder("schema") -//! .with_fields(&mut vec![Arc::new(field_a), Arc::new(field_b)]) +//! .with_fields(vec![Arc::new(field_a), Arc::new(field_b)]) //! .build() //! .unwrap(); //! diff --git a/parquet/src/schema/parser.rs b/parquet/src/schema/parser.rs index c09f13603d29..d589f8c1100a 100644 --- a/parquet/src/schema/parser.rs +++ b/parquet/src/schema/parser.rs @@ -17,7 +17,7 @@ //! Parquet schema parser. //! Provides methods to parse and validate string message type into Parquet -//! [`Type`](crate::schema::types::Type). +//! [`Type`]. //! //! # Example //! @@ -50,7 +50,7 @@ use crate::basic::{ use crate::errors::{ParquetError, Result}; use crate::schema::types::{Type, TypePtr}; -/// Parses message type as string into a Parquet [`Type`](crate::schema::types::Type) +/// Parses message type as string into a Parquet [`Type`] /// which, for example, could be used to extract individual columns. Returns Parquet /// general error when parsing or validation fails. pub fn parse_message_type(message_type: &str) -> Result { @@ -205,9 +205,8 @@ impl<'a> Parser<'a> { .tokenizer .next() .ok_or_else(|| general_err!("Expected name, found None"))?; - let mut fields = self.parse_child_types()?; Type::group_type_builder(name) - .with_fields(&mut fields) + .with_fields(self.parse_child_types()?) .build() } _ => Err(general_err!("Message type does not start with 'message'")), @@ -290,17 +289,14 @@ impl<'a> Parser<'a> { None }; - let mut fields = self.parse_child_types()?; let mut builder = Type::group_type_builder(name) .with_logical_type(logical_type) .with_converted_type(converted_type) - .with_fields(&mut fields); + .with_fields(self.parse_child_types()?) + .with_id(id); if let Some(rep) = repetition { builder = builder.with_repetition(rep); } - if let Some(id) = id { - builder = builder.with_id(id); - } builder.build() } @@ -516,17 +512,15 @@ impl<'a> Parser<'a> { }; assert_token(self.tokenizer.next(), ";")?; - let mut builder = Type::primitive_type_builder(name, physical_type) + Type::primitive_type_builder(name, physical_type) .with_repetition(repetition) .with_logical_type(logical_type) .with_converted_type(converted_type) .with_length(length) .with_precision(precision) - .with_scale(scale); - if let Some(id) = id { - builder = builder.with_id(id); - } - builder.build() + .with_scale(scale) + .with_id(id) + .build() } } @@ -845,7 +839,7 @@ mod tests { let message = parse(schema).unwrap(); let expected = Type::group_type_builder("root") - .with_fields(&mut vec![ + .with_fields(vec![ Arc::new( Type::primitive_type_builder( "f1", @@ -906,16 +900,16 @@ mod tests { let message = parse(schema).unwrap(); let expected = Type::group_type_builder("root") - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( Type::group_type_builder("a0") .with_repetition(Repetition::REQUIRED) - .with_fields(&mut vec![ + .with_fields(vec![ Arc::new( Type::group_type_builder("a1") .with_repetition(Repetition::OPTIONAL) .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( Type::primitive_type_builder( "a2", PhysicalType::BYTE_ARRAY, @@ -933,10 +927,10 @@ mod tests { .with_repetition(Repetition::OPTIONAL) .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) - .with_fields(&mut vec![Arc::new( + .with_fields(vec![Arc::new( Type::group_type_builder("b2") .with_repetition(Repetition::REPEATED) - .with_fields(&mut vec![ + .with_fields(vec![ Arc::new( Type::primitive_type_builder( "b3", @@ -984,7 +978,7 @@ mod tests { "; let message = parse(schema).unwrap(); - let mut fields = vec![ + let fields = vec![ Arc::new( Type::primitive_type_builder("_1", PhysicalType::INT32) .with_repetition(Repetition::REQUIRED) @@ -1027,7 +1021,7 @@ mod tests { ]; let expected = Type::group_type_builder("root") - .with_fields(&mut fields) + .with_fields(fields) .build() .unwrap(); assert_eq!(message, expected); @@ -1051,7 +1045,7 @@ mod tests { "; let message = parse(schema).unwrap(); - let mut fields = vec![ + let fields = vec![ Arc::new( Type::primitive_type_builder("_1", PhysicalType::INT32) .with_repetition(Repetition::REQUIRED) @@ -1135,7 +1129,7 @@ mod tests { ]; let expected = Type::group_type_builder("root") - .with_fields(&mut fields) + .with_fields(fields) .build() .unwrap(); assert_eq!(message, expected); diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs index ad4acb0cb8b1..fe63e758b251 100644 --- a/parquet/src/schema/printer.rs +++ b/parquet/src/schema/printer.rs @@ -51,8 +51,7 @@ use crate::file::metadata::{ }; use crate::schema::types::Type; -/// Prints Parquet metadata [`ParquetMetaData`](crate::file::metadata::ParquetMetaData) -/// information. +/// Prints Parquet metadata [`ParquetMetaData`] information. #[allow(unused_must_use)] pub fn print_parquet_metadata(out: &mut dyn io::Write, metadata: &ParquetMetaData) { print_file_metadata(out, metadata.file_metadata()); @@ -68,8 +67,7 @@ pub fn print_parquet_metadata(out: &mut dyn io::Write, metadata: &ParquetMetaDat } } -/// Prints file metadata [`FileMetaData`](crate::file::metadata::FileMetaData) -/// information. +/// Prints file metadata [`FileMetaData`] information. #[allow(unused_must_use)] pub fn print_file_metadata(out: &mut dyn io::Write, file_metadata: &FileMetaData) { writeln!(out, "version: {}", file_metadata.version()); @@ -92,7 +90,7 @@ pub fn print_file_metadata(out: &mut dyn io::Write, file_metadata: &FileMetaData print_schema(out, schema); } -/// Prints Parquet [`Type`](crate::schema::types::Type) information. +/// Prints Parquet [`Type`] information. #[allow(unused_must_use)] pub fn print_schema(out: &mut dyn io::Write, tp: &Type) { // TODO: better if we can pass fmt::Write to Printer. @@ -169,6 +167,11 @@ fn print_column_chunk_metadata( Some(bfo) => bfo.to_string(), }; writeln!(out, "bloom filter offset: {bloom_filter_offset_str}"); + let bloom_filter_length_str = match cc_metadata.bloom_filter_length() { + None => "N/A".to_owned(), + Some(bfo) => bfo.to_string(), + }; + writeln!(out, "bloom filter length: {bloom_filter_length_str}"); let offset_index_offset_str = match cc_metadata.offset_index_offset() { None => "N/A".to_owned(), Some(oio) => oio.to_string(), @@ -695,40 +698,40 @@ mod tests { let f1 = Type::primitive_type_builder("f1", PhysicalType::INT32) .with_repetition(Repetition::REQUIRED) .with_converted_type(ConvertedType::INT_32) - .with_id(0) + .with_id(Some(0)) .build(); let f2 = Type::primitive_type_builder("f2", PhysicalType::BYTE_ARRAY) .with_converted_type(ConvertedType::UTF8) - .with_id(1) + .with_id(Some(1)) .build(); let f3 = Type::primitive_type_builder("f3", PhysicalType::BYTE_ARRAY) .with_logical_type(Some(LogicalType::String)) - .with_id(1) + .with_id(Some(1)) .build(); let f4 = Type::primitive_type_builder("f4", PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_repetition(Repetition::REPEATED) .with_converted_type(ConvertedType::INTERVAL) .with_length(12) - .with_id(2) + .with_id(Some(2)) .build(); - let mut struct_fields = vec![ + let struct_fields = vec![ Arc::new(f1.unwrap()), Arc::new(f2.unwrap()), Arc::new(f3.unwrap()), ]; let field = Type::group_type_builder("field") .with_repetition(Repetition::OPTIONAL) - .with_fields(&mut struct_fields) - .with_id(1) + .with_fields(struct_fields) + .with_id(Some(1)) .build() .unwrap(); - let mut fields = vec![Arc::new(field), Arc::new(f4.unwrap())]; + let fields = vec![Arc::new(field), Arc::new(f4.unwrap())]; let message = Type::group_type_builder("schema") - .with_fields(&mut fields) - .with_id(2) + .with_fields(fields) + .with_id(Some(2)) .build() .unwrap(); p.print(&message); @@ -756,7 +759,7 @@ mod tests { .with_repetition(Repetition::OPTIONAL) .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) - .with_fields(&mut vec![Arc::new(a2)]) + .with_fields(vec![Arc::new(a2)]) .build() .unwrap(); @@ -773,7 +776,7 @@ mod tests { let b2 = Type::group_type_builder("b2") .with_repetition(Repetition::REPEATED) .with_converted_type(ConvertedType::NONE) - .with_fields(&mut vec![Arc::new(b3), Arc::new(b4)]) + .with_fields(vec![Arc::new(b3), Arc::new(b4)]) .build() .unwrap(); @@ -781,18 +784,18 @@ mod tests { .with_repetition(Repetition::OPTIONAL) .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) - .with_fields(&mut vec![Arc::new(b2)]) + .with_fields(vec![Arc::new(b2)]) .build() .unwrap(); let a0 = Type::group_type_builder("a0") .with_repetition(Repetition::REQUIRED) - .with_fields(&mut vec![Arc::new(a1), Arc::new(b1)]) + .with_fields(vec![Arc::new(a1), Arc::new(b1)]) .build() .unwrap(); let message = Type::group_type_builder("root") - .with_fields(&mut vec![Arc::new(a0)]) + .with_fields(vec![Arc::new(a0)]) .build() .unwrap(); @@ -815,7 +818,7 @@ mod tests { let field = Type::group_type_builder("field") .with_repetition(Repetition::OPTIONAL) - .with_fields(&mut vec![Arc::new(f1), Arc::new(f2)]) + .with_fields(vec![Arc::new(f1), Arc::new(f2)]) .build() .unwrap(); @@ -827,7 +830,7 @@ mod tests { .unwrap(); let message = Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new(field), Arc::new(f3)]) + .with_fields(vec![Arc::new(field), Arc::new(f3)]) .build() .unwrap(); @@ -861,7 +864,7 @@ mod tests { .unwrap(); let message = Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new(f1), Arc::new(f2)]) + .with_fields(vec![Arc::new(f1), Arc::new(f2)]) .build() .unwrap(); diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index fd22cedeacaa..f4cb3a9956d6 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -219,53 +219,52 @@ impl<'a> PrimitiveTypeBuilder<'a> { } } - /// Sets [`Repetition`](crate::basic::Repetition) for this field and returns itself. - pub fn with_repetition(mut self, repetition: Repetition) -> Self { - self.repetition = repetition; - self + /// Sets [`Repetition`] for this field and returns itself. + pub fn with_repetition(self, repetition: Repetition) -> Self { + Self { repetition, ..self } } - /// Sets [`ConvertedType`](crate::basic::ConvertedType) for this field and returns itself. - pub fn with_converted_type(mut self, converted_type: ConvertedType) -> Self { - self.converted_type = converted_type; - self + /// Sets [`ConvertedType`] for this field and returns itself. + pub fn with_converted_type(self, converted_type: ConvertedType) -> Self { + Self { + converted_type, + ..self + } } - /// Sets [`LogicalType`](crate::basic::LogicalType) for this field and returns itself. + /// Sets [`LogicalType`] for this field and returns itself. /// If only the logical type is populated for a primitive type, the converted type /// will be automatically populated, and can thus be omitted. - pub fn with_logical_type(mut self, logical_type: Option) -> Self { - self.logical_type = logical_type; - self + pub fn with_logical_type(self, logical_type: Option) -> Self { + Self { + logical_type, + ..self + } } /// Sets type length and returns itself. /// This is only applied to FIXED_LEN_BYTE_ARRAY and INT96 (INTERVAL) types, because /// they maintain fixed size underlying byte array. /// By default, value is `0`. - pub fn with_length(mut self, length: i32) -> Self { - self.length = length; - self + pub fn with_length(self, length: i32) -> Self { + Self { length, ..self } } /// Sets precision for Parquet DECIMAL physical type and returns itself. /// By default, it equals to `0` and used only for decimal context. - pub fn with_precision(mut self, precision: i32) -> Self { - self.precision = precision; - self + pub fn with_precision(self, precision: i32) -> Self { + Self { precision, ..self } } /// Sets scale for Parquet DECIMAL physical type and returns itself. /// By default, it equals to `0` and used only for decimal context. - pub fn with_scale(mut self, scale: i32) -> Self { - self.scale = scale; - self + pub fn with_scale(self, scale: i32) -> Self { + Self { scale, ..self } } /// Sets optional field id and returns itself. - pub fn with_id(mut self, id: i32) -> Self { - self.id = Some(id); - self + pub fn with_id(self, id: Option) -> Self { + Self { id, ..self } } /// Creates a new `PrimitiveType` instance from the collected attributes. @@ -553,35 +552,37 @@ impl<'a> GroupTypeBuilder<'a> { } } - /// Sets [`Repetition`](crate::basic::Repetition) for this field and returns itself. + /// Sets [`Repetition`] for this field and returns itself. pub fn with_repetition(mut self, repetition: Repetition) -> Self { self.repetition = Some(repetition); self } - /// Sets [`ConvertedType`](crate::basic::ConvertedType) for this field and returns itself. - pub fn with_converted_type(mut self, converted_type: ConvertedType) -> Self { - self.converted_type = converted_type; - self + /// Sets [`ConvertedType`] for this field and returns itself. + pub fn with_converted_type(self, converted_type: ConvertedType) -> Self { + Self { + converted_type, + ..self + } } - /// Sets [`LogicalType`](crate::basic::LogicalType) for this field and returns itself. - pub fn with_logical_type(mut self, logical_type: Option) -> Self { - self.logical_type = logical_type; - self + /// Sets [`LogicalType`] for this field and returns itself. + pub fn with_logical_type(self, logical_type: Option) -> Self { + Self { + logical_type, + ..self + } } /// Sets a list of fields that should be child nodes of this field. /// Returns updated self. - pub fn with_fields(mut self, fields: &mut Vec) -> Self { - self.fields.append(fields); - self + pub fn with_fields(self, fields: Vec) -> Self { + Self { fields, ..self } } /// Sets optional field id and returns itself. - pub fn with_id(mut self, id: i32) -> Self { - self.id = Some(id); - self + pub fn with_id(self, id: Option) -> Self { + Self { id, ..self } } /// Creates a new `GroupType` instance from the gathered attributes. @@ -628,18 +629,18 @@ impl BasicTypeInfo { self.repetition.is_some() } - /// Returns [`Repetition`](crate::basic::Repetition) value for the type. + /// Returns [`Repetition`] value for the type. pub fn repetition(&self) -> Repetition { assert!(self.repetition.is_some()); self.repetition.unwrap() } - /// Returns [`ConvertedType`](crate::basic::ConvertedType) value for the type. + /// Returns [`ConvertedType`] value for the type. pub fn converted_type(&self) -> ConvertedType { self.converted_type } - /// Returns [`LogicalType`](crate::basic::LogicalType) value for the type. + /// Returns [`LogicalType`] value for the type. pub fn logical_type(&self) -> Option { // Unlike ConvertedType, LogicalType cannot implement Copy, thus we clone it self.logical_type.clone() @@ -786,12 +787,12 @@ impl ColumnDescriptor { &self.path } - /// Returns self type [`Type`](crate::schema::types::Type) for this leaf column. + /// Returns self type [`Type`] for this leaf column. pub fn self_type(&self) -> &Type { self.primitive_type.as_ref() } - /// Returns self type [`TypePtr`](crate::schema::types::TypePtr) for this leaf + /// Returns self type [`TypePtr`] for this leaf /// column. pub fn self_type_ptr(&self) -> TypePtr { self.primitive_type.clone() @@ -802,12 +803,12 @@ impl ColumnDescriptor { self.primitive_type.name() } - /// Returns [`ConvertedType`](crate::basic::ConvertedType) for this column. + /// Returns [`ConvertedType`] for this column. pub fn converted_type(&self) -> ConvertedType { self.primitive_type.get_basic_info().converted_type() } - /// Returns [`LogicalType`](crate::basic::LogicalType) for this column. + /// Returns [`LogicalType`] for this column. pub fn logical_type(&self) -> Option { self.primitive_type.get_basic_info().logical_type() } @@ -927,14 +928,13 @@ impl SchemaDescriptor { self.leaves.len() } - /// Returns column root [`Type`](crate::schema::types::Type) for a leaf position. + /// Returns column root [`Type`] for a leaf position. pub fn get_column_root(&self, i: usize) -> &Type { let result = self.column_root_of(i); result.as_ref() } - /// Returns column root [`Type`](crate::schema::types::Type) pointer for a leaf - /// position. + /// Returns column root [`Type`] pointer for a leaf position. pub fn get_column_root_ptr(&self, i: usize) -> TypePtr { let result = self.column_root_of(i); result.clone() @@ -959,7 +959,7 @@ impl SchemaDescriptor { &self.schema.get_fields()[self.get_column_root_idx(i)] } - /// Returns schema as [`Type`](crate::schema::types::Type). + /// Returns schema as [`Type`]. pub fn root_schema(&self) -> &Type { self.schema.as_ref() } @@ -1093,16 +1093,14 @@ fn from_thrift_helper( let scale = elements[index].scale.unwrap_or(-1); let precision = elements[index].precision.unwrap_or(-1); let name = &elements[index].name; - let mut builder = Type::primitive_type_builder(name, physical_type) + let builder = Type::primitive_type_builder(name, physical_type) .with_repetition(repetition) .with_converted_type(converted_type) .with_logical_type(logical_type) .with_length(length) .with_precision(precision) - .with_scale(scale); - if let Some(id) = field_id { - builder = builder.with_id(id); - } + .with_scale(scale) + .with_id(field_id); Ok((index + 1, Arc::new(builder.build()?))) } Some(n) => { @@ -1122,7 +1120,8 @@ fn from_thrift_helper( let mut builder = Type::group_type_builder(&elements[index].name) .with_converted_type(converted_type) .with_logical_type(logical_type) - .with_fields(&mut fields); + .with_fields(fields) + .with_id(field_id); if let Some(rep) = repetition { // Sometimes parquet-cpp and parquet-mr set repetition level REQUIRED or // REPEATED for root node. @@ -1135,9 +1134,6 @@ fn from_thrift_helper( builder = builder.with_repetition(rep); } } - if let Some(id) = field_id { - builder = builder.with_id(id); - } Ok((next_index, Arc::new(builder.build().unwrap()))) } } @@ -1243,7 +1239,7 @@ mod tests { bit_width: 32, is_signed: true, })) - .with_id(0) + .with_id(Some(0)) .build(); assert!(result.is_ok()); @@ -1525,22 +1521,22 @@ mod tests { fn test_group_type() { let f1 = Type::primitive_type_builder("f1", PhysicalType::INT32) .with_converted_type(ConvertedType::INT_32) - .with_id(0) + .with_id(Some(0)) .build(); assert!(f1.is_ok()); let f2 = Type::primitive_type_builder("f2", PhysicalType::BYTE_ARRAY) .with_converted_type(ConvertedType::UTF8) - .with_id(1) + .with_id(Some(1)) .build(); assert!(f2.is_ok()); - let mut fields = vec![Arc::new(f1.unwrap()), Arc::new(f2.unwrap())]; + let fields = vec![Arc::new(f1.unwrap()), Arc::new(f2.unwrap())]; let result = Type::group_type_builder("foo") .with_repetition(Repetition::REPEATED) .with_logical_type(Some(LogicalType::List)) - .with_fields(&mut fields) - .with_id(1) + .with_fields(fields) + .with_id(Some(1)) .build(); assert!(result.is_ok()); @@ -1630,17 +1626,17 @@ mod tests { let list = Type::group_type_builder("records") .with_repetition(Repetition::REPEATED) .with_converted_type(ConvertedType::LIST) - .with_fields(&mut vec![Arc::new(item1), Arc::new(item2), Arc::new(item3)]) + .with_fields(vec![Arc::new(item1), Arc::new(item2), Arc::new(item3)]) .build()?; let bag = Type::group_type_builder("bag") .with_repetition(Repetition::OPTIONAL) - .with_fields(&mut vec![Arc::new(list)]) + .with_fields(vec![Arc::new(list)]) .build()?; fields.push(Arc::new(bag)); let schema = Type::group_type_builder("schema") .with_repetition(Repetition::REPEATED) - .with_fields(&mut fields) + .with_fields(fields) .build()?; let descr = SchemaDescriptor::new(Arc::new(schema)); @@ -1656,8 +1652,8 @@ mod tests { // required int64 item1 2 1 // optional boolean item2 3 1 // repeated int32 item3 3 2 - let ex_max_def_levels = vec![0, 1, 1, 2, 3, 3]; - let ex_max_rep_levels = vec![0, 0, 1, 1, 1, 2]; + let ex_max_def_levels = [0, 1, 1, 2, 3, 3]; + let ex_max_rep_levels = [0, 0, 1, 1, 1, 2]; for i in 0..nleaves { let col = descr.column(i); @@ -1789,13 +1785,9 @@ mod tests { // function to create a new group type for testing fn test_new_group_type(name: &str, repetition: Repetition, types: Vec) -> Type { - let mut fields = Vec::new(); - for tpe in types { - fields.push(Arc::new(tpe)) - } Type::group_type_builder(name) .with_repetition(repetition) - .with_fields(&mut fields) + .with_fields(types.into_iter().map(Arc::new).collect()) .build() .unwrap() } diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs new file mode 100644 index 000000000000..57f52edc6ef0 --- /dev/null +++ b/parquet/src/thrift.rs @@ -0,0 +1,284 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Custom thrift definitions + +use thrift::protocol::{ + TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType, +}; + +/// Reads and writes the struct to Thrift protocols. +/// +/// Unlike [`thrift::protocol::TSerializable`] this uses generics instead of trait objects +pub trait TSerializable: Sized { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result; + fn write_to_out_protocol( + &self, + o_prot: &mut T, + ) -> thrift::Result<()>; +} + +/// A more performant implementation of [`TCompactInputProtocol`] that reads a slice +/// +/// [`TCompactInputProtocol`]: thrift::protocol::TCompactInputProtocol +pub(crate) struct TCompactSliceInputProtocol<'a> { + buf: &'a [u8], + // Identifier of the last field deserialized for a struct. + last_read_field_id: i16, + // Stack of the last read field ids (a new entry is added each time a nested struct is read). + read_field_id_stack: Vec, + // Boolean value for a field. + // Saved because boolean fields and their value are encoded in a single byte, + // and reading the field only occurs after the field id is read. + pending_read_bool_value: Option, +} + +impl<'a> TCompactSliceInputProtocol<'a> { + pub fn new(buf: &'a [u8]) -> Self { + Self { + buf, + last_read_field_id: 0, + read_field_id_stack: Vec::with_capacity(16), + pending_read_bool_value: None, + } + } + + pub fn as_slice(&self) -> &'a [u8] { + self.buf + } + + fn read_vlq(&mut self) -> thrift::Result { + let mut in_progress = 0; + let mut shift = 0; + loop { + let byte = self.read_byte()?; + in_progress |= ((byte & 0x7F) as u64) << shift; + shift += 7; + if byte & 0x80 == 0 { + return Ok(in_progress); + } + } + } + + fn read_zig_zag(&mut self) -> thrift::Result { + let val = self.read_vlq()?; + Ok((val >> 1) as i64 ^ -((val & 1) as i64)) + } + + fn read_list_set_begin(&mut self) -> thrift::Result<(TType, i32)> { + let header = self.read_byte()?; + let element_type = collection_u8_to_type(header & 0x0F)?; + + let possible_element_count = (header & 0xF0) >> 4; + let element_count = if possible_element_count != 15 { + // high bits set high if count and type encoded separately + possible_element_count as i32 + } else { + self.read_vlq()? as _ + }; + + Ok((element_type, element_count)) + } +} + +impl<'a> TInputProtocol for TCompactSliceInputProtocol<'a> { + fn read_message_begin(&mut self) -> thrift::Result { + unimplemented!() + } + + fn read_message_end(&mut self) -> thrift::Result<()> { + unimplemented!() + } + + fn read_struct_begin(&mut self) -> thrift::Result> { + self.read_field_id_stack.push(self.last_read_field_id); + self.last_read_field_id = 0; + Ok(None) + } + + fn read_struct_end(&mut self) -> thrift::Result<()> { + self.last_read_field_id = self + .read_field_id_stack + .pop() + .expect("should have previous field ids"); + Ok(()) + } + + fn read_field_begin(&mut self) -> thrift::Result { + // we can read at least one byte, which is: + // - the type + // - the field delta and the type + let field_type = self.read_byte()?; + let field_delta = (field_type & 0xF0) >> 4; + let field_type = match field_type & 0x0F { + 0x01 => { + self.pending_read_bool_value = Some(true); + Ok(TType::Bool) + } + 0x02 => { + self.pending_read_bool_value = Some(false); + Ok(TType::Bool) + } + ttu8 => u8_to_type(ttu8), + }?; + + match field_type { + TType::Stop => Ok( + TFieldIdentifier::new::, String, Option>( + None, + TType::Stop, + None, + ), + ), + _ => { + if field_delta != 0 { + self.last_read_field_id += field_delta as i16; + } else { + self.last_read_field_id = self.read_i16()?; + }; + + Ok(TFieldIdentifier { + name: None, + field_type, + id: Some(self.last_read_field_id), + }) + } + } + } + + fn read_field_end(&mut self) -> thrift::Result<()> { + Ok(()) + } + + fn read_bool(&mut self) -> thrift::Result { + match self.pending_read_bool_value.take() { + Some(b) => Ok(b), + None => { + let b = self.read_byte()?; + match b { + 0x01 => Ok(true), + 0x02 => Ok(false), + unkn => Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into bool", unkn), + })), + } + } + } + } + + fn read_bytes(&mut self) -> thrift::Result> { + let len = self.read_vlq()? as usize; + let ret = self.buf.get(..len).ok_or_else(eof_error)?.to_vec(); + self.buf = &self.buf[len..]; + Ok(ret) + } + + fn read_i8(&mut self) -> thrift::Result { + Ok(self.read_byte()? as _) + } + + fn read_i16(&mut self) -> thrift::Result { + Ok(self.read_zig_zag()? as _) + } + + fn read_i32(&mut self) -> thrift::Result { + Ok(self.read_zig_zag()? as _) + } + + fn read_i64(&mut self) -> thrift::Result { + self.read_zig_zag() + } + + fn read_double(&mut self) -> thrift::Result { + let slice = (self.buf[..8]).try_into().unwrap(); + self.buf = &self.buf[8..]; + Ok(f64::from_le_bytes(slice)) + } + + fn read_string(&mut self) -> thrift::Result { + let bytes = self.read_bytes()?; + String::from_utf8(bytes).map_err(From::from) + } + + fn read_list_begin(&mut self) -> thrift::Result { + let (element_type, element_count) = self.read_list_set_begin()?; + Ok(TListIdentifier::new(element_type, element_count)) + } + + fn read_list_end(&mut self) -> thrift::Result<()> { + Ok(()) + } + + fn read_set_begin(&mut self) -> thrift::Result { + unimplemented!() + } + + fn read_set_end(&mut self) -> thrift::Result<()> { + unimplemented!() + } + + fn read_map_begin(&mut self) -> thrift::Result { + unimplemented!() + } + + fn read_map_end(&mut self) -> thrift::Result<()> { + Ok(()) + } + + #[inline] + fn read_byte(&mut self) -> thrift::Result { + let ret = *self.buf.first().ok_or_else(eof_error)?; + self.buf = &self.buf[1..]; + Ok(ret) + } +} + +fn collection_u8_to_type(b: u8) -> thrift::Result { + match b { + 0x01 => Ok(TType::Bool), + o => u8_to_type(o), + } +} + +fn u8_to_type(b: u8) -> thrift::Result { + match b { + 0x00 => Ok(TType::Stop), + 0x03 => Ok(TType::I08), // equivalent to TType::Byte + 0x04 => Ok(TType::I16), + 0x05 => Ok(TType::I32), + 0x06 => Ok(TType::I64), + 0x07 => Ok(TType::Double), + 0x08 => Ok(TType::String), + 0x09 => Ok(TType::List), + 0x0A => Ok(TType::Set), + 0x0B => Ok(TType::Map), + 0x0C => Ok(TType::Struct), + unkn => Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into TType", unkn), + })), + } +} + +fn eof_error() -> thrift::Error { + thrift::Error::Transport(thrift::TransportError { + kind: thrift::TransportErrorKind::EndOfFile, + message: "Unexpected EOF".to_string(), + }) +} diff --git a/parquet_derive/src/lib.rs b/parquet_derive/src/lib.rs index 0f875401f0e9..c6641cd8091d 100644 --- a/parquet_derive/src/lib.rs +++ b/parquet_derive/src/lib.rs @@ -130,7 +130,7 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke #field_types );*; let group = ParquetType::group_type_builder("rust_schema") - .with_fields(&mut fields) + .with_fields(fields) .build()?; Ok(group.into()) } diff --git a/parquet_derive_test/Cargo.toml b/parquet_derive_test/Cargo.toml index be24db85a109..a5d2e76d4503 100644 --- a/parquet_derive_test/Cargo.toml +++ b/parquet_derive_test/Cargo.toml @@ -31,4 +31,4 @@ rust-version = { workspace = true } [dependencies] parquet = { workspace = true } parquet_derive = { path = "../parquet_derive", default-features = false } -chrono = { version="0.4.23", default-features = false, features = [ "clock" ] } +chrono = { workspace = true }