diff --git a/.asf.yaml b/.asf.yaml index 968c6779215a..9541db89daf8 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -38,4 +38,10 @@ github: # require branches to be up-to-date before merging strict: true # don't require any jobs to pass - contexts: [] \ No newline at end of file + contexts: [] + +# publishes the content of the `asf-site` branch to +# https://arrow.apache.org/rust/ +publish: + whoami: asf-site + subdir: rust diff --git a/.gitattributes b/.gitattributes index fac7bf85a77f..b7b0d51ff478 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,6 +1,3 @@ -r/R/RcppExports.R linguist-generated=true -r/R/arrowExports.R linguist-generated=true -r/src/RcppExports.cpp linguist-generated=true -r/src/arrowExports.cpp linguist-generated=true -r/man/*.Rd linguist-generated=true - +parquet/src/format.rs linguist-generated +arrow-flight/src/arrow.flight.protocol.rs linguist-generated +arrow-flight/src/sql/arrow.flight.protocol.sql.rs linguist-generated diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 0ef6532da477..aa1d1d9c14da 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -20,8 +20,12 @@ description: 'Prepare Rust Build Environment' inputs: rust-version: description: 'version of rust to install (e.g. stable)' - required: true + required: false default: 'stable' + target: + description: 'target architecture(s)' + required: false + default: 'x86_64-unknown-linux-gnu' runs: using: "composite" steps: @@ -51,6 +55,17 @@ runs: shell: bash run: | echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} + rustup toolchain install ${{ inputs.rust-version }} --target ${{ inputs.target }} rustup default ${{ inputs.rust-version }} - echo "CARGO_TARGET_DIR=/github/home/target" >> $GITHUB_ENV + - name: Disable debuginfo generation + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + shell: bash + run: echo "RUSTFLAGS=-C debuginfo=1" >> $GITHUB_ENV + - name: Enable backtraces + shell: bash + run: echo "RUST_BACKTRACE=1" >> $GITHUB_ENV + - name: Fixup git permissions + # https://github.com/actions/checkout/issues/766 + shell: bash + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9c4cda5d034d..ffde5378da93 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -6,10 +6,17 @@ updates: interval: daily open-pull-requests-limit: 10 target-branch: master - labels: [auto-dependencies] + labels: [ auto-dependencies, arrow ] + - package-ecosystem: cargo + directory: "/object_store" + schedule: + interval: daily + open-pull-requests-limit: 10 + target-branch: master + labels: [ auto-dependencies, object_store ] - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" open-pull-requests-limit: 10 - labels: [auto-dependencies] + labels: [ auto-dependencies ] diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index d34ee3b49b5c..d3b2526740fa 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -18,6 +18,10 @@ # tests for arrow crate name: arrow +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + on: # always trigger push: @@ -25,8 +29,23 @@ on: - master pull_request: paths: - - arrow/** - .github/** + - arrow-arith/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-csv/** + - arrow-data/** + - arrow-integration-test/** + - arrow-ipc/** + - arrow-json/** + - arrow-avro/** + - arrow-ord/** + - arrow-row/** + - arrow-schema/** + - arrow-select/** + - arrow-string/** + - arrow/** jobs: @@ -36,24 +55,46 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Test - run: | - cargo test -p arrow - - name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict - run: | - cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict + - name: Test arrow-buffer with all features + run: cargo test -p arrow-buffer --all-features + - name: Test arrow-data with all features + run: cargo test -p arrow-data --all-features + - name: Test arrow-schema with all features + run: cargo test -p arrow-schema --all-features + - name: Test arrow-array with all features + run: cargo test -p arrow-array --all-features + - name: Test arrow-select with all features + run: cargo test -p arrow-select --all-features + - name: Test arrow-cast with all features + run: cargo test -p arrow-cast --all-features + - name: Test arrow-ipc with all features + run: cargo test -p arrow-ipc --all-features + - name: Test arrow-csv with all features + run: cargo test -p arrow-csv --all-features + - name: Test arrow-json with all features + run: cargo test -p arrow-json --all-features + - name: Test arrow-avro with all features + run: cargo test -p arrow-avro --all-features + - name: Test arrow-string with all features + run: cargo test -p arrow-string --all-features + - name: Test arrow-ord with all features + run: cargo test -p arrow-ord --all-features + - name: Test arrow-arith with all features + run: cargo test -p arrow-arith --all-features + - name: Test arrow-row with all features + run: cargo test -p arrow-row --all-features + - name: Test arrow-integration-test with all features + run: cargo test -p arrow-integration-test --all-features + - name: Test arrow with default features + run: cargo test -p arrow + - name: Test arrow with all features except pyarrow + run: cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,chrono-tz - name: Run examples run: | # Test arrow examples @@ -64,99 +105,52 @@ jobs: - name: Run non-archery based integration-tests run: cargo test -p arrow-integration-testing - # test compilaton features + # test compilation features linux-features: name: Check Compilation runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Check compilation - run: | - cargo check -p arrow + run: cargo check -p arrow - name: Check compilation --no-default-features - run: | - cargo check -p arrow --no-default-features + run: cargo check -p arrow --no-default-features - name: Check compilation --all-targets - run: | - cargo check -p arrow --all-targets + run: cargo check -p arrow --all-targets - name: Check compilation --no-default-features --all-targets - run: | - cargo check -p arrow --no-default-features --all-targets + run: cargo check -p arrow --no-default-features --all-targets - name: Check compilation --no-default-features --all-targets --features test_utils - run: | - cargo check -p arrow --no-default-features --all-targets --features test_utils - - # test the --features "simd" of the arrow crate. This requires nightly Rust. - linux-test-simd: - name: Test SIMD on AMD64 Rust ${{ matrix.rust }} - runs-on: ubuntu-latest - container: - image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - steps: - - uses: actions/checkout@v3 - with: - submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: nightly - - name: Run tests --features "simd" - run: | - cargo test -p arrow --features "simd" - - name: Check compilation --features "simd" - run: | - cargo check -p arrow --features simd - - name: Check compilation --features simd --all-targets - run: | - cargo check -p arrow --features simd --all-targets + run: cargo check -p arrow --no-default-features --all-targets --features test_utils + - name: Check compilation --no-default-features --all-targets --features ffi + run: cargo check -p arrow --no-default-features --all-targets --features ffi + - name: Check compilation --no-default-features --all-targets --features chrono-tz + run: cargo check -p arrow --no-default-features --all-targets --features chrono-tz - # test the arrow crate builds against wasm32 in stable rust + # test the arrow crate builds against wasm32 in nightly rust wasm32-build: name: Build wasm32 runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder with: - path: /github/home/.cargo - key: cargo-wasm32-cache3- - - name: Setup Rust toolchain for WASM - run: | - rustup toolchain install nightly - rustup override set nightly - rustup target add wasm32-unknown-unknown - rustup target add wasm32-wasi - - name: Build - run: | - cd arrow - cargo build --no-default-features --features=json,csv,ipc,simd,ffi --target wasm32-unknown-unknown - cargo build --no-default-features --features=json,csv,ipc,simd,ffi --target wasm32-wasi + target: wasm32-unknown-unknown,wasm32-wasi + - name: Build wasm32-unknown-unknown + run: cargo build -p arrow --no-default-features --features=json,csv,ipc,ffi --target wasm32-unknown-unknown + - name: Build wasm32-wasi + run: cargo build -p arrow --no-default-features --features=json,csv,ipc,ffi --target wasm32-wasi clippy: name: Clippy @@ -164,14 +158,42 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy - - name: Run clippy - run: | - cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings + run: rustup component add clippy + - name: Clippy arrow-buffer with all features + run: cargo clippy -p arrow-buffer --all-targets --all-features -- -D warnings + - name: Clippy arrow-data with all features + run: cargo clippy -p arrow-data --all-targets --all-features -- -D warnings + - name: Clippy arrow-schema with all features + run: cargo clippy -p arrow-schema --all-targets --all-features -- -D warnings + - name: Clippy arrow-array with all features + run: cargo clippy -p arrow-array --all-targets --all-features -- -D warnings + - name: Clippy arrow-select with all features + run: cargo clippy -p arrow-select --all-targets --all-features -- -D warnings + - name: Clippy arrow-cast with all features + run: cargo clippy -p arrow-cast --all-targets --all-features -- -D warnings + - name: Clippy arrow-ipc with all features + run: cargo clippy -p arrow-ipc --all-targets --all-features -- -D warnings + - name: Clippy arrow-csv with all features + run: cargo clippy -p arrow-csv --all-targets --all-features -- -D warnings + - name: Clippy arrow-json with all features + run: cargo clippy -p arrow-json --all-targets --all-features -- -D warnings + - name: Clippy arrow-avro with all features + run: cargo clippy -p arrow-avro --all-targets --all-features -- -D warnings + - name: Clippy arrow-string with all features + run: cargo clippy -p arrow-string --all-targets --all-features -- -D warnings + - name: Clippy arrow-ord with all features + run: cargo clippy -p arrow-ord --all-targets --all-features -- -D warnings + - name: Clippy arrow-arith with all features + run: cargo clippy -p arrow-arith --all-targets --all-features -- -D warnings + - name: Clippy arrow-row with all features + run: cargo clippy -p arrow-row --all-targets --all-features -- -D warnings + - name: Clippy arrow with all features + run: cargo clippy -p arrow --all-features --all-targets -- -D warnings + - name: Clippy arrow-integration-test with all features + run: cargo clippy -p arrow-integration-test --all-targets --all-features -- -D warnings + - name: Clippy arrow-integration-testing with all features + run: cargo clippy -p arrow-integration-testing --all-targets --all-features -- -D warnings diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml index 86a67ff9a6a4..242e0f2a3b0d 100644 --- a/.github/workflows/arrow_flight.yml +++ b/.github/workflows/arrow_flight.yml @@ -19,6 +19,9 @@ # tests for arrow_flight crate name: arrow_flight +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true # trigger for all PRs that touch certain files and changes to master on: @@ -27,35 +30,51 @@ on: - master pull_request: paths: - - arrow/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-data/** - arrow-flight/** + - arrow-ipc/** + - arrow-schema/** + - arrow-select/** - .github/** jobs: - # test the crate linux-test: name: Test runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Test run: | cargo test -p arrow-flight - name: Test --all-features run: | cargo test -p arrow-flight --all-features + - name: Test --examples + run: | + cargo test -p arrow-flight --features=flight-sql-experimental,tls --examples + + vendor: + name: Verify Vendored Code + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Run gen + run: ./arrow-flight/regen.sh + - name: Verify workspace clean (if this fails, run ./arrow-flight/regen.sh and check in results) + run: git diff --exit-code clippy: name: Clippy @@ -63,14 +82,10 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy + run: rustup component add clippy - name: Run clippy - run: | - cargo clippy -p arrow-flight --all-features -- -D warnings + run: cargo clippy -p arrow-flight --all-targets --all-features -- -D warnings diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml new file mode 100644 index 000000000000..2c1dcdfd2100 --- /dev/null +++ b/.github/workflows/audit.yml @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: audit + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +# trigger for all PRs that touch certain files and changes to master +on: + push: + branches: + - master + pull_request: + paths: + - '**/Cargo.toml' + - '**/Cargo.lock' + +jobs: + cargo-audit: + name: Audit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install cargo-audit + run: cargo install cargo-audit + - name: Run audit check + run: cargo audit diff --git a/.github/workflows/cancel.yml b/.github/workflows/cancel.yml deleted file mode 100644 index a98c8ee5d225..000000000000 --- a/.github/workflows/cancel.yml +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Attempt to cancel stale workflow runs to save github actions runner time -name: cancel - -on: - workflow_run: - # The name of another workflow (whichever one) that always runs on PRs - workflows: ['Dev'] - types: ['requested'] - -jobs: - cancel-stale-workflow-runs: - name: "Cancel stale workflow runs" - runs-on: ubuntu-latest - steps: - # Unfortunately, we need to define a separate cancellation step for - # each workflow where we want to cancel stale runs. - - uses: potiuk/cancel-workflow-runs@master - name: "Cancel stale Dev runs" - with: - cancelMode: allDuplicates - token: ${{ secrets.GITHUB_TOKEN }} - workflowFileName: dev.yml - skipEventTypes: '["push", "schedule"]' - - uses: potiuk/cancel-workflow-runs@master - name: "Cancel stale Integration runs" - with: - cancelMode: allDuplicates - token: ${{ secrets.GITHUB_TOKEN }} - workflowFileName: integration.yml - skipEventTypes: '["push", "schedule"]' - - uses: potiuk/cancel-workflow-runs@master - name: "Cancel stale Rust runs" - with: - cancelMode: allDuplicates - token: ${{ secrets.GITHUB_TOKEN }} - workflowFileName: rust.yml - skipEventTypes: '["push", "schedule"]' diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index e688428e187c..64b2ca437067 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -17,6 +17,10 @@ name: coverage +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # Trigger only on pushes to master, not pull requests on: push: @@ -32,7 +36,7 @@ jobs: # otherwise we get this error: # Failed to run tests: ASLR disable failed: EPERM: Operation not permitted steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 57dc19482761..2026e257ab29 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -17,6 +17,10 @@ name: dev +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs and changes to master on: push: @@ -34,9 +38,9 @@ jobs: name: Release Audit Tool (RAT) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.8 - name: Audit licenses @@ -46,12 +50,12 @@ jobs: name: Markdown format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-node@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 with: node-version: "14" - name: Prettier check run: | # if you encounter error, run the command below and commit the changes - npx prettier@2.3.2 --write {arrow,arrow-flight,dev,integration-testing,parquet}/**/*.md README.md CODE_OF_CONDUCT.md CONTRIBUTING.md + npx prettier@2.3.2 --write {arrow,arrow-flight,dev,arrow-integration-testing,parquet}/**/*.md README.md CODE_OF_CONDUCT.md CONTRIBUTING.md git diff --exit-code diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 38bb39390097..0d60ae006796 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -17,6 +17,10 @@ name: dev_pr +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # Trigger whenever a PR is changed (title as well as new / changed commits) on: pull_request_target: @@ -29,15 +33,18 @@ jobs: process: name: Process runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v4.0.1 + uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index aadf9c377c64..cae015018eac 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -16,16 +16,40 @@ # under the License. arrow: - - arrow/**/* + - changed-files: + - any-glob-to-any-file: + - 'arrow-arith/**/*' + - 'arrow-array/**/*' + - 'arrow-buffer/**/*' + - 'arrow-cast/**/*' + - 'arrow-csv/**/*' + - 'arrow-data/**/*' + - 'arrow-flight/**/*' + - 'arrow-integration-test/**/*' + - 'arrow-integration-testing/**/*' + - 'arrow-ipc/**/*' + - 'arrow-json/**/*' + - 'arrow-avro/**/*' + - 'arrow-ord/**/*' + - 'arrow-row/**/*' + - 'arrow-schema/**/*' + - 'arrow-select/**/*' + - 'arrow-string/**/*' + - 'arrow/**/*' arrow-flight: - - arrow-flight/**/* + - changed-files: + - any-glob-to-any-file: + - 'arrow-flight/**/*' parquet: - - parquet/**/* + - changed-files: + - any-glob-to-any-file: [ 'parquet/**/*' ] parquet-derive: - - parquet_derive/**/* + - changed-files: + - any-glob-to-any-file: [ 'parquet_derive/**/*' ] object-store: - - object_store/**/* + - changed-files: + - any-glob-to-any-file: [ 'object_store/**/*' ] diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5e82d76febe6..721260892402 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,6 +17,10 @@ name: docs +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs and changes to master on: push: @@ -37,19 +41,57 @@ jobs: container: image: ${{ matrix.arch }}/rust env: - RUSTDOCFLAGS: "-Dwarnings" + RUSTDOCFLAGS: "-Dwarnings --enable-index-page -Zunstable-options" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install python dev run: | apt update - apt install -y libpython3.9-dev + apt install -y libpython3.11-dev - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: ${{ matrix.rust }} - name: Run cargo doc + run: cargo doc --document-private-items --no-deps --workspace --all-features + - name: Fix file permissions + shell: sh + run: | + chmod -c -R +rX "target/doc" | + while read line; do + echo "::warning title=Invalid file permissions automatically fixed::$line" + done + - name: Upload artifacts + uses: actions/upload-pages-artifact@v2 + with: + name: crate-docs + path: target/doc + + deploy: + # Only deploy if a push to master + if: github.ref_name == 'master' && github.event_name == 'push' + needs: docs + permissions: + contents: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Download crate docs + uses: actions/download-artifact@v3 + with: + name: crate-docs + path: website/build + - name: Prepare website run: | - cargo doc --document-private-items --no-deps --workspace --all-features + tar -xf website/build/artifact.tar -C website/build + rm website/build/artifact.tar + cp .asf.yaml ./website/build/.asf.yaml + - name: Deploy to gh-pages + uses: peaceiris/actions-gh-pages@v3.9.3 + if: github.event_name == 'push' && github.ref_name == 'master' + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: website/build + publish_branch: asf-site diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 10a8e30212a9..1604a7be4372 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -17,6 +17,10 @@ name: integration +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs that touch certain files and changes to master on: push: @@ -24,10 +28,24 @@ on: - master pull_request: paths: - - arrow/** - - arrow-pyarrow-integration-testing/** - - integration-testing/** - .github/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-csv/** + - arrow-data/** + - arrow-integration-test/** + - arrow-integration-testing/** + - arrow-ipc/** + - arrow-json/** + - arrow-avro/** + - arrow-ord/** + - arrow-pyarrow-integration-testing/** + - arrow-schema/** + - arrow-select/** + - arrow-sort/** + - arrow-string/** + - arrow/** jobs: @@ -39,7 +57,15 @@ jobs: env: ARROW_USE_CCACHE: OFF ARROW_CPP_EXE_PATH: /build/cpp/debug + ARROW_RUST_EXE_PATH: /build/rust/debug BUILD_DOCS_CPP: OFF + ARROW_INTEGRATION_CPP: ON + ARROW_INTEGRATION_CSHARP: ON + ARROW_INTEGRATION_GO: ON + ARROW_INTEGRATION_JAVA: ON + ARROW_INTEGRATION_JS: ON + # https://github.com/apache/arrow/pull/38403/files#r1371281630 + ARCHERY_INTEGRATION_WITH_RUST: ON # These are necessary because the github runner overrides $HOME # https://github.com/actions/runner/issues/863 RUSTUP_HOME: /root/.rustup @@ -59,49 +85,20 @@ jobs: - name: Check cmake run: which cmake - name: Checkout Arrow - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: apache/arrow submodules: true fetch-depth: 0 - name: Checkout Arrow Rust - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: path: rust fetch-depth: 0 - - name: Make build directory - run: mkdir /build - - name: Build Rust - run: conda run --no-capture-output ci/scripts/rust_build.sh $PWD /build - - name: Build C++ - run: conda run --no-capture-output ci/scripts/cpp_build.sh $PWD /build - - name: Build C# - run: conda run --no-capture-output ci/scripts/csharp_build.sh $PWD /build - - name: Build Go - run: conda run --no-capture-output ci/scripts/go_build.sh $PWD - - name: Build Java - run: conda run --no-capture-output ci/scripts/java_build.sh $PWD /build - # Temporarily disable JS https://issues.apache.org/jira/browse/ARROW-17410 - # - name: Build JS - # run: conda run --no-capture-output ci/scripts/js_build.sh $PWD /build - - name: Install archery - run: conda run --no-capture-output pip install -e dev/archery - - name: Run integration tests - run: | - conda run --no-capture-output archery integration \ - --run-flight \ - --with-cpp=1 \ - --with-csharp=1 \ - --with-java=1 \ - --with-js=0 \ - --with-go=1 \ - --with-rust=1 \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/0.14.1 \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/0.17.1 \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/1.0.0-bigendian \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/1.0.0-littleendian \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/2.0.0-compression \ - --gold-dirs=testing/data/arrow-ipc-stream/integration/4.0.0-shareddict + - name: Build + run: conda run --no-capture-output ci/scripts/integration_arrow_build.sh $PWD /build + - name: Run + run: conda run --no-capture-output ci/scripts/integration_arrow.sh $PWD /build # test FFI against the C-Data interface exposed by pyarrow pyarrow-integration-test: @@ -110,8 +107,10 @@ jobs: strategy: matrix: rust: [ stable ] + # PyArrow 13 was the last version prior to introduction to Arrow PyCapsules + pyarrow: [ "13", "14" ] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -130,23 +129,23 @@ jobs: path: /home/runner/target # this key is not equal because maturin uses different compilation flags. key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.8' - name: Upgrade pip and setuptools run: pip install --upgrade pip setuptools wheel virtualenv - name: Create virtualenv and install dependencies run: | virtualenv venv source venv/bin/activate - pip install maturin toml pytest pytz pyarrow>=5.0 + pip install maturin toml pytest pytz pyarrow==${{ matrix.pyarrow }} + - name: Run Rust tests + run: | + source venv/bin/activate + cargo test -p arrow --test pyarrow --features pyarrow - name: Run tests - env: - CARGO_HOME: "/home/runner/.cargo" - CARGO_TARGET_DIR: "/home/runner/target" run: | source venv/bin/activate - pushd arrow-pyarrow-integration-testing + cd arrow-pyarrow-integration-testing maturin develop pytest -v . - popd diff --git a/.github/workflows/miri.sh b/.github/workflows/miri.sh index 56da5c5c5d3e..5057c876b952 100755 --- a/.github/workflows/miri.sh +++ b/.github/workflows/miri.sh @@ -5,13 +5,14 @@ # Must be run with nightly rust for example # rustup default nightly - -# stacked borrows checking uses too much memory to run successfully in github actions -# re-enable if the CI is migrated to something more powerful (https://github.com/apache/arrow-rs/issues/1833) -# see also https://github.com/rust-lang/miri/issues/1367 -export MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-stacked-borrows" +export MIRIFLAGS="-Zmiri-disable-isolation" cargo miri setup cargo clean echo "Starting Arrow MIRI run..." -cargo miri test -p arrow -- --skip csv --skip ipc --skip json +cargo miri test -p arrow-buffer +cargo miri test -p arrow-data --features ffi +cargo miri test -p arrow-schema --features ffi +cargo miri test -p arrow-array +cargo miri test -p arrow-arith +cargo miri test -p arrow-ord diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index b4669bbcccc0..19b432121b6f 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -17,6 +17,10 @@ name: miri +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs that touch certain files and changes to master on: push: @@ -24,15 +28,26 @@ on: - master pull_request: paths: - - arrow/** - .github/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-csv/** + - arrow-data/** + - arrow-ipc/** + - arrow-json/** + - arrow-avro/** + - arrow-schema/** + - arrow-select/** + - arrow-string/** + - arrow/** jobs: miri-checks: name: MIRI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -44,5 +59,4 @@ jobs: env: RUST_BACKTRACE: full RUST_LOG: "trace" - run: | - bash .github/workflows/miri.sh + run: bash .github/workflows/miri.sh diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index 6996aa706636..ecffa29b067c 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -19,6 +19,10 @@ # tests for `object_store` crate name: object_store +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs that touch certain files and changes to master on: push: @@ -35,29 +39,51 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust + defaults: + run: + working-directory: object_store steps: - - uses: actions/checkout@v3 - - name: Setup Rust toolchain with clippy - run: | - rustup toolchain install stable - rustup default stable - rustup component add clippy + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Setup Clippy + run: rustup component add clippy # Run different tests for the library on its own as well as # all targets to ensure that it still works in the absence of # features that might be enabled by dev-dependencies of other # targets. - name: Run clippy with default features - run: cargo clippy -p object_store -- -D warnings + run: cargo clippy -- -D warnings - name: Run clippy with aws feature - run: cargo clippy -p object_store --features aws -- -D warnings + run: cargo clippy --features aws -- -D warnings - name: Run clippy with gcp feature - run: cargo clippy -p object_store --features gcp -- -D warnings + run: cargo clippy --features gcp -- -D warnings - name: Run clippy with azure feature - run: cargo clippy -p object_store --features azure -- -D warnings + run: cargo clippy --features azure -- -D warnings + - name: Run clippy with http feature + run: cargo clippy --features http -- -D warnings - name: Run clippy with all features - run: cargo clippy -p object_store --all-features -- -D warnings + run: cargo clippy --all-features -- -D warnings - name: Run clippy with all features and all targets - run: cargo clippy -p object_store --all-features --all-targets -- -D warnings + run: cargo clippy --all-features --all-targets -- -D warnings + + # test doc links still work + # + # Note that since object_store is not part of the main workspace, + # this needs a separate docs job as it is not covered by + # `cargo doc --workspace` + docs: + name: Rustdocs + runs-on: ubuntu-latest + defaults: + run: + working-directory: object_store + env: + RUSTDOCFLAGS: "-Dwarnings" + steps: + - uses: actions/checkout@v4 + - name: Run cargo doc + run: cargo doc --document-private-items --no-deps --workspace --all-features # test the crate # This runs outside a container to workaround lack of support for passing arguments @@ -65,39 +91,49 @@ jobs: linux-test: name: Emulator Tests runs-on: ubuntu-latest + defaults: + run: + working-directory: object_store env: # Disable full debug symbol generation to speed up CI build and keep memory down # "1" means line tables only, which is useful for panic tracebacks. RUSTFLAGS: "-C debuginfo=1" - # https://github.com/rust-lang/cargo/issues/10280 - CARGO_NET_GIT_FETCH_WITH_CLI: "true" RUST_BACKTRACE: "1" # Run integration tests TEST_INTEGRATION: 1 EC2_METADATA_ENDPOINT: http://localhost:1338 - AZURE_USE_EMULATOR: "1" + AZURE_CONTAINER_NAME: test-bucket + AZURE_STORAGE_USE_EMULATOR: "1" AZURITE_BLOB_STORAGE_URL: "http://localhost:10000" AZURITE_QUEUE_STORAGE_URL: "http://localhost:10001" + AWS_BUCKET: test-bucket + AWS_DEFAULT_REGION: "us-east-1" + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test + AWS_ENDPOINT: http://localhost:4566 + AWS_ALLOW_HTTP: true + HTTP_URL: "http://localhost:8080" + GOOGLE_BUCKET: test-bucket GOOGLE_SERVICE_ACCOUNT: "/tmp/gcs.json" - OBJECT_STORE_BUCKET: test-bucket steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Configure Fake GCS Server (GCP emulation) + # Custom image - see fsouza/fake-gcs-server#1164 run: | - docker run -d -p 4443:4443 fsouza/fake-gcs-server -scheme http + docker run -d -p 4443:4443 tustvold/fake-gcs-server -scheme http -backend memory -public-host localhost:4443 + # Give the container a moment to start up prior to configuring it + sleep 1 curl -v -X POST --data-binary '{"name":"test-bucket"}' -H "Content-Type: application/json" "http://localhost:4443/storage/v1/b" - echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": ""}' > "$GOOGLE_SERVICE_ACCOUNT" + echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": "", "private_key_id": ""}' > "$GOOGLE_SERVICE_ACCOUNT" + + - name: Setup WebDav + run: docker run -d -p 8080:80 rclone/rclone serve webdav /data --addr :80 - name: Setup LocalStack (AWS emulation) - env: - AWS_DEFAULT_REGION: "us-east-1" - AWS_ACCESS_KEY_ID: test - AWS_SECRET_ACCESS_KEY: test - AWS_ENDPOINT: http://localhost:4566 run: | - docker run -d -p 4566:4566 localstack/localstack:0.14.4 + docker run -d -p 4566:4566 localstack/localstack:3.0.1 docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2 aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket @@ -114,11 +150,26 @@ jobs: rustup default stable - name: Run object_store tests - env: - OBJECT_STORE_AWS_DEFAULT_REGION: "us-east-1" - OBJECT_STORE_AWS_ACCESS_KEY_ID: test - OBJECT_STORE_AWS_SECRET_ACCESS_KEY: test - OBJECT_STORE_AWS_ENDPOINT: http://localhost:4566 - run: | - # run tests - cargo test -p object_store --features=aws,azure,gcp + run: cargo test --features=aws,azure,gcp,http + + # test the object_store crate builds against wasm32 in stable rust + wasm32-build: + name: Build wasm32 + runs-on: ubuntu-latest + container: + image: amd64/rust + defaults: + run: + working-directory: object_store + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + target: wasm32-unknown-unknown,wasm32-wasi + - name: Build wasm32-unknown-unknown + run: cargo build --target wasm32-unknown-unknown + - name: Build wasm32-wasi + run: cargo build --target wasm32-wasi diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml index 42cb06bb0a86..a4e654892662 100644 --- a/.github/workflows/parquet.yml +++ b/.github/workflows/parquet.yml @@ -19,6 +19,9 @@ # tests for parquet crate name: "parquet" +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true # trigger for all PRs that touch certain files and changes to master on: @@ -28,6 +31,16 @@ on: pull_request: paths: - arrow/** + - arrow-array/** + - arrow-buffer/** + - arrow-cast/** + - arrow-data/** + - arrow-schema/** + - arrow-select/** + - arrow-ipc/** + - arrow-csv/** + - arrow-json/** + - arrow-avro/** - parquet/** - .github/** @@ -38,25 +51,22 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Test - run: | - cargo test -p parquet + run: cargo test -p parquet - name: Test --all-features + run: cargo test -p parquet --all-features + - name: Run examples run: | - cargo test -p parquet --all-features - + # Test parquet examples + cargo run -p parquet --example read_parquet + cargo run -p parquet --example async_read_parquet --features="async" + cargo run -p parquet --example read_with_rowgroup --features="async" # test compilation linux-features: @@ -64,18 +74,12 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable # Run different tests for the library on its own as well as # all targets to ensure that it still works in the absence of @@ -88,29 +92,78 @@ jobs: # 3. compiles with just arrow feature # 3. compiles with all features - name: Check compilation - run: | - cargo check -p parquet + run: cargo check -p parquet - name: Check compilation --no-default-features - run: | - cargo check -p parquet --no-default-features + run: cargo check -p parquet --no-default-features - name: Check compilation --no-default-features --features arrow - run: | - cargo check -p parquet --no-default-features --features arrow + run: cargo check -p parquet --no-default-features --features arrow - name: Check compilation --no-default-features --all-features - run: | - cargo check -p parquet --all-features + run: cargo check -p parquet --all-features - name: Check compilation --all-targets - run: | - cargo check -p parquet --all-targets + run: cargo check -p parquet --all-targets - name: Check compilation --all-targets --no-default-features - run: | - cargo check -p parquet --all-targets --no-default-features + run: cargo check -p parquet --all-targets --no-default-features - name: Check compilation --all-targets --no-default-features --features arrow - run: | - cargo check -p parquet --all-targets --no-default-features --features arrow + run: cargo check -p parquet --all-targets --no-default-features --features arrow - name: Check compilation --all-targets --all-features + run: cargo check -p parquet --all-targets --all-features + - name: Check compilation --all-targets --no-default-features --features json + run: cargo check -p parquet --all-targets --no-default-features --features json + + # test the parquet crate builds against wasm32 in stable rust + wasm32-build: + name: Build wasm32 + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + target: wasm32-unknown-unknown,wasm32-wasi + - name: Install clang # Needed for zlib compilation + run: apt-get update && apt-get install -y clang gcc-multilib + - name: Build wasm32-unknown-unknown + run: cargo build -p parquet --target wasm32-unknown-unknown + - name: Build wasm32-wasi + run: cargo build -p parquet --target wasm32-wasi + + pyspark-integration-test: + name: PySpark Integration Test + runs-on: ubuntu-latest + strategy: + matrix: + rust: [ stable ] + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: "pip" + - name: Install Python dependencies run: | - cargo check -p parquet --all-targets --all-features + cd parquet/pytest + pip install -r requirements.txt + - name: Black check the test files + run: | + cd parquet/pytest + black --check *.py --verbose + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} + - name: Install binary for checking + run: | + cargo install --path parquet --bin parquet-show-bloom-filter --features=cli + cargo install --path parquet --bin parquet-fromcsv --features=arrow,cli + - name: Run pytest + run: | + cd parquet/pytest + pytest -v clippy: name: Clippy @@ -118,14 +171,10 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy + run: rustup component add clippy - name: Run clippy - run: | - cargo clippy -p parquet --all-targets --all-features -- -D warnings + run: cargo clippy -p parquet --all-targets --all-features -- -D warnings diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml index bd70fc30d1c5..d8b02f73a8aa 100644 --- a/.github/workflows/parquet_derive.yml +++ b/.github/workflows/parquet_derive.yml @@ -19,6 +19,9 @@ # tests for parquet_derive crate name: parquet_derive +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true # trigger for all PRs that touch certain files and changes to master on: @@ -39,21 +42,14 @@ jobs: runs-on: ubuntu-latest container: image: amd64/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Test - run: | - cargo test -p parquet_derive + run: cargo test -p parquet_derive clippy: name: Clippy @@ -61,14 +57,10 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: stable - name: Setup Clippy - run: | - rustup component add clippy + run: rustup component add clippy - name: Run clippy - run: | - cargo clippy -p parquet_derive --all-features -- -D warnings + run: cargo clippy -p parquet_derive --all-features -- -D warnings diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c04d5643b49a..9c4b28b691b7 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -18,6 +18,10 @@ # workspace wide tests name: rust +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + # trigger for all PRs and changes to master on: push: @@ -33,12 +37,11 @@ jobs: name: Test on Mac runs-on: macos-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protoc with brew - run: | - brew install protobuf + run: brew install protobuf - name: Setup Rust toolchain run: | rustup toolchain install stable --no-self-update @@ -57,7 +60,7 @@ jobs: name: Test on Windows runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protobuf compiler in /d/protoc @@ -90,11 +93,41 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 - - name: Setup toolchain - run: | - rustup toolchain install stable - rustup default stable - rustup component add rustfmt - - name: Run + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Setup rustfmt + run: rustup component add rustfmt + - name: Format arrow + run: cargo fmt --all -- --check + - name: Format object_store + working-directory: object_store run: cargo fmt --all -- --check + + msrv: + name: Verify MSRV + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Install cargo-msrv + run: cargo install cargo-msrv + - name: Check arrow + working-directory: arrow + run: cargo msrv verify + - name: Check parquet + working-directory: parquet + run: cargo msrv verify + - name: Check arrow-flight + working-directory: arrow-flight + run: cargo msrv verify + - name: Downgrade object_store dependencies + working-directory: object_store + # Necessary because 1.30.0 updates MSRV to 1.63 + run: cargo update -p tokio --precise 1.29.1 + - name: Check object_store + working-directory: object_store + run: cargo msrv verify diff --git a/.gitignore b/.gitignore index 2a21776aa545..52ad19cb077d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,7 +20,7 @@ __blobstorage__ # .bak files *.bak - +*.bak2 # OS-specific .gitignores # Mac .gitignore @@ -92,3 +92,6 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk +# Python virtual env in parquet crate +parquet/pytest/venv/ +__pycache__/ diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index 70322b5cfd1d..336adff990bd 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -17,9 +17,2116 @@ under the License. --> - # Historical Changelog +## [48.0.0](https://github.com/apache/arrow-rs/tree/48.0.0) (2023-10-18) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/47.0.0...48.0.0) + +**Breaking changes:** + +- Evaluate null\_regex for string type in csv \(now such values will be parsed as `Null` rather than `""`\) [\#4942](https://github.com/apache/arrow-rs/pull/4942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haohuaijin](https://github.com/haohuaijin)) +- fix\(csv\)!: infer null for empty column. [\#4910](https://github.com/apache/arrow-rs/pull/4910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- feat: log headers/trailers in flight CLI \(+ minor fixes\) [\#4898](https://github.com/apache/arrow-rs/pull/4898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- fix\(arrow-json\)!: include null fields in schema inference with a type of Null [\#4894](https://github.com/apache/arrow-rs/pull/4894) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Mark OnCloseRowGroup Send [\#4893](https://github.com/apache/arrow-rs/pull/4893) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) +- Specialize Thrift Decoding \(~40% Faster\) \(\#4891\) [\#4892](https://github.com/apache/arrow-rs/pull/4892) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make ArrowRowGroupWriter Public and SerializedRowGroupWriter Send [\#4850](https://github.com/apache/arrow-rs/pull/4850) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([devinjdangelo](https://github.com/devinjdangelo)) + +**Implemented enhancements:** + +- Allow schema fields to merge with `Null` datatype [\#4901](https://github.com/apache/arrow-rs/issues/4901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add option to FlightDataEncoder to always send dictionaries [\#4895](https://github.com/apache/arrow-rs/issues/4895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Rework Thrift Encoding / Decoding of Parquet Metadata [\#4891](https://github.com/apache/arrow-rs/issues/4891) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Plans for supporting Extension Array to support Fixed shape tensor Array [\#4890](https://github.com/apache/arrow-rs/issues/4890) +- Implement Take for UnionArray [\#4882](https://github.com/apache/arrow-rs/issues/4882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check precision overflow for casting floating to decimal [\#4865](https://github.com/apache/arrow-rs/issues/4865) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace lexical [\#4774](https://github.com/apache/arrow-rs/issues/4774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add read access to settings in `csv::WriterBuilder` [\#4735](https://github.com/apache/arrow-rs/issues/4735) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve the performance of "DictionaryValue" row encoding [\#4712](https://github.com/apache/arrow-rs/issues/4712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Should we make blank values and empty string to `None` in csv? [\#4939](https://github.com/apache/arrow-rs/issues/4939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] SubstraitPlan structure is not exported [\#4932](https://github.com/apache/arrow-rs/issues/4932) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Loading page index breaks skipping of pages with nested types [\#4921](https://github.com/apache/arrow-rs/issues/4921) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV schema inference assumes `Utf8` for empty columns [\#4903](https://github.com/apache/arrow-rs/issues/4903) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: Field Ids are not read from a Parquet file without serialized arrow schema [\#4877](https://github.com/apache/arrow-rs/issues/4877) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- make\_primitive\_scalar function loses DataType Internal information [\#4851](https://github.com/apache/arrow-rs/issues/4851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- StructBuilder doesn't handle nulls correctly for empty structs [\#4842](https://github.com/apache/arrow-rs/issues/4842) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `NullArray::is_null()` returns `false` incorrectly [\#4835](https://github.com/apache/arrow-rs/issues/4835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- cast\_string\_to\_decimal should check precision overflow [\#4829](https://github.com/apache/arrow-rs/issues/4829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Null fields are omitted by `infer_json_schema_from_seekable` [\#4814](https://github.com/apache/arrow-rs/issues/4814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Support for reading JSON Array to Arrow [\#4905](https://github.com/apache/arrow-rs/issues/4905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Assume Pages Delimit Records When Offset Index Loaded \(\#4921\) [\#4943](https://github.com/apache/arrow-rs/pull/4943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update pyo3 requirement from 0.19 to 0.20 [\#4941](https://github.com/apache/arrow-rs/pull/4941) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add `FileWriter` schema getter [\#4940](https://github.com/apache/arrow-rs/pull/4940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([haixuanTao](https://github.com/haixuanTao)) +- feat: support parsing for parquet writer option [\#4938](https://github.com/apache/arrow-rs/pull/4938) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([fansehep](https://github.com/fansehep)) +- Export `SubstraitPlan` structure in arrow\_flight::sql \(\#4932\) [\#4933](https://github.com/apache/arrow-rs/pull/4933) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- Update zstd requirement from 0.12.0 to 0.13.0 [\#4923](https://github.com/apache/arrow-rs/pull/4923) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: add method for async read bloom filter [\#4917](https://github.com/apache/arrow-rs/pull/4917) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hengfeiyang](https://github.com/hengfeiyang)) +- Minor: Clarify rationale for `FlightDataEncoder` API, add examples [\#4916](https://github.com/apache/arrow-rs/pull/4916) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Update regex-syntax requirement from 0.7.1 to 0.8.0 [\#4914](https://github.com/apache/arrow-rs/pull/4914) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: document & streamline flight SQL CLI [\#4912](https://github.com/apache/arrow-rs/pull/4912) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Arbitrary JSON values in JSON Reader \(\#4905\) [\#4911](https://github.com/apache/arrow-rs/pull/4911) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup CSV WriterBuilder, Default to AutoSI Second Precision \(\#4735\) [\#4909](https://github.com/apache/arrow-rs/pull/4909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.68 to =1.0.69 [\#4907](https://github.com/apache/arrow-rs/pull/4907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- chore: add csv example [\#4904](https://github.com/apache/arrow-rs/pull/4904) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +- feat\(schema\): allow null fields to be merged with other datatypes [\#4902](https://github.com/apache/arrow-rs/pull/4902) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kskalski](https://github.com/kskalski)) +- Update proc-macro2 requirement from =1.0.67 to =1.0.68 [\#4900](https://github.com/apache/arrow-rs/pull/4900) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add option to `FlightDataEncoder` to always resend batch dictionaries [\#4896](https://github.com/apache/arrow-rs/pull/4896) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- Fix integration tests [\#4889](https://github.com/apache/arrow-rs/pull/4889) ([tustvold](https://github.com/tustvold)) +- Support Parsing Avro File Headers [\#4888](https://github.com/apache/arrow-rs/pull/4888) ([tustvold](https://github.com/tustvold)) +- Support parquet bloom filter length [\#4885](https://github.com/apache/arrow-rs/pull/4885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([letian-jiang](https://github.com/letian-jiang)) +- Replace lz4 with lz4\_flex Allowing Compilation for WASM [\#4884](https://github.com/apache/arrow-rs/pull/4884) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Take for UnionArray [\#4883](https://github.com/apache/arrow-rs/pull/4883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Update tonic-build requirement from =0.10.1 to =0.10.2 [\#4881](https://github.com/apache/arrow-rs/pull/4881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- parquet: Read field IDs from Parquet Schema [\#4878](https://github.com/apache/arrow-rs/pull/4878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Samrose-Ahmed](https://github.com/Samrose-Ahmed)) +- feat: improve flight CLI error handling [\#4873](https://github.com/apache/arrow-rs/pull/4873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Support Encoding Parquet Columns in Parallel [\#4871](https://github.com/apache/arrow-rs/pull/4871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Check precision overflow for casting floating to decimal [\#4866](https://github.com/apache/arrow-rs/pull/4866) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make align\_buffers as public API [\#4863](https://github.com/apache/arrow-rs/pull/4863) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable new integration tests \(\#4828\) [\#4862](https://github.com/apache/arrow-rs/pull/4862) ([tustvold](https://github.com/tustvold)) +- Faster Serde Integration \(~80% faster\) [\#4861](https://github.com/apache/arrow-rs/pull/4861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: make\_primitive\_scalar bug [\#4852](https://github.com/apache/arrow-rs/pull/4852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JasonLi-cn](https://github.com/JasonLi-cn)) +- Update tonic-build requirement from =0.10.0 to =0.10.1 [\#4846](https://github.com/apache/arrow-rs/pull/4846) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Allow Constructing Non-Empty StructArray with no Fields \(\#4842\) [\#4845](https://github.com/apache/arrow-rs/pull/4845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Refine documentation to `Array::is_null` [\#4838](https://github.com/apache/arrow-rs/pull/4838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: add missing precision overflow checking for `cast_string_to_decimal` [\#4830](https://github.com/apache/arrow-rs/pull/4830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonahgao](https://github.com/jonahgao)) +## [47.0.0](https://github.com/apache/arrow-rs/tree/47.0.0) (2023-09-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/46.0.0...47.0.0) + +**Breaking changes:** + +- Make FixedSizeBinaryArray value\_data return a reference [\#4820](https://github.com/apache/arrow-rs/issues/4820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update prost to v0.12.1 [\#4825](https://github.com/apache/arrow-rs/pull/4825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: FixedSizeBinaryArray::value\_data return reference [\#4821](https://github.com/apache/arrow-rs/pull/4821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Stateless Row Encoding / Don't Preserve Dictionaries in `RowConverter` \(\#4811\) [\#4819](https://github.com/apache/arrow-rs/pull/4819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- fix: entries field is non-nullable [\#4808](https://github.com/apache/arrow-rs/pull/4808) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Fix flight sql do put handling, add bind parameter support to FlightSQL cli client [\#4797](https://github.com/apache/arrow-rs/pull/4797) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([suremarc](https://github.com/suremarc)) +- Remove unused dyn\_cmp\_dict feature [\#4766](https://github.com/apache/arrow-rs/pull/4766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add underlying `std::io::Error` to `IoError` and add `IpcError` variant [\#4726](https://github.com/apache/arrow-rs/pull/4726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alexandreyc](https://github.com/alexandreyc)) + +**Implemented enhancements:** + +- Row Format Adapative Block Size [\#4812](https://github.com/apache/arrow-rs/issues/4812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Stateless Row Conversion [\#4811](https://github.com/apache/arrow-rs/issues/4811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add option to specify custom null values for CSV reader [\#4794](https://github.com/apache/arrow-rs/issues/4794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet::record::RowIter cannot be customized with batch\_size and defaults to 1024 [\#4782](https://github.com/apache/arrow-rs/issues/4782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `DynScalar` abstraction \(something that makes it easy to create scalar `Datum`s\) [\#4781](https://github.com/apache/arrow-rs/issues/4781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Datum` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4780](https://github.com/apache/arrow-rs/issues/4780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `Scalar` is not exported as part of `arrow` \(it is only exported in `arrow_array`\) [\#4779](https://github.com/apache/arrow-rs/issues/4779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support IntoPyArrow for impl RecordBatchReader [\#4730](https://github.com/apache/arrow-rs/issues/4730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Datum Based String Kernels [\#4595](https://github.com/apache/arrow-rs/issues/4595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- MapArray::new\_from\_strings creates nullable entries field [\#4807](https://github.com/apache/arrow-rs/issues/4807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- pyarrow module can't roundtrip tensor arrays [\#4805](https://github.com/apache/arrow-rs/issues/4805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `concat_batches` errors with "schema mismatch" error when only metadata differs [\#4799](https://github.com/apache/arrow-rs/issues/4799) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- panic in `cmp` kernels with DictionaryArrays: `Option::unwrap()` on a `None` value' [\#4788](https://github.com/apache/arrow-rs/issues/4788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- stream ffi panics if schema metadata values aren't valid utf8 [\#4750](https://github.com/apache/arrow-rs/issues/4750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression: Incorrect Sorting of `*ListArray` in 46.0.0 [\#4746](https://github.com/apache/arrow-rs/issues/4746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row is no longer comparable after reuse [\#4741](https://github.com/apache/arrow-rs/issues/4741) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- DoPut FlightSQL handler inadvertently consumes schema at start of Request\\> [\#4658](https://github.com/apache/arrow-rs/issues/4658) +- Return error when converting schema [\#4752](https://github.com/apache/arrow-rs/pull/4752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Implement PyArrowType for `Box` [\#4751](https://github.com/apache/arrow-rs/pull/4751) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) + +**Closed issues:** + +- Building arrow-rust for target wasm32-wasi falied to compile packed\_simd\_2 [\#4717](https://github.com/apache/arrow-rs/issues/4717) + +**Merged pull requests:** + +- Respect FormatOption::nulls for NullArray [\#4836](https://github.com/apache/arrow-rs/pull/4836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix merge\_dictionary\_values in selection kernels [\#4833](https://github.com/apache/arrow-rs/pull/4833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix like scalar null [\#4832](https://github.com/apache/arrow-rs/pull/4832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More chrono deprecations [\#4822](https://github.com/apache/arrow-rs/pull/4822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Adaptive Row Block Size \(\#4812\) [\#4818](https://github.com/apache/arrow-rs/pull/4818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.66 to =1.0.67 [\#4816](https://github.com/apache/arrow-rs/pull/4816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Do not check schema for equality in concat\_batches [\#4815](https://github.com/apache/arrow-rs/pull/4815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: export record batch through stream [\#4806](https://github.com/apache/arrow-rs/pull/4806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Improve CSV Reader Benchmark Coverage of Small Primitives [\#4803](https://github.com/apache/arrow-rs/pull/4803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- csv: Add option to specify custom null values [\#4795](https://github.com/apache/arrow-rs/pull/4795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vrongmeal](https://github.com/vrongmeal)) +- Expand docstring and add example to `Scalar` [\#4793](https://github.com/apache/arrow-rs/pull/4793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Re-export array crate root \(\#4780\) \(\#4779\) [\#4791](https://github.com/apache/arrow-rs/pull/4791) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix DictionaryArray::normalized\_keys \(\#4788\) [\#4789](https://github.com/apache/arrow-rs/pull/4789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow custom tree builder for parquet::record::RowIter [\#4783](https://github.com/apache/arrow-rs/pull/4783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([YuraKotov](https://github.com/YuraKotov)) +- Bump actions/checkout from 3 to 4 [\#4767](https://github.com/apache/arrow-rs/pull/4767) ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: avoid panic if offset index not exists. [\#4761](https://github.com/apache/arrow-rs/pull/4761) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([RinChanNOWWW](https://github.com/RinChanNOWWW)) +- Relax constraints on PyArrowType [\#4757](https://github.com/apache/arrow-rs/pull/4757) ([tustvold](https://github.com/tustvold)) +- Chrono deprecations [\#4748](https://github.com/apache/arrow-rs/pull/4748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix List Sorting, Revert Removal of Rank Kernels [\#4747](https://github.com/apache/arrow-rs/pull/4747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clear row buffer before reuse [\#4742](https://github.com/apache/arrow-rs/pull/4742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- Datum based like kernels \(\#4595\) [\#4732](https://github.com/apache/arrow-rs/pull/4732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: expose DoGet response headers & trailers [\#4727](https://github.com/apache/arrow-rs/pull/4727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Cleanup length and bit\_length kernels [\#4718](https://github.com/apache/arrow-rs/pull/4718) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [46.0.0](https://github.com/apache/arrow-rs/tree/46.0.0) (2023-08-21) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/45.0.0...46.0.0) + +**Breaking changes:** + +- API improvement: `batches_to_flight_data` forces clone [\#4656](https://github.com/apache/arrow-rs/issues/4656) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add AnyDictionary Abstraction and Take ArrayRef in DictionaryArray::with\_values [\#4707](https://github.com/apache/arrow-rs/pull/4707) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup parquet type builders [\#4706](https://github.com/apache/arrow-rs/pull/4706) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Take kernel dyn Array [\#4705](https://github.com/apache/arrow-rs/pull/4705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve ergonomics of Scalar [\#4704](https://github.com/apache/arrow-rs/pull/4704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Datum based comparison kernels \(\#4596\) [\#4701](https://github.com/apache/arrow-rs/pull/4701) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Improve `Array` Logical Nullability [\#4691](https://github.com/apache/arrow-rs/pull/4691) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Validate ArrayData Buffer Alignment and Automatically Align IPC buffers \(\#4255\) [\#4681](https://github.com/apache/arrow-rs/pull/4681) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More intuitive bool-to-string casting [\#4666](https://github.com/apache/arrow-rs/pull/4666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fsdvh](https://github.com/fsdvh)) +- enhancement: batches\_to\_flight\_data use a schema ref as param. [\#4665](https://github.com/apache/arrow-rs/pull/4665) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([jackwener](https://github.com/jackwener)) +- fix: from\_thrift avoid panic when stats in invalid. [\#4642](https://github.com/apache/arrow-rs/pull/4642) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jackwener](https://github.com/jackwener)) +- bug: Add some missing field in row group metadata: ordinal, total co… [\#4636](https://github.com/apache/arrow-rs/pull/4636) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liurenjie1024](https://github.com/liurenjie1024)) +- Remove deprecated limit kernel [\#4597](https://github.com/apache/arrow-rs/pull/4597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- parquet: support setting the field\_id with an ArrowWriter [\#4702](https://github.com/apache/arrow-rs/issues/4702) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support references in i256 arithmetic ops [\#4694](https://github.com/apache/arrow-rs/issues/4694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Precision-Loss Decimal Arithmetic [\#4664](https://github.com/apache/arrow-rs/issues/4664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Faster i256 Division [\#4663](https://github.com/apache/arrow-rs/issues/4663) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `concat_batches` for 0 columns [\#4661](https://github.com/apache/arrow-rs/issues/4661) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `filter_record_batch` should support filtering record batch without columns [\#4647](https://github.com/apache/arrow-rs/issues/4647) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve speed of `lexicographical_partition_ranges` [\#4614](https://github.com/apache/arrow-rs/issues/4614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- object\_store: multipart ranges for HTTP [\#4612](https://github.com/apache/arrow-rs/issues/4612) +- Add Rank Function [\#4606](https://github.com/apache/arrow-rs/issues/4606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Datum Based Comparison Kernels [\#4596](https://github.com/apache/arrow-rs/issues/4596) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Convenience method to create `DataType::List` correctly [\#4544](https://github.com/apache/arrow-rs/issues/4544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove Deprecated Arithmetic Kernels [\#4481](https://github.com/apache/arrow-rs/issues/4481) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Equality kernel where null==null gives true [\#4438](https://github.com/apache/arrow-rs/issues/4438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Parquet ArrowWriter Ignores Nulls in Dictionary Values [\#4690](https://github.com/apache/arrow-rs/issues/4690) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Schema Nullability Validation Fails to Account for Dictionary Nulls [\#4689](https://github.com/apache/arrow-rs/issues/4689) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Comparison Kernels Ignore Nulls in Dictionary Values [\#4688](https://github.com/apache/arrow-rs/issues/4688) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting List to String Ignores Format Options [\#4669](https://github.com/apache/arrow-rs/issues/4669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Double free in C Stream Interface [\#4659](https://github.com/apache/arrow-rs/issues/4659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CI Failing On Packed SIMD [\#4651](https://github.com/apache/arrow-rs/issues/4651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `RowInterner::size()` much too low for high cardinality dictionary columns [\#4645](https://github.com/apache/arrow-rs/issues/4645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Decimal PrimitiveArray change datatype after try\_unary [\#4644](https://github.com/apache/arrow-rs/issues/4644) +- Better explanation in docs for Dictionary field encoding using RowConverter [\#4639](https://github.com/apache/arrow-rs/issues/4639) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `List(FixedSizeBinary)` array equality check may return wrong result [\#4637](https://github.com/apache/arrow-rs/issues/4637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow::compute::nullif` panics if `NullArray` is provided [\#4634](https://github.com/apache/arrow-rs/issues/4634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Empty lists in FixedSizeListArray::try\_new is not handled [\#4623](https://github.com/apache/arrow-rs/issues/4623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Bounds checking in `MutableBuffer::set_null_bits` can be bypassed [\#4620](https://github.com/apache/arrow-rs/issues/4620) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- TypedDictionaryArray Misleading Null Behaviour [\#4616](https://github.com/apache/arrow-rs/issues/4616) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- bug: Parquet writer missing row group metadata fields such as `compressed_size`, `file offset`. [\#4610](https://github.com/apache/arrow-rs/issues/4610) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `new_null_array` generates an invalid union array [\#4600](https://github.com/apache/arrow-rs/issues/4600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Footer parsing fails for very large parquet file. [\#4592](https://github.com/apache/arrow-rs/issues/4592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- bug\(parquet\): Disabling global statistics but enabling for particular column breaks reading [\#4587](https://github.com/apache/arrow-rs/issues/4587) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `arrow::compute::concat` panics for dense union arrays with non-trivial type IDs [\#4578](https://github.com/apache/arrow-rs/issues/4578) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- \[object\_store\] when Create a AmazonS3 instance work with MinIO without set endpoint got error MissingRegion [\#4617](https://github.com/apache/arrow-rs/issues/4617) + +**Merged pull requests:** + +- Add distinct kernels \(\#960\) \(\#4438\) [\#4716](https://github.com/apache/arrow-rs/pull/4716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update parquet object\_store 0.7 [\#4715](https://github.com/apache/arrow-rs/pull/4715) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support Field ID in ArrowWriter \(\#4702\) [\#4710](https://github.com/apache/arrow-rs/pull/4710) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove rank kernels [\#4703](https://github.com/apache/arrow-rs/pull/4703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support references in i256 arithmetic ops [\#4692](https://github.com/apache/arrow-rs/pull/4692) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cleanup DynComparator \(\#2654\) [\#4687](https://github.com/apache/arrow-rs/pull/4687) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Separate metadata fetch from `ArrowReaderBuilder` construction \(\#4674\) [\#4676](https://github.com/apache/arrow-rs/pull/4676) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- cleanup some assert\(\) with error propagation [\#4673](https://github.com/apache/arrow-rs/pull/4673) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Faster i256 Division \(2-100x\) \(\#4663\) [\#4672](https://github.com/apache/arrow-rs/pull/4672) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix MSRV CI [\#4671](https://github.com/apache/arrow-rs/pull/4671) ([tustvold](https://github.com/tustvold)) +- Fix equality of nested nullable FixedSizeBinary \(\#4637\) [\#4670](https://github.com/apache/arrow-rs/pull/4670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use ArrayFormatter in cast kernel [\#4668](https://github.com/apache/arrow-rs/pull/4668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Improve API docs for FlightSQL metadata builders [\#4667](https://github.com/apache/arrow-rs/pull/4667) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Support `concat_batches` for 0 columns [\#4662](https://github.com/apache/arrow-rs/pull/4662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- fix ownership of c stream error [\#4660](https://github.com/apache/arrow-rs/pull/4660) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Minor: Fix illustration for dict encoding [\#4657](https://github.com/apache/arrow-rs/pull/4657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JayjeetAtGithub](https://github.com/JayjeetAtGithub)) +- minor: move comment to the correct location [\#4655](https://github.com/apache/arrow-rs/pull/4655) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Update packed\_simd and run miri tests on simd code [\#4654](https://github.com/apache/arrow-rs/pull/4654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- impl `From>` for `BufferBuilder` and `MutableBuffer` [\#4650](https://github.com/apache/arrow-rs/pull/4650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Filter record batch with 0 columns [\#4648](https://github.com/apache/arrow-rs/pull/4648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Account for child `Bucket` size in OrderPreservingInterner [\#4646](https://github.com/apache/arrow-rs/pull/4646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Implement `Default`,`Extend` and `FromIterator` for `BufferBuilder` [\#4638](https://github.com/apache/arrow-rs/pull/4638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- fix\(select\): handle `NullArray` in `nullif` [\#4635](https://github.com/apache/arrow-rs/pull/4635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Move `BufferBuilder` to `arrow-buffer` [\#4630](https://github.com/apache/arrow-rs/pull/4630) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- allow zero sized empty fixed [\#4626](https://github.com/apache/arrow-rs/pull/4626) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- fix: compute\_dictionary\_mapping use wrong offsetSize [\#4625](https://github.com/apache/arrow-rs/pull/4625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- impl `FromIterator` for `MutableBuffer` [\#4624](https://github.com/apache/arrow-rs/pull/4624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- expand docs for FixedSizeListArray [\#4622](https://github.com/apache/arrow-rs/pull/4622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- fix\(buffer\): panic on end index overflow in `MutableBuffer::set_null_bits` [\#4621](https://github.com/apache/arrow-rs/pull/4621) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- impl `Default` for `arrow_buffer::buffer::MutableBuffer` [\#4619](https://github.com/apache/arrow-rs/pull/4619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Minor: improve docs and add example for lexicographical\_partition\_ranges [\#4615](https://github.com/apache/arrow-rs/pull/4615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Cleanup sort [\#4613](https://github.com/apache/arrow-rs/pull/4613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add rank function \(\#4606\) [\#4609](https://github.com/apache/arrow-rs/pull/4609) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add more docs and examples for ListArray and OffsetsBuffer [\#4607](https://github.com/apache/arrow-rs/pull/4607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Simplify dictionary sort [\#4605](https://github.com/apache/arrow-rs/pull/4605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Consolidate sort benchmarks [\#4604](https://github.com/apache/arrow-rs/pull/4604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Don't Reorder Nulls in sort\_to\_indices \(\#4545\) [\#4603](https://github.com/apache/arrow-rs/pull/4603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix\(data\): create child arrays of correct length when building a sparse union null array [\#4601](https://github.com/apache/arrow-rs/pull/4601) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Use u32 metadata\_len when parsing footer of parquet. [\#4599](https://github.com/apache/arrow-rs/pull/4599) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Berrysoft](https://github.com/Berrysoft)) +- fix\(data\): map type ID to child index before indexing a union child array [\#4598](https://github.com/apache/arrow-rs/pull/4598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kawadakk](https://github.com/kawadakk)) +- Remove deprecated arithmetic kernels \(\#4481\) [\#4594](https://github.com/apache/arrow-rs/pull/4594) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Test Disabled Page Statistics \(\#4587\) [\#4589](https://github.com/apache/arrow-rs/pull/4589) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup ArrayData::buffers [\#4583](https://github.com/apache/arrow-rs/pull/4583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use contains\_nulls in ArrayData equality of byte arrays [\#4582](https://github.com/apache/arrow-rs/pull/4582) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Vectorized lexicographical\_partition\_ranges \(~80% faster\) [\#4575](https://github.com/apache/arrow-rs/pull/4575) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- chore: add datatype new\_list [\#4561](https://github.com/apache/arrow-rs/pull/4561) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +## [45.0.0](https://github.com/apache/arrow-rs/tree/45.0.0) (2023-07-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/44.0.0...45.0.0) + +**Breaking changes:** + +- Fix timezoned timestamp arithmetic [\#4546](https://github.com/apache/arrow-rs/pull/4546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) + +**Implemented enhancements:** + +- Use FormatOptions in Const Contexts [\#4580](https://github.com/apache/arrow-rs/issues/4580) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Human Readable Duration Display [\#4554](https://github.com/apache/arrow-rs/issues/4554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `BooleanBuilder`: Add `validity_slice` method for accessing validity bits [\#4535](https://github.com/apache/arrow-rs/issues/4535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `FixedSizedListArray` for `length` kernel [\#4517](https://github.com/apache/arrow-rs/issues/4517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `RowCoverter::convert` that targets an existing `Rows` [\#4479](https://github.com/apache/arrow-rs/issues/4479) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Panic `assertion failed: idx < self.len` when casting DictionaryArrays with nulls [\#4576](https://github.com/apache/arrow-rs/issues/4576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-arith is\_null is buggy with NullArray [\#4565](https://github.com/apache/arrow-rs/issues/4565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect Interval to Duration Casting [\#4553](https://github.com/apache/arrow-rs/issues/4553) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Too large validity buffer pre-allocation in `FixedSizeListBuilder::new` [\#4549](https://github.com/apache/arrow-rs/issues/4549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Like with wildcards fail to match fields with new lines. [\#4547](https://github.com/apache/arrow-rs/issues/4547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Timestamp Interval Arithmetic Ignores Timezone [\#4457](https://github.com/apache/arrow-rs/issues/4457) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- refactor: simplify hour\_dyn\(\) with time\_fraction\_dyn\(\) [\#4588](https://github.com/apache/arrow-rs/pull/4588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Move from\_iter\_values to GenericByteArray [\#4586](https://github.com/apache/arrow-rs/pull/4586) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Mark GenericByteArray::new\_unchecked unsafe [\#4584](https://github.com/apache/arrow-rs/pull/4584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Configurable Duration Display [\#4581](https://github.com/apache/arrow-rs/pull/4581) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix take\_bytes Null and Overflow Handling \(\#4576\) [\#4579](https://github.com/apache/arrow-rs/pull/4579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move chrono-tz arithmetic tests to integration [\#4571](https://github.com/apache/arrow-rs/pull/4571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Write Page Offset Index For All-Nan Pages [\#4567](https://github.com/apache/arrow-rs/pull/4567) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([MachaelLee](https://github.com/MachaelLee)) +- support NullArray un arith/boolean kernel [\#4566](https://github.com/apache/arrow-rs/pull/4566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([smiklos](https://github.com/smiklos)) +- Remove Sync from arrow-flight example [\#4564](https://github.com/apache/arrow-rs/pull/4564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Fix interval to duration casting \(\#4553\) [\#4562](https://github.com/apache/arrow-rs/pull/4562) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- docs: fix wrong parameter name [\#4559](https://github.com/apache/arrow-rs/pull/4559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +- Fix FixedSizeListBuilder capacity \(\#4549\) [\#4552](https://github.com/apache/arrow-rs/pull/4552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- docs: fix wrong inline code snippet in parquet document [\#4550](https://github.com/apache/arrow-rs/pull/4550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +- fix multiline wildcard likes \(fixes \#4547\) [\#4548](https://github.com/apache/arrow-rs/pull/4548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nl5887](https://github.com/nl5887)) +- Provide default `is_empty` impl for `arrow::array::ArrayBuilder` [\#4543](https://github.com/apache/arrow-rs/pull/4543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Add RowConverter::append \(\#4479\) [\#4541](https://github.com/apache/arrow-rs/pull/4541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clarify GenericColumnReader::read\_records [\#4540](https://github.com/apache/arrow-rs/pull/4540) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Initial loongarch port [\#4538](https://github.com/apache/arrow-rs/pull/4538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xiangzhai](https://github.com/xiangzhai)) +- Update proc-macro2 requirement from =1.0.64 to =1.0.66 [\#4537](https://github.com/apache/arrow-rs/pull/4537) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- add a validity slice access for boolean array builders [\#4536](https://github.com/apache/arrow-rs/pull/4536) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ChristianBeilschmidt](https://github.com/ChristianBeilschmidt)) +- use new num version instead of explicit num-complex dependency [\#4532](https://github.com/apache/arrow-rs/pull/4532) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mwlon](https://github.com/mwlon)) +- feat: Support `FixedSizedListArray` for `length` kernel [\#4520](https://github.com/apache/arrow-rs/pull/4520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +## [44.0.0](https://github.com/apache/arrow-rs/tree/44.0.0) (2023-07-14) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/43.0.0...44.0.0) + +**Breaking changes:** + +- Use Parser for cast kernel \(\#4512\) [\#4513](https://github.com/apache/arrow-rs/pull/4513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Datum based arithmetic kernels \(\#3999\) [\#4465](https://github.com/apache/arrow-rs/pull/4465) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- eq\_dyn\_binary\_scalar should support FixedSizeBinary types [\#4491](https://github.com/apache/arrow-rs/issues/4491) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Port Tests from Deprecated Arithmetic Kernels [\#4480](https://github.com/apache/arrow-rs/issues/4480) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement RecordBatchReader for Boxed trait object [\#4474](https://github.com/apache/arrow-rs/issues/4474) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Date` - `Date` kernel [\#4383](https://github.com/apache/arrow-rs/issues/4383) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Default FlightSqlService Implementations [\#4372](https://github.com/apache/arrow-rs/issues/4372) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Parquet: `AsyncArrowWriter` to a file corrupts the footer for large columns [\#4526](https://github.com/apache/arrow-rs/issues/4526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[object\_store\] Failure to send bytes to azure [\#4522](https://github.com/apache/arrow-rs/issues/4522) +- Cannot cast string '2021-01-02' to value of Date64 type [\#4512](https://github.com/apache/arrow-rs/issues/4512) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect Interval Subtraction [\#4489](https://github.com/apache/arrow-rs/issues/4489) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Interval Negation Incorrect [\#4488](https://github.com/apache/arrow-rs/issues/4488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet: AsyncArrowWriter inner buffer is not correctly limited and causes OOM [\#4477](https://github.com/apache/arrow-rs/issues/4477) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Fix AsyncArrowWriter flush for large buffer sizes \(\#4526\) [\#4527](https://github.com/apache/arrow-rs/pull/4527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup cast\_primitive\_to\_list [\#4511](https://github.com/apache/arrow-rs/pull/4511) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Bump actions/upload-pages-artifact from 1 to 2 [\#4508](https://github.com/apache/arrow-rs/pull/4508) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support Date - Date \(\#4383\) [\#4504](https://github.com/apache/arrow-rs/pull/4504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Bump actions/labeler from 4.2.0 to 4.3.0 [\#4501](https://github.com/apache/arrow-rs/pull/4501) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.63 to =1.0.64 [\#4500](https://github.com/apache/arrow-rs/pull/4500) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add negate kernels \(\#4488\) [\#4494](https://github.com/apache/arrow-rs/pull/4494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Datum Arithmetic tests, Fix Interval Substraction \(\#4480\) [\#4493](https://github.com/apache/arrow-rs/pull/4493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- support FixedSizeBinary types in eq\_dyn\_binary\_scalar/neq\_dyn\_binary\_scalar [\#4492](https://github.com/apache/arrow-rs/pull/4492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +- Add default implementations to the FlightSqlService trait [\#4485](https://github.com/apache/arrow-rs/pull/4485) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([rossjones](https://github.com/rossjones)) +- add num-complex requirement [\#4482](https://github.com/apache/arrow-rs/pull/4482) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mwlon](https://github.com/mwlon)) +- fix incorrect buffer size limiting in parquet async writer [\#4478](https://github.com/apache/arrow-rs/pull/4478) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([richox](https://github.com/richox)) +- feat: support RecordBatchReader on boxed trait objects [\#4475](https://github.com/apache/arrow-rs/pull/4475) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Improve in-place primitive sorts by 13-67% [\#4473](https://github.com/apache/arrow-rs/pull/4473) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Add Scalar/Datum abstraction \(\#1047\) [\#4393](https://github.com/apache/arrow-rs/pull/4393) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [43.0.0](https://github.com/apache/arrow-rs/tree/43.0.0) (2023-06-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/42.0.0...43.0.0) + +**Breaking changes:** + +- Simplify ffi import/export [\#4447](https://github.com/apache/arrow-rs/pull/4447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Virgiel](https://github.com/Virgiel)) +- Return Result from Parquet Row APIs [\#4428](https://github.com/apache/arrow-rs/pull/4428) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Remove Binary Dictionary Arithmetic Support [\#4407](https://github.com/apache/arrow-rs/pull/4407) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Request: a way to copy a `Row` to `Rows` [\#4466](https://github.com/apache/arrow-rs/issues/4466) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Reuse schema when importing from FFI [\#4444](https://github.com/apache/arrow-rs/issues/4444) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] Allow implementations of `FlightSqlService` to handle custom actions and commands [\#4439](https://github.com/apache/arrow-rs/issues/4439) +- Support `NullBuilder` [\#4429](https://github.com/apache/arrow-rs/issues/4429) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Regression in in parquet `42.0.0` : Bad parquet column indexes for All Null Columns, resulting in `Parquet error: StructArrayReader out of sync` on read [\#4459](https://github.com/apache/arrow-rs/issues/4459) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Regression in 42.0.0: Parsing fractional intervals without leading 0 is not supported [\#4424](https://github.com/apache/arrow-rs/issues/4424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- doc: deploy crate docs to GitHub pages [\#4436](https://github.com/apache/arrow-rs/pull/4436) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) + +**Merged pull requests:** + +- Append Row to Rows \(\#4466\) [\#4470](https://github.com/apache/arrow-rs/pull/4470) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat\(flight-sql\): Allow implementations of FlightSqlService to handle custom actions and commands [\#4463](https://github.com/apache/arrow-rs/pull/4463) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- Docs: Add clearer API doc links [\#4461](https://github.com/apache/arrow-rs/pull/4461) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Fix empty offset index for all null columns \(\#4459\) [\#4460](https://github.com/apache/arrow-rs/pull/4460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Bump peaceiris/actions-gh-pages from 3.9.2 to 3.9.3 [\#4455](https://github.com/apache/arrow-rs/pull/4455) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Convince the compiler to auto-vectorize the range check in parquet DictionaryBuffer [\#4453](https://github.com/apache/arrow-rs/pull/4453) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- fix docs deployment [\#4452](https://github.com/apache/arrow-rs/pull/4452) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xxchan](https://github.com/xxchan)) +- Update indexmap requirement from 1.9 to 2.0 [\#4451](https://github.com/apache/arrow-rs/pull/4451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update proc-macro2 requirement from =1.0.60 to =1.0.63 [\#4450](https://github.com/apache/arrow-rs/pull/4450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/deploy-pages from 1 to 2 [\#4449](https://github.com/apache/arrow-rs/pull/4449) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Revise error message in From\ for ScalarBuffer [\#4446](https://github.com/apache/arrow-rs/pull/4446) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- minor: remove useless mut [\#4443](https://github.com/apache/arrow-rs/pull/4443) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- unify substring for binary&utf8 [\#4442](https://github.com/apache/arrow-rs/pull/4442) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Casting fixedsizelist to list/largelist [\#4433](https://github.com/apache/arrow-rs/pull/4433) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jayzhan211](https://github.com/jayzhan211)) +- feat: support `NullBuilder` [\#4430](https://github.com/apache/arrow-rs/pull/4430) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Remove Float64 -\> Float32 cast in IPC Reader [\#4427](https://github.com/apache/arrow-rs/pull/4427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) +- Parse intervals like `.5` the same as `0.5` [\#4425](https://github.com/apache/arrow-rs/pull/4425) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add strict mode to json reader [\#4421](https://github.com/apache/arrow-rs/pull/4421) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([blinkseb](https://github.com/blinkseb)) +- Add DictionaryArray::occupancy [\#4415](https://github.com/apache/arrow-rs/pull/4415) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +## [42.0.0](https://github.com/apache/arrow-rs/tree/42.0.0) (2023-06-16) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/41.0.0...42.0.0) + +**Breaking changes:** + +- Remove 64-bit to 32-bit Cast from IPC Reader [\#4412](https://github.com/apache/arrow-rs/pull/4412) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) +- Truncate Min/Max values in the Column Index [\#4389](https://github.com/apache/arrow-rs/pull/4389) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AdamGS](https://github.com/AdamGS)) +- feat\(flight\): harmonize server metadata APIs [\#4384](https://github.com/apache/arrow-rs/pull/4384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Move record delimiting into ColumnReader \(\#4365\) [\#4376](https://github.com/apache/arrow-rs/pull/4376) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Changed array\_to\_json\_array to take &dyn Array [\#4370](https://github.com/apache/arrow-rs/pull/4370) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dadepo](https://github.com/dadepo)) +- Make PrimitiveArray::with\_timezone consuming [\#4366](https://github.com/apache/arrow-rs/pull/4366) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Add doc example of constructing a MapArray [\#4385](https://github.com/apache/arrow-rs/issues/4385) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `millisecond` and `microsecond` functions [\#4374](https://github.com/apache/arrow-rs/issues/4374) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Changed array\_to\_json\_array to take &dyn Array [\#4369](https://github.com/apache/arrow-rs/issues/4369) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- compute::ord kernel for getting min and max of two scalar/array values [\#4347](https://github.com/apache/arrow-rs/issues/4347) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release 41.0.0 of arrow/arrow-flight/parquet/parquet-derive [\#4346](https://github.com/apache/arrow-rs/issues/4346) +- Refactor CAST tests to use new cast array syntax [\#4336](https://github.com/apache/arrow-rs/issues/4336) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- pass bytes directly to parquet's KeyValue [\#4317](https://github.com/apache/arrow-rs/issues/4317) +- PyArrow conversions could return TypeError if provided incorrect Python type [\#4312](https://github.com/apache/arrow-rs/issues/4312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Have array\_to\_json\_array support Map [\#4297](https://github.com/apache/arrow-rs/issues/4297) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FlightSQL: Add helpers to create `CommandGetXdbcTypeInfo` responses \(`XdbcInfoValue` and builders\) [\#4257](https://github.com/apache/arrow-rs/issues/4257) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Have array\_to\_json\_array support FixedSizeList [\#4248](https://github.com/apache/arrow-rs/issues/4248) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Truncate ColumnIndex ByteArray Statistics [\#4126](https://github.com/apache/arrow-rs/issues/4126) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Arrow compute kernel regards selection vector [\#4095](https://github.com/apache/arrow-rs/issues/4095) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Wrongly calculated data compressed length in IPC writer [\#4410](https://github.com/apache/arrow-rs/issues/4410) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Take Kernel Handles Nullable Indices Incorrectly [\#4404](https://github.com/apache/arrow-rs/issues/4404) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- StructBuilder::new Doesn't Validate Builder DataTypes [\#4397](https://github.com/apache/arrow-rs/issues/4397) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet error: Not all children array length are the same! when using RowSelection to read a parquet file [\#4396](https://github.com/apache/arrow-rs/issues/4396) +- RecordReader::skip\_records Is Incorrect for Repeated Columns [\#4368](https://github.com/apache/arrow-rs/issues/4368) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- List-of-String Array panics in the presence of row filters [\#4365](https://github.com/apache/arrow-rs/issues/4365) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Fail to read block compressed gzip files with parquet-fromcsv [\#4173](https://github.com/apache/arrow-rs/issues/4173) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Closed issues:** + +- Have a parquet file not able to be deduped via arrow-rs, complains about Decimal precision? [\#4356](https://github.com/apache/arrow-rs/issues/4356) +- Question: Could we move `dict_id, dict_is_ordered` into DataType? [\#4325](https://github.com/apache/arrow-rs/issues/4325) + +**Merged pull requests:** + +- Fix reading gzip file with multiple gzip headers in parquet-fromcsv. [\#4419](https://github.com/apache/arrow-rs/pull/4419) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ghuls](https://github.com/ghuls)) +- Cleanup nullif kernel [\#4416](https://github.com/apache/arrow-rs/pull/4416) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix bug in IPC logic that determines if the buffer should be compressed or not [\#4411](https://github.com/apache/arrow-rs/pull/4411) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lwpyr](https://github.com/lwpyr)) +- Faster unpacking of Int32Type dictionary [\#4406](https://github.com/apache/arrow-rs/pull/4406) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve `take` kernel performance on primitive arrays, fix bad null index handling \(\#4404\) [\#4405](https://github.com/apache/arrow-rs/pull/4405) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- More take benchmarks [\#4403](https://github.com/apache/arrow-rs/pull/4403) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `BooleanBuffer::new_unset` and `BooleanBuffer::new_set` and `BooleanArray::new_null` constructors [\#4402](https://github.com/apache/arrow-rs/pull/4402) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add PrimitiveBuilder type constructors [\#4401](https://github.com/apache/arrow-rs/pull/4401) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- StructBuilder Validate Child Data \(\#4397\) [\#4400](https://github.com/apache/arrow-rs/pull/4400) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster UTF-8 truncation [\#4399](https://github.com/apache/arrow-rs/pull/4399) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Minor: Derive `Hash` impls for `CastOptions` and `FormatOptions` [\#4395](https://github.com/apache/arrow-rs/pull/4395) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix typo in README [\#4394](https://github.com/apache/arrow-rs/pull/4394) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([okue](https://github.com/okue)) +- Improve parquet `WriterProperites` and `ReaderProperties` docs [\#4392](https://github.com/apache/arrow-rs/pull/4392) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Cleanup downcast macros [\#4391](https://github.com/apache/arrow-rs/pull/4391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.59 to =1.0.60 [\#4388](https://github.com/apache/arrow-rs/pull/4388) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Consolidate ByteArray::from\_iterator [\#4386](https://github.com/apache/arrow-rs/pull/4386) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add MapArray constructors and doc example [\#4382](https://github.com/apache/arrow-rs/pull/4382) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Documentation Improvements [\#4381](https://github.com/apache/arrow-rs/pull/4381) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add NullBuffer and BooleanBuffer From conversions [\#4380](https://github.com/apache/arrow-rs/pull/4380) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add more examples of constructing Boolean, Primitive, String, and Decimal Arrays, and From impl for i256 [\#4379](https://github.com/apache/arrow-rs/pull/4379) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add ListArrayReader benchmarks [\#4378](https://github.com/apache/arrow-rs/pull/4378) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update comfy-table requirement from 6.0 to 7.0 [\#4377](https://github.com/apache/arrow-rs/pull/4377) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: Add`microsecond` and `millisecond` kernels [\#4375](https://github.com/apache/arrow-rs/pull/4375) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Update hashbrown requirement from 0.13 to 0.14 [\#4373](https://github.com/apache/arrow-rs/pull/4373) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- minor: use as\_boolean to resolve TODO [\#4367](https://github.com/apache/arrow-rs/pull/4367) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Have array\_to\_json\_array support MapArray [\#4364](https://github.com/apache/arrow-rs/pull/4364) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dadepo](https://github.com/dadepo)) +- deprecate: as\_decimal\_array [\#4363](https://github.com/apache/arrow-rs/pull/4363) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add support for FixedSizeList in array\_to\_json\_array [\#4361](https://github.com/apache/arrow-rs/pull/4361) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dadepo](https://github.com/dadepo)) +- refact: use as\_primitive in cast.rs test [\#4360](https://github.com/apache/arrow-rs/pull/4360) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- feat\(flight\): add xdbc type info helpers [\#4359](https://github.com/apache/arrow-rs/pull/4359) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Minor: float16 to json [\#4358](https://github.com/apache/arrow-rs/pull/4358) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Raise TypeError on PyArrow import [\#4316](https://github.com/apache/arrow-rs/pull/4316) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Arrow Cast: Fixed Point Arithmetic for Interval Parsing [\#4291](https://github.com/apache/arrow-rs/pull/4291) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mr-brobot](https://github.com/mr-brobot)) +## [41.0.0](https://github.com/apache/arrow-rs/tree/41.0.0) (2023-06-02) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/40.0.0...41.0.0) + +**Breaking changes:** + +- Rename list contains kernels to in\_list \(\#4289\) [\#4342](https://github.com/apache/arrow-rs/pull/4342) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move BooleanBufferBuilder and NullBufferBuilder to arrow\_buffer [\#4338](https://github.com/apache/arrow-rs/pull/4338) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add separate row\_count and level\_count to PageMetadata \(\#4321\) [\#4326](https://github.com/apache/arrow-rs/pull/4326) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Treat legacy TIMSETAMP\_X converted types as UTC [\#4309](https://github.com/apache/arrow-rs/pull/4309) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sergiimk](https://github.com/sergiimk)) +- Simplify parquet PageIterator [\#4306](https://github.com/apache/arrow-rs/pull/4306) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add Builder style APIs and docs for `FlightData`,` FlightInfo`, `FlightEndpoint`, `Locaation` and `Ticket` [\#4294](https://github.com/apache/arrow-rs/pull/4294) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Make GenericColumnWriter Send [\#4287](https://github.com/apache/arrow-rs/pull/4287) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- feat: update flight-sql to latest specs [\#4250](https://github.com/apache/arrow-rs/pull/4250) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- feat\(api!\): make ArrowArrayStreamReader Send [\#4232](https://github.com/apache/arrow-rs/pull/4232) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) + +**Implemented enhancements:** + +- Make SerializedRowGroupReader::new\(\) Public [\#4330](https://github.com/apache/arrow-rs/issues/4330) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Speed up i256 division and remainder operations [\#4302](https://github.com/apache/arrow-rs/issues/4302) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- export function parquet\_to\_array\_schema\_and\_fields [\#4298](https://github.com/apache/arrow-rs/issues/4298) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- FLightSQL: add helpers to create `CommandGetCatalogs`, `CommandGetSchemas`, and `CommandGetTables` requests [\#4295](https://github.com/apache/arrow-rs/issues/4295) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Make ColumnWriter Send [\#4286](https://github.com/apache/arrow-rs/issues/4286) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add Builder for `FlightInfo` to make it easier to create new requests [\#4281](https://github.com/apache/arrow-rs/issues/4281) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support Writing/Reading Decimal256 to/from Parquet [\#4264](https://github.com/apache/arrow-rs/issues/4264) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- FlightSQL: Add helpers to create `CommandGetSqlInfo` responses \(`SqlInfoValue` and builders\) [\#4256](https://github.com/apache/arrow-rs/issues/4256) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Update flight-sql implementation to latest specs [\#4249](https://github.com/apache/arrow-rs/issues/4249) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Make ArrowArrayStreamReader Send [\#4222](https://github.com/apache/arrow-rs/issues/4222) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support writing FixedSizeList to Parquet [\#4214](https://github.com/apache/arrow-rs/issues/4214) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Cast between `Intervals` [\#4181](https://github.com/apache/arrow-rs/issues/4181) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Splice Parquet Data [\#4155](https://github.com/apache/arrow-rs/issues/4155) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV Schema More Flexible Timestamp Inference [\#4131](https://github.com/apache/arrow-rs/issues/4131) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Doc for arrow\_flight::sql is missing enums that are Xdbc related [\#4339](https://github.com/apache/arrow-rs/issues/4339) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- concat\_batches panics with total\_len \<= bit\_len assertion for records with lists [\#4324](https://github.com/apache/arrow-rs/issues/4324) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect PageMetadata Row Count returned for V1 DataPage [\#4321](https://github.com/apache/arrow-rs/issues/4321) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[parquet\] Not following the spec for TIMESTAMP\_MILLIS legacy converted types [\#4308](https://github.com/apache/arrow-rs/issues/4308) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- ambiguous glob re-exports of contains\_utf8 [\#4289](https://github.com/apache/arrow-rs/issues/4289) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- flight\_sql\_client --header "key: value" yields a value with a leading whitespace [\#4270](https://github.com/apache/arrow-rs/issues/4270) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Casting Timestamp to date is off by one day for dates before 1970-01-01 [\#4211](https://github.com/apache/arrow-rs/issues/4211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Don't infer 16-byte decimal as decimal256 [\#4349](https://github.com/apache/arrow-rs/pull/4349) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix MutableArrayData::extend\_nulls \(\#1230\) [\#4343](https://github.com/apache/arrow-rs/pull/4343) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update FlightSQL metadata locations, names and docs [\#4341](https://github.com/apache/arrow-rs/pull/4341) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- chore: expose Xdbc related FlightSQL enums [\#4340](https://github.com/apache/arrow-rs/pull/4340) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([appletreeisyellow](https://github.com/appletreeisyellow)) +- Update pyo3 requirement from 0.18 to 0.19 [\#4335](https://github.com/apache/arrow-rs/pull/4335) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Skip unnecessary null checks in MutableArrayData [\#4333](https://github.com/apache/arrow-rs/pull/4333) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add read parquet by custom rowgroup examples [\#4332](https://github.com/apache/arrow-rs/pull/4332) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sundy-li](https://github.com/sundy-li)) +- Make SerializedRowGroupReader::new\(\) public [\#4331](https://github.com/apache/arrow-rs/pull/4331) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([burmecia](https://github.com/burmecia)) +- Don't split record across pages \(\#3680\) [\#4327](https://github.com/apache/arrow-rs/pull/4327) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- fix date conversion if timestamp below unixtimestamp [\#4323](https://github.com/apache/arrow-rs/pull/4323) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Short-circuit on exhausted page in skip\_records [\#4320](https://github.com/apache/arrow-rs/pull/4320) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Handle trailing padding when skipping repetition levels \(\#3911\) [\#4319](https://github.com/apache/arrow-rs/pull/4319) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use `page_size` consistently, deprecate `pagesize` in parquet WriterProperties [\#4313](https://github.com/apache/arrow-rs/pull/4313) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add roundtrip tests for Decimal256 and fix issues \(\#4264\) [\#4311](https://github.com/apache/arrow-rs/pull/4311) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Expose page-level arrow reader API \(\#4298\) [\#4307](https://github.com/apache/arrow-rs/pull/4307) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Speed up i256 division and remainder operations [\#4303](https://github.com/apache/arrow-rs/pull/4303) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat\(flight\): support int32\_to\_int32\_list\_map in sql infos [\#4300](https://github.com/apache/arrow-rs/pull/4300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- feat\(flight\): add helpers to handle `CommandGetCatalogs`, `CommandGetSchemas`, and `CommandGetTables` requests [\#4296](https://github.com/apache/arrow-rs/pull/4296) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Improve docs and tests for `SqlInfoList [\#4293](https://github.com/apache/arrow-rs/pull/4293) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- minor: fix arrow\_row docs.rs links [\#4292](https://github.com/apache/arrow-rs/pull/4292) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([roeap](https://github.com/roeap)) +- Update proc-macro2 requirement from =1.0.58 to =1.0.59 [\#4290](https://github.com/apache/arrow-rs/pull/4290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Improve `ArrowWriter` memory usage: Buffer Pages in ArrowWriter instead of RecordBatch \(\#3871\) [\#4280](https://github.com/apache/arrow-rs/pull/4280) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Minor: Add more docstrings in arrow-flight [\#4279](https://github.com/apache/arrow-rs/pull/4279) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Add `Debug` impls for `ArrowWriter` and `SerializedFileWriter` [\#4278](https://github.com/apache/arrow-rs/pull/4278) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Expose `RecordBatchWriter` to `arrow` crate [\#4277](https://github.com/apache/arrow-rs/pull/4277) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Update criterion requirement from 0.4 to 0.5 [\#4275](https://github.com/apache/arrow-rs/pull/4275) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add parquet-concat [\#4274](https://github.com/apache/arrow-rs/pull/4274) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Convert FixedSizeListArray to GenericListArray [\#4273](https://github.com/apache/arrow-rs/pull/4273) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: support 'Decimal256' for parquet [\#4272](https://github.com/apache/arrow-rs/pull/4272) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- Strip leading whitespace from flight\_sql\_client custom header values [\#4271](https://github.com/apache/arrow-rs/pull/4271) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mkmik](https://github.com/mkmik)) +- Add Append Column API \(\#4155\) [\#4269](https://github.com/apache/arrow-rs/pull/4269) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Derive Default for WriterProperties [\#4268](https://github.com/apache/arrow-rs/pull/4268) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Parquet Reader/writer for fixed-size list arrays [\#4267](https://github.com/apache/arrow-rs/pull/4267) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dexterduck](https://github.com/dexterduck)) +- feat\(flight\): add sql-info helpers [\#4266](https://github.com/apache/arrow-rs/pull/4266) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([roeap](https://github.com/roeap)) +- Convert parquet metadata back to builders [\#4265](https://github.com/apache/arrow-rs/pull/4265) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add constructors for FixedSize array types \(\#3879\) [\#4263](https://github.com/apache/arrow-rs/pull/4263) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Extract IPC ArrayReader struct [\#4259](https://github.com/apache/arrow-rs/pull/4259) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update object\_store requirement from 0.5 to 0.6 [\#4258](https://github.com/apache/arrow-rs/pull/4258) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support Absolute Timestamps in CSV Schema Inference \(\#4131\) [\#4217](https://github.com/apache/arrow-rs/pull/4217) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: cast between `Intervals` [\#4182](https://github.com/apache/arrow-rs/pull/4182) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +## [40.0.0](https://github.com/apache/arrow-rs/tree/40.0.0) (2023-05-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/39.0.0...40.0.0) + +**Breaking changes:** + +- Prefetch page index \(\#4090\) [\#4216](https://github.com/apache/arrow-rs/pull/4216) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add RecordBatchWriter trait and implement it for CSV, JSON, IPC and P… [\#4206](https://github.com/apache/arrow-rs/pull/4206) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Remove powf\_scalar kernel [\#4187](https://github.com/apache/arrow-rs/pull/4187) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow format specification in cast [\#4169](https://github.com/apache/arrow-rs/pull/4169) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([parthchandra](https://github.com/parthchandra)) + +**Implemented enhancements:** + +- ObjectStore with\_url Should Handle Path [\#4199](https://github.com/apache/arrow-rs/issues/4199) +- Support `Interval` +/- `Interval` [\#4178](https://github.com/apache/arrow-rs/issues/4178) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[parquet\] add compression info to `print_column_chunk_metadata()` [\#4172](https://github.com/apache/arrow-rs/issues/4172) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Allow cast to take in a format specification [\#4168](https://github.com/apache/arrow-rs/issues/4168) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support extended pow arithmetic [\#4166](https://github.com/apache/arrow-rs/issues/4166) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Preload page index for async ParquetObjectReader [\#4090](https://github.com/apache/arrow-rs/issues/4090) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Subtracting `Timestamp` from `Timestamp` should produce a `Duration` \(not `Timestamp`\) [\#3964](https://github.com/apache/arrow-rs/issues/3964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Arrow Arithmetic: Subtract timestamps [\#4244](https://github.com/apache/arrow-rs/pull/4244) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mr-brobot](https://github.com/mr-brobot)) +- Update proc-macro2 requirement from =1.0.57 to =1.0.58 [\#4236](https://github.com/apache/arrow-rs/pull/4236) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix Nightly Clippy Lints [\#4233](https://github.com/apache/arrow-rs/pull/4233) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: use all primitive types in test\_layouts [\#4229](https://github.com/apache/arrow-rs/pull/4229) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add close method to RecordBatchWriter trait [\#4228](https://github.com/apache/arrow-rs/pull/4228) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Update proc-macro2 requirement from =1.0.56 to =1.0.57 [\#4219](https://github.com/apache/arrow-rs/pull/4219) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Feat docs [\#4215](https://github.com/apache/arrow-rs/pull/4215) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Folyd](https://github.com/Folyd)) +- feat: Support bitwise and boolean aggregate functions [\#4210](https://github.com/apache/arrow-rs/pull/4210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Document how to sort a RecordBatch [\#4204](https://github.com/apache/arrow-rs/pull/4204) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix incorrect cast Timestamp with Timezone [\#4201](https://github.com/apache/arrow-rs/pull/4201) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([aprimadi](https://github.com/aprimadi)) +- Add implementation of `RecordBatchReader` for CSV reader [\#4195](https://github.com/apache/arrow-rs/pull/4195) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexandreyc](https://github.com/alexandreyc)) +- Add Sliced ListArray test \(\#3748\) [\#4186](https://github.com/apache/arrow-rs/pull/4186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- refactor: simplify can\_cast\_types code. [\#4185](https://github.com/apache/arrow-rs/pull/4185) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Minor: support new types in struct\_builder.rs [\#4177](https://github.com/apache/arrow-rs/pull/4177) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- feat: add compression info to print\_column\_chunk\_metadata\(\) [\#4176](https://github.com/apache/arrow-rs/pull/4176) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([SteveLauC](https://github.com/SteveLauC)) +## [39.0.0](https://github.com/apache/arrow-rs/tree/39.0.0) (2023-05-05) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/38.0.0...39.0.0) + +**Breaking changes:** + +- Allow creating unbuffered streamreader [\#4165](https://github.com/apache/arrow-rs/pull/4165) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ming08108](https://github.com/ming08108)) +- Cleanup ChunkReader \(\#4118\) [\#4156](https://github.com/apache/arrow-rs/pull/4156) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove Type from NativeIndex [\#4146](https://github.com/apache/arrow-rs/pull/4146) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Don't Duplicate Offset Index on RowGroupMetadata [\#4142](https://github.com/apache/arrow-rs/pull/4142) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Return BooleanBuffer from BooleanBufferBuilder [\#4140](https://github.com/apache/arrow-rs/pull/4140) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup CSV schema inference \(\#4129\) \(\#4130\) [\#4133](https://github.com/apache/arrow-rs/pull/4133) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove deprecated parquet ArrowReader [\#4125](https://github.com/apache/arrow-rs/pull/4125) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- refactor: construct `StructArray` w/ `FieldRef` [\#4116](https://github.com/apache/arrow-rs/pull/4116) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Ignore Field Metadata in equals\_datatype for Dictionary, RunEndEncoded, Map and Union [\#4111](https://github.com/apache/arrow-rs/pull/4111) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add StructArray Constructors \(\#3879\) [\#4064](https://github.com/apache/arrow-rs/pull/4064) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Release 39.0.0 of arrow/arrow-flight/parquet/parquet-derive \(next release after 38.0.0\) [\#4170](https://github.com/apache/arrow-rs/issues/4170) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Fixed point decimal multiplication for DictionaryArray [\#4135](https://github.com/apache/arrow-rs/issues/4135) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove Seek Requirement from CSV ReaderBuilder [\#4130](https://github.com/apache/arrow-rs/issues/4130) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Inconsistent CSV Inference and Parsing DateTime Handling [\#4129](https://github.com/apache/arrow-rs/issues/4129) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support accessing ipc Reader/Writer inner by reference [\#4121](https://github.com/apache/arrow-rs/issues/4121) +- Add Type Declarations for All Primitive Tensors and Buffer Builders [\#4112](https://github.com/apache/arrow-rs/issues/4112) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Interval + Timestamp` and `Interval + Date` in addition to `Timestamp + Interval` and `Interval + Date` [\#4094](https://github.com/apache/arrow-rs/issues/4094) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Enable setting FlightDescriptor on FlightDataEncoderBuilder [\#3855](https://github.com/apache/arrow-rs/issues/3855) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Parquet Page Index Reader Assumes Consecutive Offsets [\#4149](https://github.com/apache/arrow-rs/issues/4149) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Equality of nested data types [\#4110](https://github.com/apache/arrow-rs/issues/4110) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Improve Documentation of Parquet ChunkReader [\#4118](https://github.com/apache/arrow-rs/issues/4118) + +**Closed issues:** + +- add specific error log for empty JSON array [\#4105](https://github.com/apache/arrow-rs/issues/4105) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Prep for 39.0.0 [\#4171](https://github.com/apache/arrow-rs/pull/4171) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Support Compression in parquet-fromcsv [\#4160](https://github.com/apache/arrow-rs/pull/4160) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([suxiaogang223](https://github.com/suxiaogang223)) +- feat: support bitwise shift left/right with scalars [\#4159](https://github.com/apache/arrow-rs/pull/4159) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Cleanup reading page index \(\#4149\) \(\#4090\) [\#4151](https://github.com/apache/arrow-rs/pull/4151) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- feat: support `bitwise` shift left/right [\#4148](https://github.com/apache/arrow-rs/pull/4148) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Don't hardcode port in FlightSQL tests [\#4145](https://github.com/apache/arrow-rs/pull/4145) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Better flight SQL example codes [\#4144](https://github.com/apache/arrow-rs/pull/4144) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sundy-li](https://github.com/sundy-li)) +- chore: clean the code by using `as_primitive` [\#4143](https://github.com/apache/arrow-rs/pull/4143) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- docs: fix the wrong ln command in CONTRIBUTING.md [\#4139](https://github.com/apache/arrow-rs/pull/4139) ([SteveLauC](https://github.com/SteveLauC)) +- Infer Float64 for JSON Numerics Beyond Bounds of i64 [\#4138](https://github.com/apache/arrow-rs/pull/4138) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([SteveLauC](https://github.com/SteveLauC)) +- Support fixed point multiplication for DictionaryArray of Decimals [\#4136](https://github.com/apache/arrow-rs/pull/4136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make arrow\_json::ReaderBuilder method names consistent [\#4128](https://github.com/apache/arrow-rs/pull/4128) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add get\_{ref, mut} to arrow\_ipc Reader and Writer [\#4122](https://github.com/apache/arrow-rs/pull/4122) ([sticnarf](https://github.com/sticnarf)) +- feat: support `Interval` + `Timestamp` and `Interval` + `Date` [\#4117](https://github.com/apache/arrow-rs/pull/4117) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Support NullArray in JSON Reader [\#4114](https://github.com/apache/arrow-rs/pull/4114) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jiangzhx](https://github.com/jiangzhx)) +- Add Type Declarations for All Primitive Tensors and Buffer Builders [\#4113](https://github.com/apache/arrow-rs/pull/4113) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Update regex-syntax requirement from 0.6.27 to 0.7.1 [\#4107](https://github.com/apache/arrow-rs/pull/4107) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: set FlightDescriptor on FlightDataEncoderBuilder [\#4101](https://github.com/apache/arrow-rs/pull/4101) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Weijun-H](https://github.com/Weijun-H)) +- optimize cast for same decimal type and same scale [\#4088](https://github.com/apache/arrow-rs/pull/4088) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) + +## [38.0.0](https://github.com/apache/arrow-rs/tree/38.0.0) (2023-04-21) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/37.0.0...38.0.0) + +**Breaking changes:** + +- Remove DataType from PrimitiveArray constructors [\#4098](https://github.com/apache/arrow-rs/pull/4098) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use Into\\> for PrimitiveArray::with\_timezone [\#4097](https://github.com/apache/arrow-rs/pull/4097) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Store StructArray entries in MapArray [\#4085](https://github.com/apache/arrow-rs/pull/4085) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add DictionaryArray Constructors \(\#3879\) [\#4068](https://github.com/apache/arrow-rs/pull/4068) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Relax JSON schema inference generics [\#4063](https://github.com/apache/arrow-rs/pull/4063) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove ArrayData from Array \(\#3880\) [\#4061](https://github.com/apache/arrow-rs/pull/4061) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add CommandGetXdbcTypeInfo to Flight SQL Server [\#4055](https://github.com/apache/arrow-rs/pull/4055) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([c-thiel](https://github.com/c-thiel)) +- Remove old JSON Reader and Decoder \(\#3610\) [\#4052](https://github.com/apache/arrow-rs/pull/4052) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use BufRead for JSON Schema Inference [\#4041](https://github.com/apache/arrow-rs/pull/4041) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([WenyXu](https://github.com/WenyXu)) + +**Implemented enhancements:** + +- Support dyn\_compare\_scalar for Decimal256 [\#4083](https://github.com/apache/arrow-rs/issues/4083) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Better JSON Reader Error Messages [\#4076](https://github.com/apache/arrow-rs/issues/4076) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Additional data type groups [\#4056](https://github.com/apache/arrow-rs/issues/4056) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Async JSON reader [\#4043](https://github.com/apache/arrow-rs/issues/4043) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Field::contains Should Recurse into DataType [\#4029](https://github.com/apache/arrow-rs/issues/4029) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Prevent UnionArray with Repeated Type IDs [\#3982](https://github.com/apache/arrow-rs/issues/3982) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Timestamp` `+`/`-` `Interval` types [\#3963](https://github.com/apache/arrow-rs/issues/3963) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- First-Class Array Abstractions [\#3880](https://github.com/apache/arrow-rs/issues/3880) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Update readme to remove reference to Jira [\#4091](https://github.com/apache/arrow-rs/issues/4091) +- OffsetBuffer::new Rejects 0 Offsets [\#4066](https://github.com/apache/arrow-rs/issues/4066) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet AsyncArrowWriter not shutting down inner async writer. [\#4058](https://github.com/apache/arrow-rs/issues/4058) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Flight SQL Server missing command type.googleapis.com/arrow.flight.protocol.sql.CommandGetXdbcTypeInfo [\#4054](https://github.com/apache/arrow-rs/issues/4054) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- RawJsonReader Errors with Empty Schema [\#4053](https://github.com/apache/arrow-rs/issues/4053) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RawJsonReader Integer Truncation [\#4049](https://github.com/apache/arrow-rs/issues/4049) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Sparse UnionArray Equality Incorrect Offset Handling [\#4044](https://github.com/apache/arrow-rs/issues/4044) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Write blog about improvements in JSON and CSV processing [\#4062](https://github.com/apache/arrow-rs/issues/4062) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Parquet reader of Int96 columns and coercion to timestamps [\#4075](https://github.com/apache/arrow-rs/issues/4075) +- Serializing timestamp from int \(json raw decoder\) [\#4069](https://github.com/apache/arrow-rs/issues/4069) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting to/from Interval and Duration [\#3998](https://github.com/apache/arrow-rs/issues/3998) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix Docs Typos [\#4100](https://github.com/apache/arrow-rs/pull/4100) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rnarkk](https://github.com/rnarkk)) +- Update tonic-build requirement from =0.9.1 to =0.9.2 [\#4099](https://github.com/apache/arrow-rs/pull/4099) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Increase minimum chrono version to 0.4.24 [\#4093](https://github.com/apache/arrow-rs/pull/4093) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Simplify reference to GitHub issues [\#4092](https://github.com/apache/arrow-rs/pull/4092) ([bkmgit](https://github.com/bkmgit)) +- \[Minor\]: Add `Hash` trait to SortOptions. [\#4089](https://github.com/apache/arrow-rs/pull/4089) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mustafasrepo](https://github.com/mustafasrepo)) +- Include byte offsets in parquet-layout [\#4086](https://github.com/apache/arrow-rs/pull/4086) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- feat: Support dyn\_compare\_scalar for Decimal256 [\#4084](https://github.com/apache/arrow-rs/pull/4084) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add ByteArray constructors \(\#3879\) [\#4081](https://github.com/apache/arrow-rs/pull/4081) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.11.8 to =0.11.9 [\#4080](https://github.com/apache/arrow-rs/pull/4080) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Improve JSON decoder errors \(\#4076\) [\#4079](https://github.com/apache/arrow-rs/pull/4079) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix Timestamp Numeric Truncation in JSON Reader [\#4074](https://github.com/apache/arrow-rs/pull/4074) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Serialize numeric to tape \(\#4069\) [\#4073](https://github.com/apache/arrow-rs/pull/4073) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Prevent UnionArray with Repeated Type IDs [\#4070](https://github.com/apache/arrow-rs/pull/4070) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Add PrimitiveArray::try\_new \(\#3879\) [\#4067](https://github.com/apache/arrow-rs/pull/4067) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add ListArray Constructors \(\#3879\) [\#4065](https://github.com/apache/arrow-rs/pull/4065) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Shutdown parquet async writer [\#4059](https://github.com/apache/arrow-rs/pull/4059) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kindly](https://github.com/kindly)) +- feat: additional data type groups [\#4057](https://github.com/apache/arrow-rs/pull/4057) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Fix precision loss in Raw JSON decoder \(\#4049\) [\#4051](https://github.com/apache/arrow-rs/pull/4051) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use lexical\_core in CSV and JSON parser \(~25% faster\) [\#4050](https://github.com/apache/arrow-rs/pull/4050) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add offsets accessors to variable length arrays \(\#3879\) [\#4048](https://github.com/apache/arrow-rs/pull/4048) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document Async decoder usage \(\#4043\) \(\#78\) [\#4046](https://github.com/apache/arrow-rs/pull/4046) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix sparse union array equality \(\#4044\) [\#4045](https://github.com/apache/arrow-rs/pull/4045) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: DataType::contains support nested type [\#4042](https://github.com/apache/arrow-rs/pull/4042) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- feat: Support Timestamp +/- Interval types [\#4038](https://github.com/apache/arrow-rs/pull/4038) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Fix object\_store CI [\#4037](https://github.com/apache/arrow-rs/pull/4037) ([tustvold](https://github.com/tustvold)) +- feat: cast from/to interval and duration [\#4020](https://github.com/apache/arrow-rs/pull/4020) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) + +## [37.0.0](https://github.com/apache/arrow-rs/tree/37.0.0) (2023-04-07) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/36.0.0...37.0.0) + +**Breaking changes:** + +- Fix timestamp handling in cast kernel \(\#1936\) \(\#4033\) [\#4034](https://github.com/apache/arrow-rs/pull/4034) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update tonic 0.9.1 [\#4011](https://github.com/apache/arrow-rs/pull/4011) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Use FieldRef in DataType \(\#3955\) [\#3983](https://github.com/apache/arrow-rs/pull/3983) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Store Timezone as Arc\ [\#3976](https://github.com/apache/arrow-rs/pull/3976) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Panic instead of discarding nulls converting StructArray to RecordBatch - \(\#3951\) [\#3953](https://github.com/apache/arrow-rs/pull/3953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix\(flight\_sql\): PreparedStatement has no token for auth. [\#3948](https://github.com/apache/arrow-rs/pull/3948) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([youngsofun](https://github.com/youngsofun)) +- Add Strongly Typed Array Slice \(\#3929\) [\#3930](https://github.com/apache/arrow-rs/pull/3930) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Zero-Copy Conversion between Vec and MutableBuffer [\#3920](https://github.com/apache/arrow-rs/pull/3920) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Support Decimals cast to Utf8/LargeUtf [\#3991](https://github.com/apache/arrow-rs/issues/3991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Date32/Date64 minus Interval [\#3962](https://github.com/apache/arrow-rs/issues/3962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Reduce Cloning of Field [\#3955](https://github.com/apache/arrow-rs/issues/3955) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support Deserializing Serde DataTypes to Arrow [\#3949](https://github.com/apache/arrow-rs/issues/3949) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add multiply\_fixed\_point [\#3946](https://github.com/apache/arrow-rs/issues/3946) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Strongly Typed Array Slicing [\#3929](https://github.com/apache/arrow-rs/issues/3929) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make it easier to match FlightSQL messages [\#3874](https://github.com/apache/arrow-rs/issues/3874) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support Casting Between Binary / LargeBinary and FixedSizeBinary [\#3826](https://github.com/apache/arrow-rs/issues/3826) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Incorrect Overflow Casting String to Timestamp [\#4033](https://github.com/apache/arrow-rs/issues/4033) +- f16::ZERO and f16::ONE are mixed up [\#4016](https://github.com/apache/arrow-rs/issues/4016) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Handle overflow precision when casting from integer to decimal [\#3995](https://github.com/apache/arrow-rs/issues/3995) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PrimitiveDictionaryBuilder.finish should use actual value type [\#3971](https://github.com/apache/arrow-rs/issues/3971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RecordBatch From StructArray Silently Discards Nulls [\#3952](https://github.com/apache/arrow-rs/issues/3952) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- I256 Checked Subtraction Overflows for i256::MINUS\_ONE [\#3942](https://github.com/apache/arrow-rs/issues/3942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- I256 Checked Multiply Overflows for i256::MIN [\#3941](https://github.com/apache/arrow-rs/issues/3941) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Remove non-existent `js` feature from README [\#4000](https://github.com/apache/arrow-rs/issues/4000) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support take on MapArray [\#3875](https://github.com/apache/arrow-rs/issues/3875) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Prep for 37.0.0 [\#4031](https://github.com/apache/arrow-rs/pull/4031) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Add RecordBatch::with\_schema [\#4028](https://github.com/apache/arrow-rs/pull/4028) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Only require compatible batch schema in ArrowWriter [\#4027](https://github.com/apache/arrow-rs/pull/4027) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add Fields::contains [\#4026](https://github.com/apache/arrow-rs/pull/4026) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: add methods "is\_positive" and "signum" to i256 [\#4024](https://github.com/apache/arrow-rs/pull/4024) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Deprecate Array::data \(\#3880\) [\#4019](https://github.com/apache/arrow-rs/pull/4019) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add tests for ArrowNativeTypeOp [\#4018](https://github.com/apache/arrow-rs/pull/4018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- fix: f16::ZERO and f16::ONE are mixed up [\#4017](https://github.com/apache/arrow-rs/pull/4017) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Minor: Float16Tensor [\#4013](https://github.com/apache/arrow-rs/pull/4013) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add FlightSQL module docs and links to `arrow-flight` crates [\#4012](https://github.com/apache/arrow-rs/pull/4012) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Update proc-macro2 requirement from =1.0.54 to =1.0.56 [\#4008](https://github.com/apache/arrow-rs/pull/4008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Cleanup Primitive take [\#4006](https://github.com/apache/arrow-rs/pull/4006) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate combine\_option\_bitmap [\#4005](https://github.com/apache/arrow-rs/pull/4005) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: add tests for BooleanBuffer [\#4004](https://github.com/apache/arrow-rs/pull/4004) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- feat: support to read/write customized metadata in ipc files [\#4003](https://github.com/apache/arrow-rs/pull/4003) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([framlog](https://github.com/framlog)) +- Cleanup more uses of Array::data \(\#3880\) [\#4002](https://github.com/apache/arrow-rs/pull/4002) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove js feature from README [\#4001](https://github.com/apache/arrow-rs/pull/4001) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([akazukin5151](https://github.com/akazukin5151)) +- feat: add the implementation BitXor to BooleanBuffer [\#3997](https://github.com/apache/arrow-rs/pull/3997) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Handle precision overflow when casting from integer to decimal [\#3996](https://github.com/apache/arrow-rs/pull/3996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support CAST from Decimal datatype to String [\#3994](https://github.com/apache/arrow-rs/pull/3994) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Add Field Constructors for Complex Fields [\#3992](https://github.com/apache/arrow-rs/pull/3992) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- fix: remove unused type parameters. [\#3986](https://github.com/apache/arrow-rs/pull/3986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([youngsofun](https://github.com/youngsofun)) +- Add UnionFields \(\#3955\) [\#3981](https://github.com/apache/arrow-rs/pull/3981) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup Fields Serde [\#3980](https://github.com/apache/arrow-rs/pull/3980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support Rust structures --\> `RecordBatch` by adding `Serde` support to `RawDecoder` \(\#3949\) [\#3979](https://github.com/apache/arrow-rs/pull/3979) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Convert string\_to\_timestamp\_nanos to doctest [\#3978](https://github.com/apache/arrow-rs/pull/3978) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix documentation of string\_to\_timestamp\_nanos [\#3977](https://github.com/apache/arrow-rs/pull/3977) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([byteink](https://github.com/byteink)) +- add Date32/Date64 support to subtract\_dyn [\#3974](https://github.com/apache/arrow-rs/pull/3974) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([SinanGncgl](https://github.com/SinanGncgl)) +- PrimitiveDictionaryBuilder.finish should use actual value type [\#3972](https://github.com/apache/arrow-rs/pull/3972) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update proc-macro2 requirement from =1.0.53 to =1.0.54 [\#3968](https://github.com/apache/arrow-rs/pull/3968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Async writer tweaks [\#3967](https://github.com/apache/arrow-rs/pull/3967) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix reading ipc files with unordered projections [\#3966](https://github.com/apache/arrow-rs/pull/3966) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([framlog](https://github.com/framlog)) +- Add Fields abstraction \(\#3955\) [\#3965](https://github.com/apache/arrow-rs/pull/3965) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- feat: cast between `Binary`/`LargeBinary` and `FixedSizeBinary` [\#3961](https://github.com/apache/arrow-rs/pull/3961) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- feat: support async writer \(\#1269\) [\#3957](https://github.com/apache/arrow-rs/pull/3957) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ShiKaiWi](https://github.com/ShiKaiWi)) +- Add ListBuilder::append\_value \(\#3949\) [\#3954](https://github.com/apache/arrow-rs/pull/3954) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve array builder documentation \(\#3949\) [\#3951](https://github.com/apache/arrow-rs/pull/3951) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster i256 parsing [\#3950](https://github.com/apache/arrow-rs/pull/3950) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add multiply\_fixed\_point [\#3945](https://github.com/apache/arrow-rs/pull/3945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat: enable metadata import/export through C data interface [\#3944](https://github.com/apache/arrow-rs/pull/3944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Fix checked i256 arithmetic \(\#3942\) \(\#3941\) [\#3943](https://github.com/apache/arrow-rs/pull/3943) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Avoid memory copies in take\_list [\#3940](https://github.com/apache/arrow-rs/pull/3940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster decimal parsing \(30-60%\) [\#3939](https://github.com/apache/arrow-rs/pull/3939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Fix: FlightSqlClient panic when execute\_update. [\#3938](https://github.com/apache/arrow-rs/pull/3938) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([youngsofun](https://github.com/youngsofun)) +- Cleanup row count handling in JSON writer [\#3934](https://github.com/apache/arrow-rs/pull/3934) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add typed buffers to UnionArray \(\#3880\) [\#3933](https://github.com/apache/arrow-rs/pull/3933) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add take for MapArray [\#3925](https://github.com/apache/arrow-rs/pull/3925) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Deprecate Array::data\_ref \(\#3880\) [\#3923](https://github.com/apache/arrow-rs/pull/3923) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Zero-copy conversion from Vec to PrimitiveArray [\#3917](https://github.com/apache/arrow-rs/pull/3917) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Add Commands enum to decode prost messages to strong type [\#3887](https://github.com/apache/arrow-rs/pull/3887) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([stuartcarnie](https://github.com/stuartcarnie)) +## [36.0.0](https://github.com/apache/arrow-rs/tree/36.0.0) (2023-03-24) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/35.0.0...36.0.0) + +**Breaking changes:** + +- Use dyn Array in sort kernels [\#3931](https://github.com/apache/arrow-rs/pull/3931) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Enforce struct nullability in JSON raw reader \(\#3900\) \(\#3904\) [\#3906](https://github.com/apache/arrow-rs/pull/3906) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Return ScalarBuffer from PrimitiveArray::values \(\#3879\) [\#3896](https://github.com/apache/arrow-rs/pull/3896) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use BooleanBuffer in BooleanArray \(\#3879\) [\#3895](https://github.com/apache/arrow-rs/pull/3895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Seal ArrowPrimitiveType [\#3882](https://github.com/apache/arrow-rs/pull/3882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support compression levels [\#3847](https://github.com/apache/arrow-rs/pull/3847) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([spebern](https://github.com/spebern)) + +**Implemented enhancements:** + +- Improve speed of parsing string to Times [\#3919](https://github.com/apache/arrow-rs/issues/3919) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- feat: add comparison/sort support for Float16 [\#3914](https://github.com/apache/arrow-rs/issues/3914) +- Pinned version in arrow-flight's build-dependencies are causing conflicts [\#3876](https://github.com/apache/arrow-rs/issues/3876) +- Add compression options \(levels\) [\#3844](https://github.com/apache/arrow-rs/issues/3844) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use Unsigned Integer for Fixed Size DataType [\#3815](https://github.com/apache/arrow-rs/issues/3815) +- Common trait for RecordBatch and StructArray [\#3764](https://github.com/apache/arrow-rs/issues/3764) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow precision loss on multiplying decimal arrays [\#3689](https://github.com/apache/arrow-rs/issues/3689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Raw JSON Reader Allows Non-Nullable Struct Children to Contain Nulls [\#3904](https://github.com/apache/arrow-rs/issues/3904) +- Nullable field with nested not nullable map in json [\#3900](https://github.com/apache/arrow-rs/issues/3900) +- parquet\_derive doesn't support Vec\ [\#3864](https://github.com/apache/arrow-rs/issues/3864) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[REGRESSION\] Parsing timestamps with lower case time separator [\#3863](https://github.com/apache/arrow-rs/issues/3863) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[REGRESSION\] Parsing timestamps with leap seconds [\#3861](https://github.com/apache/arrow-rs/issues/3861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[REGRESSION\] Parsing timestamps with fractional seconds / microseconds / milliseconds / nanoseconds [\#3859](https://github.com/apache/arrow-rs/issues/3859) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CSV Reader Doesn't set Timezone [\#3841](https://github.com/apache/arrow-rs/issues/3841) +- PyArrowConvert Leaks Memory [\#3683](https://github.com/apache/arrow-rs/issues/3683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Derive RunArray Clone [\#3932](https://github.com/apache/arrow-rs/pull/3932) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move protoc generation to binary crate, unpin prost/tonic build \(\#3876\) [\#3927](https://github.com/apache/arrow-rs/pull/3927) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Fix JSON Temporal Encoding of Multiple Batches [\#3924](https://github.com/apache/arrow-rs/pull/3924) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([doki23](https://github.com/doki23)) +- Cleanup uses of Array::data\_ref \(\#3880\) [\#3918](https://github.com/apache/arrow-rs/pull/3918) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support microsecond and nanosecond in interval parsing [\#3916](https://github.com/apache/arrow-rs/pull/3916) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add comparison/sort support for Float16 [\#3915](https://github.com/apache/arrow-rs/pull/3915) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([izveigor](https://github.com/izveigor)) +- Add AsArray trait for more ergonomic downcasting [\#3912](https://github.com/apache/arrow-rs/pull/3912) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add OffsetBuffer::new [\#3910](https://github.com/apache/arrow-rs/pull/3910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add PrimitiveArray::new \(\#3879\) [\#3909](https://github.com/apache/arrow-rs/pull/3909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support timezones in CSV reader \(\#3841\) [\#3908](https://github.com/apache/arrow-rs/pull/3908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve ScalarBuffer debug output [\#3907](https://github.com/apache/arrow-rs/pull/3907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.52 to =1.0.53 [\#3905](https://github.com/apache/arrow-rs/pull/3905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Re-export parquet compression level structs [\#3903](https://github.com/apache/arrow-rs/pull/3903) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix parsing timestamps of exactly 32 characters [\#3902](https://github.com/apache/arrow-rs/pull/3902) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add iterators to BooleanBuffer and NullBuffer [\#3901](https://github.com/apache/arrow-rs/pull/3901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Array equality for &dyn Array \(\#3880\) [\#3899](https://github.com/apache/arrow-rs/pull/3899) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BooleanArray::new \(\#3879\) [\#3898](https://github.com/apache/arrow-rs/pull/3898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Revert structured ArrayData \(\#3877\) [\#3894](https://github.com/apache/arrow-rs/pull/3894) ([tustvold](https://github.com/tustvold)) +- Fix pyarrow memory leak \(\#3683\) [\#3893](https://github.com/apache/arrow-rs/pull/3893) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: add examples for `ListBuilder` and `GenericListBuilder` [\#3891](https://github.com/apache/arrow-rs/pull/3891) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Update syn requirement from 1.0 to 2.0 [\#3890](https://github.com/apache/arrow-rs/pull/3890) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Use of `mul_checked` to avoid silent overflow in interval arithmetic [\#3886](https://github.com/apache/arrow-rs/pull/3886) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Flesh out NullBuffer abstraction \(\#3880\) [\#3885](https://github.com/apache/arrow-rs/pull/3885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Bit Operations for i256 [\#3884](https://github.com/apache/arrow-rs/pull/3884) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Flatten arrow\_buffer [\#3883](https://github.com/apache/arrow-rs/pull/3883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add Array::to\_data and Array::nulls \(\#3880\) [\#3881](https://github.com/apache/arrow-rs/pull/3881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Added support for byte vectors and slices to parquet\_derive \(\#3864\) [\#3878](https://github.com/apache/arrow-rs/pull/3878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([waymost](https://github.com/waymost)) +- chore: remove LevelDecoder [\#3872](https://github.com/apache/arrow-rs/pull/3872) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- Parse timestamps with leap seconds \(\#3861\) [\#3862](https://github.com/apache/arrow-rs/pull/3862) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster time parsing \(~93% faster\) [\#3860](https://github.com/apache/arrow-rs/pull/3860) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Parse timestamps with arbitrary seconds fraction [\#3858](https://github.com/apache/arrow-rs/pull/3858) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BitIterator [\#3856](https://github.com/apache/arrow-rs/pull/3856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve decimal parsing performance [\#3854](https://github.com/apache/arrow-rs/pull/3854) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Update proc-macro2 requirement from =1.0.51 to =1.0.52 [\#3853](https://github.com/apache/arrow-rs/pull/3853) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update bitflags requirement from 1.2.1 to 2.0.0 [\#3852](https://github.com/apache/arrow-rs/pull/3852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add offset pushdown to parquet [\#3848](https://github.com/apache/arrow-rs/pull/3848) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add timezone support to JSON reader [\#3845](https://github.com/apache/arrow-rs/pull/3845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Allow precision loss on multiplying decimal arrays [\#3690](https://github.com/apache/arrow-rs/pull/3690) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +## [35.0.0](https://github.com/apache/arrow-rs/tree/35.0.0) (2023-03-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/34.0.0...35.0.0) + +**Breaking changes:** + +- Add RunEndBuffer \(\#1799\) [\#3817](https://github.com/apache/arrow-rs/pull/3817) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Restrict DictionaryArray to ArrowDictionaryKeyType [\#3813](https://github.com/apache/arrow-rs/pull/3813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- refactor: assorted `FlightSqlServiceClient` improvements [\#3788](https://github.com/apache/arrow-rs/pull/3788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- minor: make Parquet CLI input args consistent [\#3786](https://github.com/apache/arrow-rs/pull/3786) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XinyuZeng](https://github.com/XinyuZeng)) +- Return Buffers from ArrayData::buffers instead of slice \(\#1799\) [\#3783](https://github.com/apache/arrow-rs/pull/3783) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use NullBuffer in ArrayData \(\#3775\) [\#3778](https://github.com/apache/arrow-rs/pull/3778) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Support timestamp/time and date types in json decoder [\#3834](https://github.com/apache/arrow-rs/issues/3834) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support decoding decimals in new raw json decoder [\#3819](https://github.com/apache/arrow-rs/issues/3819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Timezone Aware Timestamp Parsing [\#3794](https://github.com/apache/arrow-rs/issues/3794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Preallocate buffers for FixedSizeBinary array creation [\#3792](https://github.com/apache/arrow-rs/issues/3792) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make Parquet CLI args consistent [\#3785](https://github.com/apache/arrow-rs/issues/3785) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Creates PrimitiveDictionaryBuilder from provided keys and values builders [\#3776](https://github.com/apache/arrow-rs/issues/3776) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use NullBuffer in ArrayData [\#3775](https://github.com/apache/arrow-rs/issues/3775) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support unary\_dict\_mut in arth [\#3710](https://github.com/apache/arrow-rs/issues/3710) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support cast \<\> String to interval [\#3643](https://github.com/apache/arrow-rs/issues/3643) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Zero-Copy Conversion from Vec to/from MutableBuffer [\#3516](https://github.com/apache/arrow-rs/issues/3516) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Timestamp Unit Casts are Unchecked [\#3833](https://github.com/apache/arrow-rs/issues/3833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- regexp\_match skips first match when returning match [\#3803](https://github.com/apache/arrow-rs/issues/3803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast to timestamp with time zone returns timestamp [\#3800](https://github.com/apache/arrow-rs/issues/3800) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Schema-level metadata is not encoded in Flight responses [\#3779](https://github.com/apache/arrow-rs/issues/3779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Closed issues:** + +- FlightSQL CLI client: simple test [\#3814](https://github.com/apache/arrow-rs/issues/3814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Merged pull requests:** + +- refactor: timestamp overflow check [\#3840](https://github.com/apache/arrow-rs/pull/3840) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Prep for 35.0.0 [\#3836](https://github.com/apache/arrow-rs/pull/3836) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Support timestamp/time and date json decoding [\#3835](https://github.com/apache/arrow-rs/pull/3835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Make dictionary preservation optional in row encoding [\#3831](https://github.com/apache/arrow-rs/pull/3831) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move prettyprint to arrow-cast [\#3828](https://github.com/apache/arrow-rs/pull/3828) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Support decoding decimals in raw decoder [\#3820](https://github.com/apache/arrow-rs/pull/3820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Add ArrayDataLayout, port validation \(\#1799\) [\#3818](https://github.com/apache/arrow-rs/pull/3818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- test: add test for FlightSQL CLI client [\#3816](https://github.com/apache/arrow-rs/pull/3816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Add regexp\_match docs [\#3812](https://github.com/apache/arrow-rs/pull/3812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: Ensure Flight schema includes parent metadata [\#3811](https://github.com/apache/arrow-rs/pull/3811) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([stuartcarnie](https://github.com/stuartcarnie)) +- fix: regexp\_match skips first match [\#3807](https://github.com/apache/arrow-rs/pull/3807) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- fix: change uft8 to timestamp with timezone [\#3806](https://github.com/apache/arrow-rs/pull/3806) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Support reading decimal arrays from json [\#3805](https://github.com/apache/arrow-rs/pull/3805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([spebern](https://github.com/spebern)) +- Add unary\_dict\_mut [\#3804](https://github.com/apache/arrow-rs/pull/3804) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Faster timestamp parsing \(~70-90% faster\) [\#3801](https://github.com/apache/arrow-rs/pull/3801) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add concat\_elements\_bytes [\#3798](https://github.com/apache/arrow-rs/pull/3798) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Timezone aware timestamp parsing \(\#3794\) [\#3795](https://github.com/apache/arrow-rs/pull/3795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Preallocate buffers for FixedSizeBinary array creation [\#3793](https://github.com/apache/arrow-rs/pull/3793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +- feat: simple flight sql CLI client [\#3789](https://github.com/apache/arrow-rs/pull/3789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- Creates PrimitiveDictionaryBuilder from provided keys and values builders [\#3777](https://github.com/apache/arrow-rs/pull/3777) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- ArrayData Enumeration for Remaining Layouts [\#3769](https://github.com/apache/arrow-rs/pull/3769) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.11.7 to =0.11.8 [\#3767](https://github.com/apache/arrow-rs/pull/3767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Implement concat\_elements\_dyn kernel [\#3763](https://github.com/apache/arrow-rs/pull/3763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Support for casting `Utf8` and `LargeUtf8` --\> `Interval` [\#3762](https://github.com/apache/arrow-rs/pull/3762) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([doki23](https://github.com/doki23)) +- into\_inner\(\) for CSV Writer [\#3759](https://github.com/apache/arrow-rs/pull/3759) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Zero-copy Vec conversion \(\#3516\) \(\#1176\) [\#3756](https://github.com/apache/arrow-rs/pull/3756) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- ArrayData Enumeration for Primitive, Binary and UTF8 [\#3749](https://github.com/apache/arrow-rs/pull/3749) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `into_primitive_dict_builder` to `DictionaryArray` [\#3715](https://github.com/apache/arrow-rs/pull/3715) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +## [34.0.0](https://github.com/apache/arrow-rs/tree/34.0.0) (2023-02-24) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/33.0.0...34.0.0) + +**Breaking changes:** + +- Infer 2020-03-19 00:00:00 as timestamp not Date64 in CSV \(\#3744\) [\#3746](https://github.com/apache/arrow-rs/pull/3746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement fallible streams for `FlightClient::do_put` [\#3464](https://github.com/apache/arrow-rs/pull/3464) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) + +**Implemented enhancements:** + +- Support casting string to timestamp with microsecond resolution [\#3751](https://github.com/apache/arrow-rs/issues/3751) +- Add datatime/interval/duration into comparison kernels [\#3729](https://github.com/apache/arrow-rs/issues/3729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ! \(not\) operator overload for SortOptions [\#3726](https://github.com/apache/arrow-rs/issues/3726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet: convert Bytes to ByteArray directly [\#3719](https://github.com/apache/arrow-rs/issues/3719) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Implement simple RecordBatchReader [\#3704](https://github.com/apache/arrow-rs/issues/3704) +- Is possible to implement GenericListArray::from\_iter ? [\#3702](https://github.com/apache/arrow-rs/issues/3702) +- `take_run` improvements [\#3701](https://github.com/apache/arrow-rs/issues/3701) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `as_mut_any` in Array trait [\#3655](https://github.com/apache/arrow-rs/issues/3655) +- `Array` --\> `Display` formatter that supports more options and is configurable [\#3638](https://github.com/apache/arrow-rs/issues/3638) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-csv: support decimal256 [\#3474](https://github.com/apache/arrow-rs/issues/3474) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- CSV reader infers Date64 type for fields like "2020-03-19 00:00:00" that it can't parse to Date64 [\#3744](https://github.com/apache/arrow-rs/issues/3744) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Update to 34.0.0 and update changelog [\#3757](https://github.com/apache/arrow-rs/pull/3757) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Update MIRI for split crates \(\#2594\) [\#3754](https://github.com/apache/arrow-rs/pull/3754) ([tustvold](https://github.com/tustvold)) +- Update prost-build requirement from =0.11.6 to =0.11.7 [\#3753](https://github.com/apache/arrow-rs/pull/3753) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Enable casting of string to timestamp with microsecond resolution [\#3752](https://github.com/apache/arrow-rs/pull/3752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gruuya](https://github.com/gruuya)) +- Use Typed Buffers in Arrays \(\#1811\) \(\#1176\) [\#3743](https://github.com/apache/arrow-rs/pull/3743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup arithmetic kernel type constraints [\#3739](https://github.com/apache/arrow-rs/pull/3739) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Make dictionary kernels optional for comparison benchmark [\#3738](https://github.com/apache/arrow-rs/pull/3738) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support String Coercion in Raw JSON Reader [\#3736](https://github.com/apache/arrow-rs/pull/3736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rguerreiromsft](https://github.com/rguerreiromsft)) +- replace for loop by try\_for\_each [\#3734](https://github.com/apache/arrow-rs/pull/3734) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([suxiaogang223](https://github.com/suxiaogang223)) +- feat: implement generic record batch reader [\#3733](https://github.com/apache/arrow-rs/pull/3733) ([wjones127](https://github.com/wjones127)) +- \[minor\] fix doc test fail [\#3732](https://github.com/apache/arrow-rs/pull/3732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Add datetime/interval/duration into dyn scalar comparison [\#3730](https://github.com/apache/arrow-rs/pull/3730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Using Borrow\ on infer\_json\_schema\_from\_iterator [\#3728](https://github.com/apache/arrow-rs/pull/3728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rguerreiromsft](https://github.com/rguerreiromsft)) +- Not operator overload for SortOptions [\#3727](https://github.com/apache/arrow-rs/pull/3727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([berkaysynnada](https://github.com/berkaysynnada)) +- fix: encoding batch with no columns [\#3724](https://github.com/apache/arrow-rs/pull/3724) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([wangrunji0408](https://github.com/wangrunji0408)) +- feat: impl `Ord`/`PartialOrd` for `SortOptions` [\#3723](https://github.com/apache/arrow-rs/pull/3723) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add From\ for ByteArray [\#3720](https://github.com/apache/arrow-rs/pull/3720) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Deprecate old JSON reader \(\#3610\) [\#3718](https://github.com/apache/arrow-rs/pull/3718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add pretty format with options [\#3717](https://github.com/apache/arrow-rs/pull/3717) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove unreachable decimal take [\#3716](https://github.com/apache/arrow-rs/pull/3716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Feat: arrow csv decimal256 [\#3711](https://github.com/apache/arrow-rs/pull/3711) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([suxiaogang223](https://github.com/suxiaogang223)) +- perf: `take_run` improvements [\#3705](https://github.com/apache/arrow-rs/pull/3705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Add raw MapArrayReader [\#3703](https://github.com/apache/arrow-rs/pull/3703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Sort kernel for `RunArray` [\#3695](https://github.com/apache/arrow-rs/pull/3695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- perf: Remove sorting to yield sorted\_rank [\#3693](https://github.com/apache/arrow-rs/pull/3693) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- fix: Handle sliced array in run array iterator [\#3681](https://github.com/apache/arrow-rs/pull/3681) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +## [33.0.0](https://github.com/apache/arrow-rs/tree/33.0.0) (2023-02-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/32.0.0...33.0.0) + +**Breaking changes:** + +- Use ArrayFormatter in Cast Kernel [\#3668](https://github.com/apache/arrow-rs/pull/3668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use dyn Array in cast kernels [\#3667](https://github.com/apache/arrow-rs/pull/3667) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Return references from FixedSizeListArray and MapArray [\#3652](https://github.com/apache/arrow-rs/pull/3652) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Lazy array display \(\#3638\) [\#3647](https://github.com/apache/arrow-rs/pull/3647) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use array\_value\_to\_string in arrow-csv [\#3514](https://github.com/apache/arrow-rs/pull/3514) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JayjeetAtGithub](https://github.com/JayjeetAtGithub)) + +**Implemented enhancements:** + +- Support UTF8 cast to Timestamp with timezone [\#3664](https://github.com/apache/arrow-rs/issues/3664) +- Add modulus\_dyn and modulus\_scalar\_dyn [\#3648](https://github.com/apache/arrow-rs/issues/3648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- A trait for append\_value and append\_null on ArrayBuilders [\#3644](https://github.com/apache/arrow-rs/issues/3644) +- Improve error message "batches\[0\] schema is different with argument schema" [\#3628](https://github.com/apache/arrow-rs/issues/3628) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Specified version of helper function to cast binary to string [\#3623](https://github.com/apache/arrow-rs/issues/3623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting generic binary to generic string [\#3606](https://github.com/apache/arrow-rs/issues/3606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `array_value_to_string` in `arrow-csv` [\#3483](https://github.com/apache/arrow-rs/issues/3483) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- ArrowArray::try\_from\_raw Misleading Signature [\#3684](https://github.com/apache/arrow-rs/issues/3684) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PyArrowConvert Leaks Memory [\#3683](https://github.com/apache/arrow-rs/issues/3683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Arrow-csv reader cannot produce RecordBatch even if the bytes are necessary [\#3674](https://github.com/apache/arrow-rs/issues/3674) +- FFI Fails to Account For Offsets [\#3671](https://github.com/apache/arrow-rs/issues/3671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression in CSV reader error handling [\#3656](https://github.com/apache/arrow-rs/issues/3656) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- UnionArray Child and Value Fail to Account for non-contiguous Type IDs [\#3653](https://github.com/apache/arrow-rs/issues/3653) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Panic when accessing RecordBatch from pyarrow [\#3646](https://github.com/apache/arrow-rs/issues/3646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Multiplication for decimals is incorrect [\#3645](https://github.com/apache/arrow-rs/issues/3645) +- Inconsistent output between pretty print and CSV writer for Arrow [\#3513](https://github.com/apache/arrow-rs/issues/3513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Release 33.0.0 of arrow/arrow-flight/parquet/parquet-derive \(next release after 32.0.0\) [\#3682](https://github.com/apache/arrow-rs/issues/3682) +- Release `32.0.0` of `arrow`/`arrow-flight`/`parquet`/`parquet-derive` \(next release after `31.0.0`\) [\#3584](https://github.com/apache/arrow-rs/issues/3584) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Merged pull requests:** + +- Move FFI to sub-crates [\#3687](https://github.com/apache/arrow-rs/pull/3687) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update to 33.0.0 and update changelog [\#3686](https://github.com/apache/arrow-rs/pull/3686) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Cleanup FFI interface \(\#3684\) \(\#3683\) [\#3685](https://github.com/apache/arrow-rs/pull/3685) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix: take\_run benchmark parameter [\#3679](https://github.com/apache/arrow-rs/pull/3679) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Minor: Add some examples to Date\*Array and Time\*Array [\#3678](https://github.com/apache/arrow-rs/pull/3678) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add CSV Decoder::capacity \(\#3674\) [\#3677](https://github.com/apache/arrow-rs/pull/3677) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add ArrayData::new\_null and DataType::primitive\_width [\#3676](https://github.com/apache/arrow-rs/pull/3676) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix FFI which fails to account for offsets [\#3675](https://github.com/apache/arrow-rs/pull/3675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support UTF8 cast to Timestamp with timezone [\#3673](https://github.com/apache/arrow-rs/pull/3673) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Fix Date64Array docs [\#3670](https://github.com/apache/arrow-rs/pull/3670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.50 to =1.0.51 [\#3669](https://github.com/apache/arrow-rs/pull/3669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add timezone accessor for Timestamp\*Array [\#3666](https://github.com/apache/arrow-rs/pull/3666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster timezone cast [\#3665](https://github.com/apache/arrow-rs/pull/3665) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat + fix: IPC support for run encoded array. [\#3662](https://github.com/apache/arrow-rs/pull/3662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Implement std::fmt::Write for StringBuilder \(\#3638\) [\#3659](https://github.com/apache/arrow-rs/pull/3659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Include line and field number in CSV UTF-8 error \(\#3656\) [\#3657](https://github.com/apache/arrow-rs/pull/3657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Handle non-contiguous type\_ids in UnionArray \(\#3653\) [\#3654](https://github.com/apache/arrow-rs/pull/3654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add modulus\_dyn and modulus\_scalar\_dyn [\#3649](https://github.com/apache/arrow-rs/pull/3649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Improve error message with detailed schema [\#3637](https://github.com/apache/arrow-rs/pull/3637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Veeupup](https://github.com/Veeupup)) +- Add limit to ArrowReaderBuilder to push limit down to parquet reader [\#3633](https://github.com/apache/arrow-rs/pull/3633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- chore: delete wrong comment and refactor set\_metadata in `Field` [\#3630](https://github.com/apache/arrow-rs/pull/3630) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chunshao90](https://github.com/chunshao90)) +- Fix typo in comment [\#3627](https://github.com/apache/arrow-rs/pull/3627) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kjschiroo](https://github.com/kjschiroo)) +- Minor: Update doc strings about Page Index / Column Index [\#3625](https://github.com/apache/arrow-rs/pull/3625) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Specified version of helper function to cast binary to string [\#3624](https://github.com/apache/arrow-rs/pull/3624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat: take kernel for RunArray [\#3622](https://github.com/apache/arrow-rs/pull/3622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Remove BitSliceIterator specialization from try\_for\_each\_valid\_idx [\#3621](https://github.com/apache/arrow-rs/pull/3621) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Reduce PrimitiveArray::try\_unary codegen [\#3619](https://github.com/apache/arrow-rs/pull/3619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Reduce Dictionary Builder Codegen [\#3616](https://github.com/apache/arrow-rs/pull/3616) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Add test for dictionary encoding of batches [\#3608](https://github.com/apache/arrow-rs/pull/3608) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Casting generic binary to generic string [\#3607](https://github.com/apache/arrow-rs/pull/3607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add ArrayAccessor, Iterator, Extend and benchmarks for RunArray [\#3603](https://github.com/apache/arrow-rs/pull/3603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +## [32.0.0](https://github.com/apache/arrow-rs/tree/32.0.0) (2023-01-27) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/31.0.0...32.0.0) + +**Breaking changes:** + +- Allow `StringArray` construction with `Vec>` [\#3602](https://github.com/apache/arrow-rs/pull/3602) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sinistersnare](https://github.com/sinistersnare)) +- Use native types in PageIndex \(\#3575\) [\#3578](https://github.com/apache/arrow-rs/pull/3578) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add external variant to ParquetError \(\#3285\) [\#3574](https://github.com/apache/arrow-rs/pull/3574) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Return reference from ListArray::values [\#3561](https://github.com/apache/arrow-rs/pull/3561) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: Add `RunEndEncodedArray` [\#3553](https://github.com/apache/arrow-rs/pull/3553) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +**Implemented enhancements:** + +- There should be a `From>>` impl for `GenericStringArray` [\#3599](https://github.com/apache/arrow-rs/issues/3599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FlightDataEncoder Optionally send Schema even when no record batches [\#3591](https://github.com/apache/arrow-rs/issues/3591) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Use Native Types in PageIndex [\#3575](https://github.com/apache/arrow-rs/issues/3575) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Packing array into dictionary of generic byte array [\#3571](https://github.com/apache/arrow-rs/issues/3571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `Error::Source` for ArrowError and FlightError [\#3566](https://github.com/apache/arrow-rs/issues/3566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[FlightSQL\] Allow access to underlying FlightClient [\#3551](https://github.com/apache/arrow-rs/issues/3551) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Arrow CSV writer should not fail when cannot cast the value [\#3547](https://github.com/apache/arrow-rs/issues/3547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Write Deprecated Min Max Statistics When ColumnOrder Signed [\#3526](https://github.com/apache/arrow-rs/issues/3526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Improve Performance of JSON Reader [\#3441](https://github.com/apache/arrow-rs/issues/3441) +- Support footer kv metadata for IPC file [\#3432](https://github.com/apache/arrow-rs/issues/3432) +- Add `External` variant to ParquetError [\#3285](https://github.com/apache/arrow-rs/issues/3285) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Nullif of NULL Predicate is not NULL [\#3589](https://github.com/apache/arrow-rs/issues/3589) +- BooleanBufferBuilder Fails to Clear Set Bits On Truncate [\#3587](https://github.com/apache/arrow-rs/issues/3587) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `nullif` incorrectly calculates `null_count`, sometimes panics with subtraction overflow error [\#3579](https://github.com/apache/arrow-rs/issues/3579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Meet warning when use pyarrow [\#3543](https://github.com/apache/arrow-rs/issues/3543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect row group total\_byte\_size written to parquet file [\#3530](https://github.com/apache/arrow-rs/issues/3530) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Overflow when casting timestamps prior to the epoch [\#3512](https://github.com/apache/arrow-rs/issues/3512) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Panic on Key Overflow in Dictionary Builders [\#3562](https://github.com/apache/arrow-rs/issues/3562) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Bumping version gives compilation error \(arrow-array\) [\#3525](https://github.com/apache/arrow-rs/issues/3525) + +**Merged pull requests:** + +- Add Push-Based CSV Decoder [\#3604](https://github.com/apache/arrow-rs/pull/3604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update to flatbuffers 23.1.21 [\#3597](https://github.com/apache/arrow-rs/pull/3597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster BooleanBufferBuilder::append\_n for true values [\#3596](https://github.com/apache/arrow-rs/pull/3596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support sending schemas for empty streams [\#3594](https://github.com/apache/arrow-rs/pull/3594) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Faster ListArray to StringArray conversion [\#3593](https://github.com/apache/arrow-rs/pull/3593) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add conversion from StringArray to BinaryArray [\#3592](https://github.com/apache/arrow-rs/pull/3592) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix nullif null count \(\#3579\) [\#3590](https://github.com/apache/arrow-rs/pull/3590) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Clear bits in BooleanBufferBuilder \(\#3587\) [\#3588](https://github.com/apache/arrow-rs/pull/3588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Iterate all dictionary key types in cast test [\#3585](https://github.com/apache/arrow-rs/pull/3585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Propagate EOF Error from AsyncRead [\#3576](https://github.com/apache/arrow-rs/pull/3576) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Sach1nAgarwal](https://github.com/Sach1nAgarwal)) +- Show row\_counts also for \(FixedLen\)ByteArray [\#3573](https://github.com/apache/arrow-rs/pull/3573) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([bmmeijers](https://github.com/bmmeijers)) +- Packing array into dictionary of generic byte array [\#3572](https://github.com/apache/arrow-rs/pull/3572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove unwrap on datetime cast for CSV writer [\#3570](https://github.com/apache/arrow-rs/pull/3570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Implement `std::error::Error::source` for `ArrowError` and `FlightError` [\#3567](https://github.com/apache/arrow-rs/pull/3567) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Improve GenericBytesBuilder offset overflow panic message \(\#139\) [\#3564](https://github.com/apache/arrow-rs/pull/3564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement Extend for ArrayBuilder \(\#1841\) [\#3563](https://github.com/apache/arrow-rs/pull/3563) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update pyarrow method call with kwargs [\#3560](https://github.com/apache/arrow-rs/pull/3560) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Frankonly](https://github.com/Frankonly)) +- Update pyo3 requirement from 0.17 to 0.18 [\#3557](https://github.com/apache/arrow-rs/pull/3557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Expose Inner FlightServiceClient on FlightSqlServiceClient \(\#3551\) [\#3556](https://github.com/apache/arrow-rs/pull/3556) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Fix final page row count in parquet-index binary [\#3554](https://github.com/apache/arrow-rs/pull/3554) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Parquet Avoid Reading 8 Byte Footer Twice from AsyncRead [\#3550](https://github.com/apache/arrow-rs/pull/3550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Sach1nAgarwal](https://github.com/Sach1nAgarwal)) +- Improve concat kernel capacity estimation [\#3546](https://github.com/apache/arrow-rs/pull/3546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update proc-macro2 requirement from =1.0.49 to =1.0.50 [\#3545](https://github.com/apache/arrow-rs/pull/3545) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update pyarrow method call to avoid warning [\#3544](https://github.com/apache/arrow-rs/pull/3544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Frankonly](https://github.com/Frankonly)) +- Enable casting between Utf8/LargeUtf8 and Binary/LargeBinary [\#3542](https://github.com/apache/arrow-rs/pull/3542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Use GHA concurrency groups \(\#3495\) [\#3538](https://github.com/apache/arrow-rs/pull/3538) ([tustvold](https://github.com/tustvold)) +- set sum of uncompressed column size as row group size for parquet files [\#3531](https://github.com/apache/arrow-rs/pull/3531) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sidred](https://github.com/sidred)) +- Minor: Add documentation about memory use for ArrayData [\#3529](https://github.com/apache/arrow-rs/pull/3529) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Upgrade to clap 4.1 + fix test [\#3528](https://github.com/apache/arrow-rs/pull/3528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Write backwards compatible row group statistics \(\#3526\) [\#3527](https://github.com/apache/arrow-rs/pull/3527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- No panic on timestamp buffer overflow [\#3519](https://github.com/apache/arrow-rs/pull/3519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Support casting from binary to dictionary of binary [\#3482](https://github.com/apache/arrow-rs/pull/3482) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add Raw JSON Reader \(~2.5x faster\) [\#3479](https://github.com/apache/arrow-rs/pull/3479) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +## [31.0.0](https://github.com/apache/arrow-rs/tree/31.0.0) (2023-01-13) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/30.0.1...31.0.0) + +**Breaking changes:** + +- support RFC3339 style timestamps in `arrow-json` [\#3449](https://github.com/apache/arrow-rs/pull/3449) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JayjeetAtGithub](https://github.com/JayjeetAtGithub)) +- Improve arrow flight batch splitting and naming [\#3444](https://github.com/apache/arrow-rs/pull/3444) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Parquet record API: timestamp as signed integer [\#3437](https://github.com/apache/arrow-rs/pull/3437) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([ByteBaker](https://github.com/ByteBaker)) +- Support decimal int32/64 for writer [\#3431](https://github.com/apache/arrow-rs/pull/3431) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) + +**Implemented enhancements:** + +- Support casting Date32 to timestamp [\#3504](https://github.com/apache/arrow-rs/issues/3504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting strings like `'2001-01-01'` to timestamp [\#3492](https://github.com/apache/arrow-rs/issues/3492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CLI to "rewrite" parquet files [\#3476](https://github.com/apache/arrow-rs/issues/3476) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add more dictionary value type support to `build_compare` [\#3465](https://github.com/apache/arrow-rs/issues/3465) +- Allow `concat_batches` to take non owned RecordBatch [\#3456](https://github.com/apache/arrow-rs/issues/3456) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release Arrow `30.0.1` \(maintenance release for `30.0.0`\) [\#3455](https://github.com/apache/arrow-rs/issues/3455) +- Add string comparisons \(starts\_with, ends\_with, and contains\) to kernel [\#3442](https://github.com/apache/arrow-rs/issues/3442) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- make\_builder Loses Timezone and Decimal Scale Information [\#3435](https://github.com/apache/arrow-rs/issues/3435) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use RFC3339 style timestamps in arrow-json [\#3416](https://github.com/apache/arrow-rs/issues/3416) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ArrayData`get_slice_memory_size` or similar [\#3407](https://github.com/apache/arrow-rs/issues/3407) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Fixed bugs:** + +- Unable to read CSV with null boolean value [\#3521](https://github.com/apache/arrow-rs/issues/3521) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make consistent behavior on zeros equality on floating point types [\#3509](https://github.com/apache/arrow-rs/issues/3509) +- Sliced batch w/ bool column doesn't roundtrip through IPC [\#3496](https://github.com/apache/arrow-rs/issues/3496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- take kernel on List array introduces nulls instead of empty lists [\#3471](https://github.com/apache/arrow-rs/issues/3471) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Infinite Loop If Skipping More CSV Lines than Present [\#3469](https://github.com/apache/arrow-rs/issues/3469) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix reading null booleans from CSV [\#3523](https://github.com/apache/arrow-rs/pull/3523) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- minor fix: use the unified decimal type builder [\#3522](https://github.com/apache/arrow-rs/pull/3522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Update version to `31.0.0` and add changelog [\#3518](https://github.com/apache/arrow-rs/pull/3518) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([iajoiner](https://github.com/iajoiner)) +- Additional nullif re-export [\#3515](https://github.com/apache/arrow-rs/pull/3515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Make consistent behavior on zeros equality on floating point types [\#3510](https://github.com/apache/arrow-rs/pull/3510) ([viirya](https://github.com/viirya)) +- Enable cast Date32 to Timestamp [\#3508](https://github.com/apache/arrow-rs/pull/3508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Update prost-build requirement from =0.11.5 to =0.11.6 [\#3507](https://github.com/apache/arrow-rs/pull/3507) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- minor fix for the comments [\#3505](https://github.com/apache/arrow-rs/pull/3505) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Fix DataTypeLayout for LargeList [\#3503](https://github.com/apache/arrow-rs/pull/3503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add string comparisons \(starts\_with, ends\_with, and contains\) to kernel [\#3502](https://github.com/apache/arrow-rs/pull/3502) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([snmvaughan](https://github.com/snmvaughan)) +- Add a function to get memory size of array slice [\#3501](https://github.com/apache/arrow-rs/pull/3501) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Fix IPCWriter for Sliced BooleanArray [\#3498](https://github.com/apache/arrow-rs/pull/3498) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Fix: Added support to cast string without time [\#3494](https://github.com/apache/arrow-rs/pull/3494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gaelwjl](https://github.com/gaelwjl)) +- Fix negative interval prettyprint [\#3491](https://github.com/apache/arrow-rs/pull/3491) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Fixes a broken link in the arrow lib.rs rustdoc [\#3487](https://github.com/apache/arrow-rs/pull/3487) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- Refactoring build\_compare for decimal and using downcast\_primitive [\#3484](https://github.com/apache/arrow-rs/pull/3484) ([viirya](https://github.com/viirya)) +- Add tests for record batch size splitting logic in FlightClient [\#3481](https://github.com/apache/arrow-rs/pull/3481) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- change `concat_batches` parameter to non owned reference [\#3480](https://github.com/apache/arrow-rs/pull/3480) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- feat: add `parquet-rewrite` CLI [\#3477](https://github.com/apache/arrow-rs/pull/3477) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([crepererum](https://github.com/crepererum)) +- Preserve empty list array elements in take kernel [\#3473](https://github.com/apache/arrow-rs/pull/3473) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonmmease](https://github.com/jonmmease)) +- Add a test for stream writer for writing sliced array [\#3472](https://github.com/apache/arrow-rs/pull/3472) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix CSV infinite loop and improve error messages [\#3470](https://github.com/apache/arrow-rs/pull/3470) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add more dictionary value type support to `build_compare` [\#3466](https://github.com/apache/arrow-rs/pull/3466) ([viirya](https://github.com/viirya)) +- Add tests for `FlightClient::{list_flights, list_actions, do_action, get_schema}` [\#3463](https://github.com/apache/arrow-rs/pull/3463) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Minor: add ticket links to failing ipc integration tests [\#3461](https://github.com/apache/arrow-rs/pull/3461) ([alamb](https://github.com/alamb)) +- feat: `column_name` based index access for `RecordBatch` and `StructArray` [\#3458](https://github.com/apache/arrow-rs/pull/3458) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Support Decimal256 in FFI [\#3453](https://github.com/apache/arrow-rs/pull/3453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove multiversion dependency [\#3452](https://github.com/apache/arrow-rs/pull/3452) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Re-export nullif kernel [\#3451](https://github.com/apache/arrow-rs/pull/3451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Meaningful error message for map builder with null keys [\#3450](https://github.com/apache/arrow-rs/pull/3450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Parquet writer v2: clear buffer after page flush [\#3447](https://github.com/apache/arrow-rs/pull/3447) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([askoa](https://github.com/askoa)) +- Verify ArrayData::data\_type compatible in PrimitiveArray::from [\#3440](https://github.com/apache/arrow-rs/pull/3440) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Preserve DataType metadata in make\_builder [\#3438](https://github.com/apache/arrow-rs/pull/3438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Consolidate arrow ipc tests and increase coverage [\#3427](https://github.com/apache/arrow-rs/pull/3427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Generic bytes dictionary builder [\#3426](https://github.com/apache/arrow-rs/pull/3426) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Minor: Improve docs for arrow-ipc, remove clippy ignore [\#3421](https://github.com/apache/arrow-rs/pull/3421) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- refactor: convert `*like_dyn`, `*like_utf8_scalar_dyn` and `*like_dict` functions to macros [\#3411](https://github.com/apache/arrow-rs/pull/3411) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Add parquet-index binary [\#3405](https://github.com/apache/arrow-rs/pull/3405) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Complete mid-level `FlightClient` [\#3402](https://github.com/apache/arrow-rs/pull/3402) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Implement `RecordBatch` \<--\> `FlightData` encode/decode + tests [\#3391](https://github.com/apache/arrow-rs/pull/3391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Provide `into_builder` for bytearray [\#3326](https://github.com/apache/arrow-rs/pull/3326) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +## [30.0.1](https://github.com/apache/arrow-rs/tree/30.0.1) (2023-01-04) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/30.0.0...30.0.1) + +**Implemented enhancements:** + +- Generic bytes dictionary builder [\#3425](https://github.com/apache/arrow-rs/issues/3425) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Derive Clone for the builders in object-store. [\#3419](https://github.com/apache/arrow-rs/issues/3419) +- Mid-level `ArrowFlight` Client [\#3371](https://github.com/apache/arrow-rs/issues/3371) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Improve performance of the CSV parser [\#3338](https://github.com/apache/arrow-rs/issues/3338) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- `nullif` kernel no longer exported [\#3454](https://github.com/apache/arrow-rs/issues/3454) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PrimitiveArray from ArrayData Unsound For IntervalArray [\#3439](https://github.com/apache/arrow-rs/issues/3439) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- LZ4-compressed PQ files unreadable by Pandas and ClickHouse [\#3433](https://github.com/apache/arrow-rs/issues/3433) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet Record API: Cannot convert date before Unix epoch to json [\#3430](https://github.com/apache/arrow-rs/issues/3430) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- parquet-fromcsv with writer version v2 does not stop [\#3408](https://github.com/apache/arrow-rs/issues/3408) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +## [30.0.0](https://github.com/apache/arrow-rs/tree/30.0.0) (2022-12-29) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/29.0.0...30.0.0) + +**Breaking changes:** + +- Infer Parquet JSON Logical and Converted Type as UTF-8 [\#3376](https://github.com/apache/arrow-rs/pull/3376) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use custom Any instead of prost\_types [\#3360](https://github.com/apache/arrow-rs/pull/3360) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Use bytes in arrow-flight [\#3359](https://github.com/apache/arrow-rs/pull/3359) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Add derived implementations of Clone and Debug for `ParquetObjectReader` [\#3381](https://github.com/apache/arrow-rs/issues/3381) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Speed up TrackedWrite [\#3366](https://github.com/apache/arrow-rs/issues/3366) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Is it possible for ArrowWriter to write key\_value\_metadata after write all records [\#3356](https://github.com/apache/arrow-rs/issues/3356) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add UnionArray test to arrow-pyarrow integration test [\#3346](https://github.com/apache/arrow-rs/issues/3346) +- Document / Deprecate arrow\_flight::utils::flight\_data\_from\_arrow\_batch [\#3312](https://github.com/apache/arrow-rs/issues/3312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[FlightSQL\] Support HTTPs [\#3309](https://github.com/apache/arrow-rs/issues/3309) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support UnionArray in ffi [\#3304](https://github.com/apache/arrow-rs/issues/3304) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for Azure Data Lake Storage Gen2 \(aka: ADLS Gen2\) in Object Store library [\#3283](https://github.com/apache/arrow-rs/issues/3283) +- Support casting from String to Decimal [\#3280](https://github.com/apache/arrow-rs/issues/3280) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow ArrowCSV writer to control the display of NULL values [\#3268](https://github.com/apache/arrow-rs/issues/3268) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- FlightSQL example is broken [\#3386](https://github.com/apache/arrow-rs/issues/3386) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- CSV Reader Bounds Incorrectly Handles Header [\#3364](https://github.com/apache/arrow-rs/issues/3364) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect output string from `try_to_type` [\#3350](https://github.com/apache/arrow-rs/issues/3350) +- Decimal arithmetic computation fails to run because decimal type equality [\#3344](https://github.com/apache/arrow-rs/issues/3344) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Pretty print not implemented for Map [\#3322](https://github.com/apache/arrow-rs/issues/3322) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ILIKE Kernels Inconsistent Case Folding [\#3311](https://github.com/apache/arrow-rs/issues/3311) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- minor: Improve arrow-flight docs [\#3372](https://github.com/apache/arrow-rs/pull/3372) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- Version 30.0.0 release notes and changelog [\#3406](https://github.com/apache/arrow-rs/pull/3406) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Ends ParquetRecordBatchStream when polling on StreamState::Error [\#3404](https://github.com/apache/arrow-rs/pull/3404) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- fix clippy issues [\#3398](https://github.com/apache/arrow-rs/pull/3398) ([Jimexist](https://github.com/Jimexist)) +- Upgrade multiversion to 0.7.1 [\#3396](https://github.com/apache/arrow-rs/pull/3396) ([viirya](https://github.com/viirya)) +- Make FlightSQL Support HTTPs [\#3388](https://github.com/apache/arrow-rs/pull/3388) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Fix broken FlightSQL example [\#3387](https://github.com/apache/arrow-rs/pull/3387) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Update prost-build [\#3385](https://github.com/apache/arrow-rs/pull/3385) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-arith \(\#2594\) [\#3384](https://github.com/apache/arrow-rs/pull/3384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add derive for Clone and Debug for `ParquetObjectReader` [\#3382](https://github.com/apache/arrow-rs/pull/3382) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kszlim](https://github.com/kszlim)) +- Initial Mid-level `FlightClient` [\#3378](https://github.com/apache/arrow-rs/pull/3378) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Document all features on docs.rs [\#3377](https://github.com/apache/arrow-rs/pull/3377) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-row \(\#2594\) [\#3375](https://github.com/apache/arrow-rs/pull/3375) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove unnecessary flush calls on TrackedWrite [\#3374](https://github.com/apache/arrow-rs/pull/3374) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Update proc-macro2 requirement from =1.0.47 to =1.0.49 [\#3369](https://github.com/apache/arrow-rs/pull/3369) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add CSV build\_buffered \(\#3338\) [\#3368](https://github.com/apache/arrow-rs/pull/3368) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: add append\_key\_value\_metadata [\#3367](https://github.com/apache/arrow-rs/pull/3367) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jiacai2050](https://github.com/jiacai2050)) +- Add csv-core based reader \(\#3338\) [\#3365](https://github.com/apache/arrow-rs/pull/3365) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Put BufWriter into TrackedWrite [\#3361](https://github.com/apache/arrow-rs/pull/3361) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add CSV reader benchmark \(\#3338\) [\#3357](https://github.com/apache/arrow-rs/pull/3357) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use ArrayData::ptr\_eq in DictionaryTracker [\#3354](https://github.com/apache/arrow-rs/pull/3354) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate flight\_data\_from\_arrow\_batch [\#3353](https://github.com/apache/arrow-rs/pull/3353) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Dandandan](https://github.com/Dandandan)) +- Fix incorrect output string from try\_to\_type [\#3351](https://github.com/apache/arrow-rs/pull/3351) ([viirya](https://github.com/viirya)) +- Fix unary\_dyn for decimal scalar arithmetic computation [\#3345](https://github.com/apache/arrow-rs/pull/3345) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add UnionArray test to arrow-pyarrow integration test [\#3343](https://github.com/apache/arrow-rs/pull/3343) ([viirya](https://github.com/viirya)) +- feat: configure null value in arrow csv writer [\#3342](https://github.com/apache/arrow-rs/pull/3342) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Optimize bulk writing of all blocks of bloom filter [\#3340](https://github.com/apache/arrow-rs/pull/3340) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add MapArray to pretty print [\#3339](https://github.com/apache/arrow-rs/pull/3339) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Update prost-build 0.11.4 [\#3334](https://github.com/apache/arrow-rs/pull/3334) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Faster Parquet Bloom Writer [\#3333](https://github.com/apache/arrow-rs/pull/3333) ([tustvold](https://github.com/tustvold)) +- Add bloom filter benchmark for parquet writer [\#3323](https://github.com/apache/arrow-rs/pull/3323) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add ASCII fast path for ILIKE scalar \(90% faster\) [\#3306](https://github.com/apache/arrow-rs/pull/3306) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support UnionArray in ffi [\#3305](https://github.com/apache/arrow-rs/pull/3305) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support casting from String to Decimal [\#3281](https://github.com/apache/arrow-rs/pull/3281) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- add more integration test for parquet bloom filter round trip tests [\#3210](https://github.com/apache/arrow-rs/pull/3210) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +## [29.0.0](https://github.com/apache/arrow-rs/tree/29.0.0) (2022-12-09) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/28.0.0...29.0.0) + +**Breaking changes:** + +- Minor: Allow `Field::new` and `Field::new_with_dict` to take existing `String` as well as `&str` [\#3288](https://github.com/apache/arrow-rs/pull/3288) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- update `&Option` to `Option<&T>` [\#3249](https://github.com/apache/arrow-rs/pull/3249) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- Hide `*_dict_scalar` kernels behind `*_dyn` kernels [\#3202](https://github.com/apache/arrow-rs/pull/3202) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +**Implemented enhancements:** + +- Support writing BloomFilter in arrow\_writer [\#3275](https://github.com/apache/arrow-rs/issues/3275) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support casting from unsigned numeric to Decimal256 [\#3272](https://github.com/apache/arrow-rs/issues/3272) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting from Decimal256 to float types [\#3266](https://github.com/apache/arrow-rs/issues/3266) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make arithmetic kernels supports DictionaryArray of DecimalType [\#3254](https://github.com/apache/arrow-rs/issues/3254) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting from Decimal256 to unsigned numeric [\#3239](https://github.com/apache/arrow-rs/issues/3239) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- precision is not considered when cast value to decimal [\#3223](https://github.com/apache/arrow-rs/issues/3223) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use RegexSet in arrow\_csv::infer\_field\_schema [\#3211](https://github.com/apache/arrow-rs/issues/3211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement FlightSQL Client [\#3206](https://github.com/apache/arrow-rs/issues/3206) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add binary\_mut and try\_binary\_mut [\#3143](https://github.com/apache/arrow-rs/issues/3143) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add try\_unary\_mut [\#3133](https://github.com/apache/arrow-rs/issues/3133) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Skip null buffer when importing FFI ArrowArray struct if no null buffer in the spec [\#3290](https://github.com/apache/arrow-rs/issues/3290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- using ahash `compile-time-rng` kills reproducible builds [\#3271](https://github.com/apache/arrow-rs/issues/3271) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Decimal128 to Decimal256 Overflows [\#3265](https://github.com/apache/arrow-rs/issues/3265) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `nullif` panics on empty array [\#3261](https://github.com/apache/arrow-rs/issues/3261) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Some more inconsistency between can\_cast\_types and cast\_with\_options [\#3250](https://github.com/apache/arrow-rs/issues/3250) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Enable casting between Dictionary of DecimalArray and DecimalArray [\#3237](https://github.com/apache/arrow-rs/issues/3237) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- new\_null\_array Panics creating StructArray with non-nullable fields [\#3226](https://github.com/apache/arrow-rs/issues/3226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- bool should cast from/to Float16Type as `can_cast_types` returns true [\#3221](https://github.com/apache/arrow-rs/issues/3221) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Utf8 and LargeUtf8 cannot cast from/to Float16 but can\_cast\_types returns true [\#3220](https://github.com/apache/arrow-rs/issues/3220) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Re-enable some tests in `arrow-cast` crate [\#3219](https://github.com/apache/arrow-rs/issues/3219) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Off-by-one buffer size error triggers Panic when constructing RecordBatch from IPC bytes \(should return an Error\) [\#3215](https://github.com/apache/arrow-rs/issues/3215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow to and from pyarrow conversion results in changes in schema [\#3136](https://github.com/apache/arrow-rs/issues/3136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- better document when we need `LargeUtf8` instead of `Utf8` [\#3228](https://github.com/apache/arrow-rs/issues/3228) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Use BufWriter when writing bloom filters and limit tests \(\#3318\) [\#3319](https://github.com/apache/arrow-rs/pull/3319) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use take for dictionary like comparisons [\#3313](https://github.com/apache/arrow-rs/pull/3313) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update versions to 29.0.0 and update CHANGELOG [\#3315](https://github.com/apache/arrow-rs/pull/3315) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- refactor: Merge similar functions `ilike_scalar` and `nilike_scalar` [\#3303](https://github.com/apache/arrow-rs/pull/3303) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Split out arrow-ord \(\#2594\) [\#3299](https://github.com/apache/arrow-rs/pull/3299) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-string \(\#2594\) [\#3295](https://github.com/apache/arrow-rs/pull/3295) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Skip null buffer when importing FFI ArrowArray struct if no null buffer in the spec [\#3293](https://github.com/apache/arrow-rs/pull/3293) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Don't use dangling NonNull as sentinel [\#3289](https://github.com/apache/arrow-rs/pull/3289) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Set bloom filter on byte array [\#3284](https://github.com/apache/arrow-rs/pull/3284) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Fix ipc schema custom\_metadata serialization [\#3282](https://github.com/apache/arrow-rs/pull/3282) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Disable const-random ahash feature on non-WASM \(\#3271\) [\#3277](https://github.com/apache/arrow-rs/pull/3277) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- fix\(ffi\): handle null data buffers from empty arrays [\#3276](https://github.com/apache/arrow-rs/pull/3276) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([wjones127](https://github.com/wjones127)) +- Support casting from unsigned numeric to Decimal256 [\#3273](https://github.com/apache/arrow-rs/pull/3273) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add parquet-layout binary [\#3269](https://github.com/apache/arrow-rs/pull/3269) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support casting from Decimal256 to float types [\#3267](https://github.com/apache/arrow-rs/pull/3267) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Simplify decimal cast logic [\#3264](https://github.com/apache/arrow-rs/pull/3264) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix panic on nullif empty array \(\#3261\) [\#3263](https://github.com/apache/arrow-rs/pull/3263) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BooleanArray::from\_unary and BooleanArray::from\_binary [\#3258](https://github.com/apache/arrow-rs/pull/3258) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Remove parquet build script [\#3257](https://github.com/apache/arrow-rs/pull/3257) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make arithmetic kernels supports DictionaryArray of DecimalType [\#3255](https://github.com/apache/arrow-rs/pull/3255) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support List and LargeList in Row format \(\#3159\) [\#3251](https://github.com/apache/arrow-rs/pull/3251) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Don't recurse to children in ArrayData::try\_new [\#3248](https://github.com/apache/arrow-rs/pull/3248) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Validate dictionaries read over IPC [\#3247](https://github.com/apache/arrow-rs/pull/3247) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix MapBuilder example [\#3246](https://github.com/apache/arrow-rs/pull/3246) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Loosen nullability restrictions added in \#3205 \(\#3226\) [\#3244](https://github.com/apache/arrow-rs/pull/3244) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Better document implications of offsets \(\#3228\) [\#3243](https://github.com/apache/arrow-rs/pull/3243) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add new API to validate the precision for decimal array [\#3242](https://github.com/apache/arrow-rs/pull/3242) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Move nullif to arrow-select \(\#2594\) [\#3241](https://github.com/apache/arrow-rs/pull/3241) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Casting from Decimal256 to unsigned numeric [\#3240](https://github.com/apache/arrow-rs/pull/3240) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable casting between Dictionary of DecimalArray and DecimalArray [\#3238](https://github.com/apache/arrow-rs/pull/3238) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove unwraps from 'create\_primitive\_array' [\#3232](https://github.com/apache/arrow-rs/pull/3232) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([aarashy](https://github.com/aarashy)) +- Fix CI build by upgrading tonic-build to 0.8.4 [\#3231](https://github.com/apache/arrow-rs/pull/3231) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Remove negative scale check [\#3230](https://github.com/apache/arrow-rs/pull/3230) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update prost-build requirement from =0.11.2 to =0.11.3 [\#3225](https://github.com/apache/arrow-rs/pull/3225) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Get the round result for decimal to a decimal with smaller scale [\#3224](https://github.com/apache/arrow-rs/pull/3224) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Move tests which require chrono-tz feature from `arrow-cast` to `arrow` [\#3222](https://github.com/apache/arrow-rs/pull/3222) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- add test cases for extracting week with/without timezone [\#3218](https://github.com/apache/arrow-rs/pull/3218) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Use RegexSet for matching DataType [\#3217](https://github.com/apache/arrow-rs/pull/3217) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Update tonic-build to 0.8.3 [\#3214](https://github.com/apache/arrow-rs/pull/3214) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Support StructArray in Row Format \(\#3159\) [\#3212](https://github.com/apache/arrow-rs/pull/3212) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Infer timestamps from CSV files [\#3209](https://github.com/apache/arrow-rs/pull/3209) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- fix bug: cast decimal256 to other decimal with no-safe [\#3208](https://github.com/apache/arrow-rs/pull/3208) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- FlightSQL Client & integration test [\#3207](https://github.com/apache/arrow-rs/pull/3207) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Ensure StructArrays check nullability of fields [\#3205](https://github.com/apache/arrow-rs/pull/3205) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Remove special case ArrayData equality for decimals [\#3204](https://github.com/apache/arrow-rs/pull/3204) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add a cast test case for decimal negative scale [\#3203](https://github.com/apache/arrow-rs/pull/3203) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Move zip and shift kernels to arrow-select [\#3201](https://github.com/apache/arrow-rs/pull/3201) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate limit kernel [\#3200](https://github.com/apache/arrow-rs/pull/3200) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use SlicesIterator for ArrayData Equality [\#3198](https://github.com/apache/arrow-rs/pull/3198) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add \_dyn kernels of like, ilike, nlike, nilike kernels for dictionary support [\#3197](https://github.com/apache/arrow-rs/pull/3197) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Adding scalar nlike\_dyn, ilike\_dyn, nilike\_dyn kernels [\#3195](https://github.com/apache/arrow-rs/pull/3195) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Use self capture in DataType [\#3190](https://github.com/apache/arrow-rs/pull/3190) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- To pyarrow with schema [\#3188](https://github.com/apache/arrow-rs/pull/3188) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([doki23](https://github.com/doki23)) +- Support Duration in array\_value\_to\_string [\#3183](https://github.com/apache/arrow-rs/pull/3183) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Support `FixedSizeBinary` in Row format [\#3182](https://github.com/apache/arrow-rs/pull/3182) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add binary\_mut and try\_binary\_mut [\#3144](https://github.com/apache/arrow-rs/pull/3144) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add try\_unary\_mut [\#3134](https://github.com/apache/arrow-rs/pull/3134) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +## [28.0.0](https://github.com/apache/arrow-rs/tree/28.0.0) (2022-11-25) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/27.0.0...28.0.0) + +**Breaking changes:** + +- StructArray::columns return slice [\#3186](https://github.com/apache/arrow-rs/pull/3186) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Return slice from GenericByteArray::value\_data [\#3171](https://github.com/apache/arrow-rs/pull/3171) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support decimal negative scale [\#3152](https://github.com/apache/arrow-rs/pull/3152) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- refactor: convert `Field::metadata` to `HashMap` [\#3148](https://github.com/apache/arrow-rs/pull/3148) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Don't Skip Serializing Empty Metadata \(\#3082\) [\#3126](https://github.com/apache/arrow-rs/pull/3126) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Add Decimal128, Decimal256, Float16 to DataType::is\_numeric [\#3121](https://github.com/apache/arrow-rs/pull/3121) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Upgrade to thrift 0.17 and fix issues [\#3104](https://github.com/apache/arrow-rs/pull/3104) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- Fix prettyprint for Interval second fractions [\#3093](https://github.com/apache/arrow-rs/pull/3093) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Remove Option from `Field::metadata` [\#3091](https://github.com/apache/arrow-rs/pull/3091) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) + +**Implemented enhancements:** + +- Add iterator to RowSelection [\#3172](https://github.com/apache/arrow-rs/issues/3172) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- create an integration test set for parquet crate against pyspark for working with bloom filters [\#3167](https://github.com/apache/arrow-rs/issues/3167) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Row Format Size Tracking [\#3160](https://github.com/apache/arrow-rs/issues/3160) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add ArrayBuilder::finish\_cloned\(\) [\#3154](https://github.com/apache/arrow-rs/issues/3154) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Optimize memory usage of json reader [\#3150](https://github.com/apache/arrow-rs/issues/3150) +- Add `Field::size` and `DataType::size` [\#3147](https://github.com/apache/arrow-rs/issues/3147) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add like\_utf8\_scalar\_dyn kernel [\#3145](https://github.com/apache/arrow-rs/issues/3145) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- support comparison for decimal128 array with scalar in kernel [\#3140](https://github.com/apache/arrow-rs/issues/3140) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- audit and create a document for bloom filter configurations [\#3138](https://github.com/apache/arrow-rs/issues/3138) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Should be the rounding vs truncation when cast decimal to smaller scale [\#3137](https://github.com/apache/arrow-rs/issues/3137) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Upgrade chrono to 0.4.23 [\#3120](https://github.com/apache/arrow-rs/issues/3120) +- Implements more temporal kernels using time\_fraction\_dyn [\#3108](https://github.com/apache/arrow-rs/issues/3108) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Upgrade to thrift 0.17 [\#3105](https://github.com/apache/arrow-rs/issues/3105) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Be able to parse time formatted strings [\#3100](https://github.com/apache/arrow-rs/issues/3100) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve "Fail to merge schema" error messages [\#3095](https://github.com/apache/arrow-rs/issues/3095) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Expose `SortingColumn` when reading and writing parquet metadata [\#3090](https://github.com/apache/arrow-rs/issues/3090) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Change Field::metadata to HashMap [\#3086](https://github.com/apache/arrow-rs/issues/3086) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support bloom filter reading and writing for parquet [\#3023](https://github.com/apache/arrow-rs/issues/3023) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- API to take back ownership of an ArrayRef [\#2901](https://github.com/apache/arrow-rs/issues/2901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Specialized Interleave Kernel [\#2864](https://github.com/apache/arrow-rs/issues/2864) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- arithmetic overflow leads to segfault in `concat_batches` [\#3123](https://github.com/apache/arrow-rs/issues/3123) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Clippy failing on master : error: use of deprecated associated function chrono::NaiveDate::from\_ymd: use from\_ymd\_opt\(\) instead [\#3097](https://github.com/apache/arrow-rs/issues/3097) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Pretty print for interval types has wrong formatting [\#3092](https://github.com/apache/arrow-rs/issues/3092) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Field is not serializable with binary formats [\#3082](https://github.com/apache/arrow-rs/issues/3082) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Decimal Casts are Unchecked [\#2986](https://github.com/apache/arrow-rs/issues/2986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Release Arrow `27.0.0` \(next release after `26.0.0`\) [\#3045](https://github.com/apache/arrow-rs/issues/3045) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Perf about ParquetRecordBatchStream vs ParquetRecordBatchReader [\#2916](https://github.com/apache/arrow-rs/issues/2916) + +**Merged pull requests:** + +- Improve regex related kernels by upto 85% [\#3192](https://github.com/apache/arrow-rs/pull/3192) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Derive clone for arrays [\#3184](https://github.com/apache/arrow-rs/pull/3184) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Row decode cleanups [\#3180](https://github.com/apache/arrow-rs/pull/3180) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update zstd requirement from 0.11.1 to 0.12.0 [\#3178](https://github.com/apache/arrow-rs/pull/3178) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Move decimal constants from `arrow-data` to `arrow-schema` crate [\#3177](https://github.com/apache/arrow-rs/pull/3177) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- bloom filter part V: add an integration with pytest against pyspark [\#3176](https://github.com/apache/arrow-rs/pull/3176) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Bloom filter config tweaks \(\#3023\) [\#3175](https://github.com/apache/arrow-rs/pull/3175) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add RowParser [\#3174](https://github.com/apache/arrow-rs/pull/3174) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `RowSelection::iter()`, `Into>` and example [\#3173](https://github.com/apache/arrow-rs/pull/3173) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add read parquet examples [\#3170](https://github.com/apache/arrow-rs/pull/3170) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([xudong963](https://github.com/xudong963)) +- Faster BinaryArray to StringArray conversion \(~67%\) [\#3168](https://github.com/apache/arrow-rs/pull/3168) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove unnecessary downcasts in builders [\#3166](https://github.com/apache/arrow-rs/pull/3166) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- bloom filter part IV: adjust writer properties, bloom filter properties, and incorporate into column encoder [\#3165](https://github.com/apache/arrow-rs/pull/3165) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Fix parquet decimal precision [\#3164](https://github.com/apache/arrow-rs/pull/3164) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([psvri](https://github.com/psvri)) +- Add Row size methods \(\#3160\) [\#3163](https://github.com/apache/arrow-rs/pull/3163) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Prevent precision=0 for decimal type [\#3162](https://github.com/apache/arrow-rs/pull/3162) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Remove unnecessary Buffer::from\_slice\_ref reference [\#3161](https://github.com/apache/arrow-rs/pull/3161) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add finish\_cloned to ArrayBuilder [\#3158](https://github.com/apache/arrow-rs/pull/3158) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Check overflow in MutableArrayData extend offsets \(\#3123\) [\#3157](https://github.com/apache/arrow-rs/pull/3157) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Extend Decimal256 as Primitive [\#3156](https://github.com/apache/arrow-rs/pull/3156) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Doc improvements [\#3155](https://github.com/apache/arrow-rs/pull/3155) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Add collect.rs example [\#3153](https://github.com/apache/arrow-rs/pull/3153) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Implement Neg for i256 [\#3151](https://github.com/apache/arrow-rs/pull/3151) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: `{Field,DataType}::size` [\#3149](https://github.com/apache/arrow-rs/pull/3149) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add like\_utf8\_scalar\_dyn kernel [\#3146](https://github.com/apache/arrow-rs/pull/3146) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- comparison op: decimal128 array with scalar [\#3141](https://github.com/apache/arrow-rs/pull/3141) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Cast: should get the round result for decimal to a decimal with smaller scale [\#3139](https://github.com/apache/arrow-rs/pull/3139) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Fix Panic on Reading Corrupt Parquet Schema \(\#2855\) [\#3130](https://github.com/apache/arrow-rs/pull/3130) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([psvri](https://github.com/psvri)) +- Clippy parquet fixes [\#3124](https://github.com/apache/arrow-rs/pull/3124) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Add GenericByteBuilder \(\#2969\) [\#3122](https://github.com/apache/arrow-rs/pull/3122) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- parquet bloom filter part III: add sbbf writer, remove `bloom` default feature, add reader properties [\#3119](https://github.com/apache/arrow-rs/pull/3119) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Add downcast\_array \(\#2901\) [\#3117](https://github.com/apache/arrow-rs/pull/3117) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add COW conversion for Buffer and PrimitiveArray and unary\_mut [\#3115](https://github.com/apache/arrow-rs/pull/3115) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Include field name in merge error message [\#3113](https://github.com/apache/arrow-rs/pull/3113) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- Add PrimitiveArray::unary\_opt [\#3110](https://github.com/apache/arrow-rs/pull/3110) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implements more temporal kernels using time\_fraction\_dyn [\#3107](https://github.com/apache/arrow-rs/pull/3107) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- cast: support unsigned numeric type to decimal128 [\#3106](https://github.com/apache/arrow-rs/pull/3106) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Expose `SortingColumn` in parquet files [\#3103](https://github.com/apache/arrow-rs/pull/3103) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([askoa](https://github.com/askoa)) +- parquet bloom filter part II: read sbbf bitset from row group reader, update API, and add cli demo [\#3102](https://github.com/apache/arrow-rs/pull/3102) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Parse Time32/Time64 from formatted string [\#3101](https://github.com/apache/arrow-rs/pull/3101) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Cleanup temporal \_internal functions [\#3099](https://github.com/apache/arrow-rs/pull/3099) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Improve schema mismatch error message [\#3098](https://github.com/apache/arrow-rs/pull/3098) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- Fix clippy by avoiding deprecated functions in chrono [\#3096](https://github.com/apache/arrow-rs/pull/3096) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Minor: Add diagrams and documentation to row format [\#3094](https://github.com/apache/arrow-rs/pull/3094) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Minor: Use ArrowNativeTypeOp instead of total\_cmp directly [\#3087](https://github.com/apache/arrow-rs/pull/3087) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Check overflow while casting between decimal types [\#3076](https://github.com/apache/arrow-rs/pull/3076) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- add bloom filter implementation based on split block \(sbbf\) spec [\#3057](https://github.com/apache/arrow-rs/pull/3057) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Add FixedSizeBinaryArray::try\_from\_sparse\_iter\_with\_size [\#3054](https://github.com/apache/arrow-rs/pull/3054) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +## [27.0.0](https://github.com/apache/arrow-rs/tree/27.0.0) (2022-11-11) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/26.0.0...27.0.0) + +**Breaking changes:** + +- Recurse into Dictionary value type in DataType::is\_nested [\#3083](https://github.com/apache/arrow-rs/pull/3083) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- early type checks in `RowConverter` [\#3080](https://github.com/apache/arrow-rs/pull/3080) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Add Decimal128 and Decimal256 to downcast\_primitive [\#3056](https://github.com/apache/arrow-rs/pull/3056) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Replace remaining \_generic temporal kernels with \_dyn kernels [\#3046](https://github.com/apache/arrow-rs/pull/3046) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Replace year\_generic with year\_dyn [\#3041](https://github.com/apache/arrow-rs/pull/3041) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Validate decimal256 with i256 directly [\#3025](https://github.com/apache/arrow-rs/pull/3025) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Hadoop LZ4 Support for LZ4 Codec [\#3013](https://github.com/apache/arrow-rs/pull/3013) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([marioloko](https://github.com/marioloko)) +- Replace hour\_generic with hour\_dyn [\#3006](https://github.com/apache/arrow-rs/pull/3006) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Accept any &dyn Array in nullif kernel [\#2940](https://github.com/apache/arrow-rs/pull/2940) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Row Format: Option to detach/own a row [\#3078](https://github.com/apache/arrow-rs/issues/3078) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row Format: API to check if datatypes are supported [\#3077](https://github.com/apache/arrow-rs/issues/3077) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Deprecate Buffer::count\_set\_bits [\#3067](https://github.com/apache/arrow-rs/issues/3067) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Decimal128 and Decimal256 to downcast\_primitive [\#3055](https://github.com/apache/arrow-rs/issues/3055) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improved UX of creating `TimestampNanosecondArray` with timezones [\#3042](https://github.com/apache/arrow-rs/issues/3042) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast decimal256 to signed integer [\#3039](https://github.com/apache/arrow-rs/issues/3039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting Date64 to Timestamp [\#3037](https://github.com/apache/arrow-rs/issues/3037) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check overflow when casting floating point value to decimal256 [\#3032](https://github.com/apache/arrow-rs/issues/3032) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare i256 in validate\_decimal256\_precision [\#3024](https://github.com/apache/arrow-rs/issues/3024) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Check overflow when casting floating point value to decimal128 [\#3020](https://github.com/apache/arrow-rs/issues/3020) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add macro downcast\_temporal\_array [\#3008](https://github.com/apache/arrow-rs/issues/3008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace hour\_generic with hour\_dyn [\#3005](https://github.com/apache/arrow-rs/issues/3005) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace temporal \_generic kernels with dyn [\#3004](https://github.com/apache/arrow-rs/issues/3004) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `RowSelection::intersection` [\#3003](https://github.com/apache/arrow-rs/issues/3003) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- I would like to round rather than truncate when casting f64 to decimal [\#2997](https://github.com/apache/arrow-rs/issues/2997) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow::compute::kernels::temporal should support nanoseconds [\#2995](https://github.com/apache/arrow-rs/issues/2995) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release Arrow `26.0.0` \(next release after `25.0.0`\) [\#2953](https://github.com/apache/arrow-rs/issues/2953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Add timezone offset for debug format of Timestamp with Timezone [\#2917](https://github.com/apache/arrow-rs/issues/2917) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support merge RowSelectors when creating RowSelection [\#2858](https://github.com/apache/arrow-rs/issues/2858) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Inconsistent Nan Handling Between Scalar and Non-Scalar Comparison Kernels [\#3074](https://github.com/apache/arrow-rs/issues/3074) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Debug format for timestamp ignores timezone [\#3069](https://github.com/apache/arrow-rs/issues/3069) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Row format decode loses timezone [\#3063](https://github.com/apache/arrow-rs/issues/3063) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- binary operator produces incorrect result on arrays with resized null buffer [\#3061](https://github.com/apache/arrow-rs/issues/3061) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RLEDecoder Panics on Null Padded Pages [\#3035](https://github.com/apache/arrow-rs/issues/3035) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Nullif with incorrect valid\_count [\#3031](https://github.com/apache/arrow-rs/issues/3031) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RLEDecoder::get\_batch\_with\_dict may panic on bit-packed runs longer than 1024 [\#3029](https://github.com/apache/arrow-rs/issues/3029) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Converted type is None according to Parquet Tools then utilizing logical types [\#3017](https://github.com/apache/arrow-rs/issues/3017) +- CompressionCodec LZ4 incompatible with C++ implementation [\#2988](https://github.com/apache/arrow-rs/issues/2988) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Mark parquet predicate pushdown as complete [\#2987](https://github.com/apache/arrow-rs/pull/2987) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Merged pull requests:** + +- Improved UX of creating `TimestampNanosecondArray` with timezones [\#3088](https://github.com/apache/arrow-rs/pull/3088) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([src255](https://github.com/src255)) +- Remove unused range module [\#3085](https://github.com/apache/arrow-rs/pull/3085) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make intersect\_row\_selections a member function [\#3084](https://github.com/apache/arrow-rs/pull/3084) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update hashbrown requirement from 0.12 to 0.13 [\#3081](https://github.com/apache/arrow-rs/pull/3081) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat: add `OwnedRow` [\#3079](https://github.com/apache/arrow-rs/pull/3079) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Use ArrowNativeTypeOp on non-scalar comparison kernels [\#3075](https://github.com/apache/arrow-rs/pull/3075) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add missing inline to ArrowNativeTypeOp [\#3073](https://github.com/apache/arrow-rs/pull/3073) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- fix debug information for Timestamp with Timezone [\#3072](https://github.com/apache/arrow-rs/pull/3072) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Deprecate Buffer::count\_set\_bits \(\#3067\) [\#3071](https://github.com/apache/arrow-rs/pull/3071) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add compare to ArrowNativeTypeOp [\#3070](https://github.com/apache/arrow-rs/pull/3070) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Improve docstrings on WriterPropertiesBuilder [\#3068](https://github.com/apache/arrow-rs/pull/3068) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Faster f64 inequality [\#3065](https://github.com/apache/arrow-rs/pull/3065) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix row format decode loses timezone \(\#3063\) [\#3064](https://github.com/apache/arrow-rs/pull/3064) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix null\_count computation in binary [\#3062](https://github.com/apache/arrow-rs/pull/3062) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Faster f64 equality [\#3060](https://github.com/apache/arrow-rs/pull/3060) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update arrow-flight subcrates \(\#3044\) [\#3052](https://github.com/apache/arrow-rs/pull/3052) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Minor: Remove cloning ArrayData in with\_precision\_and\_scale [\#3050](https://github.com/apache/arrow-rs/pull/3050) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Split out arrow-json \(\#3044\) [\#3049](https://github.com/apache/arrow-rs/pull/3049) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move `intersect_row_selections` from datafusion to arrow-rs. [\#3047](https://github.com/apache/arrow-rs/pull/3047) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Split out arrow-csv \(\#2594\) [\#3044](https://github.com/apache/arrow-rs/pull/3044) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move reader\_parser to arrow-cast \(\#3022\) [\#3043](https://github.com/apache/arrow-rs/pull/3043) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cast decimal256 to signed integer [\#3040](https://github.com/apache/arrow-rs/pull/3040) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Enable casting from Date64 to Timestamp [\#3038](https://github.com/apache/arrow-rs/pull/3038) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gruuya](https://github.com/gruuya)) +- Fix decoding long and/or padded RLE data \(\#3029\) \(\#3035\) [\#3036](https://github.com/apache/arrow-rs/pull/3036) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix nullif when existing array has no nulls [\#3034](https://github.com/apache/arrow-rs/pull/3034) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Check overflow when casting floating point value to decimal256 [\#3033](https://github.com/apache/arrow-rs/pull/3033) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update parquet to depend on arrow subcrates [\#3028](https://github.com/apache/arrow-rs/pull/3028) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make various i256 methods const [\#3026](https://github.com/apache/arrow-rs/pull/3026) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-ipc [\#3022](https://github.com/apache/arrow-rs/pull/3022) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Check overflow while casting floating point value to decimal128 [\#3021](https://github.com/apache/arrow-rs/pull/3021) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update arrow-flight [\#3019](https://github.com/apache/arrow-rs/pull/3019) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Move ArrowNativeTypeOp to arrow-array \(\#2594\) [\#3018](https://github.com/apache/arrow-rs/pull/3018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support cast timestamp to time [\#3016](https://github.com/apache/arrow-rs/pull/3016) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([naosense](https://github.com/naosense)) +- Add filter example [\#3014](https://github.com/apache/arrow-rs/pull/3014) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Check overflow when casting integer to decimal [\#3009](https://github.com/apache/arrow-rs/pull/3009) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add macro downcast\_temporal\_array [\#3007](https://github.com/apache/arrow-rs/pull/3007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Parquet Writer: Make column descriptor public on the writer [\#3002](https://github.com/apache/arrow-rs/pull/3002) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([pier-oliviert](https://github.com/pier-oliviert)) +- Update chrono-tz requirement from 0.7 to 0.8 [\#3001](https://github.com/apache/arrow-rs/pull/3001) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Round instead of Truncate while casting float to decimal [\#3000](https://github.com/apache/arrow-rs/pull/3000) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Support Predicate Pushdown for Parquet Lists \(\#2108\) [\#2999](https://github.com/apache/arrow-rs/pull/2999) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-cast \(\#2594\) [\#2998](https://github.com/apache/arrow-rs/pull/2998) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- `arrow::compute::kernels::temporal` should support nanoseconds [\#2996](https://github.com/apache/arrow-rs/pull/2996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Add `RowSelection::from_selectors_and_combine` to merge RowSelectors [\#2994](https://github.com/apache/arrow-rs/pull/2994) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Simplify Single-Column Dictionary Sort [\#2993](https://github.com/apache/arrow-rs/pull/2993) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Minor: Add entry to changelog for 26.0.0 RC2 fix [\#2992](https://github.com/apache/arrow-rs/pull/2992) ([alamb](https://github.com/alamb)) +- Fix ignored limit on `lexsort_to_indices` [\#2991](https://github.com/apache/arrow-rs/pull/2991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add clone and equal functions for CastOptions [\#2985](https://github.com/apache/arrow-rs/pull/2985) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- minor: remove redundant prefix [\#2983](https://github.com/apache/arrow-rs/pull/2983) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([jackwener](https://github.com/jackwener)) +- Compare dictionary decimal arrays [\#2982](https://github.com/apache/arrow-rs/pull/2982) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Compare dictionary and non-dictionary decimal arrays [\#2980](https://github.com/apache/arrow-rs/pull/2980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add decimal comparison kernel support [\#2978](https://github.com/apache/arrow-rs/pull/2978) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Move concat kernel to arrow-select \(\#2594\) [\#2976](https://github.com/apache/arrow-rs/pull/2976) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Specialize interleave for byte arrays \(\#2864\) [\#2975](https://github.com/apache/arrow-rs/pull/2975) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use unary function for numeric to decimal cast [\#2973](https://github.com/apache/arrow-rs/pull/2973) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Specialize filter kernel for binary arrays \(\#2969\) [\#2971](https://github.com/apache/arrow-rs/pull/2971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Combine take\_utf8 and take\_binary \(\#2969\) [\#2970](https://github.com/apache/arrow-rs/pull/2970) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster Scalar Dictionary Comparison ~10% [\#2968](https://github.com/apache/arrow-rs/pull/2968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move `byte_size` from datafusion::physical\_expr [\#2965](https://github.com/apache/arrow-rs/pull/2965) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Pass decompressed size to parquet Codec::decompress \(\#2956\) [\#2959](https://github.com/apache/arrow-rs/pull/2959) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([marioloko](https://github.com/marioloko)) +- Add Decimal Arithmetic [\#2881](https://github.com/apache/arrow-rs/pull/2881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +## [26.0.0](https://github.com/apache/arrow-rs/tree/26.0.0) (2022-10-28) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/25.0.0...26.0.0) + +**Breaking changes:** + +- Cast Timestamps to RFC3339 strings [\#2934](https://github.com/apache/arrow-rs/issues/2934) +- Remove Unused NativeDecimalType [\#2945](https://github.com/apache/arrow-rs/pull/2945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Format Timestamps as RFC3339 [\#2939](https://github.com/apache/arrow-rs/pull/2939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Update flatbuffers to resolve RUSTSEC-2021-0122 [\#2895](https://github.com/apache/arrow-rs/pull/2895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- replace `from_timestamp` by `from_timestamp_opt` [\#2894](https://github.com/apache/arrow-rs/pull/2894) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) + +**Implemented enhancements:** + +- Optimized way to count the numbers of `true` and `false` values in a BooleanArray [\#2963](https://github.com/apache/arrow-rs/issues/2963) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add pow to i256 [\#2954](https://github.com/apache/arrow-rs/issues/2954) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Write Generic Code over \[Large\]BinaryArray and \[Large\]StringArray [\#2946](https://github.com/apache/arrow-rs/issues/2946) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Page Row Count Limit [\#2941](https://github.com/apache/arrow-rs/issues/2941) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- prettyprint to show timezone offset for timestamp with timezone [\#2937](https://github.com/apache/arrow-rs/issues/2937) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast numeric to decimal256 [\#2922](https://github.com/apache/arrow-rs/issues/2922) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `freeze_with_dictionary` API to `MutableArrayData` [\#2914](https://github.com/apache/arrow-rs/issues/2914) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support decimal256 array in sort kernels [\#2911](https://github.com/apache/arrow-rs/issues/2911) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- support `[+/-]hhmm` and `[+/-]hh` as fixedoffset timezone format [\#2910](https://github.com/apache/arrow-rs/issues/2910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cleanup decimal sort function [\#2907](https://github.com/apache/arrow-rs/issues/2907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- replace `from_timestamp` by `from_timestamp_opt` [\#2892](https://github.com/apache/arrow-rs/issues/2892) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Move Primitive arity kernels to arrow-array [\#2787](https://github.com/apache/arrow-rs/issues/2787) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- add overflow-checking for negative arithmetic kernel [\#2662](https://github.com/apache/arrow-rs/issues/2662) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Subtle compatibility issue with serve\_arrow [\#2952](https://github.com/apache/arrow-rs/issues/2952) +- error\[E0599\]: no method named `total_cmp` found for struct `f16` in the current scope [\#2926](https://github.com/apache/arrow-rs/issues/2926) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fail at rowSelection `and_then` method [\#2925](https://github.com/apache/arrow-rs/issues/2925) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Ordering not implemented for FixedSizeBinary types [\#2904](https://github.com/apache/arrow-rs/issues/2904) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet API: Could not convert timestamp before unix epoch to string/json [\#2897](https://github.com/apache/arrow-rs/issues/2897) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Overly Pessimistic RLE Size Estimation [\#2889](https://github.com/apache/arrow-rs/issues/2889) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Memory alignment error in `RawPtrBox::new` [\#2882](https://github.com/apache/arrow-rs/issues/2882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compilation error under chrono-tz feature [\#2878](https://github.com/apache/arrow-rs/issues/2878) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- AHash Statically Allocates 64 bytes [\#2875](https://github.com/apache/arrow-rs/issues/2875) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `parquet::arrow::arrow_writer::ArrowWriter` ignores page size properties [\#2853](https://github.com/apache/arrow-rs/issues/2853) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Document crate topology \(\#2594\) [\#2913](https://github.com/apache/arrow-rs/pull/2913) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Closed issues:** + +- SerializedFileWriter comments about multiple call on consumed self [\#2935](https://github.com/apache/arrow-rs/issues/2935) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Pointer freed error when deallocating ArrayData with shared memory buffer [\#2874](https://github.com/apache/arrow-rs/issues/2874) +- Release Arrow `25.0.0` \(next release after `24.0.0`\) [\#2820](https://github.com/apache/arrow-rs/issues/2820) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Replace DecimalArray with PrimitiveArray [\#2637](https://github.com/apache/arrow-rs/issues/2637) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix ignored limit on lexsort\_to\_indices (#2991) [\#2991](https://github.com/apache/arrow-rs/pull/2991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix GenericListArray::try\_new\_from\_array\_data error message \(\#526\) [\#2961](https://github.com/apache/arrow-rs/pull/2961) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix take string on sliced indices [\#2960](https://github.com/apache/arrow-rs/pull/2960) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add BooleanArray::true\_count and BooleanArray::false\_count [\#2957](https://github.com/apache/arrow-rs/pull/2957) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add pow to i256 [\#2955](https://github.com/apache/arrow-rs/pull/2955) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix datatype for timestamptz debug fmt [\#2948](https://github.com/apache/arrow-rs/pull/2948) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Add GenericByteArray \(\#2946\) [\#2947](https://github.com/apache/arrow-rs/pull/2947) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Specialize interleave string ~2-3x faster [\#2944](https://github.com/apache/arrow-rs/pull/2944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Added support for LZ4\_RAW compression. \(\#1604\) [\#2943](https://github.com/apache/arrow-rs/pull/2943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([marioloko](https://github.com/marioloko)) +- Add optional page row count limit for parquet `WriterProperties` \(\#2941\) [\#2942](https://github.com/apache/arrow-rs/pull/2942) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cleanup orphaned doc comments \(\#2935\) [\#2938](https://github.com/apache/arrow-rs/pull/2938) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- support more fixedoffset tz format [\#2936](https://github.com/apache/arrow-rs/pull/2936) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Benchmark with prepared row converter [\#2930](https://github.com/apache/arrow-rs/pull/2930) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add lexsort benchmark \(\#2871\) [\#2929](https://github.com/apache/arrow-rs/pull/2929) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve panic messages for RowSelection::and\_then \(\#2925\) [\#2928](https://github.com/apache/arrow-rs/pull/2928) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update required half from 2.0 --\> 2.1 [\#2927](https://github.com/apache/arrow-rs/pull/2927) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Cast numeric to decimal256 [\#2923](https://github.com/apache/arrow-rs/pull/2923) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cleanup generated proto code [\#2921](https://github.com/apache/arrow-rs/pull/2921) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Deprecate TimestampArray from\_vec and from\_opt\_vec [\#2919](https://github.com/apache/arrow-rs/pull/2919) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support decimal256 array in sort kernels [\#2912](https://github.com/apache/arrow-rs/pull/2912) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add timezone abstraction [\#2909](https://github.com/apache/arrow-rs/pull/2909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Cleanup decimal sort function [\#2908](https://github.com/apache/arrow-rs/pull/2908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Simplify TimestampArray from\_vec with timezone [\#2906](https://github.com/apache/arrow-rs/pull/2906) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement ord for FixedSizeBinary types [\#2905](https://github.com/apache/arrow-rs/pull/2905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([maxburke](https://github.com/maxburke)) +- Update chrono-tz requirement from 0.6 to 0.7 [\#2903](https://github.com/apache/arrow-rs/pull/2903) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Parquet record api support timestamp before epoch [\#2899](https://github.com/apache/arrow-rs/pull/2899) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([AnthonyPoncet](https://github.com/AnthonyPoncet)) +- Specialize interleave integer [\#2898](https://github.com/apache/arrow-rs/pull/2898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support overflow-checking variant of negate kernel [\#2893](https://github.com/apache/arrow-rs/pull/2893) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Respect Page Size Limits in ArrowWriter \(\#2853\) [\#2890](https://github.com/apache/arrow-rs/pull/2890) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Improve row format docs [\#2888](https://github.com/apache/arrow-rs/pull/2888) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add FixedSizeList::from\_iter\_primitive [\#2887](https://github.com/apache/arrow-rs/pull/2887) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify ListArray::from\_iter\_primitive [\#2886](https://github.com/apache/arrow-rs/pull/2886) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out value selection kernels into arrow-select \(\#2594\) [\#2885](https://github.com/apache/arrow-rs/pull/2885) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Increase default IPC alignment to 64 \(\#2883\) [\#2884](https://github.com/apache/arrow-rs/pull/2884) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Copying inappropriately aligned buffer in ipc reader [\#2883](https://github.com/apache/arrow-rs/pull/2883) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Validate decimal IPC read \(\#2387\) [\#2880](https://github.com/apache/arrow-rs/pull/2880) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix compilation error under `chrono-tz` feature [\#2879](https://github.com/apache/arrow-rs/pull/2879) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Don't validate decimal precision in ArrayData \(\#2637\) [\#2873](https://github.com/apache/arrow-rs/pull/2873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add downcast\_integer and downcast\_primitive [\#2872](https://github.com/apache/arrow-rs/pull/2872) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Filter DecimalArray as PrimitiveArray ~5x Faster \(\#2637\) [\#2870](https://github.com/apache/arrow-rs/pull/2870) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Treat DecimalArray as PrimitiveArray in row format [\#2866](https://github.com/apache/arrow-rs/pull/2866) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +## [25.0.0](https://github.com/apache/arrow-rs/tree/25.0.0) (2022-10-14) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/24.0.0...25.0.0) + +**Breaking changes:** + +- Make DecimalArray as PrimitiveArray [\#2857](https://github.com/apache/arrow-rs/pull/2857) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix timestamp parsing while no explicit timezone given [\#2814](https://github.com/apache/arrow-rs/pull/2814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Support Arbitrary Number of Arrays in downcast\_primitive\_array [\#2809](https://github.com/apache/arrow-rs/pull/2809) ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Restore Integration test JSON schema serialization [\#2876](https://github.com/apache/arrow-rs/issues/2876) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix various invalid\_html\_tags clippy error [\#2861](https://github.com/apache/arrow-rs/issues/2861) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Replace complicated temporal macro with generic functions [\#2851](https://github.com/apache/arrow-rs/issues/2851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add NaN handling in dyn scalar comparison kernels [\#2829](https://github.com/apache/arrow-rs/issues/2829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variant of sum kernel [\#2821](https://github.com/apache/arrow-rs/issues/2821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update to Clap 4 [\#2817](https://github.com/apache/arrow-rs/issues/2817) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Safe API to Operate on Dictionary Values [\#2797](https://github.com/apache/arrow-rs/issues/2797) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add modulus op into `ArrowNativeTypeOp` [\#2753](https://github.com/apache/arrow-rs/issues/2753) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow creating of TimeUnit instances without direct dependency on parquet-format [\#2708](https://github.com/apache/arrow-rs/issues/2708) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Arrow Row Format [\#2677](https://github.com/apache/arrow-rs/issues/2677) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Don't try to infer nulls in CSV schema inference [\#2859](https://github.com/apache/arrow-rs/issues/2859) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `parquet::arrow::arrow_writer::ArrowWriter` ignores page size properties [\#2853](https://github.com/apache/arrow-rs/issues/2853) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Introducing ArrowNativeTypeOp made it impossible to call kernels from generics [\#2839](https://github.com/apache/arrow-rs/issues/2839) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Unsound ArrayData to Array Conversions [\#2834](https://github.com/apache/arrow-rs/issues/2834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Regression: `the trait bound for<'de> arrow::datatypes::Schema: serde::de::Deserialize<'de> is not satisfied` [\#2825](https://github.com/apache/arrow-rs/issues/2825) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- convert string to timestamp shouldn't apply local timezone offset if there's no explicit timezone info in the string [\#2813](https://github.com/apache/arrow-rs/issues/2813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Add pub api for checking column index is sorted [\#2848](https://github.com/apache/arrow-rs/issues/2848) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Take decimal as primitive \(\#2637\) [\#2869](https://github.com/apache/arrow-rs/pull/2869) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-integration-test crate [\#2868](https://github.com/apache/arrow-rs/pull/2868) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Decimal cleanup \(\#2637\) [\#2865](https://github.com/apache/arrow-rs/pull/2865) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix various invalid\_html\_tags clippy errors [\#2862](https://github.com/apache/arrow-rs/pull/2862) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Don't try to infer nullability in CSV reader [\#2860](https://github.com/apache/arrow-rs/pull/2860) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Fix page size on dictionary fallback [\#2854](https://github.com/apache/arrow-rs/pull/2854) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Replace complicated temporal macro with generic functions [\#2850](https://github.com/apache/arrow-rs/pull/2850) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- \[feat\] Add pub api for checking column index is sorted. [\#2849](https://github.com/apache/arrow-rs/pull/2849) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- parquet: Add `snap` option to README [\#2847](https://github.com/apache/arrow-rs/pull/2847) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([exyi](https://github.com/exyi)) +- Cleanup cast kernel [\#2846](https://github.com/apache/arrow-rs/pull/2846) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify ArrowNativeType [\#2841](https://github.com/apache/arrow-rs/pull/2841) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Expose ArrowNativeTypeOp trait to make it useful for type bound [\#2840](https://github.com/apache/arrow-rs/pull/2840) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add `interleave` kernel \(\#1523\) [\#2838](https://github.com/apache/arrow-rs/pull/2838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Handle empty offsets buffer \(\#1824\) [\#2836](https://github.com/apache/arrow-rs/pull/2836) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Validate ArrayData type when converting to Array \(\#2834\) [\#2835](https://github.com/apache/arrow-rs/pull/2835) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Derive ArrowPrimitiveType for Decimal128Type and Decimal256Type \(\#2637\) [\#2833](https://github.com/apache/arrow-rs/pull/2833) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add NaN handling in dyn scalar comparison kernels [\#2830](https://github.com/apache/arrow-rs/pull/2830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Simplify OrderPreservingInterner allocation strategy ~97% faster \(\#2677\) [\#2827](https://github.com/apache/arrow-rs/pull/2827) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Convert rows to arrays \(\#2677\) [\#2826](https://github.com/apache/arrow-rs/pull/2826) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add overflow-checking variant of sum kernel [\#2822](https://github.com/apache/arrow-rs/pull/2822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update Clap dependency to version 4 [\#2819](https://github.com/apache/arrow-rs/pull/2819) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jgoday](https://github.com/jgoday)) +- Fix i256 checked multiplication [\#2818](https://github.com/apache/arrow-rs/pull/2818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add string\_dictionary benches for row format \(\#2677\) [\#2816](https://github.com/apache/arrow-rs/pull/2816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add OrderPreservingInterner::lookup \(\#2677\) [\#2815](https://github.com/apache/arrow-rs/pull/2815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify FixedLengthEncoding [\#2812](https://github.com/apache/arrow-rs/pull/2812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement ArrowNumericType for Float16Type [\#2810](https://github.com/apache/arrow-rs/pull/2810) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add DictionaryArray::with\_values to make it easier to operate on dictionary values [\#2798](https://github.com/apache/arrow-rs/pull/2798) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add i256 \(\#2637\) [\#2781](https://github.com/apache/arrow-rs/pull/2781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add modulus ops into `ArrowNativeTypeOp` [\#2756](https://github.com/apache/arrow-rs/pull/2756) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- feat: cast List / LargeList to Utf8 / LargeUtf8 [\#2588](https://github.com/apache/arrow-rs/pull/2588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gandronchik](https://github.com/gandronchik)) + +## [24.0.0](https://github.com/apache/arrow-rs/tree/24.0.0) (2022-09-30) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/23.0.0...24.0.0) + +**Breaking changes:** + +- Cleanup `ArrowNativeType` \(\#1918\) [\#2793](https://github.com/apache/arrow-rs/pull/2793) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove `ArrowNativeType::FromStr` [\#2775](https://github.com/apache/arrow-rs/pull/2775) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out `arrow-array` crate \(\#2594\) [\#2769](https://github.com/apache/arrow-rs/pull/2769) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `dyn_arith_dict` feature flag [\#2760](https://github.com/apache/arrow-rs/pull/2760) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out `arrow-data` into a separate crate [\#2746](https://github.com/apache/arrow-rs/pull/2746) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-schema \(\#2594\) [\#2711](https://github.com/apache/arrow-rs/pull/2711) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Include field name in Parquet PrimitiveTypeBuilder error messages [\#2804](https://github.com/apache/arrow-rs/issues/2804) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add PrimitiveArray::reinterpret\_cast [\#2785](https://github.com/apache/arrow-rs/issues/2785) +- BinaryBuilder and StringBuilder initialization parameters in struct\_builder may be wrong [\#2783](https://github.com/apache/arrow-rs/issues/2783) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add divide scalar dyn kernel which produces null for division by zero [\#2767](https://github.com/apache/arrow-rs/issues/2767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add divide dyn kernel which produces null for division by zero [\#2763](https://github.com/apache/arrow-rs/issues/2763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of checked kernels on non-null data [\#2747](https://github.com/apache/arrow-rs/issues/2747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variants of arithmetic dyn kernels [\#2739](https://github.com/apache/arrow-rs/issues/2739) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- The `binary` function should not panic on unequaled array length. [\#2721](https://github.com/apache/arrow-rs/issues/2721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- min compute kernel is incorrect with sliced buffers in arrow 23 [\#2779](https://github.com/apache/arrow-rs/issues/2779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `try_unary_dict` should check value type of dictionary array [\#2754](https://github.com/apache/arrow-rs/issues/2754) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Add back JSON import/export for schema [\#2762](https://github.com/apache/arrow-rs/issues/2762) +- null casting and coercion for Decimal128 [\#2761](https://github.com/apache/arrow-rs/issues/2761) +- Json decoder behavior changed from versions 21 to 21 and returns non-sensical num\_rows for RecordBatch [\#2722](https://github.com/apache/arrow-rs/issues/2722) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release Arrow `23.0.0` \(next release after `22.0.0`\) [\#2665](https://github.com/apache/arrow-rs/issues/2665) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Merged pull requests:** + +- add field name to parquet PrimitiveTypeBuilder error messages [\#2805](https://github.com/apache/arrow-rs/pull/2805) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([andygrove](https://github.com/andygrove)) +- Add struct equality test case \(\#514\) [\#2791](https://github.com/apache/arrow-rs/pull/2791) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move unary kernels to arrow-array \(\#2787\) [\#2789](https://github.com/apache/arrow-rs/pull/2789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Disable test harness for string\_dictionary\_builder benchmark [\#2788](https://github.com/apache/arrow-rs/pull/2788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add PrimitiveArray::reinterpret\_cast \(\#2785\) [\#2786](https://github.com/apache/arrow-rs/pull/2786) ([tustvold](https://github.com/tustvold)) +- Fix BinaryBuilder and StringBuilder Capacity Allocation in StructBuilder [\#2784](https://github.com/apache/arrow-rs/pull/2784) ([chunshao90](https://github.com/chunshao90)) +- Fix min/max computation for sliced arrays \(\#2779\) [\#2780](https://github.com/apache/arrow-rs/pull/2780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix Backwards Compatible Parquet List Encodings \(\#1915\) [\#2774](https://github.com/apache/arrow-rs/pull/2774) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- MINOR: Fix clippy for rust 1.64.0 [\#2772](https://github.com/apache/arrow-rs/pull/2772) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- MINOR: Fix clippy for rust 1.64.0 [\#2771](https://github.com/apache/arrow-rs/pull/2771) ([viirya](https://github.com/viirya)) +- Add divide scalar dyn kernel which produces null for division by zero [\#2768](https://github.com/apache/arrow-rs/pull/2768) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add divide dyn kernel which produces null for division by zero [\#2764](https://github.com/apache/arrow-rs/pull/2764) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add value type check in try\_unary\_dict [\#2755](https://github.com/apache/arrow-rs/pull/2755) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix `verify_release_candidate.sh` for new arrow subcrates [\#2752](https://github.com/apache/arrow-rs/pull/2752) ([alamb](https://github.com/alamb)) +- Fix: Issue 2721 : binary function should not panic but return error w… [\#2750](https://github.com/apache/arrow-rs/pull/2750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([aksharau](https://github.com/aksharau)) +- Speed up checked kernels for non-null data \(~1.4-5x faster\) [\#2749](https://github.com/apache/arrow-rs/pull/2749) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Add overflow-checking variants of arithmetic dyn kernels [\#2740](https://github.com/apache/arrow-rs/pull/2740) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Trim parquet row selection [\#2705](https://github.com/apache/arrow-rs/pull/2705) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +## [23.0.0](https://github.com/apache/arrow-rs/tree/24.0.0) (2022-09-16) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/22.0.0...23.0.0) + +**Breaking changes:** + +- Move JSON Test Format To integration-testing [\#2724](https://github.com/apache/arrow-rs/pull/2724) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Split out arrow-buffer crate \(\#2594\) [\#2693](https://github.com/apache/arrow-rs/pull/2693) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Simplify DictionaryBuilder constructors \(\#2684\) \(\#2054\) [\#2685](https://github.com/apache/arrow-rs/pull/2685) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Deprecate RecordBatch::concat replace with concat\_batches \(\#2594\) [\#2683](https://github.com/apache/arrow-rs/pull/2683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add overflow-checking variant for primitive arithmetic kernels and explicitly define overflow behavior [\#2643](https://github.com/apache/arrow-rs/pull/2643) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update thrift v0.16 and vendor parquet-format \(\#2502\) [\#2626](https://github.com/apache/arrow-rs/pull/2626) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update flight definitions including backwards-incompatible change to GetSchema [\#2586](https://github.com/apache/arrow-rs/pull/2586) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([liukun4515](https://github.com/liukun4515)) + +**Implemented enhancements:** + +- Cleanup like and nlike utf8 kernels [\#2744](https://github.com/apache/arrow-rs/issues/2744) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speedup eq and neq kernels for utf8 arrays [\#2742](https://github.com/apache/arrow-rs/issues/2742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- API for more ergonomic construction of `RecordBatchOptions` [\#2728](https://github.com/apache/arrow-rs/issues/2728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Automate updates to `CHANGELOG-old.md` [\#2726](https://github.com/apache/arrow-rs/issues/2726) +- Don't check the `DivideByZero` error for float modulus [\#2720](https://github.com/apache/arrow-rs/issues/2720) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `try_binary` should not panic on unequaled array length. [\#2715](https://github.com/apache/arrow-rs/issues/2715) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add benchmark for bitwise operation [\#2714](https://github.com/apache/arrow-rs/issues/2714) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variants of arithmetic scalar dyn kernels [\#2712](https://github.com/apache/arrow-rs/issues/2712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add divide\_opt kernel which produce null values on division by zero error [\#2709](https://github.com/apache/arrow-rs/issues/2709) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `DataType` function to detect nested types [\#2704](https://github.com/apache/arrow-rs/issues/2704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support of sorting dictionary of other primitive types [\#2700](https://github.com/apache/arrow-rs/issues/2700) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Sort indices of dictionary string values [\#2697](https://github.com/apache/arrow-rs/issues/2697) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support empty projection in `RecordBatch::project` [\#2690](https://github.com/apache/arrow-rs/issues/2690) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support sorting dictionary encoded primitive integer arrays [\#2679](https://github.com/apache/arrow-rs/issues/2679) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use BitIndexIterator in min\_max\_helper [\#2674](https://github.com/apache/arrow-rs/issues/2674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support building comparator for dictionaries of primitive integer values [\#2672](https://github.com/apache/arrow-rs/issues/2672) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Change max/min string macro to generic helper function `min_max_helper` [\#2657](https://github.com/apache/arrow-rs/issues/2657) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variant of arithmetic scalar kernels [\#2651](https://github.com/apache/arrow-rs/issues/2651) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with binary array [\#2644](https://github.com/apache/arrow-rs/issues/2644) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add overflow-checking variant for primitive arithmetic kernels [\#2642](https://github.com/apache/arrow-rs/issues/2642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `downcast_primitive_array` in arithmetic kernels [\#2639](https://github.com/apache/arrow-rs/issues/2639) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support DictionaryArray in temporal kernels [\#2622](https://github.com/apache/arrow-rs/issues/2622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Inline Generated Thift Code Into Parquet Crate [\#2502](https://github.com/apache/arrow-rs/issues/2502) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Escape contains patterns for utf8 like kernels [\#2745](https://github.com/apache/arrow-rs/issues/2745) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Float Array should not panic on `DivideByZero` in the `Divide` kernel [\#2719](https://github.com/apache/arrow-rs/issues/2719) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- DictionaryBuilders can Create Invalid DictionaryArrays [\#2684](https://github.com/apache/arrow-rs/issues/2684) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow` crate does not build with `features = ["ffi"]` and `default_features = false`. [\#2670](https://github.com/apache/arrow-rs/issues/2670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Invalid results with `RowSelector` having `row_count` of 0 [\#2669](https://github.com/apache/arrow-rs/issues/2669) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- clippy error: unresolved import `crate::array::layout` [\#2659](https://github.com/apache/arrow-rs/issues/2659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast the numeric without the `CastOptions` [\#2648](https://github.com/apache/arrow-rs/issues/2648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Explicitly define overflow behavior for primitive arithmetic kernels [\#2641](https://github.com/apache/arrow-rs/issues/2641) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- update the `flight.proto` and fix schema to SchemaResult [\#2571](https://github.com/apache/arrow-rs/issues/2571) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Panic when first data page is skipped using ColumnChunkData::Sparse [\#2543](https://github.com/apache/arrow-rs/issues/2543) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `SchemaResult` in IPC deviates from other implementations [\#2445](https://github.com/apache/arrow-rs/issues/2445) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] + +**Closed issues:** + +- Implement collect for int values [\#2696](https://github.com/apache/arrow-rs/issues/2696) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Speedup string equal/not equal to empty string, cleanup like/ilike kernels, fix escape bug [\#2743](https://github.com/apache/arrow-rs/pull/2743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Partially flatten arrow-buffer [\#2737](https://github.com/apache/arrow-rs/pull/2737) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Automate updates to `CHANGELOG-old.md` [\#2732](https://github.com/apache/arrow-rs/pull/2732) ([iajoiner](https://github.com/iajoiner)) +- Update read parquet example in parquet/arrow home [\#2730](https://github.com/apache/arrow-rs/pull/2730) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([datapythonista](https://github.com/datapythonista)) +- Better construction of RecordBatchOptions [\#2729](https://github.com/apache/arrow-rs/pull/2729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([askoa](https://github.com/askoa)) +- benchmark: bitwise operation [\#2718](https://github.com/apache/arrow-rs/pull/2718) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Update `try_binary` and `checked_ops`, and remove `math_checked_op` [\#2717](https://github.com/apache/arrow-rs/pull/2717) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Support bitwise op in kernel: or,xor,not [\#2716](https://github.com/apache/arrow-rs/pull/2716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Add overflow-checking variants of arithmetic scalar dyn kernels [\#2713](https://github.com/apache/arrow-rs/pull/2713) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add divide\_opt kernel which produce null values on division by zero error [\#2710](https://github.com/apache/arrow-rs/pull/2710) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add DataType::is\_nested\(\) [\#2707](https://github.com/apache/arrow-rs/pull/2707) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kfastov](https://github.com/kfastov)) +- Update criterion requirement from 0.3 to 0.4 [\#2706](https://github.com/apache/arrow-rs/pull/2706) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Support bitwise and operation in the kernel [\#2703](https://github.com/apache/arrow-rs/pull/2703) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Add support of sorting dictionary of other primitive arrays [\#2701](https://github.com/apache/arrow-rs/pull/2701) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Clarify docs of binary and string builders [\#2699](https://github.com/apache/arrow-rs/pull/2699) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([datapythonista](https://github.com/datapythonista)) +- Sort indices of dictionary string values [\#2698](https://github.com/apache/arrow-rs/pull/2698) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add support for empty projection in RecordBatch::project [\#2691](https://github.com/apache/arrow-rs/pull/2691) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Temporarily disable Golang integration tests re-enable JS [\#2689](https://github.com/apache/arrow-rs/pull/2689) ([tustvold](https://github.com/tustvold)) +- Verify valid UTF-8 when converting byte array \(\#2205\) [\#2686](https://github.com/apache/arrow-rs/pull/2686) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support sorting dictionary encoded primitive integer arrays [\#2680](https://github.com/apache/arrow-rs/pull/2680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Skip RowSelectors with zero rows [\#2678](https://github.com/apache/arrow-rs/pull/2678) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([askoa](https://github.com/askoa)) +- Faster Null Path Selection in ArrayData Equality [\#2676](https://github.com/apache/arrow-rs/pull/2676) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dhruv9vats](https://github.com/dhruv9vats)) +- Use BitIndexIterator in min\_max\_helper [\#2675](https://github.com/apache/arrow-rs/pull/2675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support building comparator for dictionaries of primitive integer values [\#2673](https://github.com/apache/arrow-rs/pull/2673) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- json feature always requires base64 feature [\#2668](https://github.com/apache/arrow-rs/pull/2668) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([eagletmt](https://github.com/eagletmt)) +- Add try\_unary, binary, try\_binary kernels ~90% faster [\#2666](https://github.com/apache/arrow-rs/pull/2666) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use downcast\_dictionary\_array in unary\_dyn [\#2663](https://github.com/apache/arrow-rs/pull/2663) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- optimize the `numeric_cast_with_error` [\#2661](https://github.com/apache/arrow-rs/pull/2661) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- ffi feature also requires layout [\#2660](https://github.com/apache/arrow-rs/pull/2660) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Change max/min string macro to generic helper function min\_max\_helper [\#2658](https://github.com/apache/arrow-rs/pull/2658) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix flaky test `test_fuzz_async_reader_selection` [\#2656](https://github.com/apache/arrow-rs/pull/2656) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- MINOR: Ignore flaky test test\_fuzz\_async\_reader\_selection [\#2655](https://github.com/apache/arrow-rs/pull/2655) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- MutableBuffer::typed\_data - shared ref access to the typed slice [\#2652](https://github.com/apache/arrow-rs/pull/2652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([medwards](https://github.com/medwards)) +- Overflow-checking variant of arithmetic scalar kernels [\#2650](https://github.com/apache/arrow-rs/pull/2650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- support `CastOption` for casting numeric [\#2649](https://github.com/apache/arrow-rs/pull/2649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Help LLVM vectorize comparison kernel ~50-80% faster [\#2646](https://github.com/apache/arrow-rs/pull/2646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support comparison between dictionary array and binary array [\#2645](https://github.com/apache/arrow-rs/pull/2645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Use `downcast_primitive_array` in arithmetic kernels [\#2640](https://github.com/apache/arrow-rs/pull/2640) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fully qualifying parquet items [\#2638](https://github.com/apache/arrow-rs/pull/2638) ([dingxiangfei2009](https://github.com/dingxiangfei2009)) +- Support DictionaryArray in temporal kernels [\#2623](https://github.com/apache/arrow-rs/pull/2623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Comparable Row Format [\#2593](https://github.com/apache/arrow-rs/pull/2593) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix bug in page skipping [\#2552](https://github.com/apache/arrow-rs/pull/2552) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) + +## [22.0.0](https://github.com/apache/arrow-rs/tree/22.0.0) (2022-09-02) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/21.0.0...22.0.0) + +**Breaking changes:** + +- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2614](https://github.com/apache/arrow-rs/pull/2614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Gate dyn comparison of dictionary arrays behind `dyn_cmp_dict` [\#2597](https://github.com/apache/arrow-rs/pull/2597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move JsonSerializable to json module \(\#2300\) [\#2595](https://github.com/apache/arrow-rs/pull/2595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Decimal precision scale datatype change [\#2532](https://github.com/apache/arrow-rs/pull/2532) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor PrimitiveBuilder Constructors [\#2518](https://github.com/apache/arrow-rs/pull/2518) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactoring DecimalBuilder constructors [\#2517](https://github.com/apache/arrow-rs/pull/2517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor FixedSizeBinaryBuilder Constructors [\#2516](https://github.com/apache/arrow-rs/pull/2516) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor BooleanBuilder Constructors [\#2515](https://github.com/apache/arrow-rs/pull/2515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor UnionBuilder Constructors [\#2488](https://github.com/apache/arrow-rs/pull/2488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) + +**Implemented enhancements:** + +- Add Macros to assist with static dispatch [\#2635](https://github.com/apache/arrow-rs/issues/2635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support comparison between DictionaryArray and BooleanArray [\#2617](https://github.com/apache/arrow-rs/issues/2617) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2613](https://github.com/apache/arrow-rs/issues/2613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support empty projection in CSV, JSON readers [\#2603](https://github.com/apache/arrow-rs/issues/2603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support SQL-compliant NaN ordering between for DictionaryArray and non-DictionaryArray [\#2599](https://github.com/apache/arrow-rs/issues/2599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `dyn_cmp_dict` feature flag to gate dyn comparison of dictionary arrays [\#2596](https://github.com/apache/arrow-rs/issues/2596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2584](https://github.com/apache/arrow-rs/issues/2584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow FlightSQL implementers to extend `do_get()` [\#2581](https://github.com/apache/arrow-rs/issues/2581) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support SQL-compliant behavior on `eq_dyn`, `neq_dyn`, `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2569](https://github.com/apache/arrow-rs/issues/2569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add sql-compliant feature for enabling sql-compliant kernel behavior [\#2568](https://github.com/apache/arrow-rs/issues/2568) +- Calculate `sum` for dictionary array [\#2565](https://github.com/apache/arrow-rs/issues/2565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add test for float nan comparison [\#2556](https://github.com/apache/arrow-rs/issues/2556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with string array [\#2548](https://github.com/apache/arrow-rs/issues/2548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with primitive array in `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2538](https://github.com/apache/arrow-rs/issues/2538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2535](https://github.com/apache/arrow-rs/issues/2535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- UnionBuilder Create Children With Capacity [\#2523](https://github.com/apache/arrow-rs/issues/2523) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up `like_utf8_scalar` for `%pat%` [\#2519](https://github.com/apache/arrow-rs/issues/2519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace macro with TypedDictionaryArray in comparison kernels [\#2513](https://github.com/apache/arrow-rs/issues/2513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use same codebase for boolean kernels [\#2507](https://github.com/apache/arrow-rs/issues/2507) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use u8 for Decimal Precision and Scale [\#2496](https://github.com/apache/arrow-rs/issues/2496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Integrate skip row without pageIndex in SerializedPageReader in Fuzz Test [\#2475](https://github.com/apache/arrow-rs/issues/2475) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Avoid unnecessary copies in Arrow IPC reader [\#2437](https://github.com/apache/arrow-rs/issues/2437) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add GenericColumnReader::skip\_records Missing OffsetIndex Fallback [\#2433](https://github.com/apache/arrow-rs/issues/2433) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support Reading PageIndex with ParquetRecordBatchStream [\#2430](https://github.com/apache/arrow-rs/issues/2430) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Specialize FixedLenByteArrayReader for Parquet [\#2318](https://github.com/apache/arrow-rs/issues/2318) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make JSON support Optional via Feature Flag [\#2300](https://github.com/apache/arrow-rs/issues/2300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Casting timestamp array to string should not ignore timezone [\#2607](https://github.com/apache/arrow-rs/issues/2607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Ilike\_ut8\_scalar kernels have incorrect logic [\#2544](https://github.com/apache/arrow-rs/issues/2544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Always validate the array data when creating array in IPC reader [\#2541](https://github.com/apache/arrow-rs/issues/2541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Int96Converter Truncates Timestamps [\#2480](https://github.com/apache/arrow-rs/issues/2480) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error Reading Page Index When Not Available [\#2434](https://github.com/apache/arrow-rs/issues/2434) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `ParquetFileArrowReader::get_record_reader[_by_column]` `batch_size` overallocates [\#2321](https://github.com/apache/arrow-rs/issues/2321) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Document All Arrow Features in docs.rs [\#2633](https://github.com/apache/arrow-rs/issues/2633) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Add support for CAST from `Interval(DayTime)` to `Timestamp(Nanosecond, None)` [\#2606](https://github.com/apache/arrow-rs/issues/2606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Why do we check for null in TypedDictionaryArray value function [\#2564](https://github.com/apache/arrow-rs/issues/2564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add the `length` field for `Buffer` [\#2524](https://github.com/apache/arrow-rs/issues/2524) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Avoid large over allocate buffer in async reader [\#2512](https://github.com/apache/arrow-rs/issues/2512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Rewriting Decimal Builders using `const_generic`. [\#2390](https://github.com/apache/arrow-rs/issues/2390) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Rewrite Decimal Array using `const_generic` [\#2384](https://github.com/apache/arrow-rs/issues/2384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Add downcast macros \(\#2635\) [\#2636](https://github.com/apache/arrow-rs/pull/2636) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document all arrow features in docs.rs \(\#2633\) [\#2634](https://github.com/apache/arrow-rs/pull/2634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document dyn\_cmp\_dict [\#2624](https://github.com/apache/arrow-rs/pull/2624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support comparison between DictionaryArray and BooleanArray [\#2618](https://github.com/apache/arrow-rs/pull/2618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cast timestamp array to string array with timezone [\#2608](https://github.com/apache/arrow-rs/pull/2608) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support empty projection in CSV and JSON readers [\#2604](https://github.com/apache/arrow-rs/pull/2604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Make JSON support optional via a feature flag \(\#2300\) [\#2601](https://github.com/apache/arrow-rs/pull/2601) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support SQL-compliant NaN ordering for DictionaryArray and non-DictionaryArray [\#2600](https://github.com/apache/arrow-rs/pull/2600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Split out integration test plumbing \(\#2594\) \(\#2300\) [\#2598](https://github.com/apache/arrow-rs/pull/2598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Refactor Binary Builder and String Builder Constructors [\#2592](https://github.com/apache/arrow-rs/pull/2592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Dictionary like scalar kernels [\#2591](https://github.com/apache/arrow-rs/pull/2591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Validate dictionary key in TypedDictionaryArray \(\#2578\) [\#2589](https://github.com/apache/arrow-rs/pull/2589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2585](https://github.com/apache/arrow-rs/pull/2585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Code cleanup of array value functions [\#2583](https://github.com/apache/arrow-rs/pull/2583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Allow overriding of do\_get & export useful macro [\#2582](https://github.com/apache/arrow-rs/pull/2582) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- MINOR: Upgrade to pyo3 0.17 [\#2576](https://github.com/apache/arrow-rs/pull/2576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- Support SQL-compliant NaN behavior on eq\_dyn, neq\_dyn, lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn [\#2570](https://github.com/apache/arrow-rs/pull/2570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add sum\_dyn to calculate sum for dictionary array [\#2566](https://github.com/apache/arrow-rs/pull/2566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- struct UnionBuilder will create child buffers with capacity [\#2560](https://github.com/apache/arrow-rs/pull/2560) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kastolars](https://github.com/kastolars)) +- Don't panic on RleValueEncoder::flush\_buffer if empty \(\#2558\) [\#2559](https://github.com/apache/arrow-rs/pull/2559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add the `length` field for Buffer and use more `Buffer` in IPC reader to avoid memory copy. [\#2557](https://github.com/apache/arrow-rs/pull/2557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([HaoYang670](https://github.com/HaoYang670)) +- Add test for float nan comparison [\#2555](https://github.com/apache/arrow-rs/pull/2555) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Compare dictionary array with string array [\#2549](https://github.com/apache/arrow-rs/pull/2549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Always validate the array data \(except the `Decimal`\) when creating array in IPC reader [\#2547](https://github.com/apache/arrow-rs/pull/2547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- MINOR: Fix test\_row\_type\_validation test [\#2546](https://github.com/apache/arrow-rs/pull/2546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix ilike\_utf8\_scalar kernels [\#2545](https://github.com/apache/arrow-rs/pull/2545) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- fix typo [\#2540](https://github.com/apache/arrow-rs/pull/2540) ([00Masato](https://github.com/00Masato)) +- Compare dictionary array and primitive array in lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn kernels [\#2539](https://github.com/apache/arrow-rs/pull/2539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- \[MINOR\]Avoid large over allocate buffer in async reader [\#2537](https://github.com/apache/arrow-rs/pull/2537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2533](https://github.com/apache/arrow-rs/pull/2533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add iterator for FixedSizeBinaryArray [\#2531](https://github.com/apache/arrow-rs/pull/2531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- add bench: decimal with byte array and fixed length byte array [\#2529](https://github.com/apache/arrow-rs/pull/2529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Add FixedLengthByteArrayReader Remove ComplexObjectArrayReader [\#2528](https://github.com/apache/arrow-rs/pull/2528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Split out byte array decoders \(\#2318\) [\#2527](https://github.com/apache/arrow-rs/pull/2527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use offset index in ParquetRecordBatchStream [\#2526](https://github.com/apache/arrow-rs/pull/2526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Clean the `create_array` in IPC reader. [\#2525](https://github.com/apache/arrow-rs/pull/2525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove DecimalByteArrayConvert \(\#2480\) [\#2522](https://github.com/apache/arrow-rs/pull/2522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Improve performance of `%pat%` \(\>3x speedup\) [\#2521](https://github.com/apache/arrow-rs/pull/2521) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- remove len field from MapBuilder [\#2520](https://github.com/apache/arrow-rs/pull/2520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Replace macro with TypedDictionaryArray in comparison kernels [\#2514](https://github.com/apache/arrow-rs/pull/2514) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Avoid large over allocate buffer in sync reader [\#2511](https://github.com/apache/arrow-rs/pull/2511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Avoid useless memory copies in IPC reader. [\#2510](https://github.com/apache/arrow-rs/pull/2510) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Refactor boolean kernels to use same codebase [\#2508](https://github.com/apache/arrow-rs/pull/2508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove Int96Converter \(\#2480\) [\#2481](https://github.com/apache/arrow-rs/pull/2481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + ## [21.0.0](https://github.com/apache/arrow-rs/tree/21.0.0) (2022-08-18) [Full Changelog](https://github.com/apache/arrow-rs/compare/20.0.0...21.0.0) @@ -430,7 +2537,7 @@ - Incorrect `null_count` of DictionaryArray [\#1962](https://github.com/apache/arrow-rs/issues/1962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Support multi diskRanges for ChunkReader [\#1955](https://github.com/apache/arrow-rs/issues/1955) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] - Persisting Arrow timestamps with Parquet produces missing `TIMESTAMP` in schema [\#1920](https://github.com/apache/arrow-rs/issues/1920) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Sperate get\_next\_page\_header from get\_next\_page in PageReader [\#1834](https://github.com/apache/arrow-rs/issues/1834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Separate get\_next\_page\_header from get\_next\_page in PageReader [\#1834](https://github.com/apache/arrow-rs/issues/1834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Merged pull requests:** @@ -487,7 +2594,7 @@ - `PrimitiveArray::from_iter` should omit validity buffer if all values are valid [\#1856](https://github.com/apache/arrow-rs/issues/1856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `from(v: Vec>)` and `from(v: Vec<&[u8]>)` for `FixedSizedBInaryArray` [\#1852](https://github.com/apache/arrow-rs/issues/1852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `Vec`-inspired APIs to `BufferBuilder` [\#1850](https://github.com/apache/arrow-rs/issues/1850) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- PyArrow intergation test for C Stream Interface [\#1847](https://github.com/apache/arrow-rs/issues/1847) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- PyArrow integration test for C Stream Interface [\#1847](https://github.com/apache/arrow-rs/issues/1847) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `nilike` support in `comparison` [\#1845](https://github.com/apache/arrow-rs/issues/1845) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Split up `arrow::array::builder` module [\#1843](https://github.com/apache/arrow-rs/issues/1843) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Add `quarter` support in `temporal` kernels [\#1835](https://github.com/apache/arrow-rs/issues/1835) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] @@ -884,7 +2991,7 @@ **Fixed bugs:** -- Error Infering Schema for LogicalType::UNKNOWN [\#1557](https://github.com/apache/arrow-rs/issues/1557) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error Inferring Schema for LogicalType::UNKNOWN [\#1557](https://github.com/apache/arrow-rs/issues/1557) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] - Read dictionary from nested struct in ipc stream reader panics [\#1549](https://github.com/apache/arrow-rs/issues/1549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - `filter` produces invalid sparse `UnionArray`s [\#1547](https://github.com/apache/arrow-rs/issues/1547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Documentation for `GenericListBuilder` is not exposed. [\#1518](https://github.com/apache/arrow-rs/issues/1518) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] @@ -1410,7 +3517,7 @@ * [094037d418381584178db1d886cad3b5024b414a](https://github.com/apache/arrow-rs/commit/094037d418381584178db1d886cad3b5024b414a) Update comfy-table to 5.0 ([#957](https://github.com/apache/arrow-rs/pull/957)) ([#964](https://github.com/apache/arrow-rs/pull/964)) * [9f635021eee6786c5377c891218c5f88ebce07c3](https://github.com/apache/arrow-rs/commit/9f635021eee6786c5377c891218c5f88ebce07c3) Fix csv writing of timestamps to show timezone. ([#849](https://github.com/apache/arrow-rs/pull/849)) ([#963](https://github.com/apache/arrow-rs/pull/963)) * [f7deba4c3a050a52608462ee8a827bb8f6364140](https://github.com/apache/arrow-rs/commit/f7deba4c3a050a52608462ee8a827bb8f6364140) Adding ability to parse float from number with leading decimal ([#831](https://github.com/apache/arrow-rs/pull/831)) ([#962](https://github.com/apache/arrow-rs/pull/962)) -* [59f96e842d05b63882f7ba285c66a9739761cf84](https://github.com/apache/arrow-rs/commit/59f96e842d05b63882f7ba285c66a9739761cf84) add ilike comparitor ([#874](https://github.com/apache/arrow-rs/pull/874)) ([#961](https://github.com/apache/arrow-rs/pull/961)) +* [59f96e842d05b63882f7ba285c66a9739761cf84](https://github.com/apache/arrow-rs/commit/59f96e842d05b63882f7ba285c66a9739761cf84) add ilike comparator ([#874](https://github.com/apache/arrow-rs/pull/874)) ([#961](https://github.com/apache/arrow-rs/pull/961)) * [54023c8a5543c9f9fa4955afa01189029f3e96f5](https://github.com/apache/arrow-rs/commit/54023c8a5543c9f9fa4955afa01189029f3e96f5) Remove unpassable cargo publish check from verify-release-candidate.sh ([#882](https://github.com/apache/arrow-rs/pull/882)) ([#949](https://github.com/apache/arrow-rs/pull/949)) @@ -1507,7 +3614,7 @@ **Fixed bugs:** - Converting from string to timestamp uses microseconds instead of milliseconds [\#780](https://github.com/apache/arrow-rs/issues/780) -- Document has no link to `RowColumIter` [\#762](https://github.com/apache/arrow-rs/issues/762) +- Document has no link to `RowColumnIter` [\#762](https://github.com/apache/arrow-rs/issues/762) - length on slices with null doesn't work [\#744](https://github.com/apache/arrow-rs/issues/744) ## [5.4.0](https://github.com/apache/arrow-rs/tree/5.4.0) (2021-09-10) @@ -1565,7 +3672,7 @@ - Remove undefined behavior in `value` method of boolean and primitive arrays [\#645](https://github.com/apache/arrow-rs/issues/645) - Avoid materialization of indices in filter\_record\_batch for single arrays [\#636](https://github.com/apache/arrow-rs/issues/636) - Add a note about arrow crate security / safety [\#627](https://github.com/apache/arrow-rs/issues/627) -- Allow the creation of String arrays from an interator of &Option\<&str\> [\#598](https://github.com/apache/arrow-rs/issues/598) +- Allow the creation of String arrays from an iterator of &Option\<&str\> [\#598](https://github.com/apache/arrow-rs/issues/598) - Support arrow map datatype [\#395](https://github.com/apache/arrow-rs/issues/395) **Fixed bugs:** @@ -1694,7 +3801,7 @@ - Add C data interface for decimal128 and timestamp [\#453](https://github.com/apache/arrow-rs/pull/453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alippai](https://github.com/alippai)) - Implement the Iterator trait for the json Reader. [\#451](https://github.com/apache/arrow-rs/pull/451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([LaurentMazare](https://github.com/LaurentMazare)) - Update release docs + release email template [\#450](https://github.com/apache/arrow-rs/pull/450) ([alamb](https://github.com/alamb)) -- remove clippy unnecessary wraps suppresions in cast kernel [\#449](https://github.com/apache/arrow-rs/pull/449) ([Jimexist](https://github.com/Jimexist)) +- remove clippy unnecessary wraps suppression in cast kernel [\#449](https://github.com/apache/arrow-rs/pull/449) ([Jimexist](https://github.com/Jimexist)) - Use partition for bool sort [\#448](https://github.com/apache/arrow-rs/pull/448) ([Jimexist](https://github.com/Jimexist)) - remove unnecessary wraps in sort [\#445](https://github.com/apache/arrow-rs/pull/445) ([Jimexist](https://github.com/Jimexist)) - Python FFI bridge for Schema, Field and DataType [\#439](https://github.com/apache/arrow-rs/pull/439) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszucs](https://github.com/kszucs)) @@ -1767,7 +3874,7 @@ - ARROW-12504: Buffer::from\_slice\_ref set correct capacity [\#18](https://github.com/apache/arrow-rs/pull/18) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) - Add GitHub templates [\#17](https://github.com/apache/arrow-rs/pull/17) ([andygrove](https://github.com/andygrove)) - ARROW-12493: Add support for writing dictionary arrays to CSV and JSON [\#16](https://github.com/apache/arrow-rs/pull/16) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- ARROW-12426: \[Rust\] Fix concatentation of arrow dictionaries [\#15](https://github.com/apache/arrow-rs/pull/15) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- ARROW-12426: \[Rust\] Fix concatenation of arrow dictionaries [\#15](https://github.com/apache/arrow-rs/pull/15) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) - Update repository and homepage urls [\#14](https://github.com/apache/arrow-rs/pull/14) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Dandandan](https://github.com/Dandandan)) - Added rebase-needed bot [\#13](https://github.com/apache/arrow-rs/pull/13) ([jorgecarleitao](https://github.com/jorgecarleitao)) - Added Integration tests against arrow [\#10](https://github.com/apache/arrow-rs/pull/10) ([jorgecarleitao](https://github.com/jorgecarleitao)) @@ -1911,7 +4018,7 @@ - Support sort [\#215](https://github.com/apache/arrow-rs/issues/215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Support stable Rust [\#214](https://github.com/apache/arrow-rs/issues/214) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - Remove Rust and point integration tests to arrow-rs repo [\#211](https://github.com/apache/arrow-rs/issues/211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- ArrayData buffers are inconsistent accross implementations [\#207](https://github.com/apache/arrow-rs/issues/207) +- ArrayData buffers are inconsistent across implementations [\#207](https://github.com/apache/arrow-rs/issues/207) - 3.0.1 patch release [\#204](https://github.com/apache/arrow-rs/issues/204) - Document patch release process [\#202](https://github.com/apache/arrow-rs/issues/202) - Simplify Offset [\#186](https://github.com/apache/arrow-rs/issues/186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] diff --git a/CHANGELOG.md b/CHANGELOG.md index 69f2b8af6cf8..ba27d6679ffe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,119 +19,54 @@ # Changelog -## [22.0.0](https://github.com/apache/arrow-rs/tree/22.0.0) (2022-09-02) +## [49.0.0](https://github.com/apache/arrow-rs/tree/49.0.0) (2023-11-07) -[Full Changelog](https://github.com/apache/arrow-rs/compare/21.0.0...22.0.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/48.0.0...49.0.0) **Breaking changes:** -- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2614](https://github.com/apache/arrow-rs/pull/2614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Gate dyn comparison of dictionary arrays behind `dyn_cmp_dict` [\#2597](https://github.com/apache/arrow-rs/pull/2597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Move JsonSerializable to json module \(\#2300\) [\#2595](https://github.com/apache/arrow-rs/pull/2595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Decimal precision scale datatype change [\#2532](https://github.com/apache/arrow-rs/pull/2532) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor PrimitiveBuilder Constructors [\#2518](https://github.com/apache/arrow-rs/pull/2518) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactoring DecimalBuilder constructors [\#2517](https://github.com/apache/arrow-rs/pull/2517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor FixedSizeBinaryBuilder Constructors [\#2516](https://github.com/apache/arrow-rs/pull/2516) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor BooleanBuilder Constructors [\#2515](https://github.com/apache/arrow-rs/pull/2515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Refactor UnionBuilder Constructors [\#2488](https://github.com/apache/arrow-rs/pull/2488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Return row count when inferring schema from JSON [\#5008](https://github.com/apache/arrow-rs/pull/5008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asayers](https://github.com/asayers)) +- Update object\_store 0.8.0 [\#5043](https://github.com/apache/arrow-rs/pull/5043) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) **Implemented enhancements:** -- Add Macros to assist with static dispatch [\#2635](https://github.com/apache/arrow-rs/issues/2635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support comparison between DictionaryArray and BooleanArray [\#2617](https://github.com/apache/arrow-rs/issues/2617) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2613](https://github.com/apache/arrow-rs/issues/2613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support empty projection in CSV, JSON readers [\#2603](https://github.com/apache/arrow-rs/issues/2603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support SQL-compliant NaN ordering between for DictionaryArray and non-DictionaryArray [\#2599](https://github.com/apache/arrow-rs/issues/2599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add `dyn_cmp_dict` feature flag to gate dyn comparison of dictionary arrays [\#2596](https://github.com/apache/arrow-rs/issues/2596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2584](https://github.com/apache/arrow-rs/issues/2584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Allow FlightSQL implementers to extend `do_get()` [\#2581](https://github.com/apache/arrow-rs/issues/2581) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Support SQL-compliant behavior on `eq_dyn`, `neq_dyn`, `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2569](https://github.com/apache/arrow-rs/issues/2569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add sql-compliant feature for enabling sql-compliant kernel behavior [\#2568](https://github.com/apache/arrow-rs/issues/2568) -- Calculate `sum` for dictionary array [\#2565](https://github.com/apache/arrow-rs/issues/2565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add test for float nan comparison [\#2556](https://github.com/apache/arrow-rs/issues/2556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Compare dictionary with string array [\#2548](https://github.com/apache/arrow-rs/issues/2548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Compare dictionary with primitive array in `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2538](https://github.com/apache/arrow-rs/issues/2538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2535](https://github.com/apache/arrow-rs/issues/2535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- UnionBuilder Create Children With Capacity [\#2523](https://github.com/apache/arrow-rs/issues/2523) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Speed up `like_utf8_scalar` for `%pat%` [\#2519](https://github.com/apache/arrow-rs/issues/2519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Replace macro with TypedDictionaryArray in comparison kernels [\#2513](https://github.com/apache/arrow-rs/issues/2513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Use same codebase for boolean kernels [\#2507](https://github.com/apache/arrow-rs/issues/2507) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Use u8 for Decimal Precision and Scale [\#2496](https://github.com/apache/arrow-rs/issues/2496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Integrate skip row without pageIndex in SerializedPageReader in Fuzz Test [\#2475](https://github.com/apache/arrow-rs/issues/2475) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Avoid unecessary copies in Arrow IPC reader [\#2437](https://github.com/apache/arrow-rs/issues/2437) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add GenericColumnReader::skip\_records Missing OffsetIndex Fallback [\#2433](https://github.com/apache/arrow-rs/issues/2433) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Support Reading PageIndex with ParquetRecordBatchStream [\#2430](https://github.com/apache/arrow-rs/issues/2430) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Specialize FixedLenByteArrayReader for Parquet [\#2318](https://github.com/apache/arrow-rs/issues/2318) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Make JSON support Optional via Feature Flag [\#2300](https://github.com/apache/arrow-rs/issues/2300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cast from integer/timestamp to timestamp/integer [\#5039](https://github.com/apache/arrow-rs/issues/5039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting from integer to binary [\#5014](https://github.com/apache/arrow-rs/issues/5014) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Return row count when inferring schema from JSON [\#5007](https://github.com/apache/arrow-rs/issues/5007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[FlightSQL\] Allow custom commands in get-flight-info [\#4996](https://github.com/apache/arrow-rs/issues/4996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support `RecordBatch::remove_column()` and `Schema::remove_field()` [\#4952](https://github.com/apache/arrow-rs/issues/4952) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow_json`: support `binary` deserialization [\#4945](https://github.com/apache/arrow-rs/issues/4945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support StructArray in Cast Kernel [\#4908](https://github.com/apache/arrow-rs/issues/4908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- There exists a `ParquetRecordWriter` proc macro in `parquet_derive`, but `ParquetRecordReader` is missing [\#4772](https://github.com/apache/arrow-rs/issues/4772) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Fixed bugs:** -- Casting timestamp array to string should not ignore timezone [\#2607](https://github.com/apache/arrow-rs/issues/2607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Ilike\_ut8\_scalar kernals have incorrect logic [\#2544](https://github.com/apache/arrow-rs/issues/2544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Always validate the array data when creating array in IPC reader [\#2541](https://github.com/apache/arrow-rs/issues/2541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Int96Converter Truncates Timestamps [\#2480](https://github.com/apache/arrow-rs/issues/2480) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Error Reading Page Index When Not Available [\#2434](https://github.com/apache/arrow-rs/issues/2434) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `ParquetFileArrowReader::get_record_reader[_by_colum]` `batch_size` overallocates [\#2321](https://github.com/apache/arrow-rs/issues/2321) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Regression when serializing large json numbers [\#5038](https://github.com/apache/arrow-rs/issues/5038) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RowSelection::intersection Produces Invalid RowSelection [\#5036](https://github.com/apache/arrow-rs/issues/5036) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Incorrect comment on arrow::compute::kernels::sort::sort\_to\_indices [\#5029](https://github.com/apache/arrow-rs/issues/5029) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Documentation updates:** -- Document All Arrow Features in docs.rs [\#2633](https://github.com/apache/arrow-rs/issues/2633) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -**Closed issues:** - -- Add support for CAST from `Interval(DayTime)` to `Timestamp(Nanosecond, None)` [\#2606](https://github.com/apache/arrow-rs/issues/2606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Why do we check for null in TypedDictionaryArray value function [\#2564](https://github.com/apache/arrow-rs/issues/2564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add the `length` field for `Buffer` [\#2524](https://github.com/apache/arrow-rs/issues/2524) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Avoid large over allocate buffer in async reader [\#2512](https://github.com/apache/arrow-rs/issues/2512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Rewriting Decimal Builders using `const_generic`. [\#2390](https://github.com/apache/arrow-rs/issues/2390) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Rewrite Decimal Array using `const_generic` [\#2384](https://github.com/apache/arrow-rs/issues/2384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- chore: Update docs to refer to non deprecated function \(`partition`\) [\#5027](https://github.com/apache/arrow-rs/pull/5027) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) **Merged pull requests:** -- Add downcast macros \(\#2635\) [\#2636](https://github.com/apache/arrow-rs/pull/2636) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Document all arrow features in docs.rs \(\#2633\) [\#2634](https://github.com/apache/arrow-rs/pull/2634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Document dyn\_cmp\_dict [\#2624](https://github.com/apache/arrow-rs/pull/2624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Support comparison between DictionaryArray and BooleanArray [\#2618](https://github.com/apache/arrow-rs/pull/2618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Cast timestamp array to string array with timezone [\#2608](https://github.com/apache/arrow-rs/pull/2608) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Support empty projection in CSV and JSON readers [\#2604](https://github.com/apache/arrow-rs/pull/2604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Make JSON support optional via a feature flag \(\#2300\) [\#2601](https://github.com/apache/arrow-rs/pull/2601) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Support SQL-compliant NaN ordering for DictionaryArray and non-DictionaryArray [\#2600](https://github.com/apache/arrow-rs/pull/2600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Split out integration test plumbing \(\#2594\) \(\#2300\) [\#2598](https://github.com/apache/arrow-rs/pull/2598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Refactor Binary Builder and String Builder Constructors [\#2592](https://github.com/apache/arrow-rs/pull/2592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Dictionary like scalar kernels [\#2591](https://github.com/apache/arrow-rs/pull/2591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Validate dictionary key in TypedDictionaryArray \(\#2578\) [\#2589](https://github.com/apache/arrow-rs/pull/2589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2585](https://github.com/apache/arrow-rs/pull/2585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Code cleanup of array value functions [\#2583](https://github.com/apache/arrow-rs/pull/2583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Allow overriding of do\_get & export useful macro [\#2582](https://github.com/apache/arrow-rs/pull/2582) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) -- MINOR: Upgrade to pyo3 0.17 [\#2576](https://github.com/apache/arrow-rs/pull/2576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) -- Support SQL-compliant NaN behavior on eq\_dyn, neq\_dyn, lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn [\#2570](https://github.com/apache/arrow-rs/pull/2570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Add sum\_dyn to calculate sum for dictionary array [\#2566](https://github.com/apache/arrow-rs/pull/2566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- struct UnionBuilder will create child buffers with capacity [\#2560](https://github.com/apache/arrow-rs/pull/2560) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kastolars](https://github.com/kastolars)) -- Don't panic on RleValueEncoder::flush\_buffer if empty \(\#2558\) [\#2559](https://github.com/apache/arrow-rs/pull/2559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Add the `length` field for Buffer and use more `Buffer` in IPC reader to avoid memory copy. [\#2557](https://github.com/apache/arrow-rs/pull/2557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([HaoYang670](https://github.com/HaoYang670)) -- Add test for float nan comparison [\#2555](https://github.com/apache/arrow-rs/pull/2555) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Compare dictionary array with string array [\#2549](https://github.com/apache/arrow-rs/pull/2549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Always validate the array data \(except the `Decimal`\) when creating array in IPC reader [\#2547](https://github.com/apache/arrow-rs/pull/2547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- MINOR: Fix test\_row\_type\_validation test [\#2546](https://github.com/apache/arrow-rs/pull/2546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Fix ilike\_utf8\_scalar kernals [\#2545](https://github.com/apache/arrow-rs/pull/2545) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- fix typo [\#2540](https://github.com/apache/arrow-rs/pull/2540) ([00Masato](https://github.com/00Masato)) -- Compare dictionary array and primitive array in lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn kernels [\#2539](https://github.com/apache/arrow-rs/pull/2539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- \[MINOR\]Avoid large over allocate buffer in async reader [\#2537](https://github.com/apache/arrow-rs/pull/2537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) -- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2533](https://github.com/apache/arrow-rs/pull/2533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Add iterator for FixedSizeBinaryArray [\#2531](https://github.com/apache/arrow-rs/pull/2531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- add bench: decimal with byte array and fixed length byte array [\#2529](https://github.com/apache/arrow-rs/pull/2529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) -- Add FixedLengthByteArrayReader Remove ComplexObjectArrayReader [\#2528](https://github.com/apache/arrow-rs/pull/2528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Split out byte array decoders \(\#2318\) [\#2527](https://github.com/apache/arrow-rs/pull/2527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Use offset index in ParquetRecordBatchStream [\#2526](https://github.com/apache/arrow-rs/pull/2526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) -- Clean the `create_array` in IPC reader. [\#2525](https://github.com/apache/arrow-rs/pull/2525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Remove DecimalByteArrayConvert \(\#2480\) [\#2522](https://github.com/apache/arrow-rs/pull/2522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Improve performance of `%pat%` \(\>3x speedup\) [\#2521](https://github.com/apache/arrow-rs/pull/2521) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- remove len field from MapBuilder [\#2520](https://github.com/apache/arrow-rs/pull/2520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) -- Replace macro with TypedDictionaryArray in comparison kernels [\#2514](https://github.com/apache/arrow-rs/pull/2514) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Avoid large over allocate buffer in sync reader [\#2511](https://github.com/apache/arrow-rs/pull/2511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) -- Avoid useless memory copies in IPC reader. [\#2510](https://github.com/apache/arrow-rs/pull/2510) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Refactor boolean kernels to use same codebase [\#2508](https://github.com/apache/arrow-rs/pull/2508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Remove Int96Converter \(\#2480\) [\#2481](https://github.com/apache/arrow-rs/pull/2481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Parquet f32/f64 handle signed zeros in statistics [\#5048](https://github.com/apache/arrow-rs/pull/5048) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jefffrey](https://github.com/Jefffrey)) +- Fix serialization of large integers in JSON \(\#5038\) [\#5042](https://github.com/apache/arrow-rs/pull/5042) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix RowSelection::intersection \(\#5036\) [\#5041](https://github.com/apache/arrow-rs/pull/5041) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cast from integer/timestamp to timestamp/integer [\#5040](https://github.com/apache/arrow-rs/pull/5040) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- doc: update comment on sort\_to\_indices to reflect correct ordering [\#5033](https://github.com/apache/arrow-rs/pull/5033) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([westonpace](https://github.com/westonpace)) +- Support casting from integer to binary [\#5015](https://github.com/apache/arrow-rs/pull/5015) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update tracing-log requirement from 0.1 to 0.2 [\#4998](https://github.com/apache/arrow-rs/pull/4998) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- feat\(flight-sql\): Allow custom commands in get-flight-info [\#4997](https://github.com/apache/arrow-rs/pull/4997) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([amartins23](https://github.com/amartins23)) +- \[MINOR\] No need to jump to web pages [\#4994](https://github.com/apache/arrow-rs/pull/4994) ([smallzhongfeng](https://github.com/smallzhongfeng)) +- Support metadata in SchemaBuilder [\#4987](https://github.com/apache/arrow-rs/pull/4987) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat: support schema change by idx and reverse [\#4985](https://github.com/apache/arrow-rs/pull/4985) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fansehep](https://github.com/fansehep)) +- Bump actions/setup-node from 3 to 4 [\#4982](https://github.com/apache/arrow-rs/pull/4982) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add arrow\_cast::base64 and document usage in arrow\_json [\#4975](https://github.com/apache/arrow-rs/pull/4975) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add SchemaBuilder::remove \(\#4952\) [\#4964](https://github.com/apache/arrow-rs/pull/4964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add `Field::remove()`, `Schema::remove()`, and `RecordBatch::remove_column()` APIs [\#4959](https://github.com/apache/arrow-rs/pull/4959) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Folyd](https://github.com/Folyd)) +- Add `RecordReader` trait and proc macro to implement it for a struct [\#4773](https://github.com/apache/arrow-rs/pull/4773) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Joseph-Rance](https://github.com/Joseph-Rance)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 67121f6cd5a3..9614ed2e5688 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -150,7 +150,7 @@ If the file already exists, to avoid mistakenly **overriding**, you MAY have to the link source or file content. Else if not exist, let's safely soft link [pre-commit.sh](pre-commit.sh) as file `.git/hooks/pre-commit`: ```bash -ln -s ../../rust/pre-commit.sh .git/hooks/pre-commit +ln -s ../../pre-commit.sh .git/hooks/pre-commit ``` If sometimes you want to commit without checking, just run `git commit` with `--no-verify`: diff --git a/Cargo.toml b/Cargo.toml index 9bf55c0f2360..d5e834316b91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,15 +16,32 @@ # under the License. [workspace] + members = [ - "arrow", - "parquet", - "parquet_derive", - "parquet_derive_test", - "arrow-flight", - "integration-testing", - "object_store", + "arrow", + "arrow-arith", + "arrow-array", + "arrow-avro", + "arrow-buffer", + "arrow-cast", + "arrow-csv", + "arrow-data", + "arrow-flight", + "arrow-flight/gen", + "arrow-integration-test", + "arrow-integration-testing", + "arrow-ipc", + "arrow-json", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "parquet", + "parquet_derive", + "parquet_derive_test", ] + # Enable the version 2 feature resolver, which avoids unifying features for targets that are not being built # # Critically this prevents dev-dependencies from enabling features even when not building a target that @@ -35,7 +52,45 @@ members = [ # resolver = "2" -# this package is excluded because it requires different compilation flags, thereby significantly changing -# how it is compiled within the workspace, causing the whole workspace to be compiled from scratch -# this way, this is a stand-alone package that compiles independently of the others. -exclude = ["arrow-pyarrow-integration-testing"] +exclude = [ + # arrow-pyarrow-integration-testing is excluded because it requires different compilation flags, thereby + # significantly changing how it is compiled within the workspace, causing the whole workspace to be compiled from + # scratch this way, this is a stand-alone package that compiles independently of the others. + "arrow-pyarrow-integration-testing", + # object_store is excluded because it follows a separate release cycle from the other arrow crates + "object_store" +] + +[workspace.package] +version = "49.0.0" +homepage = "https://github.com/apache/arrow-rs" +repository = "https://github.com/apache/arrow-rs" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = ["arrow"] +include = [ + "benches/*.rs", + "src/**/*.rs", + "Cargo.toml", +] +edition = "2021" +rust-version = "1.62" + +[workspace.dependencies] +arrow = { version = "49.0.0", path = "./arrow", default-features = false } +arrow-arith = { version = "49.0.0", path = "./arrow-arith" } +arrow-array = { version = "49.0.0", path = "./arrow-array" } +arrow-buffer = { version = "49.0.0", path = "./arrow-buffer" } +arrow-cast = { version = "49.0.0", path = "./arrow-cast" } +arrow-csv = { version = "49.0.0", path = "./arrow-csv" } +arrow-data = { version = "49.0.0", path = "./arrow-data" } +arrow-ipc = { version = "49.0.0", path = "./arrow-ipc" } +arrow-json = { version = "49.0.0", path = "./arrow-json" } +arrow-ord = { version = "49.0.0", path = "./arrow-ord" } +arrow-row = { version = "49.0.0", path = "./arrow-row" } +arrow-schema = { version = "49.0.0", path = "./arrow-schema" } +arrow-select = { version = "49.0.0", path = "./arrow-select" } +arrow-string = { version = "49.0.0", path = "./arrow-string" } +parquet = { version = "49.0.0", path = "./parquet", default-features = false } + +chrono = { version = "0.4.31", default-features = false, features = ["clock"] } diff --git a/README.md b/README.md index 55bdad6cb55c..8cd3ec970b53 100644 --- a/README.md +++ b/README.md @@ -25,12 +25,14 @@ Welcome to the implementation of Arrow, the popular in-memory columnar format, i This repo contains the following main components: -| Crate | Description | Documentation | -| ------------ | ------------------------------------------------------------------------- | ------------------------------ | -| arrow | Core functionality (memory layout, arrays, low level computations) | [(README)][arrow-readme] | -| parquet | Support for Parquet columnar file format | [(README)][parquet-readme] | -| arrow-flight | Support for Arrow-Flight IPC protocol | [(README)][flight-readme] | -| object-store | Support for object store interactions (aws, azure, gcp, local, in-memory) | [(README)][objectstore-readme] | +| Crate | Description | Latest API Docs | README | +| ------------ | ------------------------------------------------------------------------- | ---------------------------------------------- | ------------------------------ | +| arrow | Core functionality (memory layout, arrays, low level computations) | [docs.rs](https://docs.rs/arrow/latest) | [(README)][arrow-readme] | +| parquet | Support for Parquet columnar file format | [docs.rs](https://docs.rs/parquet/latest) | [(README)][parquet-readme] | +| arrow-flight | Support for Arrow-Flight IPC protocol | [docs.rs](https://docs.rs/arrow-flight/latest) | [(README)][flight-readme] | +| object-store | Support for object store interactions (aws, azure, gcp, local, in-memory) | [docs.rs](https://docs.rs/object_store/latest) | [(README)][objectstore-readme] | + +The current development version the API documentation in this repo can be found [here](https://arrow.apache.org/rust). There are two related crates in a different repository @@ -58,8 +60,8 @@ a great place to meet other contributors and get guidance on where to contribute 2. the [GitHub Discussions][discussions] 3. the [Discord channel](https://discord.gg/YAb2TdazKQ) -Unlike other parts of the Arrow ecosystem, the Rust implementation uses [GitHub issues][issues] as the system of record for new features -and bug fixes and this plays a critical role in the release process. +The Rust implementation uses [GitHub issues][issues] as the system of record for new features and bug fixes and +this plays a critical role in the release process. For design discussions we generally collaborate on Google documents and file a GitHub issue linking to the document. @@ -72,6 +74,6 @@ There is more information in the [contributing] guide. [flight-readme]: arrow-flight/README.md [datafusion-readme]: https://github.com/apache/arrow-datafusion/blob/master/README.md [ballista-readme]: https://github.com/apache/arrow-ballista/blob/master/README.md -[objectstore-readme]: https://github.com/apache/arrow-rs/blob/master/object_store/README.md +[objectstore-readme]: object_store/README.md [issues]: https://github.com/apache/arrow-rs/issues [discussions]: https://github.com/apache/arrow-rs/discussions diff --git a/arrow-arith/Cargo.toml b/arrow-arith/Cargo.toml new file mode 100644 index 000000000000..d2ee0b9e2c72 --- /dev/null +++ b/arrow-arith/Cargo.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-arith" +version = { workspace = true } +description = "Arrow arithmetic kernels" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_arith" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +chrono = { workspace = true } +half = { version = "2.1", default-features = false } +num = { version = "0.4", default-features = false, features = ["std"] } + +[dev-dependencies] diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs new file mode 100644 index 000000000000..20ff0711d735 --- /dev/null +++ b/arrow-arith/src/aggregate.rs @@ -0,0 +1,1420 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines aggregations over Arrow arrays. + +use arrow_array::cast::*; +use arrow_array::iterator::ArrayIter; +use arrow_array::*; +use arrow_buffer::{ArrowNativeType, NullBuffer}; +use arrow_data::bit_iterator::try_for_each_valid_idx; +use arrow_schema::ArrowError; +use arrow_schema::*; +use std::borrow::BorrowMut; +use std::ops::{BitAnd, BitOr, BitXor}; + +/// An accumulator for primitive numeric values. +trait NumericAccumulator: Copy + Default { + /// Accumulate a non-null value. + fn accumulate(&mut self, value: T); + /// Accumulate a nullable values. + /// If `valid` is false the `value` should not affect the accumulator state. + fn accumulate_nullable(&mut self, value: T, valid: bool); + /// Merge another accumulator into this accumulator + fn merge(&mut self, other: Self); + /// Return the aggregated value. + fn finish(&mut self) -> T; +} + +/// Helper for branchlessly selecting either `a` or `b` based on the boolean `m`. +/// After verifying the generated assembly this can be a simple `if`. +#[inline(always)] +fn select(m: bool, a: T, b: T) -> T { + if m { + a + } else { + b + } +} + +#[derive(Clone, Copy)] +struct SumAccumulator { + sum: T, +} + +impl Default for SumAccumulator { + fn default() -> Self { + Self { sum: T::ZERO } + } +} + +impl NumericAccumulator for SumAccumulator { + fn accumulate(&mut self, value: T) { + self.sum = self.sum.add_wrapping(value); + } + + fn accumulate_nullable(&mut self, value: T, valid: bool) { + let sum = self.sum; + self.sum = select(valid, sum.add_wrapping(value), sum) + } + + fn merge(&mut self, other: Self) { + self.sum = self.sum.add_wrapping(other.sum); + } + + fn finish(&mut self) -> T { + self.sum + } +} + +#[derive(Clone, Copy)] +struct MinAccumulator { + min: T, +} + +impl Default for MinAccumulator { + fn default() -> Self { + Self { + min: T::MAX_TOTAL_ORDER, + } + } +} + +impl NumericAccumulator for MinAccumulator { + fn accumulate(&mut self, value: T) { + let min = self.min; + self.min = select(value.is_lt(min), value, min); + } + + fn accumulate_nullable(&mut self, value: T, valid: bool) { + let min = self.min; + let is_lt = valid & value.is_lt(min); + self.min = select(is_lt, value, min); + } + + fn merge(&mut self, other: Self) { + self.accumulate(other.min) + } + + fn finish(&mut self) -> T { + self.min + } +} + +#[derive(Clone, Copy)] +struct MaxAccumulator { + max: T, +} + +impl Default for MaxAccumulator { + fn default() -> Self { + Self { + max: T::MIN_TOTAL_ORDER, + } + } +} + +impl NumericAccumulator for MaxAccumulator { + fn accumulate(&mut self, value: T) { + let max = self.max; + self.max = select(value.is_gt(max), value, max); + } + + fn accumulate_nullable(&mut self, value: T, valid: bool) { + let max = self.max; + let is_gt = value.is_gt(max) & valid; + self.max = select(is_gt, value, max); + } + + fn merge(&mut self, other: Self) { + self.accumulate(other.max) + } + + fn finish(&mut self) -> T { + self.max + } +} + +fn reduce_accumulators, const LANES: usize>( + mut acc: [A; LANES], +) -> A { + assert!(LANES > 0 && LANES.is_power_of_two()); + let mut len = LANES; + + // attempt at tree reduction, unfortunately llvm does not fully recognize this pattern, + // but the generated code is still a little faster than purely sequential reduction for floats. + while len >= 2 { + let mid = len / 2; + let (h, t) = acc[..len].split_at_mut(mid); + + for i in 0..mid { + h[i].merge(t[i]); + } + len /= 2; + } + acc[0] +} + +#[inline(always)] +fn aggregate_nonnull_chunk, const LANES: usize>( + acc: &mut [A; LANES], + values: &[T; LANES], +) { + for i in 0..LANES { + acc[i].accumulate(values[i]); + } +} + +#[inline(always)] +fn aggregate_nullable_chunk, const LANES: usize>( + acc: &mut [A; LANES], + values: &[T; LANES], + validity: u64, +) { + let mut bit = 1; + for i in 0..LANES { + acc[i].accumulate_nullable(values[i], (validity & bit) != 0); + bit <<= 1; + } +} + +fn aggregate_nonnull_simple>(values: &[T]) -> T { + return values + .iter() + .copied() + .fold(A::default(), |mut a, b| { + a.accumulate(b); + a + }) + .finish(); +} + +#[inline(never)] +fn aggregate_nonnull_lanes, const LANES: usize>( + values: &[T], +) -> T { + // aggregating into multiple independent accumulators allows the compiler to use vector registers + // with a single accumulator the compiler would not be allowed to reorder floating point addition + let mut acc = [A::default(); LANES]; + let mut chunks = values.chunks_exact(LANES); + chunks.borrow_mut().for_each(|chunk| { + aggregate_nonnull_chunk(&mut acc, chunk[..LANES].try_into().unwrap()); + }); + + let remainder = chunks.remainder(); + for i in 0..remainder.len() { + acc[i].accumulate(remainder[i]); + } + + reduce_accumulators(acc).finish() +} + +#[inline(never)] +fn aggregate_nullable_lanes, const LANES: usize>( + values: &[T], + validity: &NullBuffer, +) -> T { + assert!(LANES > 0 && 64 % LANES == 0); + assert_eq!(values.len(), validity.len()); + + // aggregating into multiple independent accumulators allows the compiler to use vector registers + let mut acc = [A::default(); LANES]; + // we process 64 bits of validity at a time + let mut values_chunks = values.chunks_exact(64); + let validity_chunks = validity.inner().bit_chunks(); + let mut validity_chunks_iter = validity_chunks.iter(); + + values_chunks.borrow_mut().for_each(|chunk| { + // Safety: we asserted that values and validity have the same length and trust the iterator impl + let mut validity = unsafe { validity_chunks_iter.next().unwrap_unchecked() }; + // chunk further based on the number of vector lanes + chunk.chunks_exact(LANES).for_each(|chunk| { + aggregate_nullable_chunk(&mut acc, chunk[..LANES].try_into().unwrap(), validity); + validity >>= LANES; + }); + }); + + let remainder = values_chunks.remainder(); + if !remainder.is_empty() { + let mut validity = validity_chunks.remainder_bits(); + + let mut remainder_chunks = remainder.chunks_exact(LANES); + remainder_chunks.borrow_mut().for_each(|chunk| { + aggregate_nullable_chunk(&mut acc, chunk[..LANES].try_into().unwrap(), validity); + validity >>= LANES; + }); + + let remainder = remainder_chunks.remainder(); + if !remainder.is_empty() { + let mut bit = 1; + for i in 0..remainder.len() { + acc[i].accumulate_nullable(remainder[i], (validity & bit) != 0); + bit <<= 1; + } + } + } + + reduce_accumulators(acc).finish() +} + +/// The preferred vector size in bytes for the target platform. +/// Note that the avx512 target feature is still unstable and this also means it is not detected on stable rust. +const PREFERRED_VECTOR_SIZE: usize = + if cfg!(all(target_arch = "x86_64", target_feature = "avx512f")) { + 64 + } else if cfg!(all(target_arch = "x86_64", target_feature = "avx")) { + 32 + } else { + 16 + }; + +/// non-nullable aggregation requires fewer temporary registers so we can use more of them for accumulators +const PREFERRED_VECTOR_SIZE_NON_NULL: usize = PREFERRED_VECTOR_SIZE * 2; + +/// Generic aggregation for any primitive type. +/// Returns None if there are no non-null values in `array`. +fn aggregate, A: NumericAccumulator>( + array: &PrimitiveArray

, +) -> Option { + let null_count = array.null_count(); + if null_count == array.len() { + return None; + } + let values = array.values().as_ref(); + match array.nulls() { + Some(nulls) if null_count > 0 => { + // const generics depending on a generic type parameter are not supported + // so we have to match and call aggregate with the corresponding constant + match PREFERRED_VECTOR_SIZE / std::mem::size_of::() { + 64 => Some(aggregate_nullable_lanes::(values, nulls)), + 32 => Some(aggregate_nullable_lanes::(values, nulls)), + 16 => Some(aggregate_nullable_lanes::(values, nulls)), + 8 => Some(aggregate_nullable_lanes::(values, nulls)), + 4 => Some(aggregate_nullable_lanes::(values, nulls)), + 2 => Some(aggregate_nullable_lanes::(values, nulls)), + _ => Some(aggregate_nullable_lanes::(values, nulls)), + } + } + _ => { + let is_float = matches!( + array.data_type(), + DataType::Float16 | DataType::Float32 | DataType::Float64 + ); + if is_float { + match PREFERRED_VECTOR_SIZE_NON_NULL / std::mem::size_of::() { + 64 => Some(aggregate_nonnull_lanes::(values)), + 32 => Some(aggregate_nonnull_lanes::(values)), + 16 => Some(aggregate_nonnull_lanes::(values)), + 8 => Some(aggregate_nonnull_lanes::(values)), + 4 => Some(aggregate_nonnull_lanes::(values)), + 2 => Some(aggregate_nonnull_lanes::(values)), + _ => Some(aggregate_nonnull_simple::(values)), + } + } else { + // for non-null integers its better to not chunk ourselves and instead + // let llvm fully handle loop unrolling and vectorization + Some(aggregate_nonnull_simple::(values)) + } + } + } +} + +/// Returns the minimum value in the boolean array. +/// +/// ``` +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::aggregate::min_boolean; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(min_boolean(&a), Some(false)) +/// ``` +pub fn min_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + if array.null_count() == array.len() { + return None; + } + + // Note the min bool is false (0), so short circuit as soon as we see it + array + .iter() + .find(|&b| b == Some(false)) + .flatten() + .or(Some(true)) +} + +/// Returns the maximum value in the boolean array +/// +/// ``` +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::aggregate::max_boolean; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(max_boolean(&a), Some(true)) +/// ``` +pub fn max_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + if array.null_count() == array.len() { + return None; + } + + // Note the max bool is true (1), so short circuit as soon as we see it + array + .iter() + .find(|&b| b == Some(true)) + .flatten() + .or(Some(false)) +} + +/// Helper to compute min/max of [`ArrayAccessor`]. +fn min_max_helper, F>(array: A, cmp: F) -> Option +where + F: Fn(&T, &T) -> bool, +{ + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + // JUSTIFICATION + // Benefit: ~8% speedup + // Soundness: `i` is always within the array bounds + (0..array.len()) + .map(|i| unsafe { array.value_unchecked(i) }) + .reduce(|acc, item| if cmp(&acc, &item) { item } else { acc }) + } else { + let nulls = array.nulls().unwrap(); + unsafe { + let idx = nulls.valid_indices().reduce(|acc_idx, idx| { + let acc = array.value_unchecked(acc_idx); + let item = array.value_unchecked(idx); + if cmp(&acc, &item) { + idx + } else { + acc_idx + } + }); + idx.map(|idx| array.value_unchecked(idx)) + } + } +} + +/// Returns the maximum value in the binary array, according to the natural order. +pub fn max_binary(array: &GenericBinaryArray) -> Option<&[u8]> { + min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b) +} + +/// Returns the minimum value in the binary array, according to the natural order. +pub fn min_binary(array: &GenericBinaryArray) -> Option<&[u8]> { + min_max_helper::<&[u8], _, _>(array, |a, b| *a > *b) +} + +/// Returns the maximum value in the string array, according to the natural order. +pub fn max_string(array: &GenericStringArray) -> Option<&str> { + min_max_helper::<&str, _, _>(array, |a, b| *a < *b) +} + +/// Returns the minimum value in the string array, according to the natural order. +pub fn min_string(array: &GenericStringArray) -> Option<&str> { + min_max_helper::<&str, _, _>(array, |a, b| *a > *b) +} + +/// Returns the sum of values in the array. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `sum_array_checked` instead. +pub fn sum_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let iter = ArrayIter::new(array); + let sum = iter + .into_iter() + .fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator.add_wrapping(value) + } else { + accumulator + } + }); + + Some(sum) + } + _ => sum::(as_primitive_array(&array)), + } +} + +/// Returns the sum of values in the array. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `sum_array` instead. +pub fn sum_array_checked>( + array: A, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + let iter = ArrayIter::new(array); + let sum = iter + .into_iter() + .try_fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator.add_checked(value) + } else { + Ok(accumulator) + } + })?; + + Ok(Some(sum)) + } + _ => sum_checked::(as_primitive_array(&array)), + } +} + +/// Returns the min of values in the array of `ArrowNumericType` type, or dictionary +/// array with value of `ArrowNumericType` type. +pub fn min_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeType, +{ + min_max_array_helper::(array, |a, b| a.is_gt(*b), min) +} + +/// Returns the max of values in the array of `ArrowNumericType` type, or dictionary +/// array with value of `ArrowNumericType` type. +pub fn max_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + min_max_array_helper::(array, |a, b| a.is_lt(*b), max) +} + +fn min_max_array_helper, F, M>( + array: A, + cmp: F, + m: M, +) -> Option +where + T: ArrowNumericType, + F: Fn(&T::Native, &T::Native) -> bool, + M: Fn(&PrimitiveArray) -> Option, +{ + match array.data_type() { + DataType::Dictionary(_, _) => min_max_helper::(array, cmp), + _ => m(as_primitive_array(&array)), + } +} + +macro_rules! bit_operation { + ($NAME:ident, $OP:ident, $NATIVE:ident, $DEFAULT:expr, $DOC:expr) => { + #[doc = $DOC] + /// + /// Returns `None` if the array is empty or only contains null values. + pub fn $NAME(array: &PrimitiveArray) -> Option + where + T: ArrowNumericType, + T::Native: $NATIVE + ArrowNativeTypeOp, + { + let default; + if $DEFAULT == -1 { + default = T::Native::ONE.neg_wrapping(); + } else { + default = T::default_value(); + } + + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let data: &[T::Native] = array.values(); + + match array.nulls() { + None => { + let result = data + .iter() + .fold(default, |accumulator, value| accumulator.$OP(*value)); + + Some(result) + } + Some(nulls) => { + let mut result = default; + let data_chunks = data.chunks_exact(64); + let remainder = data_chunks.remainder(); + + let bit_chunks = nulls.inner().bit_chunks(); + data_chunks + .zip(bit_chunks.iter()) + .for_each(|(chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + chunk.iter().for_each(|value| { + if (mask & index_mask) != 0 { + result = result.$OP(*value); + } + index_mask <<= 1; + }); + }); + + let remainder_bits = bit_chunks.remainder_bits(); + + remainder.iter().enumerate().for_each(|(i, value)| { + if remainder_bits & (1 << i) != 0 { + result = result.$OP(*value); + } + }); + + Some(result) + } + } + } + }; +} + +bit_operation!( + bit_and, + bitand, + BitAnd, + -1, + "Returns the bitwise and of all non-null input values." +); +bit_operation!( + bit_or, + bitor, + BitOr, + 0, + "Returns the bitwise or of all non-null input values." +); +bit_operation!( + bit_xor, + bitxor, + BitXor, + 0, + "Returns the bitwise xor of all non-null input values." +); + +/// Returns true if all non-null input values are true, otherwise false. +/// +/// Returns `None` if the array is empty or only contains null values. +pub fn bool_and(array: &BooleanArray) -> Option { + if array.null_count() == array.len() { + return None; + } + Some(array.false_count() == 0) +} + +/// Returns true if any non-null input value is true, otherwise false. +/// +/// Returns `None` if the array is empty or only contains null values. +pub fn bool_or(array: &BooleanArray) -> Option { + if array.null_count() == array.len() { + return None; + } + Some(array.true_count() != 0) +} + +/// Returns the sum of values in the primitive array. +/// +/// Returns `Ok(None)` if the array is empty or only contains null values. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use `sum` instead. +pub fn sum_checked(array: &PrimitiveArray) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + let data: &[T::Native] = array.values(); + + match array.nulls() { + None => { + let sum = data + .iter() + .try_fold(T::default_value(), |accumulator, value| { + accumulator.add_checked(*value) + })?; + + Ok(Some(sum)) + } + Some(nulls) => { + let mut sum = T::default_value(); + + try_for_each_valid_idx( + nulls.len(), + nulls.offset(), + nulls.null_count(), + Some(nulls.validity()), + |idx| { + unsafe { sum = sum.add_checked(array.value_unchecked(idx))? }; + Ok::<_, ArrowError>(()) + }, + )?; + + Ok(Some(sum)) + } + } +} + +/// Returns the sum of values in the primitive array. +/// +/// Returns `None` if the array is empty or only contains null values. +/// +/// This doesn't detect overflow in release mode by default. Once overflowing, the result will +/// wrap around. For an overflow-checking variant, use `sum_checked` instead. +pub fn sum(array: &PrimitiveArray) -> Option +where + T::Native: ArrowNativeTypeOp, +{ + aggregate::>(array) +} + +/// Returns the minimum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn min(array: &PrimitiveArray) -> Option +where + T::Native: PartialOrd, +{ + aggregate::>(array) +} + +/// Returns the maximum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn max(array: &PrimitiveArray) -> Option +where + T::Native: PartialOrd, +{ + aggregate::>(array) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::types::*; + use arrow_buffer::NullBuffer; + use std::sync::Arc; + + #[test] + fn test_primitive_array_sum() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(15, sum(&a).unwrap()); + } + + #[test] + fn test_primitive_array_float_sum() { + let a = Float64Array::from(vec![1.1, 2.2, 3.3, 4.4, 5.5]); + assert_eq!(16.5, sum(&a).unwrap()); + } + + #[test] + fn test_primitive_array_sum_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(10, sum(&a).unwrap()); + } + + #[test] + fn test_primitive_array_sum_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, sum(&a)); + } + + #[test] + fn test_primitive_array_sum_large_float_64() { + let c = Float64Array::new((1..=100).map(|x| x as f64).collect(), None); + assert_eq!(Some((1..=100).sum::() as f64), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Float64Array::new((1..=100).map(|x| x as f64).collect(), Some(validity)); + + assert_eq!( + Some((1..=100).filter(|i| i % 3 == 0).sum::() as f64), + sum(&c) + ); + } + + #[test] + fn test_primitive_array_sum_large_float_32() { + let c = Float32Array::new((1..=100).map(|x| x as f32).collect(), None); + assert_eq!(Some((1..=100).sum::() as f32), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Float32Array::new((1..=100).map(|x| x as f32).collect(), Some(validity)); + + assert_eq!( + Some((1..=100).filter(|i| i % 3 == 0).sum::() as f32), + sum(&c) + ); + } + + #[test] + fn test_primitive_array_sum_large_64() { + let c = Int64Array::new((1..=100).collect(), None); + assert_eq!(Some((1..=100).sum()), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int64Array::new((1..=100).collect(), Some(validity)); + + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + } + + #[test] + fn test_primitive_array_sum_large_32() { + let c = Int32Array::new((1..=100).collect(), None); + assert_eq!(Some((1..=100).sum()), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int32Array::new((1..=100).collect(), Some(validity)); + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + } + + #[test] + fn test_primitive_array_sum_large_16() { + let c = Int16Array::new((1..=100).collect(), None); + assert_eq!(Some((1..=100).sum()), sum(&c)); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = Int16Array::new((1..=100).collect(), Some(validity)); + assert_eq!(Some((1..=100).filter(|i| i % 3 == 0).sum()), sum(&c)); + } + + #[test] + fn test_primitive_array_sum_large_8() { + let c = UInt8Array::new((1..=100).collect(), None); + assert_eq!( + Some((1..=100).fold(0_u8, |a, x| a.wrapping_add(x))), + sum(&c) + ); + + // create an array that actually has non-zero values at the invalid indices + let validity = NullBuffer::new((1..=100).map(|x| x % 3 == 0).collect()); + let c = UInt8Array::new((1..=100).collect(), Some(validity)); + assert_eq!( + Some( + (1..=100) + .filter(|i| i % 3 == 0) + .fold(0_u8, |a, x| a.wrapping_add(x)) + ), + sum(&c) + ); + } + + #[test] + fn test_primitive_array_bit_and() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(0, bit_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_and_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, None]); + assert_eq!(2, bit_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_and_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, bit_and(&a)); + } + + #[test] + fn test_primitive_array_bit_or() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(7, bit_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_or_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(7, bit_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_or_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, bit_or(&a)); + } + + #[test] + fn test_primitive_array_bit_xor() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(1, bit_xor(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_xor_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(4, bit_xor(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bit_xor_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, bit_xor(&a)); + } + + #[test] + fn test_primitive_array_bool_and() { + let a = BooleanArray::from(vec![true, false, true, false, true]); + assert!(!bool_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_and_with_nulls() { + let a = BooleanArray::from(vec![None, Some(true), Some(true), None, Some(true)]); + assert!(bool_and(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_and_all_nulls() { + let a = BooleanArray::from(vec![None, None, None]); + assert_eq!(None, bool_and(&a)); + } + + #[test] + fn test_primitive_array_bool_or() { + let a = BooleanArray::from(vec![true, false, true, false, true]); + assert!(bool_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_or_with_nulls() { + let a = BooleanArray::from(vec![None, Some(false), Some(false), None, Some(false)]); + assert!(!bool_or(&a).unwrap()); + } + + #[test] + fn test_primitive_array_bool_or_all_nulls() { + let a = BooleanArray::from(vec![None, None, None]); + assert_eq!(None, bool_or(&a)); + } + + #[test] + fn test_primitive_array_min_max() { + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); + assert_eq!(5, min(&a).unwrap()); + assert_eq!(9, max(&a).unwrap()); + } + + #[test] + fn test_primitive_array_min_max_with_nulls() { + let a = Int32Array::from(vec![Some(5), None, None, Some(8), Some(9)]); + assert_eq!(5, min(&a).unwrap()); + assert_eq!(9, max(&a).unwrap()); + } + + #[test] + fn test_primitive_min_max_1() { + let a = Int32Array::from(vec![None, None, Some(5), Some(2)]); + assert_eq!(Some(2), min(&a)); + assert_eq!(Some(5), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_large_nonnull_array() { + let a: Float64Array = (0..256).map(|i| Some((i + 1) as f64)).collect(); + // min/max are on boundaries of chunked data + assert_eq!(Some(1.0), min(&a)); + assert_eq!(Some(256.0), max(&a)); + + // max is last value in remainder after chunking + let a: Float64Array = (0..255).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(255.0), max(&a)); + + // max is first value in remainder after chunking + let a: Float64Array = (0..257).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(257.0), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_large_nullable_array() { + let a: Float64Array = (0..256) + .map(|i| { + if (i + 1) % 3 == 0 { + None + } else { + Some((i + 1) as f64) + } + }) + .collect(); + // min/max are on boundaries of chunked data + assert_eq!(Some(1.0), min(&a)); + assert_eq!(Some(256.0), max(&a)); + + let a: Float64Array = (0..256) + .map(|i| { + if i == 0 || i == 255 { + None + } else { + Some((i + 1) as f64) + } + }) + .collect(); + // boundaries of chunked data are null + assert_eq!(Some(2.0), min(&a)); + assert_eq!(Some(255.0), max(&a)); + + let a: Float64Array = (0..256) + .map(|i| if i != 100 { None } else { Some((i) as f64) }) + .collect(); + // a single non-null value somewhere in the middle + assert_eq!(Some(100.0), min(&a)); + assert_eq!(Some(100.0), max(&a)); + + // max is last value in remainder after chunking + let a: Float64Array = (0..255).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(255.0), max(&a)); + + // max is first value in remainder after chunking + let a: Float64Array = (0..257).map(|i| Some((i + 1) as f64)).collect(); + assert_eq!(Some(257.0), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_edge_cases() { + let a: Float64Array = (0..100).map(|_| Some(f64::NEG_INFINITY)).collect(); + assert_eq!(Some(f64::NEG_INFINITY), min(&a)); + assert_eq!(Some(f64::NEG_INFINITY), max(&a)); + + let a: Float64Array = (0..100).map(|_| Some(f64::MIN)).collect(); + assert_eq!(Some(f64::MIN), min(&a)); + assert_eq!(Some(f64::MIN), max(&a)); + + let a: Float64Array = (0..100).map(|_| Some(f64::MAX)).collect(); + assert_eq!(Some(f64::MAX), min(&a)); + assert_eq!(Some(f64::MAX), max(&a)); + + let a: Float64Array = (0..100).map(|_| Some(f64::INFINITY)).collect(); + assert_eq!(Some(f64::INFINITY), min(&a)); + assert_eq!(Some(f64::INFINITY), max(&a)); + } + + #[test] + fn test_primitive_min_max_float_all_nans_non_null() { + let a: Float64Array = (0..100).map(|_| Some(f64::NAN)).collect(); + assert!(max(&a).unwrap().is_nan()); + assert!(min(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_negative_nan() { + let a: Float64Array = + Float64Array::from(vec![f64::NEG_INFINITY, f64::NAN, f64::INFINITY, -f64::NAN]); + let max = max(&a).unwrap(); + let min = min(&a).unwrap(); + assert!(max.is_nan()); + assert!(max.is_sign_positive()); + + assert!(min.is_nan()); + assert!(min.is_sign_negative()); + } + + #[test] + fn test_primitive_min_max_float_first_nan_nonnull() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 0 { + Some(f64::NAN) + } else { + Some(i as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_last_nan_nonnull() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 99 { + Some(f64::NAN) + } else { + Some((i + 1) as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_first_nan_nullable() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 0 { + Some(f64::NAN) + } else if i % 2 == 0 { + None + } else { + Some(i as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_last_nan_nullable() { + let a: Float64Array = (0..100) + .map(|i| { + if i == 99 { + Some(f64::NAN) + } else if i % 2 == 0 { + None + } else { + Some(i as f64) + } + }) + .collect(); + assert_eq!(Some(1.0), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_primitive_min_max_float_inf_and_nans() { + let a: Float64Array = (0..100) + .map(|i| { + let x = match i % 10 { + 0 => f64::NEG_INFINITY, + 1 => f64::MIN, + 2 => f64::MAX, + 4 => f64::INFINITY, + 5 => f64::NAN, + _ => i as f64, + }; + Some(x) + }) + .collect(); + assert_eq!(Some(f64::NEG_INFINITY), min(&a)); + assert!(max(&a).unwrap().is_nan()); + } + + #[test] + fn test_binary_min_max_with_nulls() { + let a = BinaryArray::from(vec![ + Some("b".as_bytes()), + None, + None, + Some(b"a"), + Some(b"c"), + ]); + assert_eq!(Some("a".as_bytes()), min_binary(&a)); + assert_eq!(Some("c".as_bytes()), max_binary(&a)); + } + + #[test] + fn test_binary_min_max_no_null() { + let a = BinaryArray::from(vec![Some("b".as_bytes()), Some(b"a"), Some(b"c")]); + assert_eq!(Some("a".as_bytes()), min_binary(&a)); + assert_eq!(Some("c".as_bytes()), max_binary(&a)); + } + + #[test] + fn test_binary_min_max_all_nulls() { + let a = BinaryArray::from(vec![None, None]); + assert_eq!(None, min_binary(&a)); + assert_eq!(None, max_binary(&a)); + } + + #[test] + fn test_binary_min_max_1() { + let a = BinaryArray::from(vec![None, None, Some("b".as_bytes()), Some(b"a")]); + assert_eq!(Some("a".as_bytes()), min_binary(&a)); + assert_eq!(Some("b".as_bytes()), max_binary(&a)); + } + + #[test] + fn test_string_min_max_with_nulls() { + let a = StringArray::from(vec![Some("b"), None, None, Some("a"), Some("c")]); + assert_eq!(Some("a"), min_string(&a)); + assert_eq!(Some("c"), max_string(&a)); + } + + #[test] + fn test_string_min_max_all_nulls() { + let v: Vec> = vec![None, None]; + let a = StringArray::from(v); + assert_eq!(None, min_string(&a)); + assert_eq!(None, max_string(&a)); + } + + #[test] + fn test_string_min_max_1() { + let a = StringArray::from(vec![None, None, Some("b"), Some("a")]); + assert_eq!(Some("a"), min_string(&a)); + assert_eq!(Some("b"), max_string(&a)); + } + + #[test] + fn test_boolean_min_max_empty() { + let a = BooleanArray::from(vec![] as Vec>); + assert_eq!(None, min_boolean(&a)); + assert_eq!(None, max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max_all_null() { + let a = BooleanArray::from(vec![None, None]); + assert_eq!(None, min_boolean(&a)); + assert_eq!(None, max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max_no_null() { + let a = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max() { + let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), None]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![None, Some(true), None, Some(false), None]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(false), Some(true), None, Some(false), None]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + } + + #[test] + fn test_boolean_min_max_smaller() { + let a = BooleanArray::from(vec![Some(false)]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(false), max_boolean(&a)); + + let a = BooleanArray::from(vec![None, Some(false)]); + assert_eq!(Some(false), min_boolean(&a)); + assert_eq!(Some(false), max_boolean(&a)); + + let a = BooleanArray::from(vec![None, Some(true)]); + assert_eq!(Some(true), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + + let a = BooleanArray::from(vec![Some(true)]); + assert_eq!(Some(true), min_boolean(&a)); + assert_eq!(Some(true), max_boolean(&a)); + } + + #[test] + fn test_sum_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let values = Arc::new(values) as ArrayRef; + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(39, sum_array::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(15, sum_array::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(26, sum_array::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(sum_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + let values = Arc::new(values) as ArrayRef; + + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(14, max_array::(array).unwrap()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(5, max_array::(&a).unwrap()); + assert_eq!(1, min_array::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(7)]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(17, max_array::(array).unwrap()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::new(keys, values.clone()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_array::(array).is_none()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(min_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn_nan() { + let values = Float32Array::from(vec![5.0_f32, 2.0_f32, f32::NAN]); + let keys = Int8Array::from_iter_values([0_i8, 1, 2]); + + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_array::(array).unwrap().is_nan()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(2.0_f32, min_array::(array).unwrap()); + } + + #[test] + fn test_min_max_sliced_primitive() { + let expected = Some(4.0); + let input: Float64Array = vec![None, Some(4.0)].into_iter().collect(); + let actual = min(&input); + assert_eq!(actual, expected); + let actual = max(&input); + assert_eq!(actual, expected); + + let sliced_input: Float64Array = vec![None, None, None, None, None, Some(4.0)] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(&sliced_input, &input); + + let actual = min(&sliced_input); + assert_eq!(actual, expected); + let actual = max(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_min_max_sliced_boolean() { + let expected = Some(true); + let input: BooleanArray = vec![None, Some(true)].into_iter().collect(); + let actual = min_boolean(&input); + assert_eq!(actual, expected); + let actual = max_boolean(&input); + assert_eq!(actual, expected); + + let sliced_input: BooleanArray = vec![None, None, None, None, None, Some(true)] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(sliced_input, input); + + let actual = min_boolean(&sliced_input); + assert_eq!(actual, expected); + let actual = max_boolean(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_min_max_sliced_string() { + let expected = Some("foo"); + let input: StringArray = vec![None, Some("foo")].into_iter().collect(); + let actual = min_string(&input); + assert_eq!(actual, expected); + let actual = max_string(&input); + assert_eq!(actual, expected); + + let sliced_input: StringArray = vec![None, None, None, None, None, Some("foo")] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(&sliced_input, &input); + + let actual = min_string(&sliced_input); + assert_eq!(actual, expected); + let actual = max_string(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_min_max_sliced_binary() { + let expected: Option<&[u8]> = Some(&[5]); + let input: BinaryArray = vec![None, Some(&[5])].into_iter().collect(); + let actual = min_binary(&input); + assert_eq!(actual, expected); + let actual = max_binary(&input); + assert_eq!(actual, expected); + + let sliced_input: BinaryArray = vec![None, None, None, None, None, Some(&[5])] + .into_iter() + .collect(); + let sliced_input = sliced_input.slice(4, 2); + + assert_eq!(&sliced_input, &input); + + let actual = min_binary(&sliced_input); + assert_eq!(actual, expected); + let actual = max_binary(&sliced_input); + assert_eq!(actual, expected); + } + + #[test] + fn test_sum_overflow() { + let a = Int32Array::from(vec![i32::MAX, 1]); + + assert_eq!(sum(&a).unwrap(), -2147483648); + assert_eq!(sum_array::(&a).unwrap(), -2147483648); + } + + #[test] + fn test_sum_checked_overflow() { + let a = Int32Array::from(vec![i32::MAX, 1]); + + sum_checked(&a).expect_err("overflow should be detected"); + sum_array_checked::(&a).expect_err("overflow should be detected"); + } +} diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs new file mode 100644 index 000000000000..124614d77f97 --- /dev/null +++ b/arrow-arith/src/arithmetic.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines basic arithmetic kernels for `PrimitiveArrays`. +//! +//! These kernels can leverage SIMD if available on your system. Currently no runtime +//! detection is provided, you should enable the specific SIMD intrinsics using +//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation +//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. + +use crate::arity::*; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::i256; +use arrow_buffer::ArrowNativeType; +use arrow_schema::*; +use std::cmp::min; +use std::sync::Arc; + +/// Returns the precision and scale of the result of a multiplication of two decimal types, +/// and the divisor for fixed point multiplication. +fn get_fixed_point_info( + left: (u8, i8), + right: (u8, i8), + required_scale: i8, +) -> Result<(u8, i8, i256), ArrowError> { + let product_scale = left.1 + right.1; + let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION); + + if required_scale > product_scale { + return Err(ArrowError::ComputeError(format!( + "Required scale {} is greater than product scale {}", + required_scale, product_scale + ))); + } + + let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32); + + Ok((precision, product_scale, divisor)) +} + +/// Perform `left * right` operation on two decimal arrays. If either left or right value is +/// null then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply_dyn` or `multiply_dyn_checked` instead. +pub fn multiply_fixed_point_dyn( + left: &dyn Array, + right: &dyn Array, + required_scale: i8, +) -> Result { + match (left.data_type(), right.data_type()) { + (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => { + let left = left.as_any().downcast_ref::().unwrap(); + let right = right.as_any().downcast_ref::().unwrap(); + + multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef) + } + (_, _) => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), + right.data_type() + ))), + } +} + +/// Perform `left * right` operation on two decimal arrays. If either left or right value is +/// null then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply` or `multiply_checked` instead. +pub fn multiply_fixed_point_checked( + left: &PrimitiveArray, + right: &PrimitiveArray, + required_scale: i8, +) -> Result, ArrowError> { + let (precision, product_scale, divisor) = get_fixed_point_info( + (left.precision(), left.scale()), + (right.precision(), right.scale()), + required_scale, + )?; + + if required_scale == product_scale { + return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))? + .with_precision_and_scale(precision, required_scale); + } + + try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { + let a = i256::from_i128(a); + let b = i256::from_i128(b); + + let mut mul = a.wrapping_mul(b); + mul = divide_and_round::(mul, divisor); + mul.to_i128().ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow happened on: {:?} * {:?}", a, b)) + }) + }) + .and_then(|a| a.with_precision_and_scale(precision, required_scale)) +} + +/// Perform `left * right` operation on two decimal arrays. If either left or right value is +/// null then the result is also null. +/// +/// This performs decimal multiplication which allows precision loss if an exact representation +/// is not possible for the result, according to the required scale. In the case, the result +/// will be rounded to the required scale. +/// +/// If the required scale is greater than the product scale, an error is returned. +/// +/// This doesn't detect overflow. Once overflowing, the result will wrap around. +/// For an overflow-checking variant, use `multiply_fixed_point_checked` instead. +/// +/// It is implemented for compatibility with precision loss `multiply` function provided by +/// other data processing engines. For multiplication with precision loss detection, use +/// `multiply` or `multiply_checked` instead. +pub fn multiply_fixed_point( + left: &PrimitiveArray, + right: &PrimitiveArray, + required_scale: i8, +) -> Result, ArrowError> { + let (precision, product_scale, divisor) = get_fixed_point_info( + (left.precision(), left.scale()), + (right.precision(), right.scale()), + required_scale, + )?; + + if required_scale == product_scale { + return binary(left, right, |a, b| a.mul_wrapping(b))? + .with_precision_and_scale(precision, required_scale); + } + + binary::<_, _, _, Decimal128Type>(left, right, |a, b| { + let a = i256::from_i128(a); + let b = i256::from_i128(b); + + let mut mul = a.wrapping_mul(b); + mul = divide_and_round::(mul, divisor); + mul.as_i128() + }) + .and_then(|a| a.with_precision_and_scale(precision, required_scale)) +} + +/// Divide a decimal native value by given divisor and round the result. +fn divide_and_round(input: I::Native, div: I::Native) -> I::Native +where + I: DecimalType, + I::Native: ArrowNativeTypeOp, +{ + let d = input.div_wrapping(div); + let r = input.mod_wrapping(div); + + let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); + let half_neg = half.neg_wrapping(); + + // Round result + match input >= I::Native::ZERO { + true if r >= half => d.add_wrapping(I::Native::ONE), + false if r <= half_neg => d.sub_wrapping(I::Native::ONE), + _ => d, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::numeric::mul; + + #[test] + fn test_decimal_multiply_allow_precision_loss() { + // Overflow happening as i128 cannot hold multiplying result. + // [123456789] + let a = Decimal128Array::from(vec![123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [10] + let b = Decimal128Array::from(vec![10000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let err = mul(&a, &b).unwrap_err(); + assert!(err + .to_string() + .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000")); + + // Allow precision loss. + let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); + // [1234567890] + let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + assert_eq!( + result.value_as_string(0), + "1234567890.0000000000000000000000000000" + ); + + // Rounding case + // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555] + let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [1.555555555555555555, 11.222222222222222222, 0.000000000000000001] + let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); + // [ + // 0.0000000000000000015555555556, + // 1385459527.2345679012071330528765432099, + // 0.0000000000000000015555555556 + // ] + let expected = Decimal128Array::from(vec![ + 15555555556, + 13854595272345679012071330528765432099, + 15555555556, + ]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + + // Rounded the value "1385459527.234567901207133052876543209876543210". + assert_eq!( + result.value_as_string(1), + "1385459527.2345679012071330528765432099" + ); + assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556"); + assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556"); + + let a = Decimal128Array::from(vec![1230]) + .with_precision_and_scale(4, 2) + .unwrap(); + + let b = Decimal128Array::from(vec![1000]) + .with_precision_and_scale(4, 2) + .unwrap(); + + // Required scale is same as the product of the input scales. Behavior is same as multiply. + let result = multiply_fixed_point_checked(&a, &b, 4).unwrap(); + assert_eq!(result.precision(), 9); + assert_eq!(result.scale(), 4); + + let expected = mul(&a, &b).unwrap(); + assert_eq!(expected.as_ref(), &result); + + // Required scale cannot be larger than the product of the input scales. + let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err(); + assert!(result + .to_string() + .contains("Required scale 5 is greater than product scale 4")); + } + + #[test] + fn test_decimal_multiply_allow_precision_loss_overflow() { + // [99999999999123456789] + let a = Decimal128Array::from(vec![99999999999123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [9999999999910] + let b = Decimal128Array::from(vec![9999999999910000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err(); + assert!(err.to_string().contains( + "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000" + )); + + let result = multiply_fixed_point(&a, &b, 28).unwrap(); + let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + } + + #[test] + fn test_decimal_multiply_fixed_point() { + // [123456789] + let a = Decimal128Array::from(vec![123456789000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // [10] + let b = Decimal128Array::from(vec![10000000000000000000]) + .with_precision_and_scale(38, 18) + .unwrap(); + + // `multiply` overflows on this case. + let err = mul(&a, &b).unwrap_err(); + assert_eq!(err.to_string(), "Compute error: Overflow happened on: 123456789000000000000000000 * 10000000000000000000"); + + // Avoid overflow by reducing the scale. + let result = multiply_fixed_point(&a, &b, 28).unwrap(); + // [1234567890] + let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) + .with_precision_and_scale(38, 28) + .unwrap(); + + assert_eq!(&expected, &result); + assert_eq!( + result.value_as_string(0), + "1234567890.0000000000000000000000000000" + ); + } +} diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs new file mode 100644 index 000000000000..ff8b82a5d943 --- /dev/null +++ b/arrow-arith/src/arity.rs @@ -0,0 +1,545 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines kernels suitable to perform operations to primitive arrays. + +use arrow_array::builder::BufferBuilder; +use arrow_array::types::ArrowDictionaryKeyType; +use arrow_array::*; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_data::ArrayData; +use arrow_schema::ArrowError; +use std::sync::Arc; + +/// See [`PrimitiveArray::unary`] +pub fn unary(array: &PrimitiveArray, op: F) -> PrimitiveArray +where + I: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(I::Native) -> O::Native, +{ + array.unary(op) +} + +/// See [`PrimitiveArray::unary_mut`] +pub fn unary_mut( + array: PrimitiveArray, + op: F, +) -> Result, PrimitiveArray> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> I::Native, +{ + array.unary_mut(op) +} + +/// See [`PrimitiveArray::try_unary`] +pub fn try_unary(array: &PrimitiveArray, op: F) -> Result, ArrowError> +where + I: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(I::Native) -> Result, +{ + array.try_unary(op) +} + +/// See [`PrimitiveArray::try_unary_mut`] +pub fn try_unary_mut( + array: PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + I: ArrowPrimitiveType, + F: Fn(I::Native) -> Result, +{ + array.try_unary_mut(op) +} + +/// A helper function that applies an infallible unary function to a dictionary array with primitive value type. +fn unary_dict(array: &DictionaryArray, op: F) -> Result +where + K: ArrowDictionaryKeyType + ArrowNumericType, + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + let dict_values = array.values().as_any().downcast_ref().unwrap(); + let values = unary::(dict_values, op); + Ok(Arc::new(array.with_values(Arc::new(values)))) +} + +/// A helper function that applies a fallible unary function to a dictionary array with primitive value type. +fn try_unary_dict(array: &DictionaryArray, op: F) -> Result +where + K: ArrowDictionaryKeyType + ArrowNumericType, + T: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, +{ + if !PrimitiveArray::::is_compatible(&array.value_type()) { + return Err(ArrowError::CastError(format!( + "Cannot perform the unary operation of type {} on dictionary array of value type {}", + T::DATA_TYPE, + array.value_type() + ))); + } + + let dict_values = array.values().as_any().downcast_ref().unwrap(); + let values = try_unary::(dict_values, op)?; + Ok(Arc::new(array.with_values(Arc::new(values)))) +} + +/// Applies an infallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] +pub fn unary_dyn(array: &dyn Array, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + downcast_dictionary_array! { + array => unary_dict::<_, F, T>(array, op), + t => { + if PrimitiveArray::::is_compatible(t) { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + t + ))) + } + } + } +} + +/// Applies a fallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] +pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, +{ + downcast_dictionary_array! { + array => if array.values().data_type() == &T::DATA_TYPE { + try_unary_dict::<_, F, T>(array, op) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation on dictionary array of type {}", + array.data_type() + ))) + }, + t => { + if PrimitiveArray::::is_compatible(t) { + Ok(Arc::new(try_unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + )?)) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + t + ))) + } + } + } +} + +/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting +/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the +/// corresponding index in the result will also be null +/// +/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This +/// is beneficial when the cost of the operation is low compared to the cost of branching, and +/// especially when the operation can be vectorised, however, requires `op` to be infallible +/// for all possible values of its inputs +/// +/// # Error +/// +/// This function gives error if the arrays have different lengths +pub fn binary( + a: &PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError> +where + A: ArrowPrimitiveType, + B: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(A::Native, B::Native) -> O::Native, +{ + if a.len() != b.len() { + return Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + )); + } + + if a.is_empty() { + return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); + } + + let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()); + + let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size from a PrimitiveArray + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + Ok(PrimitiveArray::new(buffer.into(), nulls)) +} + +/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, mutating +/// the mutable [`PrimitiveArray`] `a`. If any index is null in either `a` or `b`, the +/// corresponding index in the result will also be null. +/// +/// Mutable primitive array means that the buffer is not shared with other arrays. +/// As a result, this mutates the buffer directly without allocating new buffer. +/// +/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This +/// is beneficial when the cost of the operation is low compared to the cost of branching, and +/// especially when the operation can be vectorised, however, requires `op` to be infallible +/// for all possible values of its inputs +/// +/// # Error +/// +/// This function gives error if the arrays have different lengths. +/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable +/// primitive array. +pub fn binary_mut( + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> T::Native, +{ + if a.len() != b.len() { + return Ok(Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + ))); + } + + if a.is_empty() { + return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty( + &T::DATA_TYPE, + )))); + } + + let nulls = NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()); + + let mut builder = a.into_builder()?; + + builder + .values_slice_mut() + .iter_mut() + .zip(b.values()) + .for_each(|(l, r)| *l = op(*l, *r)); + + let array_builder = builder.finish().into_data().into_builder().nulls(nulls); + + let array_data = unsafe { array_builder.build_unchecked() }; + Ok(Ok(PrimitiveArray::::from(array_data))) +} + +/// Applies the provided fallible binary operation across `a` and `b`, returning any error, +/// and collecting the results into a [`PrimitiveArray`]. If any index is null in either `a` +/// or `b`, the corresponding index in the result will also be null +/// +/// Like [`try_unary`] the function is only evaluated for non-null indices +/// +/// # Error +/// +/// Return an error if the arrays have different lengths or +/// the operation is under erroneous +pub fn try_binary( + a: A, + b: B, + op: F, +) -> Result, ArrowError> +where + O: ArrowPrimitiveType, + F: Fn(A::Item, B::Item) -> Result, +{ + if a.len() != b.len() { + return Err(ArrowError::ComputeError( + "Cannot perform a binary operation on arrays of different length".to_string(), + )); + } + if a.is_empty() { + return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); + } + let len = a.len(); + + if a.null_count() == 0 && b.null_count() == 0 { + try_binary_no_nulls(len, a, b, op) + } else { + let nulls = + NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap(); + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + nulls.try_for_each_valid_idx(|idx| { + unsafe { + *slice.get_unchecked_mut(idx) = op(a.value_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + })?; + + let values = buffer.finish().into(); + Ok(PrimitiveArray::new(values, Some(nulls))) + } +} + +/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable +/// [`PrimitiveArray`] `a` with the results, returning any error. If any index is null in +/// either `a` or `b`, the corresponding index in the result will also be null +/// +/// Like [`try_unary`] the function is only evaluated for non-null indices +/// +/// Mutable primitive array means that the buffer is not shared with other arrays. +/// As a result, this mutates the buffer directly without allocating new buffer. +/// +/// # Error +/// +/// Return an error if the arrays have different lengths or +/// the operation is under erroneous. +/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable +/// primitive array. +pub fn try_binary_mut( + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> Result, +{ + if a.len() != b.len() { + return Ok(Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + ))); + } + let len = a.len(); + + if a.is_empty() { + return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty( + &T::DATA_TYPE, + )))); + } + + if a.null_count() == 0 && b.null_count() == 0 { + try_binary_no_nulls_mut(len, a, b, op) + } else { + let nulls = + NullBuffer::union(a.logical_nulls().as_ref(), b.logical_nulls().as_ref()).unwrap(); + + let mut builder = a.into_builder()?; + + let slice = builder.values_slice_mut(); + + match nulls.try_for_each_valid_idx(|idx| { + unsafe { + *slice.get_unchecked_mut(idx) = + op(*slice.get_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + }) { + Ok(_) => {} + Err(err) => return Ok(Err(err)), + }; + + let array_builder = builder.finish().into_data().into_builder(); + let array_data = unsafe { array_builder.nulls(Some(nulls)).build_unchecked() }; + Ok(Ok(PrimitiveArray::::from(array_data))) + } +} + +/// This intentional inline(never) attribute helps LLVM optimize the loop. +#[inline(never)] +fn try_binary_no_nulls( + len: usize, + a: A, + b: B, + op: F, +) -> Result, ArrowError> +where + O: ArrowPrimitiveType, + F: Fn(A::Item, B::Item) -> Result, +{ + let mut buffer = MutableBuffer::new(len * O::get_byte_width()); + for idx in 0..len { + unsafe { + buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?); + }; + } + Ok(PrimitiveArray::new(buffer.into(), None)) +} + +/// This intentional inline(never) attribute helps LLVM optimize the loop. +#[inline(never)] +fn try_binary_no_nulls_mut( + len: usize, + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> Result, ArrowError>, PrimitiveArray> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> Result, +{ + let mut builder = a.into_builder()?; + let slice = builder.values_slice_mut(); + + for idx in 0..len { + unsafe { + match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) { + Ok(value) => *slice.get_unchecked_mut(idx) = value, + Err(err) => return Ok(Err(err)), + }; + }; + } + Ok(Ok(builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::builder::*; + use arrow_array::types::*; + + #[test] + #[allow(deprecated)] + fn test_unary_f64_slice() { + let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); + let input_slice = input.slice(1, 4); + let result = unary(&input_slice, |n| n.round()); + assert_eq!( + result, + Float64Array::from(vec![None, Some(7.0), None, Some(7.0)]) + ); + + let result = unary_dyn::<_, Float64Type>(&input_slice, |n| n + 1.0).unwrap(); + + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &Float64Array::from(vec![None, Some(7.8), None, Some(8.2)]) + ); + } + + #[test] + #[allow(deprecated)] + fn test_unary_dict_and_unary_dyn() { + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(5).unwrap(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null(); + builder.append(9).unwrap(); + let dictionary_array = builder.finish(); + + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + builder.append_null(); + builder.append(10).unwrap(); + let expected = builder.finish(); + + let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); + assert_eq!( + result + .as_any() + .downcast_ref::>() + .unwrap(), + &expected + ); + + let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap(); + assert_eq!( + result + .as_any() + .downcast_ref::>() + .unwrap(), + &expected + ); + } + + #[test] + fn test_binary_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap(); + + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + } + + #[test] + fn test_try_binary_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap(); + + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![1, 2, 3, 4, 5]); + let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap(); + let expected = Int32Array::from(vec![16, 16, 12, 12, 6]); + assert_eq!(c, expected); + + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let _ = try_binary_mut(a, &b, |l, r| { + if l == 1 { + Err(ArrowError::InvalidArgumentError( + "got error".parse().unwrap(), + )) + } else { + Ok(l + r) + } + }) + .unwrap() + .expect_err("should got error"); + } + + #[test] + fn test_unary_dict_mut() { + let values = Int32Array::from(vec![Some(10), Some(20), None]); + let keys = Int8Array::from_iter_values([0, 0, 1, 2]); + let dictionary = DictionaryArray::new(keys, Arc::new(values)); + + let updated = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap(); + let typed = updated.downcast_dict::().unwrap(); + assert_eq!(typed.value(0), 11); + assert_eq!(typed.value(1), 11); + assert_eq!(typed.value(2), 21); + + let values = updated.values(); + assert!(values.is_null(2)); + } +} diff --git a/arrow-arith/src/bitwise.rs b/arrow-arith/src/bitwise.rs new file mode 100644 index 000000000000..c7885952f8ba --- /dev/null +++ b/arrow-arith/src/bitwise.rs @@ -0,0 +1,351 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arity::{binary, unary}; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::ArrowError; +use num::traits::{WrappingShl, WrappingShr}; +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +/// The helper function for bitwise operation with two array +fn bitwise_op( + left: &PrimitiveArray, + right: &PrimitiveArray, + op: F, +) -> Result, ArrowError> +where + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> T::Native, +{ + binary(left, right, op) +} + +/// Perform `left & right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_and( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitAnd, +{ + bitwise_op(left, right, |a, b| a & b) +} + +/// Perform `left | right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_or( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitOr, +{ + bitwise_op(left, right, |a, b| a | b) +} + +/// Perform `left ^ right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_xor( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitXor, +{ + bitwise_op(left, right, |a, b| a ^ b) +} + +/// Perform bitwise `left << right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_shift_left( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShl, +{ + bitwise_op(left, right, |a, b| { + let b = b.as_usize(); + a.wrapping_shl(b as u32) + }) +} + +/// Perform bitwise `left >> right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn bitwise_shift_right( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShr, +{ + bitwise_op(left, right, |a, b| { + let b = b.as_usize(); + a.wrapping_shr(b as u32) + }) +} + +/// Perform `!array` operation on array. If array value is null +/// then the result is also null. +pub fn bitwise_not(array: &PrimitiveArray) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: Not, +{ + Ok(unary(array, |value| !value)) +} + +/// Perform bitwise `and` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_and_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitAnd, +{ + Ok(unary(array, |value| value & scalar)) +} + +/// Perform bitwise `or` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_or_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitOr, +{ + Ok(unary(array, |value| value | scalar)) +} + +/// Perform bitwise `xor` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_xor_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: BitXor, +{ + Ok(unary(array, |value| value ^ scalar)) +} + +/// Perform bitwise `left << right` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_shift_left_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShl, +{ + Ok(unary(array, |value| { + let scalar = scalar.as_usize(); + value.wrapping_shl(scalar as u32) + })) +} + +/// Perform bitwise `left >> right` every value in an array with the scalar. If any value in the array is null then the +/// result is also null. +pub fn bitwise_shift_right_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result, ArrowError> +where + T: ArrowNumericType, + T::Native: WrappingShr, +{ + Ok(unary(array, |value| { + let scalar = scalar.as_usize(); + value.wrapping_shr(scalar as u32) + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bitwise_and_array() -> Result<(), ArrowError> { + // unsigned value + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12)]); + let expected = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let result = bitwise_and(&left, &right)?; + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = Int32Array::from(vec![Some(5), Some(-10), Some(8), Some(12)]); + let expected = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let result = bitwise_and(&left, &right)?; + assert_eq!(expected, result); + Ok(()) + } + + #[test] + fn test_bitwise_shift_left() { + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(8)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(u64::MAX)]); + let expected = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(0)]); + let result = bitwise_shift_left(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_shift_left_scalar() { + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(8)]); + let scalar = 2; + let expected = UInt64Array::from(vec![Some(4), Some(8), None, Some(16), Some(32)]); + let result = bitwise_shift_left_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_shift_right() { + let left = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); + let right = UInt64Array::from(vec![Some(5), Some(10), Some(8), Some(12), Some(65)]); + let expected = UInt64Array::from(vec![Some(1), Some(2), None, Some(4), Some(1)]); + let result = bitwise_shift_right(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_shift_right_scalar() { + let left = UInt64Array::from(vec![Some(32), Some(2048), None, Some(16384), Some(3)]); + let scalar = 2; + let expected = UInt64Array::from(vec![Some(8), Some(512), None, Some(4096), Some(0)]); + let result = bitwise_shift_right_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_and_array_scalar() { + // unsigned value + let left = UInt64Array::from(vec![Some(15), Some(2), None, Some(4)]); + let scalar = 7; + let expected = UInt64Array::from(vec![Some(7), Some(2), None, Some(4)]); + let result = bitwise_and_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let scalar = -20; + let expected = Int32Array::from(vec![Some(0), Some(0), None, Some(4)]); + let result = bitwise_and_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_or_array() { + // unsigned value + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = UInt64Array::from(vec![Some(7), Some(5), Some(8), Some(13)]); + let expected = UInt64Array::from(vec![Some(7), Some(7), None, Some(13)]); + let result = bitwise_or(&left, &right).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = Int32Array::from(vec![Some(-7), Some(-5), Some(8), Some(13)]); + let expected = Int32Array::from(vec![Some(-7), Some(-5), None, Some(13)]); + let result = bitwise_or(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_not_array() { + // unsigned value + let array = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let expected = UInt64Array::from(vec![ + Some(18446744073709551614), + Some(18446744073709551613), + None, + Some(18446744073709551611), + ]); + let result = bitwise_not(&array).unwrap(); + assert_eq!(expected, result); + // signed value + let array = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let expected = Int32Array::from(vec![Some(-2), Some(-3), None, Some(-5)]); + let result = bitwise_not(&array).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_or_array_scalar() { + // unsigned value + let left = UInt64Array::from(vec![Some(15), Some(2), None, Some(4)]); + let scalar = 7; + let expected = UInt64Array::from(vec![Some(15), Some(7), None, Some(7)]); + let result = bitwise_or_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let scalar = 20; + let expected = Int32Array::from(vec![Some(21), Some(22), None, Some(20)]); + let result = bitwise_or_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_xor_array() { + // unsigned value + let left = UInt64Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = UInt64Array::from(vec![Some(7), Some(5), Some(8), Some(13)]); + let expected = UInt64Array::from(vec![Some(6), Some(7), None, Some(9)]); + let result = bitwise_xor(&left, &right).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let right = Int32Array::from(vec![Some(-7), Some(5), Some(8), Some(-13)]); + let expected = Int32Array::from(vec![Some(-8), Some(7), None, Some(-9)]); + let result = bitwise_xor(&left, &right).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn test_bitwise_xor_array_scalar() { + // unsigned value + let left = UInt64Array::from(vec![Some(15), Some(2), None, Some(4)]); + let scalar = 7; + let expected = UInt64Array::from(vec![Some(8), Some(5), None, Some(3)]); + let result = bitwise_xor_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + + // signed value + let left = Int32Array::from(vec![Some(1), Some(2), None, Some(4)]); + let scalar = -20; + let expected = Int32Array::from(vec![Some(-19), Some(-18), None, Some(-24)]); + let result = bitwise_xor_scalar(&left, scalar).unwrap(); + assert_eq!(expected, result); + } +} diff --git a/arrow/src/compute/kernels/boolean.rs b/arrow-arith/src/boolean.rs similarity index 56% rename from arrow/src/compute/kernels/boolean.rs rename to arrow-arith/src/boolean.rs index c51953a7540c..269a36d66c2b 100644 --- a/arrow/src/compute/kernels/boolean.rs +++ b/arrow-arith/src/boolean.rs @@ -22,36 +22,52 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. -use std::ops::Not; - -use crate::array::{Array, ArrayData, BooleanArray, PrimitiveArray}; -use crate::buffer::{ - bitwise_bin_op_helper, bitwise_quaternary_op_helper, buffer_bin_and, buffer_bin_or, - buffer_unary_not, Buffer, MutableBuffer, -}; -use crate::compute::util::combine_option_bitmap; -use crate::datatypes::{ArrowNumericType, DataType}; -use crate::error::{ArrowError, Result}; -use crate::util::bit_util::ceil; - -/// Updates null buffer based on data buffer and null buffer of the operand at other side -/// in boolean AND kernel with Kleene logic. In short, because for AND kernel, null AND false -/// results false. So we cannot simply AND two null buffers. This function updates null buffer -/// of one side if other side is a false value. -pub(crate) fn build_null_buffer_for_and_kleene( - left_data: &ArrayData, - left_offset: usize, - right_data: &ArrayData, - right_offset: usize, - len_in_bits: usize, -) -> Option { - let left_buffer = &left_data.buffers()[0]; - let right_buffer = &right_data.buffers()[0]; - - let left_null_buffer = left_data.null_buffer(); - let right_null_buffer = right_data.null_buffer(); - - match (left_null_buffer, right_null_buffer) { +use arrow_array::*; +use arrow_buffer::buffer::{bitwise_bin_op_helper, bitwise_quaternary_op_helper}; +use arrow_buffer::{BooleanBuffer, NullBuffer}; +use arrow_schema::ArrowError; + +/// Logical 'and' boolean values with Kleene logic +/// +/// # Behavior +/// +/// This function behaves as follows with nulls: +/// +/// * `true` and `null` = `null` +/// * `null` and `true` = `null` +/// * `false` and `null` = `false` +/// * `null` and `false` = `false` +/// * `null` and `null` = `null` +/// +/// In other words, in this context a null value really means \"unknown\", +/// and an unknown value 'and' false is always false. +/// For a different null behavior, see function \"and\". +/// +/// # Example +/// +/// ```rust +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::and_kleene; +/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); +/// let b = BooleanArray::from(vec![None, None, None]); +/// let and_ab = and_kleene(&a, &b).unwrap(); +/// assert_eq!(and_ab, BooleanArray::from(vec![None, Some(false), None])); +/// ``` +/// +/// # Fails +/// +/// If the operands have different lengths +pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform bitwise operation on arrays of different length".to_string(), + )); + } + + let left_values = left.values(); + let right_values = right.values(); + + let buffer = match (left.nulls(), right.nulls()) { (None, None) => None, (Some(left_null_buffer), None) => { // The right side has no null values. @@ -59,22 +75,22 @@ pub(crate) fn build_null_buffer_for_and_kleene( // 1. left null bit is set, or // 2. right data bit is false (because null AND false = false). Some(bitwise_bin_op_helper( - left_null_buffer, - left_offset, - right_buffer, - right_offset, - len_in_bits, + left_null_buffer.buffer(), + left_null_buffer.offset(), + right_values.inner(), + right_values.offset(), + left.len(), |a, b| a | !b, )) } (None, Some(right_null_buffer)) => { // Same as above Some(bitwise_bin_op_helper( - right_null_buffer, - right_offset, - left_buffer, - left_offset, - len_in_bits, + right_null_buffer.buffer(), + right_null_buffer.offset(), + left_values.inner(), + left_values.offset(), + left.len(), |a, b| a | !b, )) } @@ -85,109 +101,131 @@ pub(crate) fn build_null_buffer_for_and_kleene( // The final null bits are: // (a | (c & !d)) & (c | (a & !b)) Some(bitwise_quaternary_op_helper( - left_null_buffer, - left_offset, - left_buffer, - left_offset, - right_null_buffer, - right_offset, - right_buffer, - right_offset, - len_in_bits, + [ + left_null_buffer.buffer(), + left_values.inner(), + right_null_buffer.buffer(), + right_values.inner(), + ], + [ + left_null_buffer.offset(), + left_values.offset(), + right_null_buffer.offset(), + right_values.offset(), + ], + left.len(), |a, b, c, d| (a | (c & !d)) & (c | (a & !b)), )) } - } + }; + let nulls = buffer.map(|b| NullBuffer::new(BooleanBuffer::new(b, 0, left.len()))); + Ok(BooleanArray::new(left_values & right_values, nulls)) } -/// For AND/OR kernels, the result of null buffer is simply a bitwise `and` operation. -pub(crate) fn build_null_buffer_for_and_or( - left_data: &ArrayData, - _left_offset: usize, - right_data: &ArrayData, - _right_offset: usize, - len_in_bits: usize, -) -> Option { - // `arrays` are not empty, so safely do `unwrap` directly. - combine_option_bitmap(&[left_data, right_data], len_in_bits).unwrap() -} +/// Logical 'or' boolean values with Kleene logic +/// +/// # Behavior +/// +/// This function behaves as follows with nulls: +/// +/// * `true` or `null` = `true` +/// * `null` or `true` = `true` +/// * `false` or `null` = `null` +/// * `null` or `false` = `null` +/// * `null` or `null` = `null` +/// +/// In other words, in this context a null value really means \"unknown\", +/// and an unknown value 'or' true is always true. +/// For a different null behavior, see function \"or\". +/// +/// # Example +/// +/// ```rust +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::or_kleene; +/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); +/// let b = BooleanArray::from(vec![None, None, None]); +/// let or_ab = or_kleene(&a, &b).unwrap(); +/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), None, None])); +/// ``` +/// +/// # Fails +/// +/// If the operands have different lengths +pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform bitwise operation on arrays of different length".to_string(), + )); + } + + let left_values = left.values(); + let right_values = right.values(); -/// Updates null buffer based on data buffer and null buffer of the operand at other side -/// in boolean OR kernel with Kleene logic. In short, because for OR kernel, null OR true -/// results true. So we cannot simply AND two null buffers. This function updates null -/// buffer of one side if other side is a true value. -pub(crate) fn build_null_buffer_for_or_kleene( - left_data: &ArrayData, - left_offset: usize, - right_data: &ArrayData, - right_offset: usize, - len_in_bits: usize, -) -> Option { - let left_buffer = &left_data.buffers()[0]; - let right_buffer = &right_data.buffers()[0]; - - let left_null_buffer = left_data.null_buffer(); - let right_null_buffer = right_data.null_buffer(); - - match (left_null_buffer, right_null_buffer) { + let buffer = match (left.nulls(), right.nulls()) { (None, None) => None, - (Some(left_null_buffer), None) => { + (Some(left_nulls), None) => { // The right side has no null values. // The final null bit is set only if: // 1. left null bit is set, or // 2. right data bit is true (because null OR true = true). Some(bitwise_bin_op_helper( - left_null_buffer, - left_offset, - right_buffer, - right_offset, - len_in_bits, + left_nulls.buffer(), + left_nulls.offset(), + right_values.inner(), + right_values.offset(), + left.len(), |a, b| a | b, )) } - (None, Some(right_null_buffer)) => { + (None, Some(right_nulls)) => { // Same as above Some(bitwise_bin_op_helper( - right_null_buffer, - right_offset, - left_buffer, - left_offset, - len_in_bits, + right_nulls.buffer(), + right_nulls.offset(), + left_values.inner(), + left_values.offset(), + left.len(), |a, b| a | b, )) } - (Some(left_null_buffer), Some(right_null_buffer)) => { + (Some(left_nulls), Some(right_nulls)) => { // Follow the same logic above. Both sides have null values. // Assume a is left null bits, b is left data bits, c is right null bits, // d is right data bits. // The final null bits are: // (a | (c & d)) & (c | (a & b)) Some(bitwise_quaternary_op_helper( - left_null_buffer, - left_offset, - left_buffer, - left_offset, - right_null_buffer, - right_offset, - right_buffer, - right_offset, - len_in_bits, + [ + left_nulls.buffer(), + left_values.inner(), + right_nulls.buffer(), + right_values.inner(), + ], + [ + left_nulls.offset(), + left_values.offset(), + right_nulls.offset(), + right_values.offset(), + ], + left.len(), |a, b, c, d| (a | (c & d)) & (c | (a & b)), )) } - } + }; + + let nulls = buffer.map(|b| NullBuffer::new(BooleanBuffer::new(b, 0, left.len()))); + Ok(BooleanArray::new(left_values | right_values, nulls)) } /// Helper function to implement binary kernels -pub(crate) fn binary_boolean_kernel( +pub(crate) fn binary_boolean_kernel( left: &BooleanArray, right: &BooleanArray, op: F, - null_op: U, -) -> Result +) -> Result where - F: Fn(&Buffer, usize, &Buffer, usize, usize) -> Buffer, - U: Fn(&ArrayData, usize, &ArrayData, usize, usize) -> Option, + F: Fn(&BooleanBuffer, &BooleanBuffer) -> BooleanBuffer, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -195,32 +233,9 @@ where )); } - let len = left.len(); - - let left_data = left.data_ref(); - let right_data = right.data_ref(); - - let left_buffer = &left_data.buffers()[0]; - let right_buffer = &right_data.buffers()[0]; - let left_offset = left.offset(); - let right_offset = right.offset(); - - let null_bit_buffer = null_op(left_data, left_offset, right_data, right_offset, len); - - let values = op(left_buffer, left_offset, right_buffer, right_offset, len); - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - null_bit_buffer, - 0, - vec![values], - vec![], - ) - }; - Ok(BooleanArray::from(data)) + let nulls = NullBuffer::union(left.nulls(), right.nulls()); + let values = op(left.values(), right.values()); + Ok(BooleanArray::new(values, nulls)) } /// Performs `AND` operation on two arrays. If either left or right value is null then the @@ -229,62 +244,15 @@ where /// This function errors when the arrays have different lengths. /// # Example /// ```rust -/// use arrow::array::BooleanArray; -/// use arrow::error::Result; -/// use arrow::compute::kernels::boolean::and; -/// # fn main() -> Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::and; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); /// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); -/// let and_ab = and(&a, &b)?; +/// let and_ab = and(&a, &b).unwrap(); /// assert_eq!(and_ab, BooleanArray::from(vec![Some(false), Some(true), None])); -/// # Ok(()) -/// # } /// ``` -pub fn and(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel(left, right, buffer_bin_and, build_null_buffer_for_and_or) -} - -/// Logical 'and' boolean values with Kleene logic -/// -/// # Behavior -/// -/// This function behaves as follows with nulls: -/// -/// * `true` and `null` = `null` -/// * `null` and `true` = `null` -/// * `false` and `null` = `false` -/// * `null` and `false` = `false` -/// * `null` and `null` = `null` -/// -/// In other words, in this context a null value really means \"unknown\", -/// and an unknown value 'and' false is always false. -/// For a different null behavior, see function \"and\". -/// -/// # Example -/// -/// ```rust -/// use arrow::array::BooleanArray; -/// use arrow::error::Result; -/// use arrow::compute::kernels::boolean::and_kleene; -/// # fn main() -> Result<()> { -/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); -/// let b = BooleanArray::from(vec![None, None, None]); -/// let and_ab = and_kleene(&a, &b)?; -/// assert_eq!(and_ab, BooleanArray::from(vec![None, Some(false), None])); -/// # Ok(()) -/// # } -/// ``` -/// -/// # Fails -/// -/// If the operands have different lengths -pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel( - left, - right, - buffer_bin_and, - build_null_buffer_for_and_kleene, - ) +pub fn and(left: &BooleanArray, right: &BooleanArray) -> Result { + binary_boolean_kernel(left, right, |a, b| a & b) } /// Performs `OR` operation on two arrays. If either left or right value is null then the @@ -293,57 +261,15 @@ pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::or; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); /// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); -/// let or_ab = or(&a, &b)?; +/// let or_ab = or(&a, &b).unwrap(); /// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None])); -/// # Ok(()) -/// # } /// ``` -pub fn or(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel(left, right, buffer_bin_or, build_null_buffer_for_and_or) -} - -/// Logical 'or' boolean values with Kleene logic -/// -/// # Behavior -/// -/// This function behaves as follows with nulls: -/// -/// * `true` or `null` = `true` -/// * `null` or `true` = `true` -/// * `false` or `null` = `null` -/// * `null` or `false` = `null` -/// * `null` or `null` = `null` -/// -/// In other words, in this context a null value really means \"unknown\", -/// and an unknown value 'or' true is always true. -/// For a different null behavior, see function \"or\". -/// -/// # Example -/// -/// ```rust -/// use arrow::array::BooleanArray; -/// use arrow::error::Result; -/// use arrow::compute::kernels::boolean::or_kleene; -/// # fn main() -> Result<()> { -/// let a = BooleanArray::from(vec![Some(true), Some(false), None]); -/// let b = BooleanArray::from(vec![None, None, None]); -/// let or_ab = or_kleene(&a, &b)?; -/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), None, None])); -/// # Ok(()) -/// # } -/// ``` -/// -/// # Fails -/// -/// If the operands have different lengths -pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel(left, right, buffer_bin_or, build_null_buffer_for_or_kleene) +pub fn or(left: &BooleanArray, right: &BooleanArray) -> Result { + binary_boolean_kernel(left, right, |a, b| a | b) } /// Performs unary `NOT` operation on an arrays. If value is null then the result is also @@ -352,40 +278,16 @@ pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::not; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); -/// let not_a = not(&a)?; +/// let not_a = not(&a).unwrap(); /// assert_eq!(not_a, BooleanArray::from(vec![Some(true), Some(false), None])); -/// # Ok(()) -/// # } /// ``` -pub fn not(left: &BooleanArray) -> Result { - let left_offset = left.offset(); - let len = left.len(); - - let data = left.data_ref(); - let null_bit_buffer = data - .null_bitmap() - .as_ref() - .map(|b| b.bits.bit_slice(left_offset, len)); - - let values = buffer_unary_not(&data.buffers()[0], left_offset, len); - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - null_bit_buffer, - 0, - vec![values], - vec![], - ) - }; - Ok(BooleanArray::from(data)) +pub fn not(left: &BooleanArray) -> Result { + let nulls = left.nulls().cloned(); + let values = !left.values(); + Ok(BooleanArray::new(values, nulls)) } /// Returns a non-null [BooleanArray] with whether each value of the array is null. @@ -393,40 +295,19 @@ pub fn not(left: &BooleanArray) -> Result { /// This function never errors. /// # Example /// ```rust -/// # use arrow::error::Result; -/// use arrow::array::BooleanArray; -/// use arrow::compute::kernels::boolean::is_null; -/// # fn main() -> Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::is_null; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); -/// let a_is_null = is_null(&a)?; +/// let a_is_null = is_null(&a).unwrap(); /// assert_eq!(a_is_null, BooleanArray::from(vec![false, false, true])); -/// # Ok(()) -/// # } /// ``` -pub fn is_null(input: &dyn Array) -> Result { - let len = input.len(); - - let output = match input.data_ref().null_buffer() { - None => { - let len_bytes = ceil(len, 8); - MutableBuffer::from_len_zeroed(len_bytes).into() - } - Some(buffer) => buffer_unary_not(buffer, input.offset(), len), - }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - None, - 0, - vec![output], - vec![], - ) +pub fn is_null(input: &dyn Array) -> Result { + let values = match input.logical_nulls() { + None => BooleanBuffer::new_unset(input.len()), + Some(nulls) => !nulls.inner(), }; - Ok(BooleanArray::from(data)) + Ok(BooleanArray::new(values, None)) } /// Returns a non-null [BooleanArray] with whether each value of the array is not null. @@ -434,141 +315,23 @@ pub fn is_null(input: &dyn Array) -> Result { /// This function never errors. /// # Example /// ```rust -/// # use arrow::error::Result; -/// use arrow::array::BooleanArray; -/// use arrow::compute::kernels::boolean::is_not_null; -/// # fn main() -> Result<()> { +/// # use arrow_array::BooleanArray; +/// # use arrow_arith::boolean::is_not_null; /// let a = BooleanArray::from(vec![Some(false), Some(true), None]); -/// let a_is_not_null = is_not_null(&a)?; +/// let a_is_not_null = is_not_null(&a).unwrap(); /// assert_eq!(a_is_not_null, BooleanArray::from(vec![true, true, false])); -/// # Ok(()) -/// # } /// ``` -pub fn is_not_null(input: &dyn Array) -> Result { - let len = input.len(); - - let output = match input.data_ref().null_buffer() { - None => { - let len_bytes = ceil(len, 8); - MutableBuffer::new(len_bytes) - .with_bitset(len_bytes, true) - .into() - } - Some(buffer) => buffer.bit_slice(input.offset(), len), +pub fn is_not_null(input: &dyn Array) -> Result { + let values = match input.logical_nulls() { + None => BooleanBuffer::new_set(input.len()), + Some(n) => n.inner().clone(), }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - None, - 0, - vec![output], - vec![], - ) - }; - - Ok(BooleanArray::from(data)) -} - -/// Copies original array, setting null bit to true if a secondary comparison boolean array is set to true. -/// Typically used to implement NULLIF. -// NOTE: For now this only supports Primitive Arrays. Although the code could be made generic, the issue -// is that currently the bitmap operations result in a final bitmap which is aligned to bit 0, and thus -// the left array's data needs to be sliced to a new offset, and for non-primitive arrays shifting the -// data might be too complicated. In the future, to avoid shifting left array's data, we could instead -// shift the final bitbuffer to the right, prepending with 0's instead. -pub fn nullif( - left: &PrimitiveArray, - right: &BooleanArray, -) -> Result> -where - T: ArrowNumericType, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - let left_data = left.data(); - let right_data = right.data(); - - // If left has no bitmap, create a new one with all values set for nullity op later - // left=0 (null) right=null output bitmap=null - // left=0 right=1 output bitmap=null - // left=1 (set) right=null output bitmap=set (passthrough) - // left=1 right=1 & comp=true output bitmap=null - // left=1 right=1 & comp=false output bitmap=set - // - // Thus: result = left null bitmap & (!right_values | !right_bitmap) - // OR left null bitmap & !(right_values & right_bitmap) - // - // Do the right expression !(right_values & right_bitmap) first since there are two steps - // TRICK: convert BooleanArray buffer as a bitmap for faster operation - let right_combo_buffer = match right.data().null_bitmap() { - Some(right_bitmap) => { - // NOTE: right values and bitmaps are combined and stay at bit offset right.offset() - (right.values() & &right_bitmap.bits).ok().map(|b| b.not()) - } - None => Some(!right.values()), - }; - - // AND of original left null bitmap with right expression - // Here we take care of the possible offsets of the left and right arrays all at once. - let modified_null_buffer = match left_data.null_bitmap() { - Some(left_null_bitmap) => match right_combo_buffer { - Some(rcb) => Some(buffer_bin_and( - &left_null_bitmap.bits, - left_data.offset(), - &rcb, - right_data.offset(), - left_data.len(), - )), - None => Some( - left_null_bitmap - .bits - .bit_slice(left_data.offset(), left.len()), - ), - }, - None => right_combo_buffer - .map(|rcb| rcb.bit_slice(right_data.offset(), right_data.len())), - }; - - // Align/shift left data on offset as needed, since new bitmaps are shifted and aligned to 0 already - // NOTE: this probably only works for primitive arrays. - let data_buffers = if left.offset() == 0 { - left_data.buffers().to_vec() - } else { - // Shift each data buffer by type's bit_width * offset. - left_data - .buffers() - .iter() - .map(|buf| buf.slice(left.offset() * T::get_byte_width())) - .collect::>() - }; - - // Construct new array with same values but modified null bitmap - // TODO: shift data buffer as needed - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - left.len(), - None, // force new to compute the number of null bits - modified_null_buffer, - 0, // No need for offset since left data has been shifted - data_buffers, - left_data.child_data().to_vec(), - ) - }; - Ok(PrimitiveArray::::from(data)) + Ok(BooleanArray::new(values, None)) } #[cfg(test)] mod tests { use super::*; - use crate::array::{ArrayRef, Int32Array}; use std::sync::Arc; #[test] @@ -731,7 +494,7 @@ mod tests { let a = BooleanArray::from(vec![false, false, false, true, true, true]); // ensure null bitmap of a is absent - assert!(a.data_ref().null_bitmap().is_none()); + assert!(a.nulls().is_none()); let b = BooleanArray::from(vec![ Some(true), @@ -743,7 +506,7 @@ mod tests { ]); // ensure null bitmap of b is present - assert!(b.data_ref().null_bitmap().is_some()); + assert!(b.nulls().is_some()); let c = or_kleene(&a, &b).unwrap(); @@ -771,12 +534,12 @@ mod tests { ]); // ensure null bitmap of b is absent - assert!(a.data_ref().null_bitmap().is_some()); + assert!(a.nulls().is_some()); let b = BooleanArray::from(vec![false, false, false, true, true, true]); // ensure null bitmap of a is present - assert!(b.data_ref().null_bitmap().is_none()); + assert!(b.nulls().is_none()); let c = or_kleene(&a, &b).unwrap(); @@ -809,8 +572,7 @@ mod tests { let a = a.as_any().downcast_ref::().unwrap(); let c = not(a).unwrap(); - let expected = - BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); + let expected = BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); assert_eq!(c, expected); } @@ -859,12 +621,10 @@ mod tests { #[test] fn test_bool_array_and_sliced_same_offset() { let a = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, false, true, - true, + false, false, false, false, false, false, false, false, false, false, true, true, ]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let a = a.slice(8, 4); @@ -882,12 +642,10 @@ mod tests { #[test] fn test_bool_array_and_sliced_same_offset_mod8() { let a = BooleanArray::from(vec![ - false, false, true, true, false, false, false, false, false, false, false, - false, + false, false, true, true, false, false, false, false, false, false, false, false, ]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let a = a.slice(0, 4); @@ -905,8 +663,7 @@ mod tests { #[test] fn test_bool_array_and_sliced_offset1() { let a = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, false, true, - true, + false, false, false, false, false, false, false, false, false, false, true, true, ]); let b = BooleanArray::from(vec![false, true, false, true]); @@ -924,8 +681,7 @@ mod tests { fn test_bool_array_and_sliced_offset2() { let a = BooleanArray::from(vec![false, false, true, true]); let b = BooleanArray::from(vec![ - false, false, false, false, false, false, false, false, false, true, false, - true, + false, false, false, false, false, false, false, false, false, true, false, true, ]); let b = b.slice(8, 4); @@ -958,8 +714,7 @@ mod tests { let c = and(a, b).unwrap(); - let expected = - BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); + let expected = BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); assert_eq!(expected, c); } @@ -973,7 +728,7 @@ mod tests { let expected = BooleanArray::from(vec![false, false, false, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -981,12 +736,12 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); let a = a.slice(8, 4); - let res = is_null(a.as_ref()).unwrap(); + let res = is_null(&a).unwrap(); let expected = BooleanArray::from(vec![false, false, false, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -998,7 +753,7 @@ mod tests { let expected = BooleanArray::from(vec![true, true, true, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1006,12 +761,12 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); let a = a.slice(8, 4); - let res = is_not_null(a.as_ref()).unwrap(); + let res = is_not_null(&a).unwrap(); let expected = BooleanArray::from(vec![true, true, true, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1023,7 +778,7 @@ mod tests { let expected = BooleanArray::from(vec![false, true, false, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1049,12 +804,12 @@ mod tests { ]); let a = a.slice(8, 4); - let res = is_null(a.as_ref()).unwrap(); + let res = is_null(&a).unwrap(); let expected = BooleanArray::from(vec![false, true, false, true]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1066,7 +821,7 @@ mod tests { let expected = BooleanArray::from(vec![true, false, true, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] @@ -1092,57 +847,35 @@ mod tests { ]); let a = a.slice(8, 4); - let res = is_not_null(a.as_ref()).unwrap(); + let res = is_not_null(&a).unwrap(); let expected = BooleanArray::from(vec![true, false, true, false]); assert_eq!(expected, res); - assert_eq!(None, res.data_ref().null_bitmap()); + assert!(res.nulls().is_none()); } #[test] - fn test_nullif_int_array() { - let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9)]); - let comp = - BooleanArray::from(vec![Some(false), None, Some(true), Some(false), None]); - let res = nullif(&a, &comp).unwrap(); + fn test_null_array_is_null() { + let a = NullArray::new(3); - let expected = Int32Array::from(vec![ - Some(15), - None, - None, // comp true, slot 2 turned into null - Some(1), - // Even though comp array / right is null, should still pass through original value - // comp true, slot 2 turned into null - Some(9), - ]); + let res = is_null(&a).unwrap(); + + let expected = BooleanArray::from(vec![true, true, true]); assert_eq!(expected, res); + assert!(res.nulls().is_none()); } #[test] - fn test_nullif_int_array_offset() { - let a = Int32Array::from(vec![None, Some(15), Some(8), Some(1), Some(9)]); - let a = a.slice(1, 3); // Some(15), Some(8), Some(1) - let a = a.as_any().downcast_ref::().unwrap(); - let comp = BooleanArray::from(vec![ - Some(false), - Some(false), - Some(false), - None, - Some(true), - Some(false), - None, - ]); - let comp = comp.slice(2, 3); // Some(false), None, Some(true) - let comp = comp.as_any().downcast_ref::().unwrap(); - let res = nullif(a, comp).unwrap(); - - let expected = Int32Array::from(vec![ - Some(15), // False => keep it - Some(8), // None => keep it - None, // true => None - ]); - assert_eq!(&expected, &res) + fn test_null_array_is_not_null() { + let a = NullArray::new(3); + + let res = is_not_null(&a).unwrap(); + + let expected = BooleanArray::from(vec![false, false, false]); + + assert_eq!(expected, res); + assert!(res.nulls().is_none()); } } diff --git a/arrow/src/ipc/compression/mod.rs b/arrow-arith/src/lib.rs similarity index 75% rename from arrow/src/ipc/compression/mod.rs rename to arrow-arith/src/lib.rs index 666fa6d86a27..2d5451e04dd2 100644 --- a/arrow/src/ipc/compression/mod.rs +++ b/arrow-arith/src/lib.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -#[cfg(feature = "ipc_compression")] -mod codec; -#[cfg(feature = "ipc_compression")] -pub(crate) use codec::CompressionCodec; +//! Arrow arithmetic and aggregation kernels -#[cfg(not(feature = "ipc_compression"))] -mod stub; -#[cfg(not(feature = "ipc_compression"))] -pub(crate) use stub::CompressionCodec; +pub mod aggregate; +#[doc(hidden)] // Kernels to be removed in a future release +pub mod arithmetic; +pub mod arity; +pub mod bitwise; +pub mod boolean; +pub mod numeric; +pub mod temporal; diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs new file mode 100644 index 000000000000..b2c87bba5143 --- /dev/null +++ b/arrow-arith/src/numeric.rs @@ -0,0 +1,1520 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines numeric arithmetic kernels on [`PrimitiveArray`], such as [`add`] + +use std::cmp::Ordering; +use std::fmt::Formatter; +use std::sync::Arc; + +use arrow_array::cast::AsArray; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; + +use crate::arity::{binary, try_binary}; + +/// Perform `lhs + rhs`, returning an error on overflow +pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Add, lhs, rhs) +} + +/// Perform `lhs + rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::AddWrapping, lhs, rhs) +} + +/// Perform `lhs - rhs`, returning an error on overflow +pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Sub, lhs, rhs) +} + +/// Perform `lhs - rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::SubWrapping, lhs, rhs) +} + +/// Perform `lhs * rhs`, returning an error on overflow +pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Mul, lhs, rhs) +} + +/// Perform `lhs * rhs`, wrapping on overflow for [`DataType::is_integer`] +pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::MulWrapping, lhs, rhs) +} + +/// Perform `lhs / rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Div, lhs, rhs) +} + +/// Perform `lhs % rhs` +/// +/// Overflow or division by zero will result in an error, with exception to +/// floating point numbers, which instead follow the IEEE 754 rules +pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + arithmetic_op(Op::Rem, lhs, rhs) +} + +macro_rules! neg_checked { + ($t:ty, $a:ident) => {{ + let array = $a + .as_primitive::<$t>() + .try_unary::<_, $t, _>(|x| x.neg_checked())?; + Ok(Arc::new(array)) + }}; +} + +macro_rules! neg_wrapping { + ($t:ty, $a:ident) => {{ + let array = $a.as_primitive::<$t>().unary::<_, $t>(|x| x.neg_wrapping()); + Ok(Arc::new(array)) + }}; +} + +/// Negates each element of `array`, returning an error on overflow +/// +/// Note: negation of unsigned arrays is not supported and will return in an error, +/// for wrapping unsigned negation consider using [`neg_wrapping`][neg_wrapping()] +pub fn neg(array: &dyn Array) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + + match array.data_type() { + Int8 => neg_checked!(Int8Type, array), + Int16 => neg_checked!(Int16Type, array), + Int32 => neg_checked!(Int32Type, array), + Int64 => neg_checked!(Int64Type, array), + Float16 => neg_wrapping!(Float16Type, array), + Float32 => neg_wrapping!(Float32Type, array), + Float64 => neg_wrapping!(Float64Type, array), + Decimal128(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } + Decimal256(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } + Duration(Second) => neg_checked!(DurationSecondType, array), + Duration(Millisecond) => neg_checked!(DurationMillisecondType, array), + Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array), + Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array), + Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array), + Interval(DayTime) => { + let a = array + .as_primitive::() + .try_unary::<_, IntervalDayTimeType, ArrowError>(|x| { + let (days, ms) = IntervalDayTimeType::to_parts(x); + Ok(IntervalDayTimeType::make_value( + days.neg_checked()?, + ms.neg_checked()?, + )) + })?; + Ok(Arc::new(a)) + } + Interval(MonthDayNano) => { + let a = array + .as_primitive::() + .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x); + Ok(IntervalMonthDayNanoType::make_value( + months.neg_checked()?, + days.neg_checked()?, + nanos.neg_checked()?, + )) + })?; + Ok(Arc::new(a)) + } + t => Err(ArrowError::InvalidArgumentError(format!( + "Invalid arithmetic operation: !{t}" + ))), + } +} + +/// Negates each element of `array`, wrapping on overflow for [`DataType::is_integer`] +pub fn neg_wrapping(array: &dyn Array) -> Result { + downcast_integer! { + array.data_type() => (neg_wrapping, array), + _ => neg(array), + } +} + +/// An enumeration of arithmetic operations +/// +/// This allows sharing the type dispatch logic across the various kernels +#[derive(Debug, Copy, Clone)] +enum Op { + AddWrapping, + Add, + SubWrapping, + Sub, + MulWrapping, + Mul, + Div, + Rem, +} + +impl std::fmt::Display for Op { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Op::AddWrapping | Op::Add => write!(f, "+"), + Op::SubWrapping | Op::Sub => write!(f, "-"), + Op::MulWrapping | Op::Mul => write!(f, "*"), + Op::Div => write!(f, "/"), + Op::Rem => write!(f, "%"), + } + } +} + +impl Op { + fn commutative(&self) -> bool { + matches!(self, Self::Add | Self::AddWrapping) + } +} + +/// Dispatch the given `op` to the appropriate specialized kernel +fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + + macro_rules! integer_helper { + ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident, $r_scalar:ident) => { + integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar) + }; + } + + let (l, l_scalar) = lhs.get(); + let (r, r_scalar) = rhs.get(); + downcast_integer! { + l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r, r_scalar), + (Float16, Float16) => float_op::(op, l, l_scalar, r, r_scalar), + (Float32, Float32) => float_op::(op, l, l_scalar, r, r_scalar), + (Float64, Float64) => float_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Second, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Millisecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Microsecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Timestamp(Nanosecond, _), _) => timestamp_op::(op, l, l_scalar, r, r_scalar), + (Duration(Second), Duration(Second)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Millisecond), Duration(Millisecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Microsecond), Duration(Microsecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Duration(Nanosecond), Duration(Nanosecond)) => duration_op::(op, l, l_scalar, r, r_scalar), + (Interval(YearMonth), Interval(YearMonth)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(DayTime), Interval(DayTime)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Decimal128(_, _), Decimal128(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (Decimal256(_, _), Decimal256(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (l_t, r_t) => match (l_t, r_t) { + (Duration(_) | Interval(_), Date32 | Date64 | Timestamp(_, _)) if op.commutative() => { + arithmetic_op(op, rhs, lhs) + } + _ => Err(ArrowError::InvalidArgumentError( + format!("Invalid arithmetic operation: {l_t} {op} {r_t}") + )) + } + } +} + +/// Perform an infallible binary operation on potentially scalar inputs +macro_rules! op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.unary(|$r| $op), + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.unary(|$l| $op), + }, + } + }; +} + +/// Same as `op` but with a type hint for the returned array +macro_rules! op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform a fallible binary operation on potentially scalar inputs +macro_rules! try_op { + ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => { + match ($l_s, $r_s) { + (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?, + (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) { + None => PrimitiveArray::new_null($r.len()), + Some($l) => $r.try_unary(|$r| $op)?, + }, + (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) { + None => PrimitiveArray::new_null($l.len()), + Some($r) => $l.try_unary(|$l| $op)?, + }, + } + }; +} + +/// Same as `try_op` but with a type hint for the returned array +macro_rules! try_op_ref { + ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{ + let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op); + Arc::new(array) + }}; +} + +/// Perform an arithmetic operation on integers +fn integer_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)), + Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)), + Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)), + Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)), + Op::Rem => try_op!(l, l_s, r, r_s, l.mod_checked(r)), + }; + Ok(Arc::new(array)) +} + +/// Perform an arithmetic operation on floats +fn float_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let array: PrimitiveArray = match op { + Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)), + Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)), + Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)), + Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)), + Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)), + }; + Ok(Arc::new(array)) +} + +/// Arithmetic trait for timestamp arrays +trait TimestampOp: ArrowTimestampType { + type Duration: ArrowPrimitiveType; + + fn add_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option; + fn add_day_time(timestamp: i64, delta: i64, tz: Tz) -> Option; + fn add_month_day_nano(timestamp: i64, delta: i128, tz: Tz) -> Option; + + fn sub_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option; + fn sub_day_time(timestamp: i64, delta: i64, tz: Tz) -> Option; + fn sub_month_day_nano(timestamp: i64, delta: i128, tz: Tz) -> Option; +} + +macro_rules! timestamp { + ($t:ty, $d:ty) => { + impl TimestampOp for $t { + type Duration = $d; + + fn add_year_month(left: i64, right: i32, tz: Tz) -> Option { + Self::add_year_months(left, right, tz) + } + + fn add_day_time(left: i64, right: i64, tz: Tz) -> Option { + Self::add_day_time(left, right, tz) + } + + fn add_month_day_nano(left: i64, right: i128, tz: Tz) -> Option { + Self::add_month_day_nano(left, right, tz) + } + + fn sub_year_month(left: i64, right: i32, tz: Tz) -> Option { + Self::subtract_year_months(left, right, tz) + } + + fn sub_day_time(left: i64, right: i64, tz: Tz) -> Option { + Self::subtract_day_time(left, right, tz) + } + + fn sub_month_day_nano(left: i64, right: i128, tz: Tz) -> Option { + Self::subtract_month_day_nano(left, right, tz) + } + } + }; +} +timestamp!(TimestampSecondType, DurationSecondType); +timestamp!(TimestampMillisecondType, DurationMillisecondType); +timestamp!(TimestampMicrosecondType, DurationMicrosecondType); +timestamp!(TimestampNanosecondType, DurationNanosecondType); + +/// Perform arithmetic operation on a timestamp array +fn timestamp_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + let l = l.as_primitive::(); + let l_tz: Tz = l.timezone().unwrap_or("+00:00").parse()?; + + let array: PrimitiveArray = match (op, r.data_type()) { + (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + return Ok(try_op_ref!(T::Duration, l, l_s, r, r_s, l.sub_checked(r))); + } + + (Op::Add | Op::AddWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.add_checked(r)) + } + (Op::Sub | Op::SubWrapping, Duration(unit)) if unit == &T::UNIT => { + let r = r.as_primitive::(); + try_op!(l, l_s, r, r_s, l.sub_checked(r)) + } + + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::add_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + try_op!( + l, + l_s, + r, + r_s, + T::sub_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError( + "Timestamp out of range".to_string() + )) + ) + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid timestamp arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))) + } + }; + Ok(Arc::new(array.with_timezone_opt(l.timezone()))) +} + +/// Arithmetic trait for date arrays +/// +/// Note: these should be fallible (#4456) +trait DateOp: ArrowTemporalType { + fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn add_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn add_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; + + fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; + fn sub_day_time(timestamp: Self::Native, delta: i64) -> Self::Native; + fn sub_month_day_nano(timestamp: Self::Native, delta: i128) -> Self::Native; +} + +macro_rules! date { + ($t:ty) => { + impl DateOp for $t { + fn add_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::add_year_months(left, right) + } + + fn add_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::add_day_time(left, right) + } + + fn add_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::add_month_day_nano(left, right) + } + + fn sub_year_month(left: Self::Native, right: i32) -> Self::Native { + Self::subtract_year_months(left, right) + } + + fn sub_day_time(left: Self::Native, right: i64) -> Self::Native { + Self::subtract_day_time(left, right) + } + + fn sub_month_day_nano(left: Self::Native, right: i128) -> Self::Native { + Self::subtract_month_day_nano(left, right) + } + } + }; +} +date!(Date32Type); +date!(Date64Type); + +/// Arithmetic trait for interval arrays +trait IntervalOp: ArrowPrimitiveType { + fn add(left: Self::Native, right: Self::Native) -> Result; + fn sub(left: Self::Native, right: Self::Native) -> Result; +} + +impl IntervalOp for IntervalYearMonthType { + fn add(left: Self::Native, right: Self::Native) -> Result { + left.add_checked(right) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + left.sub_checked(right) + } +} + +impl IntervalOp for IntervalDayTimeType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.add_checked(r_days)?; + let ms = l_ms.add_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_days, l_ms) = Self::to_parts(left); + let (r_days, r_ms) = Self::to_parts(right); + let days = l_days.sub_checked(r_days)?; + let ms = l_ms.sub_checked(r_ms)?; + Ok(Self::make_value(days, ms)) + } +} + +impl IntervalOp for IntervalMonthDayNanoType { + fn add(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.add_checked(r_months)?; + let days = l_days.add_checked(r_days)?; + let nanos = l_nanos.add_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } + + fn sub(left: Self::Native, right: Self::Native) -> Result { + let (l_months, l_days, l_nanos) = Self::to_parts(left); + let (r_months, r_days, r_nanos) = Self::to_parts(right); + let months = l_months.sub_checked(r_months)?; + let days = l_days.sub_checked(r_days)?; + let nanos = l_nanos.sub_checked(r_nanos)?; + Ok(Self::make_value(months, days, nanos)) + } +} + +/// Perform arithmetic operation on an interval array +fn interval_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::add(l, r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub(l, r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +fn duration_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + match op { + Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.add_checked(r))), + Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.sub_checked(r))), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid duration arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on a date array +fn date_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + use DataType::*; + use IntervalUnit::*; + + const NUM_SECONDS_IN_DAY: i64 = 60 * 60 * 24; + + let r_t = r.data_type(); + match (T::DATA_TYPE, op, r_t) { + (Date32, Op::Sub | Op::SubWrapping, Date32) => { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + return Ok(op_ref!( + DurationSecondType, + l, + l_s, + r, + r_s, + ((l as i64) - (r as i64)) * NUM_SECONDS_IN_DAY + )); + } + (Date64, Op::Sub | Op::SubWrapping, Date64) => { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + let result = try_op_ref!(DurationMillisecondType, l, l_s, r, r_s, l.sub_checked(r)); + return Ok(result); + } + _ => {} + } + + let l = l.as_primitive::(); + match (op, r_t) { + (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) + } + + (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) + } + (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { + let r = r.as_primitive::(); + Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) + } + + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid date arithmetic operation: {} {op} {}", + l.data_type(), + r.data_type() + ))), + } +} + +/// Perform arithmetic operation on decimal arrays +fn decimal_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + let l = l.as_primitive::(); + let r = r.as_primitive::(); + + let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) { + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1, s1, p2, s2), + (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1, s1, p2, s2), + _ => unreachable!(), + }; + + // Follow the Hive decimal arithmetic rules + // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + let array: PrimitiveArray = match op { + Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => { + // max(s1, s2) + let result_scale = *s1.max(s2); + + // max(s1, s2) + max(p1-s1, p2-s2) + 1 + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).max(*p2 as i8 - s2)) as u8) + .saturating_add(1) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_checked((result_scale - s1) as _)?; + let r_mul = T::Native::usize_as(10).pow_checked((result_scale - s2) as _)?; + + match op { + Op::Add | Op::AddWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.add_checked(r.mul_checked(r_mul)?) + ) + } + Op::Sub | Op::SubWrapping => { + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.sub_checked(r.mul_checked(r_mul)?) + ) + } + _ => unreachable!(), + } + .with_precision_and_scale(result_precision, result_scale)? + } + Op::Mul | Op::MulWrapping => { + let result_precision = p1.saturating_add(p2 + 1).min(T::MAX_PRECISION); + let result_scale = s1.saturating_add(*s2); + if result_scale > T::MAX_SCALE { + // SQL standard says that if the resulting scale of a multiply operation goes + // beyond the maximum, rounding is not acceptable and thus an error occurs + return Err(ArrowError::InvalidArgumentError(format!( + "Output scale of {} {op} {} would exceed max scale of {}", + l.data_type(), + r.data_type(), + T::MAX_SCALE + ))); + } + + try_op!(l, l_s, r, r_s, l.mul_checked(r)) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Div => { + // Follow postgres and MySQL adding a fixed scale increment of 4 + // s1 + 4 + let result_scale = s1.saturating_add(4).min(T::MAX_SCALE); + let mul_pow = result_scale - s1 + s2; + + // p1 - s1 + s2 + result_scale + let result_precision = (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION); + + let (l_mul, r_mul) = match mul_pow.cmp(&0) { + Ordering::Greater => ( + T::Native::usize_as(10).pow_checked(mul_pow as _)?, + T::Native::ONE, + ), + Ordering::Equal => (T::Native::ONE, T::Native::ONE), + Ordering::Less => ( + T::Native::ONE, + T::Native::usize_as(10).pow_checked(mul_pow.neg_wrapping() as _)?, + ), + }; + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.div_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + + Op::Rem => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // min(p1-s1, p2 -s2) + max( s1,s2 ) + let result_precision = + (result_scale.saturating_add((*p1 as i8 - s1).min(*p2 as i8 - s2)) as u8) + .min(T::MAX_PRECISION); + + let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s1) as _); + let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s2) as _); + + try_op!( + l, + l_s, + r, + r_s, + l.mul_checked(l_mul)?.mod_checked(r.mul_checked(r_mul)?) + ) + .with_precision_and_scale(result_precision, result_scale)? + } + }; + + Ok(Arc::new(array)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::temporal_conversions::{as_date, as_datetime}; + use arrow_buffer::{i256, ScalarBuffer}; + use chrono::{DateTime, NaiveDate}; + + fn test_neg_primitive( + input: &[T::Native], + out: Result<&[T::Native], &str>, + ) { + let a = PrimitiveArray::::new(ScalarBuffer::from(input.to_vec()), None); + match out { + Ok(expected) => { + let result = neg(&a).unwrap(); + assert_eq!(result.as_primitive::().values(), expected); + } + Err(e) => { + let err = neg(&a).unwrap_err().to_string(); + assert_eq!(e, err); + } + } + } + + #[test] + fn test_neg() { + let input = &[1, -5, 2, 693, 3929]; + let output = &[-1, 5, -2, -693, -3929]; + test_neg_primitive::(input, Ok(output)); + + let input = &[1, -5, 2, 693, 3929]; + let output = &[-1, 5, -2, -693, -3929]; + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + test_neg_primitive::(input, Ok(output)); + + let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5]; + let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5]; + test_neg_primitive::(input, Ok(output)); + + test_neg_primitive::( + &[i32::MIN], + Err("Compute error: Overflow happened on: -2147483648"), + ); + test_neg_primitive::( + &[i64::MIN], + Err("Compute error: Overflow happened on: -9223372036854775808"), + ); + test_neg_primitive::( + &[i64::MIN], + Err("Compute error: Overflow happened on: -9223372036854775808"), + ); + + let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap(); + assert_eq!(r.as_primitive::().value(0), i32::MIN); + + let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap(); + assert_eq!(r.as_primitive::().value(0), i64::MIN); + + let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN])) + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Compute error: Overflow happened on: -9223372036854775808" + ); + + let a = Decimal128Array::from(vec![1, 3, -44, 2, 4]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[-1, -3, 44, -2, -4] + ); + + let a = Decimal256Array::from(vec![ + i256::from_i128(342), + i256::from_i128(-4949), + i256::from_i128(3), + ]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[ + i256::from_i128(-342), + i256::from_i128(4949), + i256::from_i128(-3), + ] + ); + + let a = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(2, 4), + IntervalYearMonthType::make_value(2, -4), + IntervalYearMonthType::make_value(-3, -5), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalYearMonthType::make_value(-2, -4), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(3, 5), + ] + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(2, 4), + IntervalDayTimeType::make_value(2, -4), + IntervalDayTimeType::make_value(-3, -5), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalDayTimeType::make_value(-2, -4), + IntervalDayTimeType::make_value(-2, 4), + IntervalDayTimeType::make_value(3, 5), + ] + ); + + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(2, 4, 5953394), + IntervalMonthDayNanoType::make_value(2, -4, -45839), + IntervalMonthDayNanoType::make_value(-3, -5, 6944), + ]); + let r = neg(&a).unwrap(); + assert_eq!( + r.as_primitive::().values(), + &[ + IntervalMonthDayNanoType::make_value(-2, -4, -5953394), + IntervalMonthDayNanoType::make_value(-2, 4, 45839), + IntervalMonthDayNanoType::make_value(3, 5, -6944), + ] + ); + } + + #[test] + fn test_integer() { + let a = Int32Array::from(vec![4, 3, 5, -6, 100]); + let b = Int32Array::from(vec![6, 2, 5, -7, 3]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Int32Array::from(vec![10, 5, 10, -13, 103]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![-2, 1, 0, 1, 97])); + let result = div(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![0, 1, 1, 0, 33])); + let result = mul(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![24, 6, 25, 42, 300])); + let result = rem(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int32Array::from(vec![4, 1, 0, -6, 1])); + + let a = Int8Array::from(vec![Some(2), None, Some(45)]); + let b = Int8Array::from(vec![Some(5), Some(3), None]); + let result = add(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &Int8Array::from(vec![Some(7), None, None])); + + let a = UInt8Array::from(vec![56, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 56 + 200"); + let result = add_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![0, 7, 8])); + + let a = UInt8Array::from(vec![34, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = sub(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 34 - 200"); + let result = sub_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![90, 3, 254])); + + let a = UInt8Array::from(vec![34, 5, 3]); + let b = UInt8Array::from(vec![200, 2, 5]); + let err = mul(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 34 * 200"); + let result = mul_wrapping(&a, &b).unwrap(); + assert_eq!(result.as_ref(), &UInt8Array::from(vec![144, 10, 15])); + + let a = Int16Array::from(vec![i16::MIN]); + let b = Int16Array::from(vec![-1]); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: -32768 / -1"); + + let a = Int16Array::from(vec![21]); + let b = Int16Array::from(vec![0]); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + + let a = Int16Array::from(vec![21]); + let b = Int16Array::from(vec![0]); + let err = rem(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + } + + #[test] + fn test_float() { + let a = Float32Array::from(vec![1., f32::MAX, 6., -4., -1., 0.]); + let b = Float32Array::from(vec![1., f32::MAX, f32::MAX, -3., 45., 0.]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![2., f32::INFINITY, f32::MAX, -7., 44.0, 0.]) + ); + + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![0., 0., f32::MIN, -1., -46., 0.]) + ); + + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &Float32Array::from(vec![1., f32::INFINITY, f32::INFINITY, 12., -45., 0.]) + ); + + let result = div(&a, &b).unwrap(); + let r = result.as_primitive::(); + assert_eq!(r.value(0), 1.); + assert_eq!(r.value(1), 1.); + assert!(r.value(2) < f32::EPSILON); + assert_eq!(r.value(3), -4. / -3.); + assert!(r.value(5).is_nan()); + + let result = rem(&a, &b).unwrap(); + let r = result.as_primitive::(); + assert_eq!(&r.values()[..5], &[0., 0., 6., -1., -1.]); + assert!(r.value(5).is_nan()); + } + + #[test] + fn test_decimal() { + // 0.015 7.842 -0.577 0.334 -0.078 0.003 + let a = Decimal128Array::from(vec![15, 0, -577, 334, -78, 3]) + .with_precision_and_scale(12, 3) + .unwrap(); + + // 5.4 0 -35.6 0.3 0.6 7.45 + let b = Decimal128Array::from(vec![54, 34, -356, 3, 6, 745]) + .with_precision_and_scale(12, 1) + .unwrap(); + + let result = add(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(15, 3)); + assert_eq!( + result.as_primitive::().values(), + &[5415, 3400, -36177, 634, 522, 74503] + ); + + let result = sub(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(15, 3)); + assert_eq!( + result.as_primitive::().values(), + &[-5385, -3400, 35023, 34, -678, -74497] + ); + + let result = mul(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(25, 4)); + assert_eq!( + result.as_primitive::().values(), + &[810, 0, 205412, 1002, -468, 2235] + ); + + let result = div(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(17, 7)); + assert_eq!( + result.as_primitive::().values(), + &[27777, 0, 162078, 11133333, -1300000, 402] + ); + + let result = rem(&a, &b).unwrap(); + assert_eq!(result.data_type(), &DataType::Decimal128(12, 3)); + assert_eq!( + result.as_primitive::().values(), + &[15, 0, -577, 34, -78, 3] + ); + + let a = Decimal128Array::from(vec![1]) + .with_precision_and_scale(3, 3) + .unwrap(); + let b = Decimal128Array::from(vec![1]) + .with_precision_and_scale(37, 37) + .unwrap(); + let err = mul(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Invalid argument error: Output scale of Decimal128(3, 3) * Decimal128(37, 37) would exceed max scale of 38"); + + let a = Decimal128Array::from(vec![1]) + .with_precision_and_scale(3, -2) + .unwrap(); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Compute error: Overflow happened on: 10 ^ 39"); + + let a = Decimal128Array::from(vec![10]) + .with_precision_and_scale(3, -1) + .unwrap(); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Compute error: Overflow happened on: 10 * 100000000000000000000000000000000000000" + ); + + let b = Decimal128Array::from(vec![0]) + .with_precision_and_scale(1, 1) + .unwrap(); + let err = div(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + let err = rem(&a, &b).unwrap_err().to_string(); + assert_eq!(err, "Divide by zero error"); + } + + fn test_timestamp_impl() { + let a = PrimitiveArray::::new(vec![2000000, 434030324, 53943340].into(), None); + let b = PrimitiveArray::::new(vec![329593, 59349, 694994].into(), None); + + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[1670407, 433970975, 53248346] + ); + + let r2 = add(&b, &result.as_ref()).unwrap(); + assert_eq!(r2.as_ref(), &a); + + let r3 = add(&result.as_ref(), &b).unwrap(); + assert_eq!(r3.as_ref(), &a); + + let format_array = |x: &dyn Array| -> Vec { + x.as_primitive::() + .values() + .into_iter() + .map(|x| as_datetime::(*x).unwrap().to_string()) + .collect() + }; + + let values = vec![ + "1970-01-01T00:00:00Z", + "2010-04-01T04:00:20Z", + "1960-01-30T04:23:20Z", + ] + .into_iter() + .map(|x| T::make_value(DateTime::parse_from_rfc3339(x).unwrap().naive_utc()).unwrap()) + .collect(); + + let a = PrimitiveArray::::new(values, None); + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(5, 34), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(7, -4), + ]); + let r4 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r4.as_ref()), + &[ + "1977-11-01 00:00:00".to_string(), + "2008-08-01 04:00:20".to_string(), + "1966-09-30 04:23:20".to_string() + ] + ); + + let r5 = sub(&r4, &b).unwrap(); + assert_eq!(r5.as_ref(), &a); + + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(5, 454000), + IntervalDayTimeType::make_value(-34, 0), + IntervalDayTimeType::make_value(7, -4000), + ]); + let r6 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r6.as_ref()), + &[ + "1970-01-06 00:07:34".to_string(), + "2010-02-26 04:00:20".to_string(), + "1960-02-06 04:23:16".to_string() + ] + ); + + let r7 = sub(&r6, &b).unwrap(); + assert_eq!(r7.as_ref(), &a); + + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000), + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000), + IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000), + ]); + let r8 = add(&a, &b).unwrap(); + assert_eq!( + &format_array(r8.as_ref()), + &[ + "1998-10-04 23:59:17".to_string(), + "1960-09-29 04:00:33".to_string(), + "1960-07-02 04:31:33".to_string() + ] + ); + + let r9 = sub(&r8, &b).unwrap(); + // Note: subtraction is not the inverse of addition for intervals + assert_eq!( + &format_array(r9.as_ref()), + &[ + "1970-01-02 00:00:00".to_string(), + "2010-04-02 04:00:20".to_string(), + "1960-01-31 04:23:20".to_string() + ] + ); + } + + #[test] + fn test_timestamp() { + test_timestamp_impl::(); + test_timestamp_impl::(); + test_timestamp_impl::(); + test_timestamp_impl::(); + } + + #[test] + fn test_interval() { + let a = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(32, 4), + IntervalYearMonthType::make_value(32, 4), + ]); + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(-4, 6), + IntervalYearMonthType::make_value(-3, 23), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(28, 10), + IntervalYearMonthType::make_value(29, 27) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(36, -2), + IntervalYearMonthType::make_value(35, -19) + ]) + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(32, 4), + IntervalDayTimeType::make_value(32, 4), + ]); + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(-4, 6), + IntervalDayTimeType::make_value(-3, 23), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(28, 10), + IntervalDayTimeType::make_value(29, 27) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(36, -2), + IntervalDayTimeType::make_value(35, -19) + ]) + ); + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(32, 4, 4000000000000), + IntervalMonthDayNanoType::make_value(32, 4, 45463000000000000), + ]); + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(-4, 6, 46000000000000), + IntervalMonthDayNanoType::make_value(-3, 23, 3564000000000000), + ]); + let result = add(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(28, 10, 50000000000000), + IntervalMonthDayNanoType::make_value(29, 27, 49027000000000000) + ]) + ); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(36, -2, -42000000000000), + IntervalMonthDayNanoType::make_value(35, -19, 41899000000000000) + ]) + ); + let a = IntervalMonthDayNanoArray::from(vec![i64::MAX as i128]); + let b = IntervalMonthDayNanoArray::from(vec![1]); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Compute error: Overflow happened on: 9223372036854775807 + 1" + ); + } + + fn test_duration_impl>() { + let a = PrimitiveArray::::new(vec![1000, 4394, -3944].into(), None); + let b = PrimitiveArray::::new(vec![4, -5, -243].into(), None); + + let result = add(&a, &b).unwrap(); + assert_eq!(result.as_primitive::().values(), &[1004, 4389, -4187]); + let result = sub(&a, &b).unwrap(); + assert_eq!(result.as_primitive::().values(), &[996, 4399, -3701]); + + let err = mul(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let err = div(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let err = rem(&a, &b).unwrap_err().to_string(); + assert!( + err.contains("Invalid duration arithmetic operation"), + "{err}" + ); + + let a = PrimitiveArray::::new(vec![i64::MAX].into(), None); + let b = PrimitiveArray::::new(vec![1].into(), None); + let err = add(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Compute error: Overflow happened on: 9223372036854775807 + 1" + ); + } + + #[test] + fn test_duration() { + test_duration_impl::(); + test_duration_impl::(); + test_duration_impl::(); + test_duration_impl::(); + } + + fn test_date_impl(f: F) + where + F: Fn(NaiveDate) -> T::Native, + T::Native: TryInto, + { + let a = PrimitiveArray::::new( + vec![ + f(NaiveDate::from_ymd_opt(1979, 1, 30).unwrap()), + f(NaiveDate::from_ymd_opt(2010, 4, 3).unwrap()), + f(NaiveDate::from_ymd_opt(2008, 2, 29).unwrap()), + ] + .into(), + None, + ); + + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(34, 2), + IntervalYearMonthType::make_value(3, -3), + IntervalYearMonthType::make_value(-12, 4), + ]); + + let format_array = |x: &dyn Array| -> Vec { + x.as_primitive::() + .values() + .into_iter() + .map(|x| { + as_date::((*x).try_into().ok().unwrap()) + .unwrap() + .to_string() + }) + .collect() + }; + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "2013-03-30".to_string(), + "2013-01-03".to_string(), + "1996-06-29".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!(result.as_ref(), &a); + + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(34, 2), + IntervalDayTimeType::make_value(3, -3), + IntervalDayTimeType::make_value(-12, 4), + ]); + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1979-03-05".to_string(), + "2010-04-06".to_string(), + "2008-02-17".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!(result.as_ref(), &a); + + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(34, 2, -34353534), + IntervalMonthDayNanoType::make_value(3, -3, 2443), + IntervalMonthDayNanoType::make_value(-12, 4, 2323242423232), + ]); + + let result = add(&a, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1981-12-02".to_string(), + "2010-06-30".to_string(), + "2007-03-04".to_string(), + ] + ); + let result = sub(&result, &b).unwrap(); + assert_eq!( + &format_array(result.as_ref()), + &[ + "1979-01-31".to_string(), + "2010-04-02".to_string(), + "2008-02-29".to_string(), + ] + ); + } + + #[test] + fn test_date() { + test_date_impl::(Date32Type::from_naive_date); + test_date_impl::(Date64Type::from_naive_date); + + let a = Date32Array::from(vec![i32::MIN, i32::MAX, 23, 7684]); + let b = Date32Array::from(vec![i32::MIN, i32::MIN, -2, 45]); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[0, 371085174288000, 2160000, 660009600] + ); + + let a = Date64Array::from(vec![4343, 76676, 3434]); + let b = Date64Array::from(vec![3, -5, 5]); + let result = sub(&a, &b).unwrap(); + assert_eq!( + result.as_primitive::().values(), + &[4340, 76681, 3429] + ); + + let a = Date64Array::from(vec![i64::MAX]); + let b = Date64Array::from(vec![-1]); + let err = sub(&a, &b).unwrap_err().to_string(); + assert_eq!( + err, + "Compute error: Overflow happened on: 9223372036854775807 - -1" + ); + } +} diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs new file mode 100644 index 000000000000..a9c3de5401c1 --- /dev/null +++ b/arrow-arith/src/temporal.rs @@ -0,0 +1,1100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines temporal kernels for time and date related functions. + +use std::sync::Arc; + +use chrono::{DateTime, Datelike, NaiveDateTime, NaiveTime, Offset, Timelike}; + +use arrow_array::builder::*; +use arrow_array::iterator::ArrayIter; +use arrow_array::temporal_conversions::{as_datetime, as_datetime_with_timezone, as_time}; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType}; + +/// This function takes an `ArrayIter` of input array and an extractor `op` which takes +/// an input `NaiveTime` and returns time component (e.g. hour) as `i32` value. +/// The extracted values are built by the given `builder` to be an `Int32Array`. +fn as_time_with_op, T: ArrowTemporalType, F>( + iter: ArrayIter, + mut builder: PrimitiveBuilder, + op: F, +) -> Int32Array +where + F: Fn(NaiveTime) -> i32, + i64: From, +{ + iter.into_iter().for_each(|value| { + if let Some(value) = value { + match as_time::(i64::from(value)) { + Some(dt) => builder.append_value(op(dt)), + None => builder.append_null(), + } + } else { + builder.append_null(); + } + }); + + builder.finish() +} + +/// This function takes an `ArrayIter` of input array and an extractor `op` which takes +/// an input `NaiveDateTime` and returns data time component (e.g. hour) as `i32` value. +/// The extracted values are built by the given `builder` to be an `Int32Array`. +fn as_datetime_with_op, T: ArrowTemporalType, F>( + iter: ArrayIter, + mut builder: PrimitiveBuilder, + op: F, +) -> Int32Array +where + F: Fn(NaiveDateTime) -> i32, + i64: From, +{ + iter.into_iter().for_each(|value| { + if let Some(value) = value { + match as_datetime::(i64::from(value)) { + Some(dt) => builder.append_value(op(dt)), + None => builder.append_null(), + } + } else { + builder.append_null(); + } + }); + + builder.finish() +} + +/// This function extracts date time component (e.g. hour) from an array of datatime. +/// `iter` is the `ArrayIter` of input datatime array. `builder` is used to build the +/// returned `Int32Array` containing the extracted components. `tz` is timezone string +/// which will be added to datetime values in the input array. `parsed` is a `Parsed` +/// object used to parse timezone string. `op` is the extractor closure which takes +/// data time object of `NaiveDateTime` type and returns `i32` value of extracted +/// component. +fn extract_component_from_datetime_array< + A: ArrayAccessor, + T: ArrowTemporalType, + F, +>( + iter: ArrayIter, + mut builder: PrimitiveBuilder, + tz: &str, + op: F, +) -> Result +where + F: Fn(DateTime) -> i32, + i64: From, +{ + let tz: Tz = tz.parse()?; + for value in iter { + match value { + Some(value) => match as_datetime_with_timezone::(value.into(), tz) { + Some(time) => builder.append_value(op(time)), + _ => { + return Err(ArrowError::ComputeError( + "Unable to read value as datetime".to_string(), + )) + } + }, + None => builder.append_null(), + } + } + Ok(builder.finish()) +} + +macro_rules! return_compute_error_with { + ($msg:expr, $param:expr) => { + return { Err(ArrowError::ComputeError(format!("{}: {:?}", $msg, $param))) } + }; +} + +pub(crate) use return_compute_error_with; + +// Internal trait, which is used for mapping values from DateLike structures +trait ChronoDateExt { + /// Returns a value in range `1..=4` indicating the quarter this date falls into + fn quarter(&self) -> u32; + + /// Returns a value in range `0..=3` indicating the quarter (zero-based) this date falls into + fn quarter0(&self) -> u32; + + /// Returns the day of week; Monday is encoded as `0`, Tuesday as `1`, etc. + fn num_days_from_monday(&self) -> i32; + + /// Returns the day of week; Sunday is encoded as `0`, Monday as `1`, etc. + fn num_days_from_sunday(&self) -> i32; +} + +impl ChronoDateExt for T { + fn quarter(&self) -> u32 { + self.quarter0() + 1 + } + + fn quarter0(&self) -> u32 { + self.month0() / 3 + } + + fn num_days_from_monday(&self) -> i32 { + self.weekday().num_days_from_monday() as i32 + } + + fn num_days_from_sunday(&self) -> i32 { + self.weekday().num_days_from_sunday() as i32 + } +} + +/// Parse the given string into a string representing fixed-offset that is correct as of the given +/// UTC NaiveDateTime. +/// Note that the offset is function of time and can vary depending on whether daylight savings is +/// in effect or not. e.g. Australia/Sydney is +10:00 or +11:00 depending on DST. +#[deprecated(note = "Use arrow_array::timezone::Tz instead")] +pub fn using_chrono_tz_and_utc_naive_date_time( + tz: &str, + utc: NaiveDateTime, +) -> Option { + use chrono::TimeZone; + let tz: Tz = tz.parse().ok()?; + Some(tz.offset_from_utc_datetime(&utc).fix()) +} + +/// Extracts the hours of a given array as an array of integers within +/// the range of [0, 23]. If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn hour_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "hour", |t| t.hour() as i32) +} + +/// Extracts the hours of a given temporal primitive array as an array of integers within +/// the range of [0, 23]. +pub fn hour(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let b = Int32Builder::with_capacity(array.len()); + match array.data_type() { + DataType::Time32(_) | DataType::Time64(_) => { + let iter = ArrayIter::new(array); + Ok(as_time_with_op::<&PrimitiveArray, T, _>(iter, b, |t| { + t.hour() as i32 + })) + } + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, None) => { + let iter = ArrayIter::new(array); + Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( + iter, + b, + |t| t.hour() as i32, + )) + } + DataType::Timestamp(_, Some(tz)) => { + let iter = ArrayIter::new(array); + extract_component_from_datetime_array::<&PrimitiveArray, T, _>(iter, b, tz, |t| { + t.hour() as i32 + }) + } + _ => return_compute_error_with!("hour does not support", array.data_type()), + } +} + +/// Extracts the years of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn year_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "year", |t| t.year()) +} + +/// Extracts the years of a given temporal primitive array as an array of integers +pub fn year(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "year", |t| t.year()) +} + +/// Extracts the quarter of a given temporal array as an array of integersa within +/// the range of [1, 4]. If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn quarter_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "quarter", |t| t.quarter() as i32) +} + +/// Extracts the quarter of a given temporal primitive array as an array of integers within +/// the range of [1, 4]. +pub fn quarter(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "quarter", |t| t.quarter() as i32) +} + +/// Extracts the month of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn month_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "month", |t| t.month() as i32) +} + +/// Extracts the month of a given temporal primitive array as an array of integers within +/// the range of [1, 12]. +pub fn month(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "month", |t| t.month() as i32) +} + +/// Extracts the day of week of a given temporal array as an array of +/// integers. +/// +/// Monday is encoded as `0`, Tuesday as `1`, etc. +/// +/// See also [`num_days_from_sunday`] which starts at Sunday. +/// +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn num_days_from_monday_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "num_days_from_monday", |t| t.num_days_from_monday()) +} + +/// Extracts the day of week of a given temporal primitive array as an array of +/// integers. +/// +/// Monday is encoded as `0`, Tuesday as `1`, etc. +/// +/// See also [`num_days_from_sunday`] which starts at Sunday. +pub fn num_days_from_monday(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "num_days_from_monday", |t| t.num_days_from_monday()) +} + +/// Extracts the day of week of a given temporal array as an array of +/// integers, starting at Sunday. +/// +/// Sunday is encoded as `0`, Monday as `1`, etc. +/// +/// See also [`num_days_from_monday`] which starts at Monday. +/// +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn num_days_from_sunday_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "num_days_from_sunday", |t| t.num_days_from_sunday()) +} + +/// Extracts the day of week of a given temporal primitive array as an array of +/// integers, starting at Sunday. +/// +/// Sunday is encoded as `0`, Monday as `1`, etc. +/// +/// See also [`num_days_from_monday`] which starts at Monday. +pub fn num_days_from_sunday(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "num_days_from_sunday", |t| t.num_days_from_sunday()) +} + +/// Extracts the day of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn day_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "day", |t| t.day() as i32) +} + +/// Extracts the day of a given temporal primitive array as an array of integers +pub fn day(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "day", |t| t.day() as i32) +} + +/// Extracts the day of year of a given temporal array as an array of integers +/// The day of year that ranges from 1 to 366. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn doy_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "doy", |t| t.ordinal() as i32) +} + +/// Extracts the day of year of a given temporal primitive array as an array of integers +/// The day of year that ranges from 1 to 366 +pub fn doy(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + T::Native: ArrowNativeType, + i64: From, +{ + time_fraction_internal(array, "doy", |t| t.ordinal() as i32) +} + +/// Extracts the minutes of a given temporal primitive array as an array of integers +pub fn minute(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "minute", |t| t.minute() as i32) +} + +/// Extracts the week of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn week_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "week", |t| t.iso_week().week() as i32) +} + +/// Extracts the week of a given temporal primitive array as an array of integers +pub fn week(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "week", |t| t.iso_week().week() as i32) +} + +/// Extracts the seconds of a given temporal primitive array as an array of integers +pub fn second(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "second", |t| t.second() as i32) +} + +/// Extracts the nanoseconds of a given temporal primitive array as an array of integers +pub fn nanosecond(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "nanosecond", |t| t.nanosecond() as i32) +} + +/// Extracts the nanoseconds of a given temporal primitive array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn nanosecond_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "nanosecond", |t| t.nanosecond() as i32) +} + +/// Extracts the microseconds of a given temporal primitive array as an array of integers +pub fn microsecond(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "microsecond", |t| (t.nanosecond() / 1_000) as i32) +} + +/// Extracts the microseconds of a given temporal primitive array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn microsecond_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "microsecond", |t| (t.nanosecond() / 1_000) as i32) +} + +/// Extracts the milliseconds of a given temporal primitive array as an array of integers +pub fn millisecond(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + time_fraction_internal(array, "millisecond", |t| { + (t.nanosecond() / 1_000_000) as i32 + }) +} +/// Extracts the milliseconds of a given temporal primitive array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn millisecond_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "millisecond", |t| { + (t.nanosecond() / 1_000_000) as i32 + }) +} + +/// Extracts the time fraction of a given temporal array as an array of integers +fn time_fraction_dyn(array: &dyn Array, name: &str, op: F) -> Result +where + F: Fn(NaiveDateTime) -> i32, +{ + match array.data_type().clone() { + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + array => { + let values = time_fraction_dyn(array.values(), name, op)?; + Ok(Arc::new(array.with_values(values))) + } + dt => return_compute_error_with!(format!("{name} does not support"), dt), + ) + } + _ => { + downcast_temporal_array!( + array => { + time_fraction_internal(array, name, op) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!(format!("{name} does not support"), dt), + ) + } + } +} + +/// Extracts the time fraction of a given temporal array as an array of integers +fn time_fraction_internal( + array: &PrimitiveArray, + name: &str, + op: F, +) -> Result +where + F: Fn(NaiveDateTime) -> i32, + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let b = Int32Builder::with_capacity(array.len()); + match array.data_type() { + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, None) => { + let iter = ArrayIter::new(array); + Ok(as_datetime_with_op::<_, T, _>(iter, b, op)) + } + DataType::Timestamp(_, Some(tz)) => { + let iter = ArrayIter::new(array); + extract_component_from_datetime_array::<_, T, _>(iter, b, tz, |t| op(t.naive_local())) + } + _ => return_compute_error_with!(format!("{name} does not support"), array.data_type()), + } +} + +/// Extracts the minutes of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn minute_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "minute", |t| t.minute() as i32) +} + +/// Extracts the seconds of a given temporal array as an array of integers. +/// If the given array isn't temporal primitive or dictionary array, +/// an `Err` will be returned. +pub fn second_dyn(array: &dyn Array) -> Result { + time_fraction_dyn(array, "second", |t| t.second() as i32) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_temporal_array_date64_hour() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = hour(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(4, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_hour() { + let a: PrimitiveArray = vec![Some(15147), None, Some(15148)].into(); + + let b = hour(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(0, b.value(2)); + } + + #[test] + fn test_temporal_array_time32_second_hour() { + let a: PrimitiveArray = vec![37800, 86339].into(); + + let b = hour(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(23, b.value(1)); + } + + #[test] + fn test_temporal_array_time64_micro_hour() { + let a: PrimitiveArray = vec![37800000000, 86339000000].into(); + + let b = hour(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(23, b.value(1)); + } + + #[test] + fn test_temporal_array_timestamp_micro_hour() { + let a: TimestampMicrosecondArray = vec![37800000000, 86339000000].into(); + + let b = hour(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(23, b.value(1)); + } + + #[test] + fn test_temporal_array_date64_year() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = year(&a).unwrap(); + assert_eq!(2018, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2019, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_year() { + let a: PrimitiveArray = vec![Some(15147), None, Some(15448)].into(); + + let b = year(&a).unwrap(); + assert_eq!(2011, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2012, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_quarter() { + //1514764800000 -> 2018-01-01 + //1566275025000 -> 2019-08-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1566275025000)].into(); + + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(3, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_quarter() { + let a: PrimitiveArray = vec![Some(1), None, Some(300)].into(); + + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(4, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_quarter_with_timezone() { + // 24 * 60 * 60 = 86400 + let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("+00:00".to_string()); + let b = quarter(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("-10:00".to_string()); + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_date64_month() { + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_month() { + let a: PrimitiveArray = vec![Some(1), None, Some(31)].into(); + + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_month_with_timezone() { + // 24 * 60 * 60 = 86400 + let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("+00:00".to_string()); + let b = month(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("-10:00".to_string()); + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_day_with_timezone() { + // 24 * 60 * 60 = 86400 + let a = TimestampSecondArray::from(vec![86400]).with_timezone("+00:00".to_string()); + let b = day(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = TimestampSecondArray::from(vec![86400]).with_timezone("-10:00".to_string()); + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_date64_weekday() { + //1514764800000 -> 2018-01-01 (Monday) + //1550636625000 -> 2019-02-20 (Wednesday) + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = num_days_from_monday(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_weekday0() { + //1483228800000 -> 2017-01-01 (Sunday) + //1514764800000 -> 2018-01-01 (Monday) + //1550636625000 -> 2019-02-20 (Wednesday) + let a: PrimitiveArray = vec![ + Some(1483228800000), + None, + Some(1514764800000), + Some(1550636625000), + ] + .into(); + + let b = num_days_from_sunday(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + assert_eq!(3, b.value(3)); + } + + #[test] + fn test_temporal_array_date64_day() { + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(20, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_day() { + let a: PrimitiveArray = vec![Some(0), None, Some(31)].into(); + + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_doy() { + //1483228800000 -> 2017-01-01 (Sunday) + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = vec![ + Some(1483228800000), + Some(1514764800000), + None, + Some(1550636625000), + ] + .into(); + + let b = doy(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert_eq!(1, b.value(1)); + assert!(!b.is_valid(2)); + assert_eq!(51, b.value(3)); + } + + #[test] + fn test_temporal_array_timestamp_micro_year() { + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = year(&a).unwrap(); + assert_eq!(2021, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2024, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_minute() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = minute(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(23, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_micro_minute() { + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = minute(&a).unwrap(); + assert_eq!(57, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(44, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_week() { + let a: PrimitiveArray = vec![Some(0), None, Some(7)].into(); + + let b = week(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_week() { + // 1646116175000 -> 2022.03.01 , 1641171600000 -> 2022.01.03 + // 1640998800000 -> 2022.01.01 + let a: PrimitiveArray = vec![ + Some(1646116175000), + None, + Some(1641171600000), + Some(1640998800000), + ] + .into(); + + let b = week(&a).unwrap(); + assert_eq!(9, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + assert_eq!(52, b.value(3)); + } + + #[test] + fn test_temporal_array_timestamp_micro_week() { + //1612025847000000 -> 2021.1.30 + //1722015847000000 -> 2024.7.27 + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = week(&a).unwrap(); + assert_eq!(4, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(30, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_second() { + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = second(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(45, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_micro_second() { + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = second(&a).unwrap(); + assert_eq!(27, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(7, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_second_with_timezone() { + let a = TimestampSecondArray::from(vec![10, 20]).with_timezone("+00:00".to_string()); + let b = second(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(20, b.value(1)); + } + + #[test] + fn test_temporal_array_timestamp_minute_with_timezone() { + let a = TimestampSecondArray::from(vec![0, 60]).with_timezone("+00:50".to_string()); + let b = minute(&a).unwrap(); + assert_eq!(50, b.value(0)); + assert_eq!(51, b.value(1)); + } + + #[test] + fn test_temporal_array_timestamp_minute_with_negative_timezone() { + let a = TimestampSecondArray::from(vec![60 * 55]).with_timezone("-00:50".to_string()); + let b = minute(&a).unwrap(); + assert_eq!(5, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01:00".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(11, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_without_colon() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+0100".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(11, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_without_minutes() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01".to_string()); + let b = hour(&a).unwrap(); + assert_eq!(11, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_without_initial_sign() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("0100".to_string()); + let err = hour(&a).unwrap_err().to_string(); + assert!(err.contains("Invalid timezone"), "{}", err); + } + + #[test] + fn test_temporal_array_timestamp_hour_with_timezone_with_only_colon() { + let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("01:00".to_string()); + let err = hour(&a).unwrap_err().to_string(); + assert!(err.contains("Invalid timezone"), "{}", err); + } + + #[test] + fn test_temporal_array_timestamp_week_without_timezone() { + // 1970-01-01T00:00:00 -> 1970-01-01T00:00:00 Thursday (week 1) + // 1970-01-01T00:00:00 + 4 days -> 1970-01-05T00:00:00 Monday (week 2) + // 1970-01-01T00:00:00 + 4 days - 1 second -> 1970-01-04T23:59:59 Sunday (week 1) + let a = TimestampSecondArray::from(vec![0, 86400 * 4, 86400 * 4 - 1]); + let b = week(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert_eq!(2, b.value(1)); + assert_eq!(1, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_week_with_timezone() { + // 1970-01-01T01:00:00+01:00 -> 1970-01-01T01:00:00+01:00 Thursday (week 1) + // 1970-01-01T01:00:00+01:00 + 4 days -> 1970-01-05T01:00:00+01:00 Monday (week 2) + // 1970-01-01T01:00:00+01:00 + 4 days - 1 second -> 1970-01-05T00:59:59+01:00 Monday (week 2) + let a = TimestampSecondArray::from(vec![0, 86400 * 4, 86400 * 4 - 1]) + .with_timezone("+01:00".to_string()); + let b = week(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert_eq!(2, b.value(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_hour_minute_second_dictionary_array() { + let a = TimestampSecondArray::from(vec![ + 60 * 60 * 10 + 61, + 60 * 60 * 20 + 122, + 60 * 60 * 30 + 183, + ]) + .with_timezone("+01:00".to_string()); + + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 1]); + let dict = DictionaryArray::try_new(keys.clone(), Arc::new(a)).unwrap(); + + let b = hour_dyn(&dict).unwrap(); + + let expected_dict = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![11, 21, 7]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + + let b = time_fraction_dyn(&dict, "minute", |t| t.minute() as i32).unwrap(); + + let b_old = minute_dyn(&dict).unwrap(); + + let expected_dict = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 2, 3]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + assert_eq!(&expected, &b_old); + + let b = time_fraction_dyn(&dict, "second", |t| t.second() as i32).unwrap(); + + let b_old = second_dyn(&dict).unwrap(); + + let expected_dict = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 2, 3]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + assert_eq!(&expected, &b_old); + + let b = time_fraction_dyn(&dict, "nanosecond", |t| t.nanosecond() as i32).unwrap(); + + let expected_dict = + DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0]))); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_year_dictionary_array() { + let a: PrimitiveArray = vec![Some(1514764800000), Some(1550636625000)].into(); + + let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + + let b = year_dyn(&dict).unwrap(); + + let expected_dict = DictionaryArray::new( + keys, + Arc::new(Int32Array::from(vec![2018, 2019, 2019, 2018])), + ); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_quarter_month_dictionary_array() { + //1514764800000 -> 2018-01-01 + //1566275025000 -> 2019-08-20 + let a: PrimitiveArray = vec![Some(1514764800000), Some(1566275025000)].into(); + + let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + + let b = quarter_dyn(&dict).unwrap(); + + let expected = + DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 3, 3, 1]))); + assert_eq!(b.as_ref(), &expected); + + let b = month_dyn(&dict).unwrap(); + + let expected = DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![1, 8, 8, 1]))); + assert_eq!(b.as_ref(), &expected); + } + + #[test] + fn test_num_days_from_monday_sunday_day_doy_week_dictionary_array() { + //1514764800000 -> 2018-01-01 (Monday) + //1550636625000 -> 2019-02-20 (Wednesday) + let a: PrimitiveArray = vec![Some(1514764800000), Some(1550636625000)].into(); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), Some(0), None]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + + let b = num_days_from_monday_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(0), Some(2), Some(2), Some(0), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = num_days_from_sunday_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(3), Some(3), Some(1), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = day_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(20), Some(20), Some(1), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = doy_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(51), Some(51), Some(1), None]); + let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + + let b = week_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![Some(1), Some(8), Some(8), Some(1), None]); + let expected = DictionaryArray::new(keys, Arc::new(a)); + assert_eq!(b.as_ref(), &expected); + } + + #[test] + fn test_temporal_array_date64_nanosecond() { + // new Date(1667328721453) + // Tue Nov 01 2022 11:52:01 GMT-0700 (Pacific Daylight Time) + // + // new Date(1667328721453).getMilliseconds() + // 453 + + let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); + + let b = nanosecond(&a).unwrap(); + assert!(!b.is_valid(0)); + assert_eq!(453_000_000, b.value(1)); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + let b = nanosecond_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![None, Some(453_000_000)]); + let expected_dict = DictionaryArray::new(keys, Arc::new(a)); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_temporal_array_date64_microsecond() { + let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); + + let b = microsecond(&a).unwrap(); + assert!(!b.is_valid(0)); + assert_eq!(453_000, b.value(1)); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + let b = microsecond_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![None, Some(453_000)]); + let expected_dict = DictionaryArray::new(keys, Arc::new(a)); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } + + #[test] + fn test_temporal_array_date64_millisecond() { + let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); + + let b = millisecond(&a).unwrap(); + assert!(!b.is_valid(0)); + assert_eq!(453, b.value(1)); + + let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); + let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); + let b = millisecond_dyn(&dict).unwrap(); + + let a = Int32Array::from(vec![None, Some(453)]); + let expected_dict = DictionaryArray::new(keys, Arc::new(a)); + let expected = Arc::new(expected_dict) as ArrayRef; + assert_eq!(&expected, &b); + } +} diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml new file mode 100644 index 000000000000..04eec8df6379 --- /dev/null +++ b/arrow-array/Cargo.toml @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-array" +version = { workspace = true } +description = "Array abstractions for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_array" +path = "src/lib.rs" +bench = false + + +[target.'cfg(target_arch = "wasm32")'.dependencies] +ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } + +[dependencies] +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } +arrow-data = { workspace = true } +chrono = { workspace = true } +chrono-tz = { version = "0.8", optional = true } +num = { version = "0.4.1", default-features = false, features = ["std"] } +half = { version = "2.1", default-features = false, features = ["num-traits"] } +hashbrown = { version = "0.14", default-features = false } + +[dev-dependencies] +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } +criterion = { version = "0.5", default-features = false } + +[build-dependencies] + +[[bench]] +name = "occupancy" +harness = false diff --git a/arrow-array/benches/occupancy.rs b/arrow-array/benches/occupancy.rs new file mode 100644 index 000000000000..ed4b94351c28 --- /dev/null +++ b/arrow-array/benches/occupancy.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Int32Type; +use arrow_array::{DictionaryArray, Int32Array}; +use arrow_buffer::NullBuffer; +use criterion::*; +use rand::{thread_rng, Rng}; +use std::sync::Arc; + +fn gen_dict( + len: usize, + values_len: usize, + occupancy: f64, + null_percent: f64, +) -> DictionaryArray { + let mut rng = thread_rng(); + let values = Int32Array::from(vec![0; values_len]); + let max_key = (values_len as f64 * occupancy) as i32; + let keys = (0..len).map(|_| rng.gen_range(0..max_key)).collect(); + let nulls = (0..len).map(|_| !rng.gen_bool(null_percent)).collect(); + + let keys = Int32Array::new(keys, Some(NullBuffer::new(nulls))); + DictionaryArray::new(keys, Arc::new(values)) +} + +fn criterion_benchmark(c: &mut Criterion) { + for values in [10, 100, 512] { + for occupancy in [1., 0.5, 0.1] { + for null_percent in [0.0, 0.1, 0.5, 0.9] { + let dict = gen_dict(1024, values, occupancy, null_percent); + c.bench_function(&format!("occupancy(values: {values}, occupancy: {occupancy}, null_percent: {null_percent})"), |b| { + b.iter(|| { + black_box(&dict).occupancy() + }); + }); + } + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/src/arithmetic.rs b/arrow-array/src/arithmetic.rs new file mode 100644 index 000000000000..590536190309 --- /dev/null +++ b/arrow-array/src/arithmetic.rs @@ -0,0 +1,854 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::{i256, ArrowNativeType}; +use arrow_schema::ArrowError; +use half::f16; +use num::complex::ComplexFloat; +use std::cmp::Ordering; + +/// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations, +/// and totally ordered comparison operations +/// +/// The APIs with `_wrapping` suffix do not perform overflow-checking. For integer +/// types they will wrap around the boundary of the type. For floating point types they +/// will overflow to INF or -INF preserving the expected sign value +/// +/// Note `div_wrapping` and `mod_wrapping` will panic for integer types if `rhs` is zero +/// although this may be subject to change +/// +/// The APIs with `_checked` suffix perform overflow-checking. For integer types +/// these will return `Err` instead of wrapping. For floating point types they will +/// overflow to INF or -INF preserving the expected sign value +/// +/// Comparison of integer types is as per normal integer comparison rules, floating +/// point values are compared as per IEEE 754's totalOrder predicate see [`f32::total_cmp`] +/// +pub trait ArrowNativeTypeOp: ArrowNativeType { + /// The additive identity + const ZERO: Self; + + /// The multiplicative identity + const ONE: Self; + + /// The minimum value and identity for the `max` aggregation. + /// Note that the aggregation uses the total order predicate for floating point values, + /// which means that this value is a negative NaN. + const MIN_TOTAL_ORDER: Self; + + /// The maximum value and identity for the `min` aggregation. + /// Note that the aggregation uses the total order predicate for floating point values, + /// which means that this value is a positive NaN. + const MAX_TOTAL_ORDER: Self; + + /// Checked addition operation + fn add_checked(self, rhs: Self) -> Result; + + /// Wrapping addition operation + fn add_wrapping(self, rhs: Self) -> Self; + + /// Checked subtraction operation + fn sub_checked(self, rhs: Self) -> Result; + + /// Wrapping subtraction operation + fn sub_wrapping(self, rhs: Self) -> Self; + + /// Checked multiplication operation + fn mul_checked(self, rhs: Self) -> Result; + + /// Wrapping multiplication operation + fn mul_wrapping(self, rhs: Self) -> Self; + + /// Checked division operation + fn div_checked(self, rhs: Self) -> Result; + + /// Wrapping division operation + fn div_wrapping(self, rhs: Self) -> Self; + + /// Checked remainder operation + fn mod_checked(self, rhs: Self) -> Result; + + /// Wrapping remainder operation + fn mod_wrapping(self, rhs: Self) -> Self; + + /// Checked negation operation + fn neg_checked(self) -> Result; + + /// Wrapping negation operation + fn neg_wrapping(self) -> Self; + + /// Checked exponentiation operation + fn pow_checked(self, exp: u32) -> Result; + + /// Wrapping exponentiation operation + fn pow_wrapping(self, exp: u32) -> Self; + + /// Returns true if zero else false + fn is_zero(self) -> bool; + + /// Compare operation + fn compare(self, rhs: Self) -> Ordering; + + /// Equality operation + fn is_eq(self, rhs: Self) -> bool; + + /// Not equal operation + #[inline] + fn is_ne(self, rhs: Self) -> bool { + !self.is_eq(rhs) + } + + /// Less than operation + #[inline] + fn is_lt(self, rhs: Self) -> bool { + self.compare(rhs).is_lt() + } + + /// Less than equals operation + #[inline] + fn is_le(self, rhs: Self) -> bool { + self.compare(rhs).is_le() + } + + /// Greater than operation + #[inline] + fn is_gt(self, rhs: Self) -> bool { + self.compare(rhs).is_gt() + } + + /// Greater than equals operation + #[inline] + fn is_ge(self, rhs: Self) -> bool { + self.compare(rhs).is_ge() + } +} + +macro_rules! native_type_op { + ($t:tt) => { + native_type_op!($t, 0, 1, $t::MIN, $t::MAX); + }; + ($t:tt, $zero:expr, $one: expr, $min: expr, $max: expr) => { + impl ArrowNativeTypeOp for $t { + const ZERO: Self = $zero; + const ONE: Self = $one; + const MIN_TOTAL_ORDER: Self = $min; + const MAX_TOTAL_ORDER: Self = $max; + + #[inline] + fn add_checked(self, rhs: Self) -> Result { + self.checked_add(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} + {:?}", + self, rhs + )) + }) + } + + #[inline] + fn add_wrapping(self, rhs: Self) -> Self { + self.wrapping_add(rhs) + } + + #[inline] + fn sub_checked(self, rhs: Self) -> Result { + self.checked_sub(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} - {:?}", + self, rhs + )) + }) + } + + #[inline] + fn sub_wrapping(self, rhs: Self) -> Self { + self.wrapping_sub(rhs) + } + + #[inline] + fn mul_checked(self, rhs: Self) -> Result { + self.checked_mul(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} * {:?}", + self, rhs + )) + }) + } + + #[inline] + fn mul_wrapping(self, rhs: Self) -> Self { + self.wrapping_mul(rhs) + } + + #[inline] + fn div_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + self.checked_div(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} / {:?}", + self, rhs + )) + }) + } + } + + #[inline] + fn div_wrapping(self, rhs: Self) -> Self { + self.wrapping_div(rhs) + } + + #[inline] + fn mod_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + self.checked_rem(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} % {:?}", + self, rhs + )) + }) + } + } + + #[inline] + fn mod_wrapping(self, rhs: Self) -> Self { + self.wrapping_rem(rhs) + } + + #[inline] + fn neg_checked(self) -> Result { + self.checked_neg().ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow happened on: {:?}", self)) + }) + } + + #[inline] + fn pow_checked(self, exp: u32) -> Result { + self.checked_pow(exp).ok_or_else(|| { + ArrowError::ComputeError(format!("Overflow happened on: {:?} ^ {exp:?}", self)) + }) + } + + #[inline] + fn pow_wrapping(self, exp: u32) -> Self { + self.wrapping_pow(exp) + } + + #[inline] + fn neg_wrapping(self) -> Self { + self.wrapping_neg() + } + + #[inline] + fn is_zero(self) -> bool { + self == Self::ZERO + } + + #[inline] + fn compare(self, rhs: Self) -> Ordering { + self.cmp(&rhs) + } + + #[inline] + fn is_eq(self, rhs: Self) -> bool { + self == rhs + } + } + }; +} + +native_type_op!(i8); +native_type_op!(i16); +native_type_op!(i32); +native_type_op!(i64); +native_type_op!(i128); +native_type_op!(u8); +native_type_op!(u16); +native_type_op!(u32); +native_type_op!(u64); +native_type_op!(i256, i256::ZERO, i256::ONE, i256::MIN, i256::MAX); + +macro_rules! native_type_float_op { + ($t:tt, $zero:expr, $one:expr, $min:expr, $max:expr) => { + impl ArrowNativeTypeOp for $t { + const ZERO: Self = $zero; + const ONE: Self = $one; + const MIN_TOTAL_ORDER: Self = $min; + const MAX_TOTAL_ORDER: Self = $max; + + #[inline] + fn add_checked(self, rhs: Self) -> Result { + Ok(self + rhs) + } + + #[inline] + fn add_wrapping(self, rhs: Self) -> Self { + self + rhs + } + + #[inline] + fn sub_checked(self, rhs: Self) -> Result { + Ok(self - rhs) + } + + #[inline] + fn sub_wrapping(self, rhs: Self) -> Self { + self - rhs + } + + #[inline] + fn mul_checked(self, rhs: Self) -> Result { + Ok(self * rhs) + } + + #[inline] + fn mul_wrapping(self, rhs: Self) -> Self { + self * rhs + } + + #[inline] + fn div_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self / rhs) + } + } + + #[inline] + fn div_wrapping(self, rhs: Self) -> Self { + self / rhs + } + + #[inline] + fn mod_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self % rhs) + } + } + + #[inline] + fn mod_wrapping(self, rhs: Self) -> Self { + self % rhs + } + + #[inline] + fn neg_checked(self) -> Result { + Ok(-self) + } + + #[inline] + fn neg_wrapping(self) -> Self { + -self + } + + #[inline] + fn pow_checked(self, exp: u32) -> Result { + Ok(self.powi(exp as i32)) + } + + #[inline] + fn pow_wrapping(self, exp: u32) -> Self { + self.powi(exp as i32) + } + + #[inline] + fn is_zero(self) -> bool { + self == $zero + } + + #[inline] + fn compare(self, rhs: Self) -> Ordering { + <$t>::total_cmp(&self, &rhs) + } + + #[inline] + fn is_eq(self, rhs: Self) -> bool { + // Equivalent to `self.total_cmp(&rhs).is_eq()` + // but LLVM isn't able to realise this is bitwise equality + // https://rust.godbolt.org/z/347nWGxoW + self.to_bits() == rhs.to_bits() + } + } + }; +} + +// the smallest/largest bit patterns for floating point numbers are NaN, but differ from the canonical NAN constants. +// See test_float_total_order_min_max for details. +native_type_float_op!( + f16, + f16::ZERO, + f16::ONE, + f16::from_bits(-1 as _), + f16::from_bits(i16::MAX as _) +); +// from_bits is not yet stable as const fn, see https://github.com/rust-lang/rust/issues/72447 +native_type_float_op!( + f32, + 0., + 1., + unsafe { std::mem::transmute(-1_i32) }, + unsafe { std::mem::transmute(i32::MAX) } +); +native_type_float_op!( + f64, + 0., + 1., + unsafe { std::mem::transmute(-1_i64) }, + unsafe { std::mem::transmute(i64::MAX) } +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_native_type_is_zero() { + assert!(0_i8.is_zero()); + assert!(0_i16.is_zero()); + assert!(0_i32.is_zero()); + assert!(0_i64.is_zero()); + assert!(0_i128.is_zero()); + assert!(i256::ZERO.is_zero()); + assert!(0_u8.is_zero()); + assert!(0_u16.is_zero()); + assert!(0_u32.is_zero()); + assert!(0_u64.is_zero()); + assert!(f16::ZERO.is_zero()); + assert!(0.0_f32.is_zero()); + assert!(0.0_f64.is_zero()); + } + + #[test] + fn test_native_type_comparison() { + // is_eq + assert!(8_i8.is_eq(8_i8)); + assert!(8_i16.is_eq(8_i16)); + assert!(8_i32.is_eq(8_i32)); + assert!(8_i64.is_eq(8_i64)); + assert!(8_i128.is_eq(8_i128)); + assert!(i256::from_parts(8, 0).is_eq(i256::from_parts(8, 0))); + assert!(8_u8.is_eq(8_u8)); + assert!(8_u16.is_eq(8_u16)); + assert!(8_u32.is_eq(8_u32)); + assert!(8_u64.is_eq(8_u64)); + assert!(f16::from_f32(8.0).is_eq(f16::from_f32(8.0))); + assert!(8.0_f32.is_eq(8.0_f32)); + assert!(8.0_f64.is_eq(8.0_f64)); + + // is_ne + assert!(8_i8.is_ne(1_i8)); + assert!(8_i16.is_ne(1_i16)); + assert!(8_i32.is_ne(1_i32)); + assert!(8_i64.is_ne(1_i64)); + assert!(8_i128.is_ne(1_i128)); + assert!(i256::from_parts(8, 0).is_ne(i256::from_parts(1, 0))); + assert!(8_u8.is_ne(1_u8)); + assert!(8_u16.is_ne(1_u16)); + assert!(8_u32.is_ne(1_u32)); + assert!(8_u64.is_ne(1_u64)); + assert!(f16::from_f32(8.0).is_ne(f16::from_f32(1.0))); + assert!(8.0_f32.is_ne(1.0_f32)); + assert!(8.0_f64.is_ne(1.0_f64)); + + // is_lt + assert!(8_i8.is_lt(10_i8)); + assert!(8_i16.is_lt(10_i16)); + assert!(8_i32.is_lt(10_i32)); + assert!(8_i64.is_lt(10_i64)); + assert!(8_i128.is_lt(10_i128)); + assert!(i256::from_parts(8, 0).is_lt(i256::from_parts(10, 0))); + assert!(8_u8.is_lt(10_u8)); + assert!(8_u16.is_lt(10_u16)); + assert!(8_u32.is_lt(10_u32)); + assert!(8_u64.is_lt(10_u64)); + assert!(f16::from_f32(8.0).is_lt(f16::from_f32(10.0))); + assert!(8.0_f32.is_lt(10.0_f32)); + assert!(8.0_f64.is_lt(10.0_f64)); + + // is_gt + assert!(8_i8.is_gt(1_i8)); + assert!(8_i16.is_gt(1_i16)); + assert!(8_i32.is_gt(1_i32)); + assert!(8_i64.is_gt(1_i64)); + assert!(8_i128.is_gt(1_i128)); + assert!(i256::from_parts(8, 0).is_gt(i256::from_parts(1, 0))); + assert!(8_u8.is_gt(1_u8)); + assert!(8_u16.is_gt(1_u16)); + assert!(8_u32.is_gt(1_u32)); + assert!(8_u64.is_gt(1_u64)); + assert!(f16::from_f32(8.0).is_gt(f16::from_f32(1.0))); + assert!(8.0_f32.is_gt(1.0_f32)); + assert!(8.0_f64.is_gt(1.0_f64)); + } + + #[test] + fn test_native_type_add() { + // add_wrapping + assert_eq!(8_i8.add_wrapping(2_i8), 10_i8); + assert_eq!(8_i16.add_wrapping(2_i16), 10_i16); + assert_eq!(8_i32.add_wrapping(2_i32), 10_i32); + assert_eq!(8_i64.add_wrapping(2_i64), 10_i64); + assert_eq!(8_i128.add_wrapping(2_i128), 10_i128); + assert_eq!( + i256::from_parts(8, 0).add_wrapping(i256::from_parts(2, 0)), + i256::from_parts(10, 0) + ); + assert_eq!(8_u8.add_wrapping(2_u8), 10_u8); + assert_eq!(8_u16.add_wrapping(2_u16), 10_u16); + assert_eq!(8_u32.add_wrapping(2_u32), 10_u32); + assert_eq!(8_u64.add_wrapping(2_u64), 10_u64); + assert_eq!( + f16::from_f32(8.0).add_wrapping(f16::from_f32(2.0)), + f16::from_f32(10.0) + ); + assert_eq!(8.0_f32.add_wrapping(2.0_f32), 10_f32); + assert_eq!(8.0_f64.add_wrapping(2.0_f64), 10_f64); + + // add_checked + assert_eq!(8_i8.add_checked(2_i8).unwrap(), 10_i8); + assert_eq!(8_i16.add_checked(2_i16).unwrap(), 10_i16); + assert_eq!(8_i32.add_checked(2_i32).unwrap(), 10_i32); + assert_eq!(8_i64.add_checked(2_i64).unwrap(), 10_i64); + assert_eq!(8_i128.add_checked(2_i128).unwrap(), 10_i128); + assert_eq!( + i256::from_parts(8, 0) + .add_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(10, 0) + ); + assert_eq!(8_u8.add_checked(2_u8).unwrap(), 10_u8); + assert_eq!(8_u16.add_checked(2_u16).unwrap(), 10_u16); + assert_eq!(8_u32.add_checked(2_u32).unwrap(), 10_u32); + assert_eq!(8_u64.add_checked(2_u64).unwrap(), 10_u64); + assert_eq!( + f16::from_f32(8.0).add_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(10.0) + ); + assert_eq!(8.0_f32.add_checked(2.0_f32).unwrap(), 10_f32); + assert_eq!(8.0_f64.add_checked(2.0_f64).unwrap(), 10_f64); + } + + #[test] + fn test_native_type_sub() { + // sub_wrapping + assert_eq!(8_i8.sub_wrapping(2_i8), 6_i8); + assert_eq!(8_i16.sub_wrapping(2_i16), 6_i16); + assert_eq!(8_i32.sub_wrapping(2_i32), 6_i32); + assert_eq!(8_i64.sub_wrapping(2_i64), 6_i64); + assert_eq!(8_i128.sub_wrapping(2_i128), 6_i128); + assert_eq!( + i256::from_parts(8, 0).sub_wrapping(i256::from_parts(2, 0)), + i256::from_parts(6, 0) + ); + assert_eq!(8_u8.sub_wrapping(2_u8), 6_u8); + assert_eq!(8_u16.sub_wrapping(2_u16), 6_u16); + assert_eq!(8_u32.sub_wrapping(2_u32), 6_u32); + assert_eq!(8_u64.sub_wrapping(2_u64), 6_u64); + assert_eq!( + f16::from_f32(8.0).sub_wrapping(f16::from_f32(2.0)), + f16::from_f32(6.0) + ); + assert_eq!(8.0_f32.sub_wrapping(2.0_f32), 6_f32); + assert_eq!(8.0_f64.sub_wrapping(2.0_f64), 6_f64); + + // sub_checked + assert_eq!(8_i8.sub_checked(2_i8).unwrap(), 6_i8); + assert_eq!(8_i16.sub_checked(2_i16).unwrap(), 6_i16); + assert_eq!(8_i32.sub_checked(2_i32).unwrap(), 6_i32); + assert_eq!(8_i64.sub_checked(2_i64).unwrap(), 6_i64); + assert_eq!(8_i128.sub_checked(2_i128).unwrap(), 6_i128); + assert_eq!( + i256::from_parts(8, 0) + .sub_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(6, 0) + ); + assert_eq!(8_u8.sub_checked(2_u8).unwrap(), 6_u8); + assert_eq!(8_u16.sub_checked(2_u16).unwrap(), 6_u16); + assert_eq!(8_u32.sub_checked(2_u32).unwrap(), 6_u32); + assert_eq!(8_u64.sub_checked(2_u64).unwrap(), 6_u64); + assert_eq!( + f16::from_f32(8.0).sub_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(6.0) + ); + assert_eq!(8.0_f32.sub_checked(2.0_f32).unwrap(), 6_f32); + assert_eq!(8.0_f64.sub_checked(2.0_f64).unwrap(), 6_f64); + } + + #[test] + fn test_native_type_mul() { + // mul_wrapping + assert_eq!(8_i8.mul_wrapping(2_i8), 16_i8); + assert_eq!(8_i16.mul_wrapping(2_i16), 16_i16); + assert_eq!(8_i32.mul_wrapping(2_i32), 16_i32); + assert_eq!(8_i64.mul_wrapping(2_i64), 16_i64); + assert_eq!(8_i128.mul_wrapping(2_i128), 16_i128); + assert_eq!( + i256::from_parts(8, 0).mul_wrapping(i256::from_parts(2, 0)), + i256::from_parts(16, 0) + ); + assert_eq!(8_u8.mul_wrapping(2_u8), 16_u8); + assert_eq!(8_u16.mul_wrapping(2_u16), 16_u16); + assert_eq!(8_u32.mul_wrapping(2_u32), 16_u32); + assert_eq!(8_u64.mul_wrapping(2_u64), 16_u64); + assert_eq!( + f16::from_f32(8.0).mul_wrapping(f16::from_f32(2.0)), + f16::from_f32(16.0) + ); + assert_eq!(8.0_f32.mul_wrapping(2.0_f32), 16_f32); + assert_eq!(8.0_f64.mul_wrapping(2.0_f64), 16_f64); + + // mul_checked + assert_eq!(8_i8.mul_checked(2_i8).unwrap(), 16_i8); + assert_eq!(8_i16.mul_checked(2_i16).unwrap(), 16_i16); + assert_eq!(8_i32.mul_checked(2_i32).unwrap(), 16_i32); + assert_eq!(8_i64.mul_checked(2_i64).unwrap(), 16_i64); + assert_eq!(8_i128.mul_checked(2_i128).unwrap(), 16_i128); + assert_eq!( + i256::from_parts(8, 0) + .mul_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(16, 0) + ); + assert_eq!(8_u8.mul_checked(2_u8).unwrap(), 16_u8); + assert_eq!(8_u16.mul_checked(2_u16).unwrap(), 16_u16); + assert_eq!(8_u32.mul_checked(2_u32).unwrap(), 16_u32); + assert_eq!(8_u64.mul_checked(2_u64).unwrap(), 16_u64); + assert_eq!( + f16::from_f32(8.0).mul_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(16.0) + ); + assert_eq!(8.0_f32.mul_checked(2.0_f32).unwrap(), 16_f32); + assert_eq!(8.0_f64.mul_checked(2.0_f64).unwrap(), 16_f64); + } + + #[test] + fn test_native_type_div() { + // div_wrapping + assert_eq!(8_i8.div_wrapping(2_i8), 4_i8); + assert_eq!(8_i16.div_wrapping(2_i16), 4_i16); + assert_eq!(8_i32.div_wrapping(2_i32), 4_i32); + assert_eq!(8_i64.div_wrapping(2_i64), 4_i64); + assert_eq!(8_i128.div_wrapping(2_i128), 4_i128); + assert_eq!( + i256::from_parts(8, 0).div_wrapping(i256::from_parts(2, 0)), + i256::from_parts(4, 0) + ); + assert_eq!(8_u8.div_wrapping(2_u8), 4_u8); + assert_eq!(8_u16.div_wrapping(2_u16), 4_u16); + assert_eq!(8_u32.div_wrapping(2_u32), 4_u32); + assert_eq!(8_u64.div_wrapping(2_u64), 4_u64); + assert_eq!( + f16::from_f32(8.0).div_wrapping(f16::from_f32(2.0)), + f16::from_f32(4.0) + ); + assert_eq!(8.0_f32.div_wrapping(2.0_f32), 4_f32); + assert_eq!(8.0_f64.div_wrapping(2.0_f64), 4_f64); + + // div_checked + assert_eq!(8_i8.div_checked(2_i8).unwrap(), 4_i8); + assert_eq!(8_i16.div_checked(2_i16).unwrap(), 4_i16); + assert_eq!(8_i32.div_checked(2_i32).unwrap(), 4_i32); + assert_eq!(8_i64.div_checked(2_i64).unwrap(), 4_i64); + assert_eq!(8_i128.div_checked(2_i128).unwrap(), 4_i128); + assert_eq!( + i256::from_parts(8, 0) + .div_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(4, 0) + ); + assert_eq!(8_u8.div_checked(2_u8).unwrap(), 4_u8); + assert_eq!(8_u16.div_checked(2_u16).unwrap(), 4_u16); + assert_eq!(8_u32.div_checked(2_u32).unwrap(), 4_u32); + assert_eq!(8_u64.div_checked(2_u64).unwrap(), 4_u64); + assert_eq!( + f16::from_f32(8.0).div_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(4.0) + ); + assert_eq!(8.0_f32.div_checked(2.0_f32).unwrap(), 4_f32); + assert_eq!(8.0_f64.div_checked(2.0_f64).unwrap(), 4_f64); + } + + #[test] + fn test_native_type_mod() { + // mod_wrapping + assert_eq!(9_i8.mod_wrapping(2_i8), 1_i8); + assert_eq!(9_i16.mod_wrapping(2_i16), 1_i16); + assert_eq!(9_i32.mod_wrapping(2_i32), 1_i32); + assert_eq!(9_i64.mod_wrapping(2_i64), 1_i64); + assert_eq!(9_i128.mod_wrapping(2_i128), 1_i128); + assert_eq!( + i256::from_parts(9, 0).mod_wrapping(i256::from_parts(2, 0)), + i256::from_parts(1, 0) + ); + assert_eq!(9_u8.mod_wrapping(2_u8), 1_u8); + assert_eq!(9_u16.mod_wrapping(2_u16), 1_u16); + assert_eq!(9_u32.mod_wrapping(2_u32), 1_u32); + assert_eq!(9_u64.mod_wrapping(2_u64), 1_u64); + assert_eq!( + f16::from_f32(9.0).mod_wrapping(f16::from_f32(2.0)), + f16::from_f32(1.0) + ); + assert_eq!(9.0_f32.mod_wrapping(2.0_f32), 1_f32); + assert_eq!(9.0_f64.mod_wrapping(2.0_f64), 1_f64); + + // mod_checked + assert_eq!(9_i8.mod_checked(2_i8).unwrap(), 1_i8); + assert_eq!(9_i16.mod_checked(2_i16).unwrap(), 1_i16); + assert_eq!(9_i32.mod_checked(2_i32).unwrap(), 1_i32); + assert_eq!(9_i64.mod_checked(2_i64).unwrap(), 1_i64); + assert_eq!(9_i128.mod_checked(2_i128).unwrap(), 1_i128); + assert_eq!( + i256::from_parts(9, 0) + .mod_checked(i256::from_parts(2, 0)) + .unwrap(), + i256::from_parts(1, 0) + ); + assert_eq!(9_u8.mod_checked(2_u8).unwrap(), 1_u8); + assert_eq!(9_u16.mod_checked(2_u16).unwrap(), 1_u16); + assert_eq!(9_u32.mod_checked(2_u32).unwrap(), 1_u32); + assert_eq!(9_u64.mod_checked(2_u64).unwrap(), 1_u64); + assert_eq!( + f16::from_f32(9.0).mod_checked(f16::from_f32(2.0)).unwrap(), + f16::from_f32(1.0) + ); + assert_eq!(9.0_f32.mod_checked(2.0_f32).unwrap(), 1_f32); + assert_eq!(9.0_f64.mod_checked(2.0_f64).unwrap(), 1_f64); + } + + #[test] + fn test_native_type_neg() { + // neg_wrapping + assert_eq!(8_i8.neg_wrapping(), -8_i8); + assert_eq!(8_i16.neg_wrapping(), -8_i16); + assert_eq!(8_i32.neg_wrapping(), -8_i32); + assert_eq!(8_i64.neg_wrapping(), -8_i64); + assert_eq!(8_i128.neg_wrapping(), -8_i128); + assert_eq!(i256::from_parts(8, 0).neg_wrapping(), i256::from_i128(-8)); + assert_eq!(8_u8.neg_wrapping(), u8::MAX - 7_u8); + assert_eq!(8_u16.neg_wrapping(), u16::MAX - 7_u16); + assert_eq!(8_u32.neg_wrapping(), u32::MAX - 7_u32); + assert_eq!(8_u64.neg_wrapping(), u64::MAX - 7_u64); + assert_eq!(f16::from_f32(8.0).neg_wrapping(), f16::from_f32(-8.0)); + assert_eq!(8.0_f32.neg_wrapping(), -8_f32); + assert_eq!(8.0_f64.neg_wrapping(), -8_f64); + + // neg_checked + assert_eq!(8_i8.neg_checked().unwrap(), -8_i8); + assert_eq!(8_i16.neg_checked().unwrap(), -8_i16); + assert_eq!(8_i32.neg_checked().unwrap(), -8_i32); + assert_eq!(8_i64.neg_checked().unwrap(), -8_i64); + assert_eq!(8_i128.neg_checked().unwrap(), -8_i128); + assert_eq!( + i256::from_parts(8, 0).neg_checked().unwrap(), + i256::from_i128(-8) + ); + assert!(8_u8.neg_checked().is_err()); + assert!(8_u16.neg_checked().is_err()); + assert!(8_u32.neg_checked().is_err()); + assert!(8_u64.neg_checked().is_err()); + assert_eq!( + f16::from_f32(8.0).neg_checked().unwrap(), + f16::from_f32(-8.0) + ); + assert_eq!(8.0_f32.neg_checked().unwrap(), -8_f32); + assert_eq!(8.0_f64.neg_checked().unwrap(), -8_f64); + } + + #[test] + fn test_native_type_pow() { + // pow_wrapping + assert_eq!(8_i8.pow_wrapping(2_u32), 64_i8); + assert_eq!(8_i16.pow_wrapping(2_u32), 64_i16); + assert_eq!(8_i32.pow_wrapping(2_u32), 64_i32); + assert_eq!(8_i64.pow_wrapping(2_u32), 64_i64); + assert_eq!(8_i128.pow_wrapping(2_u32), 64_i128); + assert_eq!( + i256::from_parts(8, 0).pow_wrapping(2_u32), + i256::from_parts(64, 0) + ); + assert_eq!(8_u8.pow_wrapping(2_u32), 64_u8); + assert_eq!(8_u16.pow_wrapping(2_u32), 64_u16); + assert_eq!(8_u32.pow_wrapping(2_u32), 64_u32); + assert_eq!(8_u64.pow_wrapping(2_u32), 64_u64); + assert_eq!(f16::from_f32(8.0).pow_wrapping(2_u32), f16::from_f32(64.0)); + assert_eq!(8.0_f32.pow_wrapping(2_u32), 64_f32); + assert_eq!(8.0_f64.pow_wrapping(2_u32), 64_f64); + + // pow_checked + assert_eq!(8_i8.pow_checked(2_u32).unwrap(), 64_i8); + assert_eq!(8_i16.pow_checked(2_u32).unwrap(), 64_i16); + assert_eq!(8_i32.pow_checked(2_u32).unwrap(), 64_i32); + assert_eq!(8_i64.pow_checked(2_u32).unwrap(), 64_i64); + assert_eq!(8_i128.pow_checked(2_u32).unwrap(), 64_i128); + assert_eq!( + i256::from_parts(8, 0).pow_checked(2_u32).unwrap(), + i256::from_parts(64, 0) + ); + assert_eq!(8_u8.pow_checked(2_u32).unwrap(), 64_u8); + assert_eq!(8_u16.pow_checked(2_u32).unwrap(), 64_u16); + assert_eq!(8_u32.pow_checked(2_u32).unwrap(), 64_u32); + assert_eq!(8_u64.pow_checked(2_u32).unwrap(), 64_u64); + assert_eq!( + f16::from_f32(8.0).pow_checked(2_u32).unwrap(), + f16::from_f32(64.0) + ); + assert_eq!(8.0_f32.pow_checked(2_u32).unwrap(), 64_f32); + assert_eq!(8.0_f64.pow_checked(2_u32).unwrap(), 64_f64); + } + + #[test] + fn test_float_total_order_min_max() { + assert!(::MIN_TOTAL_ORDER.is_lt(f64::NEG_INFINITY)); + assert!(::MAX_TOTAL_ORDER.is_gt(f64::INFINITY)); + + assert!(::MIN_TOTAL_ORDER.is_nan()); + assert!(::MIN_TOTAL_ORDER.is_sign_negative()); + assert!(::MIN_TOTAL_ORDER.is_lt(-f64::NAN)); + + assert!(::MAX_TOTAL_ORDER.is_nan()); + assert!(::MAX_TOTAL_ORDER.is_sign_positive()); + assert!(::MAX_TOTAL_ORDER.is_gt(f64::NAN)); + + assert!(::MIN_TOTAL_ORDER.is_lt(f32::NEG_INFINITY)); + assert!(::MAX_TOTAL_ORDER.is_gt(f32::INFINITY)); + + assert!(::MIN_TOTAL_ORDER.is_nan()); + assert!(::MIN_TOTAL_ORDER.is_sign_negative()); + assert!(::MIN_TOTAL_ORDER.is_lt(-f32::NAN)); + + assert!(::MAX_TOTAL_ORDER.is_nan()); + assert!(::MAX_TOTAL_ORDER.is_sign_positive()); + assert!(::MAX_TOTAL_ORDER.is_gt(f32::NAN)); + + assert!(::MIN_TOTAL_ORDER.is_lt(f16::NEG_INFINITY)); + assert!(::MAX_TOTAL_ORDER.is_gt(f16::INFINITY)); + + assert!(::MIN_TOTAL_ORDER.is_nan()); + assert!(::MIN_TOTAL_ORDER.is_sign_negative()); + assert!(::MIN_TOTAL_ORDER.is_lt(-f16::NAN)); + + assert!(::MAX_TOTAL_ORDER.is_nan()); + assert!(::MAX_TOTAL_ORDER.is_sign_positive()); + assert!(::MAX_TOTAL_ORDER.is_gt(f16::NAN)); + } +} diff --git a/arrow/src/array/array_binary.rs b/arrow-array/src/array/binary_array.rs similarity index 60% rename from arrow/src/array/array_binary.rs rename to arrow-array/src/array/binary_array.rs index 1c63e8e24b29..6b18cbc2d9f7 100644 --- a/arrow/src/array/array_binary.rs +++ b/arrow-array/src/array/binary_array.rs @@ -15,120 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::convert::From; -use std::fmt; -use std::{any::Any, iter::FromIterator}; - -use super::{ - array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, GenericBinaryIter, - GenericListArray, OffsetSizeTrait, -}; -use crate::array::array::ArrayAccessor; -use crate::buffer::Buffer; -use crate::util::bit_util; -use crate::{buffer::MutableBuffer, datatypes::DataType}; - -/// See [`BinaryArray`] and [`LargeBinaryArray`] for storing -/// binary data. -pub struct GenericBinaryArray { - data: ArrayData, - value_offsets: RawPtrBox, - value_data: RawPtrBox, -} +use crate::types::{ByteArrayType, GenericBinaryType}; +use crate::{Array, GenericByteArray, GenericListArray, GenericStringArray, OffsetSizeTrait}; +use arrow_data::ArrayData; +use arrow_schema::DataType; -impl GenericBinaryArray { - /// Data type of the array. - pub const DATA_TYPE: DataType = if OffsetSize::IS_LARGE { - DataType::LargeBinary - } else { - DataType::Binary - }; +/// A [`GenericBinaryArray`] for storing `[u8]` +pub type GenericBinaryArray = GenericByteArray>; +impl GenericBinaryArray { /// Get the data type of the array. #[deprecated(note = "please use `Self::DATA_TYPE` instead")] pub const fn get_data_type() -> DataType { Self::DATA_TYPE } - /// Returns the length for value at index `i`. - #[inline] - pub fn value_length(&self, i: usize) -> OffsetSize { - let offsets = self.value_offsets(); - offsets[i + 1] - offsets[i] - } - - /// Returns a clone of the value data buffer - pub fn value_data(&self) -> Buffer { - self.data.buffers()[1].clone() - } - - /// Returns the offset values in the offsets buffer - #[inline] - pub fn value_offsets(&self) -> &[OffsetSize] { - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the ArrayData instance. - unsafe { - std::slice::from_raw_parts( - self.value_offsets.as_ptr().add(self.data.offset()), - self.len() + 1, - ) - } - } - - /// Returns the element at index `i` as bytes slice - /// # Safety - /// Caller is responsible for ensuring that the index is within the bounds of the array - pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { - let end = *self.value_offsets().get_unchecked(i + 1); - let start = *self.value_offsets().get_unchecked(i); - - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the value_offset invariants - - // Safety of `to_isize().unwrap()` - // `start` and `end` are &OffsetSize, which is a generic type that implements the - // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, - // both of which should cleanly cast to isize on an architecture that supports - // 32/64-bit offsets - std::slice::from_raw_parts( - self.value_data.as_ptr().offset(start.to_isize().unwrap()), - (end - start).to_usize().unwrap(), - ) - } - - /// Returns the element at index `i` as bytes slice - /// # Panics - /// Panics if index `i` is out of bounds. - pub fn value(&self, i: usize) -> &[u8] { - assert!( - i < self.data.len(), - "Trying to access an element at index {} from a BinaryArray of length {}", - i, - self.len() - ); - //Soundness: length checked above, offset buffer length is 1 larger than logical array length - let end = unsafe { self.value_offsets().get_unchecked(i + 1) }; - let start = unsafe { self.value_offsets().get_unchecked(i) }; - - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the value_offset invariants - - // Safety of `to_isize().unwrap()` - // `start` and `end` are &OffsetSize, which is a generic type that implements the - // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, - // both of which should cleanly cast to isize on an architecture that supports - // 32/64-bit offsets - unsafe { - std::slice::from_raw_parts( - self.value_data.as_ptr().offset(start.to_isize().unwrap()), - (*end - *start).to_usize().unwrap(), - ) - } - } - /// Creates a [GenericBinaryArray] from a vector of byte slices /// /// See also [`Self::from_iter_values`] @@ -142,13 +43,14 @@ impl GenericBinaryArray { } fn from_list(v: GenericListArray) -> Self { + let v = v.into_data(); assert_eq!( - v.data_ref().child_data().len(), + v.child_data().len(), 1, "BinaryArray can only be created from list array of u8 values \ (i.e. List>)." ); - let child_data = &v.data_ref().child_data()[0]; + let child_data = &v.child_data()[0]; assert_eq!( child_data.child_data().len(), @@ -170,50 +72,14 @@ impl GenericBinaryArray { let builder = ArrayData::builder(Self::DATA_TYPE) .len(v.len()) .offset(v.offset()) - .add_buffer(v.data_ref().buffers()[0].clone()) + .add_buffer(v.buffers()[0].clone()) .add_buffer(child_data.buffers()[0].slice(child_data.offset())) - .null_bit_buffer(v.data_ref().null_buffer().cloned()); + .nulls(v.nulls().cloned()); let data = unsafe { builder.build_unchecked() }; Self::from(data) } - /// Creates a [`GenericBinaryArray`] based on an iterator of values without nulls - pub fn from_iter_values(iter: I) -> Self - where - Ptr: AsRef<[u8]>, - I: IntoIterator, - { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let mut offsets = - MutableBuffer::new((data_len + 1) * std::mem::size_of::()); - let mut values = MutableBuffer::new(0); - - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - for s in iter { - let s = s.as_ref(); - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - offsets.push(length_so_far); - values.extend_from_slice(s); - } - - // iterator size hint may not be correct so compute the actual number of offsets - assert!(!offsets.is_empty()); // wrote at least one - let actual_len = (offsets.len() / std::mem::size_of::()) - 1; - - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(actual_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) - } - /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` pub fn take_iter<'a>( &'a self, @@ -232,84 +98,9 @@ impl GenericBinaryArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) } - - /// constructs a new iterator - pub fn iter(&self) -> GenericBinaryIter<'_, OffsetSize> { - GenericBinaryIter::<'_, OffsetSize>::new(self) - } -} - -impl fmt::Debug for GenericBinaryArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = OffsetSize::PREFIX; - - write!(f, "{}BinaryArray\n[\n", prefix)?; - print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) - })?; - write!(f, "]") - } -} - -impl Array for GenericBinaryArray { - fn as_any(&self) -> &dyn Any { - self - } - - fn data(&self) -> &ArrayData { - &self.data - } - - fn into_data(self) -> ArrayData { - self.into() - } } -impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor - for &'a GenericBinaryArray -{ - type Item = &'a [u8]; - - fn value(&self, index: usize) -> Self::Item { - GenericBinaryArray::value(self, index) - } - - unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - GenericBinaryArray::value_unchecked(self, index) - } -} - -impl From for GenericBinaryArray { - fn from(data: ArrayData) -> Self { - assert_eq!( - data.data_type(), - &Self::DATA_TYPE, - "[Large]BinaryArray expects Datatype::[Large]Binary" - ); - assert_eq!( - data.buffers().len(), - 2, - "BinaryArray data should contain 2 buffers only (offsets and values)" - ); - let offsets = data.buffers()[0].as_ptr(); - let values = data.buffers()[1].as_ptr(); - Self { - data, - value_offsets: unsafe { RawPtrBox::new(offsets) }, - value_data: unsafe { RawPtrBox::new(values) }, - } - } -} - -impl From> for ArrayData { - fn from(array: GenericBinaryArray) -> Self { - array.data - } -} - -impl From>> - for GenericBinaryArray -{ +impl From>> for GenericBinaryArray { fn from(v: Vec>) -> Self { Self::from_opt_vec(v) } @@ -327,59 +118,23 @@ impl From> for GenericBinaryArray { } } -impl FromIterator> +impl From> for GenericBinaryArray -where - Ptr: AsRef<[u8]>, { - fn from_iter>>(iter: I) -> Self { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let mut offsets = Vec::with_capacity(data_len + 1); - let mut values = Vec::new(); - let mut null_buf = MutableBuffer::new_null(data_len); - let mut length_so_far: OffsetSize = OffsetSize::zero(); - offsets.push(length_so_far); - - { - let null_slice = null_buf.as_slice_mut(); - - for (i, s) in iter.enumerate() { - if let Some(s) = s { - let s = s.as_ref(); - bit_util::set_bit(null_slice, i); - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - values.extend_from_slice(s); - } - // always add an element in offsets - offsets.push(length_so_far); - } - } + fn from(value: GenericStringArray) -> Self { + let builder = value + .into_data() + .into_builder() + .data_type(GenericBinaryType::::DATA_TYPE); - // calculate actual data_len, which may be different from the iterator's upper bound - let data_len = offsets.len() - 1; - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(data_len) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) - .null_bit_buffer(Some(null_buf.into())); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) + // Safety: + // A StringArray is a valid BinaryArray + Self::from(unsafe { builder.build_unchecked() }) } } -impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { - type Item = Option<&'a [u8]>; - type IntoIter = GenericBinaryIter<'a, T>; - - fn into_iter(self) -> Self::IntoIter { - GenericBinaryIter::<'a, T>::new(self) - } -} - -/// An array where each element contains 0 or more bytes. +/// A [`GenericBinaryArray`] of `[u8]` using `i32` offsets +/// /// The byte length of each element is represented by an i32. /// /// # Examples @@ -387,7 +142,7 @@ impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { /// Create a BinaryArray from a vector of byte slices. /// /// ``` -/// use arrow::array::{Array, BinaryArray}; +/// use arrow_array::{Array, BinaryArray}; /// let values: Vec<&[u8]> = /// vec![b"one", b"two", b"", b"three"]; /// let array = BinaryArray::from_vec(values); @@ -401,7 +156,7 @@ impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { /// Create a BinaryArray from a vector of Optional (null) byte slices. /// /// ``` -/// use arrow::array::{Array, BinaryArray}; +/// use arrow_array::{Array, BinaryArray}; /// let values: Vec> = /// vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")]; /// let array = BinaryArray::from_opt_vec(values); @@ -417,17 +172,17 @@ impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { /// assert!(!array.is_null(4)); /// ``` /// +/// See [`GenericByteArray`] for more information and examples pub type BinaryArray = GenericBinaryArray; -/// An array where each element contains 0 or more bytes. -/// The byte length of each element is represented by an i64. +/// A [`GenericBinaryArray`] of `[u8]` using `i64` offsets /// /// # Examples /// /// Create a LargeBinaryArray from a vector of byte slices. /// /// ``` -/// use arrow::array::{Array, LargeBinaryArray}; +/// use arrow_array::{Array, LargeBinaryArray}; /// let values: Vec<&[u8]> = /// vec![b"one", b"two", b"", b"three"]; /// let array = LargeBinaryArray::from_vec(values); @@ -441,7 +196,7 @@ pub type BinaryArray = GenericBinaryArray; /// Create a LargeBinaryArray from a vector of Optional (null) byte slices. /// /// ``` -/// use arrow::array::{Array, LargeBinaryArray}; +/// use arrow_array::{Array, LargeBinaryArray}; /// let values: Vec> = /// vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")]; /// let array = LargeBinaryArray::from_opt_vec(values); @@ -457,12 +212,16 @@ pub type BinaryArray = GenericBinaryArray; /// assert!(!array.is_null(4)); /// ``` /// +/// See [`GenericByteArray`] for more information and examples pub type LargeBinaryArray = GenericBinaryArray; #[cfg(test)] mod tests { use super::*; - use crate::{array::ListArray, datatypes::Field}; + use crate::{ListArray, StringArray}; + use arrow_buffer::Buffer; + use arrow_schema::Field; + use std::sync::Arc; #[test] fn test_binary_array() { @@ -474,8 +233,8 @@ mod tests { // Array data: ["hello", "", "parquet"] let array_data = ArrayData::builder(DataType::Binary) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = BinaryArray::from(array_data); @@ -513,8 +272,8 @@ mod tests { let array_data = ArrayData::builder(DataType::Binary) .len(2) .offset(1) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = BinaryArray::from(array_data); @@ -538,8 +297,8 @@ mod tests { // Array data: ["hello", "", "parquet"] let array_data = ArrayData::builder(DataType::LargeBinary) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = LargeBinaryArray::from(array_data); @@ -577,8 +336,8 @@ mod tests { let array_data = ArrayData::builder(DataType::LargeBinary) .len(2) .offset(1) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = LargeBinaryArray::from(array_data); @@ -607,28 +366,27 @@ mod tests { // Array data: ["hello", "", "parquet"] let array_data1 = ArrayData::builder(GenericBinaryArray::::DATA_TYPE) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array1 = GenericBinaryArray::::from(array_data1); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); let array_data2 = ArrayData::builder(data_type) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(child_data) .build() .unwrap(); let list_array = GenericListArray::::from(array_data2); let binary_array2 = GenericBinaryArray::::from(list_array); - assert_eq!(2, binary_array2.data().buffers().len()); - assert_eq!(0, binary_array2.data().child_data().len()); - assert_eq!(binary_array1.len(), binary_array2.len()); assert_eq!(binary_array1.null_count(), binary_array2.null_count()); assert_eq!(binary_array1.value_offsets(), binary_array2.value_offsets()); @@ -662,16 +420,18 @@ mod tests { .unwrap(); let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); - let null_buffer = Buffer::from_slice_ref(&[0b101]); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + let null_buffer = Buffer::from_slice_ref([0b101]); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) .len(2) .offset(1) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .null_bit_buffer(Some(null_buffer)) .add_child_data(child_data) .build() @@ -696,26 +456,26 @@ mod tests { _test_generic_binary_array_from_list_array_with_offset::(); } - fn _test_generic_binary_array_from_list_array_with_child_nulls_failed< - O: OffsetSizeTrait, - >() { + fn _test_generic_binary_array_from_list_array_with_child_nulls_failed() { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt8) .len(10) .add_buffer(Buffer::from(&values[..])) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b1010101010]))) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b1010101010]))) .build() .unwrap(); let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + true, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) .len(2) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(child_data) .build() .unwrap(); @@ -768,7 +528,7 @@ mod tests { .scan(0usize, |pos, i| { if *pos < 10 { *pos += 1; - Some(Some(format!("value {}", i))) + Some(Some(format!("value {i}"))) } else { // actually returns up to 10 values None @@ -787,24 +547,21 @@ mod tests { #[test] #[should_panic( - expected = "assertion failed: `(left == right)`\n left: `UInt32`,\n \ - right: `UInt8`: BinaryArray can only be created from List arrays, \ - mismatched data types." + expected = "BinaryArray can only be created from List arrays, mismatched data types." )] fn test_binary_array_from_incorrect_list_array() { let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt32) .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let offsets: [i32; 4] = [0, 5, 5, 12]; - let data_type = - DataType::List(Box::new(Field::new("item", DataType::UInt32, false))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::UInt32, false))); let array_data = ArrayData::builder(data_type) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(values_data) .build() .unwrap(); @@ -817,25 +574,31 @@ mod tests { expected = "Trying to access an element at index 4 from a BinaryArray of length 3" )] fn test_binary_array_get_value_index_out_of_bound() { - let values: [u8; 12] = - [104, 101, 108, 108, 111, 112, 97, 114, 113, 117, 101, 116]; + let values: [u8; 12] = [104, 101, 108, 108, 111, 112, 97, 114, 113, 117, 101, 116]; let offsets: [i32; 4] = [0, 5, 5, 12]; let array_data = ArrayData::builder(DataType::Binary) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let binary_array = BinaryArray::from(array_data); binary_array.value(4); } + #[test] + #[should_panic(expected = "LargeBinaryArray expects DataType::LargeBinary")] + fn test_binary_array_validation() { + let array = BinaryArray::from_iter_values([&[1, 2]]); + let _ = LargeBinaryArray::from(array.into_data()); + } + #[test] fn test_binary_array_all_null() { let data = vec![None]; let array = BinaryArray::from(data); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } @@ -845,8 +608,36 @@ mod tests { let data = vec![None]; let array = LargeBinaryArray::from(data); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } + + #[test] + fn test_empty_offsets() { + let string = BinaryArray::from( + ArrayData::builder(DataType::Binary) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.value_offsets(), &[0]); + let string = LargeBinaryArray::from( + ArrayData::builder(DataType::LargeBinary) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + } + + #[test] + fn test_to_from_string() { + let s = StringArray::from_iter_values(["a", "b", "c", "d"]); + let b = BinaryArray::from(s.clone()); + let sa = StringArray::from(b); // Performs UTF-8 validation again + + assert_eq!(s, sa); + } } diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs new file mode 100644 index 000000000000..fe374d965714 --- /dev/null +++ b/arrow-array/src/array/boolean_array.rs @@ -0,0 +1,643 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::print_long_array; +use crate::builder::BooleanBuilder; +use crate::iterator::BooleanIter; +use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; +use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::DataType; +use std::any::Any; +use std::sync::Arc; + +/// An array of [boolean values](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout) +/// +/// # Example: From a Vec +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// let arr: BooleanArray = vec![true, true, false].into(); +/// ``` +/// +/// # Example: From an optional Vec +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// let arr: BooleanArray = vec![Some(true), None, Some(false)].into(); +/// ``` +/// +/// # Example: From an iterator +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// let arr: BooleanArray = (0..5).map(|x| (x % 2 == 0).then(|| x % 3 == 0)).collect(); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(&values, &[Some(true), None, Some(false), None, Some(false)]) +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::Array; +/// # use arrow_array::builder::BooleanBuilder; +/// let mut builder = BooleanBuilder::new(); +/// builder.append_value(true); +/// builder.append_null(); +/// builder.append_value(false); +/// let array = builder.finish(); +/// let values: Vec<_> = array.iter().collect(); +/// assert_eq!(&values, &[Some(true), None, Some(false)]) +/// ``` +/// +#[derive(Clone)] +pub struct BooleanArray { + values: BooleanBuffer, + nulls: Option, +} + +impl std::fmt::Debug for BooleanArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "BooleanArray\n[\n")?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl BooleanArray { + /// Create a new [`BooleanArray`] from the provided values and nulls + /// + /// # Panics + /// + /// Panics if `values.len() != nulls.len()` + pub fn new(values: BooleanBuffer, nulls: Option) -> Self { + if let Some(n) = nulls.as_ref() { + assert_eq!(values.len(), n.len()); + } + Self { values, nulls } + } + + /// Create a new [`BooleanArray`] with length `len` consisting only of nulls + pub fn new_null(len: usize) -> Self { + Self { + values: BooleanBuffer::new_unset(len), + nulls: Some(NullBuffer::new_null(len)), + } + } + + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: bool) -> Scalar { + let values = match value { + true => BooleanBuffer::new_set(1), + false => BooleanBuffer::new_unset(1), + }; + Scalar::new(Self::new(values, None)) + } + + /// Returns the length of this array. + pub fn len(&self) -> usize { + self.values.len() + } + + /// Returns whether this array is empty. + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + values: self.values.slice(offset, length), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + } + } + + /// Returns a new boolean array builder + pub fn builder(capacity: usize) -> BooleanBuilder { + BooleanBuilder::with_capacity(capacity) + } + + /// Returns the underlying [`BooleanBuffer`] holding all the values of this array + pub fn values(&self) -> &BooleanBuffer { + &self.values + } + + /// Returns the number of non null, true values within this array + pub fn true_count(&self) -> usize { + match self.nulls() { + Some(nulls) => { + let null_chunks = nulls.inner().bit_chunks().iter_padded(); + let value_chunks = self.values().bit_chunks().iter_padded(); + null_chunks + .zip(value_chunks) + .map(|(a, b)| (a & b).count_ones() as usize) + .sum() + } + None => self.values().count_set_bits(), + } + } + + /// Returns the number of non null, false values within this array + pub fn false_count(&self) -> usize { + self.len() - self.null_count() - self.true_count() + } + + /// Returns the boolean value at index `i`. + /// + /// # Safety + /// This doesn't check bounds, the caller must ensure that index < self.len() + pub unsafe fn value_unchecked(&self, i: usize) -> bool { + self.values.value_unchecked(i) + } + + /// Returns the boolean value at index `i`. + /// # Panics + /// Panics if index `i` is out of bounds + pub fn value(&self, i: usize) -> bool { + assert!( + i < self.len(), + "Trying to access an element at index {} from a BooleanArray of length {}", + i, + self.len() + ); + // Safety: + // `i < self.len() + unsafe { self.value_unchecked(i) } + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } + + /// Create a [`BooleanArray`] by evaluating the operation for + /// each element of the provided array + /// + /// ``` + /// # use arrow_array::{BooleanArray, Int32Array}; + /// + /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let r = BooleanArray::from_unary(&array, |x| x > 2); + /// assert_eq!(&r, &BooleanArray::from(vec![false, false, true, true, true])); + /// ``` + pub fn from_unary(left: T, mut op: F) -> Self + where + F: FnMut(T::Item) -> bool, + { + let nulls = left.logical_nulls(); + let values = BooleanBuffer::collect_bool(left.len(), |i| unsafe { + // SAFETY: i in range 0..len + op(left.value_unchecked(i)) + }); + Self::new(values, nulls) + } + + /// Create a [`BooleanArray`] by evaluating the binary operation for + /// each element of the provided arrays + /// + /// ``` + /// # use arrow_array::{BooleanArray, Int32Array}; + /// + /// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let b = Int32Array::from(vec![1, 2, 0, 2, 5]); + /// let r = BooleanArray::from_binary(&a, &b, |a, b| a == b); + /// assert_eq!(&r, &BooleanArray::from(vec![true, true, false, false, true])); + /// ``` + /// + /// # Panics + /// + /// This function panics if left and right are not the same length + /// + pub fn from_binary(left: T, right: S, mut op: F) -> Self + where + F: FnMut(T::Item, S::Item) -> bool, + { + assert_eq!(left.len(), right.len()); + + let nulls = NullBuffer::union( + left.logical_nulls().as_ref(), + right.logical_nulls().as_ref(), + ); + let values = BooleanBuffer::collect_bool(left.len(), |i| unsafe { + // SAFETY: i in range 0..len + op(left.value_unchecked(i), right.value_unchecked(i)) + }); + Self::new(values, nulls) + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (BooleanBuffer, Option) { + (self.values, self.nulls) + } +} + +impl Array for BooleanArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &DataType::Boolean + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn offset(&self) -> usize { + self.values.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.values.inner().capacity(); + if let Some(x) = &self.nulls { + sum += x.buffer().capacity() + } + sum + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } +} + +impl<'a> ArrayAccessor for &'a BooleanArray { + type Item = bool; + + fn value(&self, index: usize) -> Self::Item { + BooleanArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + BooleanArray::value_unchecked(self, index) + } +} + +impl From> for BooleanArray { + fn from(data: Vec) -> Self { + let mut mut_buf = MutableBuffer::new_null(data.len()); + { + let mut_slice = mut_buf.as_slice_mut(); + for (i, b) in data.iter().enumerate() { + if *b { + bit_util::set_bit(mut_slice, i); + } + } + } + let array_data = ArrayData::builder(DataType::Boolean) + .len(data.len()) + .add_buffer(mut_buf.into()); + + let array_data = unsafe { array_data.build_unchecked() }; + BooleanArray::from(array_data) + } +} + +impl From>> for BooleanArray { + fn from(data: Vec>) -> Self { + data.iter().collect() + } +} + +impl From for BooleanArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.data_type(), + &DataType::Boolean, + "BooleanArray expected ArrayData with type {} got {}", + DataType::Boolean, + data.data_type() + ); + assert_eq!( + data.buffers().len(), + 1, + "BooleanArray data should contain a single buffer only (values buffer)" + ); + let values = BooleanBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); + + Self { + values, + nulls: data.nulls().cloned(), + } + } +} + +impl From for ArrayData { + fn from(array: BooleanArray) -> Self { + let builder = ArrayDataBuilder::new(DataType::Boolean) + .len(array.values.len()) + .offset(array.values.offset()) + .nulls(array.nulls) + .buffers(vec![array.values.into_inner()]); + + unsafe { builder.build_unchecked() } + } +} + +impl<'a> IntoIterator for &'a BooleanArray { + type Item = Option; + type IntoIter = BooleanIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BooleanIter::<'a>::new(self) + } +} + +impl<'a> BooleanArray { + /// constructs a new iterator + pub fn iter(&'a self) -> BooleanIter<'a> { + BooleanIter::<'a>::new(self) + } +} + +impl>> FromIterator for BooleanArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (_, data_len) = iter.size_hint(); + let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. + + let num_bytes = bit_util::ceil(data_len, 8); + let mut null_builder = MutableBuffer::from_len_zeroed(num_bytes); + let mut val_builder = MutableBuffer::from_len_zeroed(num_bytes); + + let data = val_builder.as_slice_mut(); + + let null_slice = null_builder.as_slice_mut(); + iter.enumerate().for_each(|(i, item)| { + if let Some(a) = item.borrow() { + bit_util::set_bit(null_slice, i); + if *a { + bit_util::set_bit(data, i); + } + } + }); + + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + data_len, + None, + Some(null_builder.into()), + 0, + vec![val_builder.into()], + vec![], + ) + }; + BooleanArray::from(data) + } +} + +impl From for BooleanArray { + fn from(values: BooleanBuffer) -> Self { + Self { + values, + nulls: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::Buffer; + use rand::{thread_rng, Rng}; + + #[test] + fn test_boolean_fmt_debug() { + let arr = BooleanArray::from(vec![true, false, false]); + assert_eq!( + "BooleanArray\n[\n true,\n false,\n false,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_boolean_with_null_fmt_debug() { + let mut builder = BooleanArray::builder(3); + builder.append_value(true); + builder.append_null(); + builder.append_value(false); + let arr = builder.finish(); + assert_eq!( + "BooleanArray\n[\n true,\n null,\n false,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_boolean_array_from_vec() { + let buf = Buffer::from([10_u8]); + let arr = BooleanArray::from(vec![false, true, false, true]); + assert_eq!(&buf, arr.values().inner()); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + for i in 0..4 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i == 1 || i == 3, arr.value(i), "failed at {i}") + } + } + + #[test] + fn test_boolean_array_from_vec_option() { + let buf = Buffer::from([10_u8]); + let arr = BooleanArray::from(vec![Some(false), Some(true), None, Some(true)]); + assert_eq!(&buf, arr.values().inner()); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + for i in 0..4 { + if i == 2 { + assert!(arr.is_null(i)); + assert!(!arr.is_valid(i)); + } else { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i == 1 || i == 3, arr.value(i), "failed at {i}") + } + } + } + + #[test] + fn test_boolean_array_from_iter() { + let v = vec![Some(false), Some(true), Some(false), Some(true)]; + let arr = v.into_iter().collect::(); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert!(arr.nulls().is_none()); + for i in 0..3 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i == 1 || i == 3, arr.value(i), "failed at {i}") + } + } + + #[test] + fn test_boolean_array_from_nullable_iter() { + let v = vec![Some(true), None, Some(false), None]; + let arr = v.into_iter().collect::(); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(2, arr.null_count()); + assert!(arr.nulls().is_some()); + + assert!(arr.is_valid(0)); + assert!(arr.is_null(1)); + assert!(arr.is_valid(2)); + assert!(arr.is_null(3)); + + assert!(arr.value(0)); + assert!(!arr.value(2)); + } + + #[test] + fn test_boolean_array_builder() { + // Test building a boolean array with ArrayData builder and offset + // 000011011 + let buf = Buffer::from([27_u8]); + let buf2 = buf.clone(); + let data = ArrayData::builder(DataType::Boolean) + .len(5) + .offset(2) + .add_buffer(buf) + .build() + .unwrap(); + let arr = BooleanArray::from(data); + assert_eq!(&buf2, arr.values().inner()); + assert_eq!(5, arr.len()); + assert_eq!(2, arr.offset()); + assert_eq!(0, arr.null_count()); + for i in 0..3 { + assert_eq!(i != 0, arr.value(i), "failed at {i}"); + } + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a BooleanArray of length 3" + )] + fn test_fixed_size_binary_array_get_value_index_out_of_bound() { + let v = vec![Some(true), None, Some(false)]; + let array = v.into_iter().collect::(); + + array.value(4); + } + + #[test] + #[should_panic(expected = "BooleanArray data should contain a single buffer only \ + (values buffer)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_boolean_array_invalid_buffer_len() { + let data = unsafe { + ArrayData::builder(DataType::Boolean) + .len(5) + .build_unchecked() + }; + drop(BooleanArray::from(data)); + } + + #[test] + #[should_panic(expected = "BooleanArray expected ArrayData with type Boolean got Int32")] + fn test_from_array_data_validation() { + let _ = BooleanArray::from(ArrayData::new_empty(&DataType::Int32)); + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_true_false_count() { + let mut rng = thread_rng(); + + for _ in 0..10 { + // No nulls + let d: Vec<_> = (0..2000).map(|_| rng.gen_bool(0.5)).collect(); + let b = BooleanArray::from(d.clone()); + + let expected_true = d.iter().filter(|x| **x).count(); + assert_eq!(b.true_count(), expected_true); + assert_eq!(b.false_count(), d.len() - expected_true); + + // With nulls + let d: Vec<_> = (0..2000) + .map(|_| rng.gen_bool(0.5).then(|| rng.gen_bool(0.5))) + .collect(); + let b = BooleanArray::from(d.clone()); + + let expected_true = d.iter().filter(|x| matches!(x, Some(true))).count(); + assert_eq!(b.true_count(), expected_true); + + let expected_false = d.iter().filter(|x| matches!(x, Some(false))).count(); + assert_eq!(b.false_count(), expected_false); + } + } + + #[test] + fn test_into_parts() { + let boolean_array = [Some(true), None, Some(false)] + .into_iter() + .collect::(); + let (values, nulls) = boolean_array.into_parts(); + assert_eq!(values.values(), &[0b0000_0001]); + assert!(nulls.is_some()); + assert_eq!(nulls.unwrap().buffer().as_slice(), &[0b0000_0101]); + + let boolean_array = + BooleanArray::from(vec![false, false, false, false, false, false, false, true]); + let (values, nulls) = boolean_array.into_parts(); + assert_eq!(values.values(), &[0b1000_0000]); + assert!(nulls.is_none()); + } +} diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs new file mode 100644 index 000000000000..db825bbea97d --- /dev/null +++ b/arrow-array/src/array/byte_array.rs @@ -0,0 +1,617 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::{get_offsets, print_long_array}; +use crate::builder::GenericByteBuilder; +use crate::iterator::ArrayIter; +use crate::types::bytes::ByteArrayNativeType; +use crate::types::ByteArrayType; +use crate::{Array, ArrayAccessor, ArrayRef, OffsetSizeTrait, Scalar}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{NullBuffer, OffsetBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType}; +use std::any::Any; +use std::sync::Arc; + +/// An array of [variable length byte arrays](https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout) +/// +/// See [`StringArray`] and [`LargeStringArray`] for storing utf8 encoded string data +/// +/// See [`BinaryArray`] and [`LargeBinaryArray`] for storing arbitrary bytes +/// +/// # Example: From a Vec +/// +/// ``` +/// # use arrow_array::{Array, GenericByteArray, types::Utf8Type}; +/// let arr: GenericByteArray = vec!["hello", "world", ""].into(); +/// assert_eq!(arr.value_data(), b"helloworld"); +/// assert_eq!(arr.value_offsets(), &[0, 5, 10, 10]); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(values, &[Some("hello"), Some("world"), Some("")]); +/// ``` +/// +/// # Example: From an optional Vec +/// +/// ``` +/// # use arrow_array::{Array, GenericByteArray, types::Utf8Type}; +/// let arr: GenericByteArray = vec![Some("hello"), Some("world"), Some(""), None].into(); +/// assert_eq!(arr.value_data(), b"helloworld"); +/// assert_eq!(arr.value_offsets(), &[0, 5, 10, 10, 10]); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(values, &[Some("hello"), Some("world"), Some(""), None]); +/// ``` +/// +/// # Example: From an iterator of option +/// +/// ``` +/// # use arrow_array::{Array, GenericByteArray, types::Utf8Type}; +/// let arr: GenericByteArray = (0..5).map(|x| (x % 2 == 0).then(|| x.to_string())).collect(); +/// let values: Vec<_> = arr.iter().collect(); +/// assert_eq!(values, &[Some("0"), None, Some("2"), None, Some("4")]); +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::Array; +/// # use arrow_array::builder::GenericByteBuilder; +/// # use arrow_array::types::Utf8Type; +/// let mut builder = GenericByteBuilder::::new(); +/// builder.append_value("hello"); +/// builder.append_null(); +/// builder.append_value("world"); +/// let array = builder.finish(); +/// let values: Vec<_> = array.iter().collect(); +/// assert_eq!(values, &[Some("hello"), None, Some("world")]); +/// ``` +/// +/// [`StringArray`]: crate::StringArray +/// [`LargeStringArray`]: crate::LargeStringArray +/// [`BinaryArray`]: crate::BinaryArray +/// [`LargeBinaryArray`]: crate::LargeBinaryArray +pub struct GenericByteArray { + data_type: DataType, + value_offsets: OffsetBuffer, + value_data: Buffer, + nulls: Option, +} + +impl Clone for GenericByteArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + value_offsets: self.value_offsets.clone(), + value_data: self.value_data.clone(), + nulls: self.nulls.clone(), + } + } +} + +impl GenericByteArray { + /// Data type of the array. + pub const DATA_TYPE: DataType = T::DATA_TYPE; + + /// Create a new [`GenericByteArray`] from the provided parts, panicking on failure + /// + /// # Panics + /// + /// Panics if [`GenericByteArray::try_new`] returns an error + pub fn new( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Self { + Self::try_new(offsets, values, nulls).unwrap() + } + + /// Create a new [`GenericByteArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `offsets.len() - 1 != nulls.len()` + /// * Any consecutive pair of `offsets` does not denote a valid slice of `values` + pub fn try_new( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Result { + let len = offsets.len() - 1; + + // Verify that each pair of offsets is a valid slices of values + T::validate(&offsets, &values)?; + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for {}{}Array, expected {len} got {}", + T::Offset::PREFIX, + T::PREFIX, + n.len(), + ))); + } + } + + Ok(Self { + data_type: T::DATA_TYPE, + value_offsets: offsets, + value_data: values, + nulls, + }) + } + + /// Create a new [`GenericByteArray`] from the provided parts, without validation + /// + /// # Safety + /// + /// Safe if [`Self::try_new`] would not error + pub unsafe fn new_unchecked( + offsets: OffsetBuffer, + values: Buffer, + nulls: Option, + ) -> Self { + Self { + data_type: T::DATA_TYPE, + value_offsets: offsets, + value_data: values, + nulls, + } + } + + /// Create a new [`GenericByteArray`] of length `len` where all values are null + pub fn new_null(len: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + value_offsets: OffsetBuffer::new_zeroed(len), + value_data: MutableBuffer::new(0).into(), + nulls: Some(NullBuffer::new_null(len)), + } + } + + /// Create a new [`Scalar`] from `v` + pub fn new_scalar(value: impl AsRef) -> Scalar { + Scalar::new(Self::from_iter_values(std::iter::once(value))) + } + + /// Creates a [`GenericByteArray`] based on an iterator of values without nulls + pub fn from_iter_values(iter: I) -> Self + where + Ptr: AsRef, + I: IntoIterator, + { + let iter = iter.into_iter(); + let (_, data_len) = iter.size_hint(); + let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. + + let mut offsets = MutableBuffer::new((data_len + 1) * std::mem::size_of::()); + offsets.push(T::Offset::usize_as(0)); + + let mut values = MutableBuffer::new(0); + for s in iter { + let s: &[u8] = s.as_ref().as_ref(); + values.extend_from_slice(s); + offsets.push(T::Offset::usize_as(values.len())); + } + + T::Offset::from_usize(values.len()).expect("offset overflow"); + let offsets = Buffer::from(offsets); + + // Safety: valid by construction + let value_offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Self { + data_type: T::DATA_TYPE, + value_data: values.into(), + value_offsets, + nulls: None, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (OffsetBuffer, Buffer, Option) { + (self.value_offsets, self.value_data, self.nulls) + } + + /// Returns the length for value at index `i`. + /// # Panics + /// Panics if index `i` is out of bounds. + #[inline] + pub fn value_length(&self, i: usize) -> T::Offset { + let offsets = self.value_offsets(); + offsets[i + 1] - offsets[i] + } + + /// Returns a reference to the offsets of this array + /// + /// Unlike [`Self::value_offsets`] this returns the [`OffsetBuffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn offsets(&self) -> &OffsetBuffer { + &self.value_offsets + } + + /// Returns the values of this array + /// + /// Unlike [`Self::value_data`] this returns the [`Buffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn values(&self) -> &Buffer { + &self.value_data + } + + /// Returns the raw value data + pub fn value_data(&self) -> &[u8] { + self.value_data.as_slice() + } + + /// Returns true if all data within this array is ASCII + pub fn is_ascii(&self) -> bool { + let offsets = self.value_offsets(); + let start = offsets.first().unwrap(); + let end = offsets.last().unwrap(); + self.value_data()[start.as_usize()..end.as_usize()].is_ascii() + } + + /// Returns the offset values in the offsets buffer + #[inline] + pub fn value_offsets(&self) -> &[T::Offset] { + &self.value_offsets + } + + /// Returns the element at index `i` + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + pub unsafe fn value_unchecked(&self, i: usize) -> &T::Native { + let end = *self.value_offsets().get_unchecked(i + 1); + let start = *self.value_offsets().get_unchecked(i); + + // Soundness + // pointer alignment & location is ensured by RawPtrBox + // buffer bounds/offset is ensured by the value_offset invariants + + // Safety of `to_isize().unwrap()` + // `start` and `end` are &OffsetSize, which is a generic type that implements the + // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, + // both of which should cleanly cast to isize on an architecture that supports + // 32/64-bit offsets + let b = std::slice::from_raw_parts( + self.value_data.as_ptr().offset(start.to_isize().unwrap()), + (end - start).to_usize().unwrap(), + ); + + // SAFETY: + // ArrayData is valid + T::Native::from_bytes_unchecked(b) + } + + /// Returns the element at index `i` + /// # Panics + /// Panics if index `i` is out of bounds. + pub fn value(&self, i: usize) -> &T::Native { + assert!( + i < self.len(), + "Trying to access an element at index {} from a {}{}Array of length {}", + i, + T::Offset::PREFIX, + T::PREFIX, + self.len() + ); + // SAFETY: + // Verified length above + unsafe { self.value_unchecked(i) } + } + + /// constructs a new iterator + pub fn iter(&self) -> ArrayIter<&Self> { + ArrayIter::new(self) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + value_offsets: self.value_offsets.slice(offset, length), + value_data: self.value_data.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + } + } + + /// Returns `GenericByteBuilder` of this byte array for mutating its values if the underlying + /// offset and data buffers are not shared by others. + pub fn into_builder(self) -> Result, Self> { + let len = self.len(); + let value_len = T::Offset::as_usize(self.value_offsets()[len] - self.value_offsets()[0]); + + let data = self.into_data(); + let null_bit_buffer = data.nulls().map(|b| b.inner().sliced()); + + let element_len = std::mem::size_of::(); + let offset_buffer = data.buffers()[0] + .slice_with_length(data.offset() * element_len, (len + 1) * element_len); + + let element_len = std::mem::size_of::(); + let value_buffer = data.buffers()[1] + .slice_with_length(data.offset() * element_len, value_len * element_len); + + drop(data); + + let try_mutable_null_buffer = match null_bit_buffer { + None => Ok(None), + Some(null_buffer) => { + // Null buffer exists, tries to make it mutable + null_buffer.into_mutable().map(Some) + } + }; + + let try_mutable_buffers = match try_mutable_null_buffer { + Ok(mutable_null_buffer) => { + // Got mutable null buffer, tries to get mutable value buffer + let try_mutable_offset_buffer = offset_buffer.into_mutable(); + let try_mutable_value_buffer = value_buffer.into_mutable(); + + // try_mutable_offset_buffer.map(...).map_err(...) doesn't work as the compiler complains + // mutable_null_buffer is moved into map closure. + match (try_mutable_offset_buffer, try_mutable_value_buffer) { + (Ok(mutable_offset_buffer), Ok(mutable_value_buffer)) => unsafe { + Ok(GenericByteBuilder::::new_from_buffer( + mutable_offset_buffer, + mutable_value_buffer, + mutable_null_buffer, + )) + }, + (Ok(mutable_offset_buffer), Err(value_buffer)) => Err(( + mutable_offset_buffer.into(), + value_buffer, + mutable_null_buffer.map(|b| b.into()), + )), + (Err(offset_buffer), Ok(mutable_value_buffer)) => Err(( + offset_buffer, + mutable_value_buffer.into(), + mutable_null_buffer.map(|b| b.into()), + )), + (Err(offset_buffer), Err(value_buffer)) => Err(( + offset_buffer, + value_buffer, + mutable_null_buffer.map(|b| b.into()), + )), + } + } + Err(mutable_null_buffer) => { + // Unable to get mutable null buffer + Err((offset_buffer, value_buffer, Some(mutable_null_buffer))) + } + }; + + match try_mutable_buffers { + Ok(builder) => Ok(builder), + Err((offset_buffer, value_buffer, null_bit_buffer)) => { + let builder = ArrayData::builder(T::DATA_TYPE) + .len(len) + .add_buffer(offset_buffer) + .add_buffer(value_buffer) + .null_bit_buffer(null_bit_buffer); + + let array_data = unsafe { builder.build_unchecked() }; + let array = GenericByteArray::::from(array_data); + + Err(array) + } + } + } +} + +impl std::fmt::Debug for GenericByteArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}{}Array\n[\n", T::Offset::PREFIX, T::PREFIX)?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl Array for GenericByteArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.value_offsets.len() - 1 + } + + fn is_empty(&self) -> bool { + self.value_offsets.len() <= 1 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.value_offsets.inner().inner().capacity(); + sum += self.value_data.capacity(); + if let Some(x) = &self.nulls { + sum += x.buffer().capacity() + } + sum + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } +} + +impl<'a, T: ByteArrayType> ArrayAccessor for &'a GenericByteArray { + type Item = &'a T::Native; + + fn value(&self, index: usize) -> Self::Item { + GenericByteArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + GenericByteArray::value_unchecked(self, index) + } +} + +impl From for GenericByteArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.data_type(), + &Self::DATA_TYPE, + "{}{}Array expects DataType::{}", + T::Offset::PREFIX, + T::PREFIX, + Self::DATA_TYPE + ); + assert_eq!( + data.buffers().len(), + 2, + "{}{}Array data should contain 2 buffers only (offsets and values)", + T::Offset::PREFIX, + T::PREFIX, + ); + // SAFETY: + // ArrayData is valid, and verified type above + let value_offsets = unsafe { get_offsets(&data) }; + let value_data = data.buffers()[1].clone(); + Self { + value_offsets, + value_data, + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + } + } +} + +impl From> for ArrayData { + fn from(array: GenericByteArray) -> Self { + let len = array.len(); + + let offsets = array.value_offsets.into_inner().into_inner(); + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .buffers(vec![offsets, array.value_data]) + .nulls(array.nulls); + + unsafe { builder.build_unchecked() } + } +} + +impl<'a, T: ByteArrayType> IntoIterator for &'a GenericByteArray { + type Item = Option<&'a T::Native>; + type IntoIter = ArrayIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIter::new(self) + } +} + +impl<'a, Ptr, T: ByteArrayType> FromIterator<&'a Option> for GenericByteArray +where + Ptr: AsRef + 'a, +{ + fn from_iter>>(iter: I) -> Self { + iter.into_iter() + .map(|o| o.as_ref().map(|p| p.as_ref())) + .collect() + } +} + +impl FromIterator> for GenericByteArray +where + Ptr: AsRef, +{ + fn from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let mut builder = GenericByteBuilder::with_capacity(iter.size_hint().0, 1024); + builder.extend(iter); + builder.finish() + } +} + +#[cfg(test)] +mod tests { + use crate::{BinaryArray, StringArray}; + use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer}; + + #[test] + fn try_new() { + let data = Buffer::from_slice_ref("helloworld"); + let offsets = OffsetBuffer::new(vec![0, 5, 10].into()); + StringArray::new(offsets.clone(), data.clone(), None); + + let nulls = NullBuffer::new_null(3); + let err = + StringArray::try_new(offsets.clone(), data.clone(), Some(nulls.clone())).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for StringArray, expected 2 got 3"); + + let err = BinaryArray::try_new(offsets.clone(), data.clone(), Some(nulls)).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for BinaryArray, expected 2 got 3"); + + let non_utf8_data = Buffer::from_slice_ref(b"he\xFFloworld"); + let err = StringArray::try_new(offsets.clone(), non_utf8_data.clone(), None).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2"); + + BinaryArray::new(offsets, non_utf8_data, None); + + let offsets = OffsetBuffer::new(vec![0, 5, 11].into()); + let err = StringArray::try_new(offsets.clone(), data.clone(), None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Offset of 11 exceeds length of values 10" + ); + + let err = BinaryArray::try_new(offsets.clone(), data, None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Maximum offset of 11 is larger than values of length 10" + ); + + let non_ascii_data = Buffer::from_slice_ref("heìloworld"); + StringArray::new(offsets.clone(), non_ascii_data.clone(), None); + BinaryArray::new(offsets, non_ascii_data.clone(), None); + + let offsets = OffsetBuffer::new(vec![0, 3, 10].into()); + let err = StringArray::try_new(offsets.clone(), non_ascii_data.clone(), None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Split UTF-8 codepoint at offset 3" + ); + + BinaryArray::new(offsets, non_ascii_data, None); + } +} diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs new file mode 100644 index 000000000000..1f4d83b1c5d0 --- /dev/null +++ b/arrow-array/src/array/dictionary_array.rs @@ -0,0 +1,1378 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}; +use crate::cast::AsArray; +use crate::iterator::ArrayIter; +use crate::types::*; +use crate::{ + make_array, Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + PrimitiveArray, StringArray, +}; +use arrow_buffer::bit_util::set_bit; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{ArrowNativeType, BooleanBuffer, BooleanBufferBuilder}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; +use std::any::Any; +use std::sync::Arc; + +/// A [`DictionaryArray`] indexed by `i8` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int8DictionaryArray, Int8Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int8DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int8DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `i16` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int16DictionaryArray, Int16Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int16DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int16Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int16DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `i32` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int32DictionaryArray, Int32Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int32DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int32Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int32DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `i64` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int64DictionaryArray, Int64Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int64DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &Int64Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type Int64DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u8` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt8DictionaryArray, UInt8Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt8DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt8Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt8DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u16` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt16DictionaryArray, UInt16Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt16DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt16Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt16DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u32` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt32DictionaryArray, UInt32Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt32DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt32Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt32DictionaryArray = DictionaryArray; + +/// A [`DictionaryArray`] indexed by `u64` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, UInt64DictionaryArray, UInt64Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: UInt64DictionaryArray = vec!["a", "a", "b", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.keys(), &UInt64Array::from(vec![0, 0, 1, 2])); +/// assert_eq!(array.values(), &values); +/// ``` +/// +/// See [`DictionaryArray`] for more information and examples +pub type UInt64DictionaryArray = DictionaryArray; + +/// An array of [dictionary encoded values](https://arrow.apache.org/docs/format/Columnar.html#dictionary-encoded-layout) +/// +/// This is mostly used to represent strings or a limited set of primitive types as integers, +/// for example when doing NLP analysis or representing chromosomes by name. +/// +/// [`DictionaryArray`] are represented using a `keys` array and a +/// `values` array, which may be different lengths. The `keys` array +/// stores indexes in the `values` array which holds +/// the corresponding logical value, as shown here: +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌─────────────────┐ ┌─────────┐ │ ┌─────────────────┐ +/// │ │ A │ │ 0 │ │ A │ values[keys[0]] +/// ├─────────────────┤ ├─────────┤ │ ├─────────────────┤ +/// │ │ D │ │ 2 │ │ B │ values[keys[1]] +/// ├─────────────────┤ ├─────────┤ │ ├─────────────────┤ +/// │ │ B │ │ 2 │ │ B │ values[keys[2]] +/// └─────────────────┘ ├─────────┤ │ ├─────────────────┤ +/// │ │ 1 │ │ D │ values[keys[3]] +/// ├─────────┤ │ ├─────────────────┤ +/// │ │ 1 │ │ D │ values[keys[4]] +/// ├─────────┤ │ ├─────────────────┤ +/// │ │ 0 │ │ A │ values[keys[5]] +/// └─────────┘ │ └─────────────────┘ +/// │ values keys +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Logical array +/// Contents +/// DictionaryArray +/// length = 6 +/// ``` +/// +/// # Example: From Nullable Data +/// +/// ``` +/// # use arrow_array::{DictionaryArray, Int8Array, types::Int8Type}; +/// let test = vec!["a", "a", "b", "c"]; +/// let array : DictionaryArray = test.iter().map(|&x| if x == "b" {None} else {Some(x)}).collect(); +/// assert_eq!(array.keys(), &Int8Array::from(vec![Some(0), Some(0), None, Some(1)])); +/// ``` +/// +/// # Example: From Non-Nullable Data +/// +/// ``` +/// # use arrow_array::{DictionaryArray, Int8Array, types::Int8Type}; +/// let test = vec!["a", "a", "b", "c"]; +/// let array : DictionaryArray = test.into_iter().collect(); +/// assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2])); +/// ``` +/// +/// # Example: From Existing Arrays +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{DictionaryArray, Int8Array, StringArray, types::Int8Type}; +/// // You can form your own DictionaryArray by providing the +/// // values (dictionary) and keys (indexes into the dictionary): +/// let values = StringArray::from_iter_values(["a", "b", "c"]); +/// let keys = Int8Array::from_iter_values([0, 0, 1, 2]); +/// let array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); +/// let expected: DictionaryArray:: = vec!["a", "a", "b", "c"].into_iter().collect(); +/// assert_eq!(&array, &expected); +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::{Array, StringArray}; +/// # use arrow_array::builder::StringDictionaryBuilder; +/// # use arrow_array::types::Int32Type; +/// let mut builder = StringDictionaryBuilder::::new(); +/// builder.append_value("a"); +/// builder.append_null(); +/// builder.append_value("a"); +/// builder.append_value("b"); +/// let array = builder.finish(); +/// +/// let values: Vec<_> = array.downcast_dict::().unwrap().into_iter().collect(); +/// assert_eq!(&values, &[Some("a"), None, Some("a"), Some("b")]); +/// ``` +pub struct DictionaryArray { + data_type: DataType, + + /// The keys of this dictionary. These are constructed from the + /// buffer and null bitmap of `data`. Also, note that these do + /// not correspond to the true values of this array. Rather, they + /// map to the real values. + keys: PrimitiveArray, + + /// Array of dictionary values (can by any DataType). + values: ArrayRef, + + /// Values are ordered. + is_ordered: bool, +} + +impl Clone for DictionaryArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + keys: self.keys.clone(), + values: self.values.clone(), + is_ordered: self.is_ordered, + } + } +} + +impl DictionaryArray { + /// Attempt to create a new DictionaryArray with a specified keys + /// (indexes into the dictionary) and values (dictionary) + /// array. + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(keys: PrimitiveArray, values: ArrayRef) -> Self { + Self::try_new(keys, values).unwrap() + } + + /// Attempt to create a new DictionaryArray with a specified keys + /// (indexes into the dictionary) and values (dictionary) + /// array. + /// + /// # Errors + /// + /// Returns an error if any `keys[i] >= values.len() || keys[i] < 0` + pub fn try_new(keys: PrimitiveArray, values: ArrayRef) -> Result { + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + let zero = K::Native::usize_as(0); + let values_len = values.len(); + + if let Some((idx, v)) = + keys.values().iter().enumerate().find(|(idx, v)| { + (v.is_lt(zero) || v.as_usize() >= values_len) && keys.is_valid(*idx) + }) + { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid dictionary key {v:?} at index {idx}, expected 0 <= key < {values_len}", + ))); + } + + Ok(Self { + data_type, + keys, + values, + is_ordered: false, + }) + } + + /// Create a new [`DictionaryArray`] without performing validation + /// + /// # Safety + /// + /// Safe provided [`Self::try_new`] would not return an error + pub unsafe fn new_unchecked(keys: PrimitiveArray, values: ArrayRef) -> Self { + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + Self { + data_type, + keys, + values, + is_ordered: false, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (PrimitiveArray, ArrayRef) { + (self.keys, self.values) + } + + /// Return an array view of the keys of this dictionary as a PrimitiveArray. + pub fn keys(&self) -> &PrimitiveArray { + &self.keys + } + + /// If `value` is present in `values` (aka the dictionary), + /// returns the corresponding key (index into the `values` + /// array). Otherwise returns `None`. + /// + /// Panics if `values` is not a [`StringArray`]. + pub fn lookup_key(&self, value: &str) -> Option { + let rd_buf: &StringArray = self.values.as_any().downcast_ref::().unwrap(); + + (0..rd_buf.len()) + .position(|i| rd_buf.value(i) == value) + .and_then(K::Native::from_usize) + } + + /// Returns a reference to the dictionary values array + pub fn values(&self) -> &ArrayRef { + &self.values + } + + /// Returns a clone of the value type of this list. + pub fn value_type(&self) -> DataType { + self.values.data_type().clone() + } + + /// The length of the dictionary is the length of the keys array. + pub fn len(&self) -> usize { + self.keys.len() + } + + /// Whether this dictionary is empty + pub fn is_empty(&self) -> bool { + self.keys.is_empty() + } + + /// Currently exists for compatibility purposes with Arrow IPC. + pub fn is_ordered(&self) -> bool { + self.is_ordered + } + + /// Return an iterator over the keys (indexes into the dictionary) + pub fn keys_iter(&self) -> impl Iterator> + '_ { + self.keys.iter().map(|key| key.map(|k| k.as_usize())) + } + + /// Return the value of `keys` (the dictionary key) at index `i`, + /// cast to `usize`, `None` if the value at `i` is `NULL`. + pub fn key(&self, i: usize) -> Option { + self.keys.is_valid(i).then(|| self.keys.value(i).as_usize()) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + keys: self.keys.slice(offset, length), + values: self.values.clone(), + is_ordered: self.is_ordered, + } + } + + /// Downcast this dictionary to a [`TypedDictionaryArray`] + /// + /// ``` + /// use arrow_array::{Array, ArrayAccessor, DictionaryArray, StringArray, types::Int32Type}; + /// + /// let orig = [Some("a"), Some("b"), None]; + /// let dictionary = DictionaryArray::::from_iter(orig); + /// let typed = dictionary.downcast_dict::().unwrap(); + /// assert_eq!(typed.value(0), "a"); + /// assert_eq!(typed.value(1), "b"); + /// assert!(typed.is_null(2)); + /// ``` + /// + pub fn downcast_dict(&self) -> Option> { + let values = self.values.as_any().downcast_ref()?; + Some(TypedDictionaryArray { + dictionary: self, + values, + }) + } + + /// Returns a new dictionary with the same keys as the current instance + /// but with a different set of dictionary values + /// + /// This can be used to perform an operation on the values of a dictionary + /// + /// # Panics + /// + /// Panics if `values` has a length less than the current values + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::builder::PrimitiveDictionaryBuilder; + /// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor}; + /// # use arrow_array::types::{Int32Type, Int8Type}; + /// + /// // Construct a Dict(Int32, Int8) + /// let mut builder = PrimitiveDictionaryBuilder::::with_capacity(2, 200); + /// for i in 0..100 { + /// builder.append(i % 2).unwrap(); + /// } + /// + /// let dictionary = builder.finish(); + /// + /// // Perform a widening cast of dictionary values + /// let typed_dictionary = dictionary.downcast_dict::().unwrap(); + /// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64); + /// + /// // Create a Dict(Int32, + /// let new = dictionary.with_values(Arc::new(values)); + /// + /// // Verify values are as expected + /// let new_typed = new.downcast_dict::().unwrap(); + /// for i in 0..100 { + /// assert_eq!(new_typed.value(i), (i % 2) as i64) + /// } + /// ``` + /// + pub fn with_values(&self, values: ArrayRef) -> Self { + assert!(values.len() >= self.values.len()); + let data_type = + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); + Self { + data_type, + keys: self.keys.clone(), + values, + is_ordered: false, + } + } + + /// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating + /// its keys and values if the underlying data buffer is not shared by others. + pub fn into_primitive_dict_builder(self) -> Result, Self> + where + V: ArrowPrimitiveType, + { + if !self.value_type().is_primitive() { + return Err(self); + } + + let key_array = self.keys().clone(); + let value_array = self.values().as_primitive::().clone(); + + drop(self.keys); + drop(self.values); + + let key_builder = key_array.into_builder(); + let value_builder = value_array.into_builder(); + + match (key_builder, value_builder) { + (Ok(key_builder), Ok(value_builder)) => Ok(unsafe { + PrimitiveDictionaryBuilder::new_from_builders(key_builder, value_builder) + }), + (Err(key_array), Ok(mut value_builder)) => { + Err(Self::try_new(key_array, Arc::new(value_builder.finish())).unwrap()) + } + (Ok(mut key_builder), Err(value_array)) => { + Err(Self::try_new(key_builder.finish(), Arc::new(value_array)).unwrap()) + } + (Err(key_array), Err(value_array)) => { + Err(Self::try_new(key_array, Arc::new(value_array)).unwrap()) + } + } + } + + /// Applies an unary and infallible function to a mutable dictionary array. + /// Mutable dictionary array means that the buffers are not shared with other arrays. + /// As a result, this mutates the buffers directly without allocating new buffers. + /// + /// # Implementation + /// + /// This will apply the function for all dictionary values, including those on null slots. + /// This implies that the operation must be infallible for any value of the corresponding type + /// or this function may panic. + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::{Array, ArrayAccessor, DictionaryArray, StringArray, types::{Int8Type, Int32Type}}; + /// # use arrow_array::{Int8Array, Int32Array}; + /// let values = Int32Array::from(vec![Some(10), Some(20), None]); + /// let keys = Int8Array::from_iter_values([0, 0, 1, 2]); + /// let dictionary = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + /// let c = dictionary.unary_mut::<_, Int32Type>(|x| x + 1).unwrap(); + /// let typed = c.downcast_dict::().unwrap(); + /// assert_eq!(typed.value(0), 11); + /// assert_eq!(typed.value(1), 11); + /// assert_eq!(typed.value(2), 21); + /// ``` + pub fn unary_mut(self, op: F) -> Result, DictionaryArray> + where + V: ArrowPrimitiveType, + F: Fn(V::Native) -> V::Native, + { + let mut builder: PrimitiveDictionaryBuilder = self.into_primitive_dict_builder()?; + builder + .values_slice_mut() + .iter_mut() + .for_each(|v| *v = op(*v)); + Ok(builder.finish()) + } + + /// Computes an occupancy mask for this dictionary's values + /// + /// For each value in [`Self::values`] the corresponding bit will be set in the + /// returned mask if it is referenced by a key in this [`DictionaryArray`] + pub fn occupancy(&self) -> BooleanBuffer { + let len = self.values.len(); + let mut builder = BooleanBufferBuilder::new(len); + builder.resize(len); + let slice = builder.as_slice_mut(); + match self.keys.nulls().filter(|n| n.null_count() > 0) { + Some(n) => { + let v = self.keys.values(); + n.valid_indices() + .for_each(|idx| set_bit(slice, v[idx].as_usize())) + } + None => { + let v = self.keys.values(); + v.iter().for_each(|v| set_bit(slice, v.as_usize())) + } + } + builder.finish() + } +} + +/// Constructs a `DictionaryArray` from an array data reference. +impl From for DictionaryArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.buffers().len(), + 1, + "DictionaryArray data should contain a single buffer only (keys)." + ); + assert_eq!( + data.child_data().len(), + 1, + "DictionaryArray should contain a single child array (values)." + ); + + if let DataType::Dictionary(key_data_type, _) = data.data_type() { + assert_eq!( + &T::DATA_TYPE, + key_data_type.as_ref(), + "DictionaryArray's data type must match, expected {} got {}", + T::DATA_TYPE, + key_data_type + ); + + let values = make_array(data.child_data()[0].clone()); + let data_type = data.data_type().clone(); + + // create a zero-copy of the keys' data + // SAFETY: + // ArrayData is valid and verified type above + + let keys = PrimitiveArray::::from(unsafe { + data.into_builder() + .data_type(T::DATA_TYPE) + .child_data(vec![]) + .build_unchecked() + }); + + Self { + data_type, + keys, + values, + is_ordered: false, + } + } else { + panic!("DictionaryArray must have Dictionary data type.") + } + } +} + +impl From> for ArrayData { + fn from(array: DictionaryArray) -> Self { + let builder = array + .keys + .into_data() + .into_builder() + .data_type(array.data_type) + .child_data(vec![array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +/// Constructs a `DictionaryArray` from an iterator of optional strings. +/// +/// # Example: +/// ``` +/// use arrow_array::{DictionaryArray, PrimitiveArray, StringArray, types::Int8Type}; +/// +/// let test = vec!["a", "a", "b", "c"]; +/// let array: DictionaryArray = test +/// .iter() +/// .map(|&x| if x == "b" { None } else { Some(x) }) +/// .collect(); +/// assert_eq!( +/// "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n null,\n 1,\n] values: StringArray\n[\n \"a\",\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: ArrowDictionaryKeyType> FromIterator> for DictionaryArray { + fn from_iter>>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringDictionaryBuilder::with_capacity(lower, 256, 1024); + builder.extend(it); + builder.finish() + } +} + +/// Constructs a `DictionaryArray` from an iterator of strings. +/// +/// # Example: +/// +/// ``` +/// use arrow_array::{DictionaryArray, PrimitiveArray, StringArray, types::Int8Type}; +/// +/// let test = vec!["a", "a", "b", "c"]; +/// let array: DictionaryArray = test.into_iter().collect(); +/// assert_eq!( +/// "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n 1,\n 2,\n] values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray { + fn from_iter>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringDictionaryBuilder::with_capacity(lower, 256, 1024); + it.for_each(|i| { + builder + .append(i) + .expect("Unable to append a value to a dictionary array."); + }); + + builder.finish() + } +} + +impl Array for DictionaryArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.keys.len() + } + + fn is_empty(&self) -> bool { + self.keys.is_empty() + } + + fn offset(&self) -> usize { + self.keys.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.keys.nulls() + } + + fn logical_nulls(&self) -> Option { + match self.values.nulls() { + None => self.nulls().cloned(), + Some(value_nulls) => { + let mut builder = BooleanBufferBuilder::new(self.len()); + match self.keys.nulls() { + Some(n) => builder.append_buffer(n.inner()), + None => builder.append_n(self.len(), true), + } + for (idx, k) in self.keys.values().iter().enumerate() { + let k = k.as_usize(); + // Check range to allow for nulls + if k < value_nulls.len() && value_nulls.is_null(k) { + builder.set_bit(idx, false); + } + } + Some(builder.finish().into()) + } + } + } + + fn is_nullable(&self) -> bool { + !self.is_empty() && (self.nulls().is_some() || self.values.is_nullable()) + } + + fn get_buffer_memory_size(&self) -> usize { + self.keys.get_buffer_memory_size() + self.values.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + + self.keys.get_buffer_memory_size() + + self.values.get_array_memory_size() + } +} + +impl std::fmt::Debug for DictionaryArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!( + f, + "DictionaryArray {{keys: {:?} values: {:?}}}", + self.keys, self.values + ) + } +} + +/// A [`DictionaryArray`] typed on its child values array +/// +/// Implements [`ArrayAccessor`] allowing fast access to its elements +/// +/// ``` +/// use arrow_array::{DictionaryArray, StringArray, types::Int32Type}; +/// +/// let orig = ["a", "b", "a", "b"]; +/// let dictionary = DictionaryArray::::from_iter(orig); +/// +/// // `TypedDictionaryArray` allows you to access the values directly +/// let typed = dictionary.downcast_dict::().unwrap(); +/// +/// for (maybe_val, orig) in typed.into_iter().zip(orig) { +/// assert_eq!(maybe_val.unwrap(), orig) +/// } +/// ``` +pub struct TypedDictionaryArray<'a, K: ArrowDictionaryKeyType, V> { + /// The dictionary array + dictionary: &'a DictionaryArray, + /// The values of the dictionary + values: &'a V, +} + +// Manually implement `Clone` to avoid `V: Clone` type constraint +impl<'a, K: ArrowDictionaryKeyType, V> Clone for TypedDictionaryArray<'a, K, V> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, K: ArrowDictionaryKeyType, V> Copy for TypedDictionaryArray<'a, K, V> {} + +impl<'a, K: ArrowDictionaryKeyType, V> std::fmt::Debug for TypedDictionaryArray<'a, K, V> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f, "TypedDictionaryArray({:?})", self.dictionary) + } +} + +impl<'a, K: ArrowDictionaryKeyType, V> TypedDictionaryArray<'a, K, V> { + /// Returns the keys of this [`TypedDictionaryArray`] + pub fn keys(&self) -> &'a PrimitiveArray { + self.dictionary.keys() + } + + /// Returns the values of this [`TypedDictionaryArray`] + pub fn values(&self) -> &'a V { + self.values + } +} + +impl<'a, K: ArrowDictionaryKeyType, V: Sync> Array for TypedDictionaryArray<'a, K, V> { + fn as_any(&self) -> &dyn Any { + self.dictionary + } + + fn to_data(&self) -> ArrayData { + self.dictionary.to_data() + } + + fn into_data(self) -> ArrayData { + self.dictionary.into_data() + } + + fn data_type(&self) -> &DataType { + self.dictionary.data_type() + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.dictionary.slice(offset, length)) + } + + fn len(&self) -> usize { + self.dictionary.len() + } + + fn is_empty(&self) -> bool { + self.dictionary.is_empty() + } + + fn offset(&self) -> usize { + self.dictionary.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.dictionary.nulls() + } + + fn logical_nulls(&self) -> Option { + self.dictionary.logical_nulls() + } + + fn is_nullable(&self) -> bool { + self.dictionary.is_nullable() + } + + fn get_buffer_memory_size(&self) -> usize { + self.dictionary.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.dictionary.get_array_memory_size() + } +} + +impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V> +where + K: ArrowDictionaryKeyType, + Self: ArrayAccessor, +{ + type Item = Option<::Item>; + type IntoIter = ArrayIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIter::new(self) + } +} + +impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V> +where + K: ArrowDictionaryKeyType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = <&'a V as ArrayAccessor>::Item; + + fn value(&self, index: usize) -> Self::Item { + assert!( + index < self.len(), + "Trying to access an element at index {} from a TypedDictionaryArray of length {}", + index, + self.len() + ); + unsafe { self.value_unchecked(index) } + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + let val = self.dictionary.keys.value_unchecked(index); + let value_idx = val.as_usize(); + + // As dictionary keys are only verified for non-null indexes + // we must check the value is within bounds + match value_idx < self.values.len() { + true => self.values.value_unchecked(value_idx), + false => Default::default(), + } + } +} + +/// A [`DictionaryArray`] with the key type erased +/// +/// This can be used to efficiently implement kernels for all possible dictionary +/// keys without needing to create specialized implementations for each key type +/// +/// For example +/// +/// ``` +/// # use arrow_array::*; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::builder::PrimitiveDictionaryBuilder; +/// # use arrow_array::types::*; +/// # use arrow_schema::ArrowError; +/// # use std::sync::Arc; +/// +/// fn to_string(a: &dyn Array) -> Result { +/// if let Some(d) = a.as_any_dictionary_opt() { +/// // Recursively handle dictionary input +/// let r = to_string(d.values().as_ref())?; +/// return Ok(d.with_values(r)); +/// } +/// downcast_primitive_array! { +/// a => Ok(Arc::new(a.iter().map(|x| x.map(|x| x.to_string())).collect::())), +/// d => Err(ArrowError::InvalidArgumentError(format!("{d:?} not supported"))) +/// } +/// } +/// +/// let result = to_string(&Int32Array::from(vec![1, 2, 3])).unwrap(); +/// let actual = result.as_string::().iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "2", "3"]); +/// +/// let mut dict = PrimitiveDictionaryBuilder::::new(); +/// dict.extend([Some(1), Some(1), Some(2), Some(3), Some(2)]); +/// let dict = dict.finish(); +/// +/// let r = to_string(&dict).unwrap(); +/// let r = r.as_dictionary::().downcast_dict::().unwrap(); +/// assert_eq!(r.keys(), dict.keys()); // Keys are the same +/// +/// let actual = r.into_iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "1", "2", "3", "2"]); +/// ``` +/// +/// See [`AsArray::as_any_dictionary_opt`] and [`AsArray::as_any_dictionary`] +pub trait AnyDictionaryArray: Array { + /// Returns the primitive keys of this dictionary as an [`Array`] + fn keys(&self) -> &dyn Array; + + /// Returns the values of this dictionary + fn values(&self) -> &ArrayRef; + + /// Returns the keys of this dictionary as usize + /// + /// The values for nulls will be arbitrary, but are guaranteed + /// to be in the range `0..self.values.len()` + /// + /// # Panic + /// + /// Panics if `values.len() == 0` + fn normalized_keys(&self) -> Vec; + + /// Create a new [`DictionaryArray`] replacing `values` with the new values + /// + /// See [`DictionaryArray::with_values`] + fn with_values(&self, values: ArrayRef) -> ArrayRef; +} + +impl AnyDictionaryArray for DictionaryArray { + fn keys(&self) -> &dyn Array { + &self.keys + } + + fn values(&self) -> &ArrayRef { + self.values() + } + + fn normalized_keys(&self) -> Vec { + let v_len = self.values().len(); + assert_ne!(v_len, 0); + let iter = self.keys().values().iter(); + iter.map(|x| x.as_usize().min(v_len - 1)).collect() + } + + fn with_values(&self, values: ArrayRef) -> ArrayRef { + Arc::new(self.with_values(values)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::PrimitiveDictionaryBuilder; + use crate::cast::as_dictionary_array; + use crate::types::{Int32Type, Int8Type, UInt32Type, UInt8Type}; + use crate::{Int16Array, Int32Array, Int8Array}; + use arrow_buffer::{Buffer, ToByteSlice}; + use std::sync::Arc; + + #[test] + fn test_dictionary_array() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int8) + .len(8) + .add_buffer(Buffer::from( + &[10_i8, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(), + )) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + let keys = Buffer::from(&[2_i16, 3, 4].to_byte_slice()); + + // Construct a dictionary array from the above two + let key_type = DataType::Int16; + let value_type = DataType::Int8; + let dict_data_type = DataType::Dictionary(Box::new(key_type), Box::new(value_type)); + let dict_data = ArrayData::builder(dict_data_type.clone()) + .len(3) + .add_buffer(keys.clone()) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let dict_array = Int16DictionaryArray::from(dict_data); + + let values = dict_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int8, dict_array.value_type()); + assert_eq!(3, dict_array.len()); + + // Null count only makes sense in terms of the component arrays. + assert_eq!(0, dict_array.null_count()); + assert_eq!(0, dict_array.values().null_count()); + assert_eq!(dict_array.keys(), &Int16Array::from(vec![2_i16, 3, 4])); + + // Now test with a non-zero offset + let dict_data = ArrayData::builder(dict_data_type) + .len(2) + .offset(1) + .add_buffer(keys) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let dict_array = Int16DictionaryArray::from(dict_data); + + let values = dict_array.values(); + assert_eq!(value_data, values.to_data()); + assert_eq!(DataType::Int8, dict_array.value_type()); + assert_eq!(2, dict_array.len()); + assert_eq!(dict_array.keys(), &Int16Array::from(vec![3_i16, 4])); + } + + #[test] + fn test_dictionary_array_fmt_debug() { + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(3, 2); + builder.append(12345678).unwrap(); + builder.append_null(); + builder.append(22345678).unwrap(); + let array = builder.finish(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n null,\n 1,\n] values: PrimitiveArray\n[\n 12345678,\n 22345678,\n]}\n", + format!("{array:?}") + ); + + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(20, 2); + for _ in 0..20 { + builder.append(1).unwrap(); + } + let array = builder.finish(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n] values: PrimitiveArray\n[\n 1,\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_dictionary_array_from_iter() { + let test = vec!["a", "a", "b", "c"]; + let array: DictionaryArray = test + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n null,\n 1,\n] values: StringArray\n[\n \"a\",\n \"c\",\n]}\n", + format!("{array:?}") + ); + + let array: DictionaryArray = test.into_iter().collect(); + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 0,\n 1,\n 2,\n] values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_dictionary_array_reverse_lookup_key() { + let test = vec!["a", "a", "b", "c"]; + let array: DictionaryArray = test.into_iter().collect(); + + assert_eq!(array.lookup_key("c"), Some(2)); + + // Direction of building a dictionary is the iterator direction + let test = vec!["t3", "t3", "t2", "t2", "t1", "t3", "t4", "t1", "t0"]; + let array: DictionaryArray = test.into_iter().collect(); + + assert_eq!(array.lookup_key("t1"), Some(2)); + assert_eq!(array.lookup_key("non-existent"), None); + } + + #[test] + fn test_dictionary_keys_as_primitive_array() { + let test = vec!["a", "b", "c", "a"]; + let array: DictionaryArray = test.into_iter().collect(); + + let keys = array.keys(); + assert_eq!(&DataType::Int8, keys.data_type()); + assert_eq!(0, keys.null_count()); + assert_eq!(&[0, 1, 2, 0], keys.values()); + } + + #[test] + fn test_dictionary_keys_as_primitive_array_with_null() { + let test = vec![Some("a"), None, Some("b"), None, None, Some("a")]; + let array: DictionaryArray = test.into_iter().collect(); + + let keys = array.keys(); + assert_eq!(&DataType::Int32, keys.data_type()); + assert_eq!(3, keys.null_count()); + + assert!(keys.is_valid(0)); + assert!(!keys.is_valid(1)); + assert!(keys.is_valid(2)); + assert!(!keys.is_valid(3)); + assert!(!keys.is_valid(4)); + assert!(keys.is_valid(5)); + + assert_eq!(0, keys.value(0)); + assert_eq!(1, keys.value(2)); + assert_eq!(0, keys.value(5)); + } + + #[test] + fn test_dictionary_all_nulls() { + let test = vec![None, None, None]; + let array: DictionaryArray = test.into_iter().collect(); + array + .into_data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + fn test_dictionary_iter() { + // Construct a value array + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int16Array::from_iter_values([2_i16, 3, 4]); + + // Construct a dictionary array from the above two + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + + let mut key_iter = dict_array.keys_iter(); + assert_eq!(2, key_iter.next().unwrap().unwrap()); + assert_eq!(3, key_iter.next().unwrap().unwrap()); + assert_eq!(4, key_iter.next().unwrap().unwrap()); + assert!(key_iter.next().is_none()); + + let mut iter = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(dict_array.keys_iter()); + + assert_eq!(12, iter.next().unwrap().unwrap()); + assert_eq!(13, iter.next().unwrap().unwrap()); + assert_eq!(14, iter.next().unwrap().unwrap()); + assert!(iter.next().is_none()); + } + + #[test] + fn test_dictionary_iter_with_null() { + let test = vec![Some("a"), None, Some("b"), None, None, Some("a")]; + let array: DictionaryArray = test.into_iter().collect(); + + let mut iter = array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(array.keys_iter()); + + assert_eq!("a", iter.next().unwrap().unwrap()); + assert!(iter.next().unwrap().is_none()); + assert_eq!("b", iter.next().unwrap().unwrap()); + assert!(iter.next().unwrap().is_none()); + assert!(iter.next().unwrap().is_none()); + assert_eq!("a", iter.next().unwrap().unwrap()); + assert!(iter.next().is_none()); + } + + #[test] + fn test_dictionary_key() { + let keys = Int8Array::from(vec![Some(2), None, Some(1)]); + let values = StringArray::from(vec!["foo", "bar", "baz", "blarg"]); + + let array = DictionaryArray::new(keys, Arc::new(values)); + assert_eq!(array.key(0), Some(2)); + assert_eq!(array.key(1), None); + assert_eq!(array.key(2), Some(1)); + } + + #[test] + fn test_try_new() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect(); + + let array = DictionaryArray::new(keys, Arc::new(values)); + assert_eq!(array.keys().data_type(), &DataType::Int32); + assert_eq!(array.values().data_type(), &DataType::Utf8); + + assert_eq!(array.null_count(), 1); + + assert!(array.keys().is_valid(0)); + assert!(array.keys().is_valid(1)); + assert!(array.keys().is_null(2)); + assert!(array.keys().is_valid(3)); + + assert_eq!(array.keys().value(0), 0); + assert_eq!(array.keys().value(1), 2); + assert_eq!(array.keys().value(3), 1); + + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 2,\n null,\n 1,\n] values: StringArray\n[\n \"foo\",\n \"bar\",\n \"baz\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + #[should_panic(expected = "Invalid dictionary key 3 at index 1, expected 0 <= key < 2")] + fn test_try_new_index_too_large() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + // dictionary only has 2 values, so offset 3 is out of bounds + let keys: Int32Array = [Some(0), Some(3)].into_iter().collect(); + DictionaryArray::new(keys, Arc::new(values)); + } + + #[test] + #[should_panic(expected = "Invalid dictionary key -100 at index 0, expected 0 <= key < 2")] + fn test_try_new_index_too_small() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + let keys: Int32Array = [Some(-100)].into_iter().collect(); + DictionaryArray::new(keys, Arc::new(values)); + } + + #[test] + #[should_panic(expected = "DictionaryArray's data type must match, expected Int64 got Int32")] + fn test_from_array_data_validation() { + let a = DictionaryArray::::from_iter(["32"]); + let _ = DictionaryArray::::from(a.into_data()); + } + + #[test] + fn test_into_primitive_dict_builder() { + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + + let boxed: ArrayRef = Arc::new(dict_array); + let col: DictionaryArray = as_dictionary_array(&boxed).clone(); + + drop(boxed); + + let mut builder = col.into_primitive_dict_builder::().unwrap(); + + let slice = builder.values_slice_mut(); + assert_eq!(slice, &[10, 12, 15]); + + slice[0] = 4; + slice[1] = 2; + slice[2] = 1; + + let values = Int32Array::from_iter_values([4_i32, 2, 1]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let expected = DictionaryArray::new(keys, Arc::new(values)); + + let new_array = builder.finish(); + assert_eq!(expected, new_array); + } + + #[test] + fn test_into_primitive_dict_builder_cloned_array() { + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let dict_array = DictionaryArray::new(keys, Arc::new(values)); + + let boxed: ArrayRef = Arc::new(dict_array); + + let col: DictionaryArray = DictionaryArray::::from(boxed.to_data()); + let err = col.into_primitive_dict_builder::(); + + let returned = err.unwrap_err(); + + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let expected = DictionaryArray::new(keys, Arc::new(values)); + assert_eq!(expected, returned); + } + + #[test] + fn test_occupancy() { + let keys = Int32Array::new((100..200).collect(), None); + let values = Int32Array::from(vec![0; 1024]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + for (idx, v) in dict.occupancy().iter().enumerate() { + let expected = (100..200).contains(&idx); + assert_eq!(v, expected, "{idx}"); + } + + let keys = Int32Array::new( + (0..100).collect(), + Some((0..100).map(|x| x % 4 == 0).collect()), + ); + let values = Int32Array::from(vec![0; 1024]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + for (idx, v) in dict.occupancy().iter().enumerate() { + let expected = idx % 4 == 0 && idx < 100; + assert_eq!(v, expected, "{idx}"); + } + } + + #[test] + fn test_iterator_nulls() { + let keys = Int32Array::new( + vec![0, 700, 1, 2].into(), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + let values = Int32Array::from(vec![Some(50), None, Some(2)]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + let values: Vec<_> = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect(); + assert_eq!(values, &[Some(50), None, None, Some(2)]) + } + + #[test] + fn test_normalized_keys() { + let values = vec![132, 0, 1].into(); + let nulls = NullBuffer::from(vec![false, true, true]); + let keys = Int32Array::new(values, Some(nulls)); + let dictionary = DictionaryArray::new(keys, Arc::new(Int32Array::new_null(2))); + assert_eq!(&dictionary.normalized_keys(), &[1, 0, 1]) + } +} diff --git a/arrow/src/array/array_fixed_size_binary.rs b/arrow-array/src/array/fixed_size_binary_array.rs similarity index 57% rename from arrow/src/array/array_fixed_size_binary.rs rename to arrow-array/src/array/fixed_size_binary_array.rs index 22eac1435a8d..d89bbd5ad084 100644 --- a/arrow/src/array/array_fixed_size_binary.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -15,27 +15,24 @@ // specific language governing permissions and limitations // under the License. +use crate::array::print_long_array; +use crate::iterator::FixedSizeBinaryIter; +use crate::{Array, ArrayAccessor, ArrayRef, FixedSizeListArray}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; -use std::convert::From; -use std::fmt; - -use super::{ - array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, FixedSizeListArray, -}; -use crate::array::{ArrayAccessor, FixedSizeBinaryIter}; -use crate::buffer::Buffer; -use crate::error::{ArrowError, Result}; -use crate::util::bit_util; -use crate::{buffer::MutableBuffer, datatypes::DataType}; - -/// An array where each element is a fixed-size sequence of bytes. +use std::sync::Arc; + +/// An array of [fixed size binary arrays](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout) /// /// # Examples /// /// Create an array from an iterable argument of byte slices. /// /// ``` -/// use arrow::array::{Array, FixedSizeBinaryArray}; +/// use arrow_array::{Array, FixedSizeBinaryArray}; /// let input_arg = vec![ vec![1, 2], vec![3, 4], vec![5, 6] ]; /// let arr = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); /// @@ -45,31 +42,103 @@ use crate::{buffer::MutableBuffer, datatypes::DataType}; /// Create an array from an iterable argument of sparse byte slices. /// Sparsity means that the input argument can contain `None` items. /// ``` -/// use arrow::array::{Array, FixedSizeBinaryArray}; +/// use arrow_array::{Array, FixedSizeBinaryArray}; /// let input_arg = vec![ None, Some(vec![7, 8]), Some(vec![9, 10]), None, Some(vec![13, 14]) ]; -/// let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); +/// let arr = FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 2).unwrap(); /// assert_eq!(5, arr.len()) /// /// ``` /// +#[derive(Clone)] pub struct FixedSizeBinaryArray { - data: ArrayData, - value_data: RawPtrBox, - length: i32, + data_type: DataType, // Must be DataType::FixedSizeBinary(value_length) + value_data: Buffer, + nulls: Option, + len: usize, + value_length: i32, } impl FixedSizeBinaryArray { + /// Create a new [`FixedSizeBinaryArray`] with `size` element size, panicking on failure + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(size: i32, values: Buffer, nulls: Option) -> Self { + Self::try_new(size, values, nulls).unwrap() + } + + /// Create a new [`FixedSizeBinaryArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `size < 0` + /// * `values.len() / size != nulls.len()` + pub fn try_new( + size: i32, + values: Buffer, + nulls: Option, + ) -> Result { + let data_type = DataType::FixedSizeBinary(size); + let s = size.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) + })?; + + let len = values.len() / s; + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for FixedSizeBinaryArray, expected {} got {}", + len, + n.len(), + ))); + } + } + + Ok(Self { + data_type, + value_data: values, + value_length: size, + nulls, + len, + }) + } + + /// Create a new [`FixedSizeBinaryArray`] of length `len` where all values are null + /// + /// # Panics + /// + /// Panics if + /// + /// * `size < 0` + /// * `size * len` would overflow `usize` + pub fn new_null(size: i32, len: usize) -> Self { + let capacity = size.to_usize().unwrap().checked_mul(len).unwrap(); + Self { + data_type: DataType::FixedSizeBinary(size), + value_data: MutableBuffer::new(capacity).into(), + nulls: Some(NullBuffer::new_null(len)), + value_length: size, + len, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (i32, Buffer, Option) { + (self.value_length, self.value_data, self.nulls) + } + /// Returns the element at index `i` as a byte slice. /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &[u8] { assert!( - i < self.data.len(), + i < self.len(), "Trying to access an element at index {} from a FixedSizeBinaryArray of length {}", i, self.len() ); - let offset = i + self.data.offset(); + let offset = i + self.offset(); unsafe { let pos = self.value_offset_at(offset); std::slice::from_raw_parts( @@ -83,7 +152,7 @@ impl FixedSizeBinaryArray { /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { - let offset = i + self.data.offset(); + let offset = i + self.offset(); let pos = self.value_offset_at(offset); std::slice::from_raw_parts( self.value_data.as_ptr().offset(pos as isize), @@ -96,7 +165,7 @@ impl FixedSizeBinaryArray { /// Note this doesn't do any bound checking, for performance reason. #[inline] pub fn value_offset(&self, i: usize) -> i32 { - self.value_offset_at(self.data.offset() + i) + self.value_offset_at(self.offset() + i) } /// Returns the length for an element. @@ -104,12 +173,39 @@ impl FixedSizeBinaryArray { /// All elements have the same length as the array is a fixed size. #[inline] pub fn value_length(&self) -> i32 { - self.length + self.value_length + } + + /// Returns the values of this array. + /// + /// Unlike [`Self::value_data`] this returns the [`Buffer`] + /// allowing for zero-copy cloning. + #[inline] + pub fn values(&self) -> &Buffer { + &self.value_data } - /// Returns a clone of the value data buffer - pub fn value_data(&self) -> Buffer { - self.data.buffers()[0].clone() + /// Returns the raw value data. + pub fn value_data(&self) -> &[u8] { + self.value_data.as_slice() + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced FixedSizeBinaryArray cannot exceed the existing length" + ); + + let size = self.value_length as usize; + + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)), + value_length: self.value_length, + value_data: self.value_data.slice_with_length(offset * size, len * size), + len, + } } /// Create an array from an iterable argument of sparse byte slices. @@ -119,7 +215,7 @@ impl FixedSizeBinaryArray { /// # Examples /// /// ``` - /// use arrow::array::FixedSizeBinaryArray; + /// use arrow_array::FixedSizeBinaryArray; /// let input_arg = vec![ /// None, /// Some(vec![7, 8]), @@ -134,7 +230,10 @@ impl FixedSizeBinaryArray { /// # Errors /// /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_sparse_iter(mut iter: T) -> Result + #[deprecated( + note = "This function will fail if the iterator produces only None values; prefer `try_from_sparse_iter_with_size`" + )] + pub fn try_from_sparse_iter(mut iter: T) -> Result where T: Iterator>, U: AsRef<[u8]>, @@ -142,10 +241,13 @@ impl FixedSizeBinaryArray { let mut len = 0; let mut size = None; let mut byte = 0; - let mut null_buf = MutableBuffer::from_len_zeroed(0); - let mut buffer = MutableBuffer::from_len_zeroed(0); + + let iter_size_hint = iter.size_hint().0; + let mut null_buf = MutableBuffer::new(bit_util::ceil(iter_size_hint, 8)); + let mut buffer = MutableBuffer::new(0); + let mut prepend = 0; - iter.try_for_each(|item| -> Result<()> { + iter.try_for_each(|item| -> Result<(), ArrowError> { // extend null bitmask by one byte per each 8 items if byte == 0 { null_buf.push(0u8); @@ -164,7 +266,12 @@ impl FixedSizeBinaryArray { ))); } } else { - size = Some(slice.len()); + let len = slice.len(); + size = Some(len); + // Now that we know how large each element is we can reserve + // sufficient capacity in the underlying mutable buffer for + // the data. + buffer.reserve(iter_size_hint * len); buffer.extend_zeros(slice.len() * prepend); } bit_util::set_bit(null_buf.as_slice_mut(), len); @@ -186,19 +293,94 @@ impl FixedSizeBinaryArray { )); } - let size = size.unwrap_or(0); - let array_data = unsafe { - ArrayData::new_unchecked( - DataType::FixedSizeBinary(size as i32), - len, - None, - Some(null_buf.into()), - 0, - vec![buffer.into()], - vec![], - ) - }; - Ok(FixedSizeBinaryArray::from(array_data)) + let null_buf = BooleanBuffer::new(null_buf.into(), 0, len); + let nulls = Some(NullBuffer::new(null_buf)).filter(|n| n.null_count() > 0); + + let size = size.unwrap_or(0) as i32; + Ok(Self { + data_type: DataType::FixedSizeBinary(size), + value_data: buffer.into(), + nulls, + value_length: size, + len, + }) + } + + /// Create an array from an iterable argument of sparse byte slices. + /// Sparsity means that items returned by the iterator are optional, i.e input argument can + /// contain `None` items. In cases where the iterator returns only `None` values, this + /// also takes a size parameter to ensure that the a valid FixedSizeBinaryArray is still + /// created. + /// + /// # Examples + /// + /// ``` + /// use arrow_array::FixedSizeBinaryArray; + /// let input_arg = vec![ + /// None, + /// Some(vec![7, 8]), + /// Some(vec![9, 10]), + /// None, + /// Some(vec![13, 14]), + /// None, + /// ]; + /// let array = FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 2).unwrap(); + /// ``` + /// + /// # Errors + /// + /// Returns error if argument has length zero, or sizes of nested slices don't match. + pub fn try_from_sparse_iter_with_size(mut iter: T, size: i32) -> Result + where + T: Iterator>, + U: AsRef<[u8]>, + { + let mut len = 0; + let mut byte = 0; + + let iter_size_hint = iter.size_hint().0; + let mut null_buf = MutableBuffer::new(bit_util::ceil(iter_size_hint, 8)); + let mut buffer = MutableBuffer::new(iter_size_hint * (size as usize)); + + iter.try_for_each(|item| -> Result<(), ArrowError> { + // extend null bitmask by one byte per each 8 items + if byte == 0 { + null_buf.push(0u8); + byte = 8; + } + byte -= 1; + + if let Some(slice) = item { + let slice = slice.as_ref(); + if size as usize != slice.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Nested array size mismatch: one is {}, and the other is {}", + size, + slice.len() + ))); + } + + bit_util::set_bit(null_buf.as_slice_mut(), len); + buffer.extend_from_slice(slice); + } else { + buffer.extend_zeros(size as usize); + } + + len += 1; + + Ok(()) + })?; + + let null_buf = BooleanBuffer::new(null_buf.into(), 0, len); + let nulls = Some(NullBuffer::new(null_buf)).filter(|n| n.null_count() > 0); + + Ok(Self { + data_type: DataType::FixedSizeBinary(size), + value_data: buffer.into(), + nulls, + len, + value_length: size, + }) } /// Create an array from an iterable argument of byte slices. @@ -206,7 +388,7 @@ impl FixedSizeBinaryArray { /// # Examples /// /// ``` - /// use arrow::array::FixedSizeBinaryArray; + /// use arrow_array::FixedSizeBinaryArray; /// let input_arg = vec![ /// vec![1, 2], /// vec![3, 4], @@ -218,15 +400,17 @@ impl FixedSizeBinaryArray { /// # Errors /// /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_iter(mut iter: T) -> Result + pub fn try_from_iter(mut iter: T) -> Result where T: Iterator, U: AsRef<[u8]>, { let mut len = 0; let mut size = None; - let mut buffer = MutableBuffer::from_len_zeroed(0); - iter.try_for_each(|item| -> Result<()> { + let iter_size_hint = iter.size_hint().0; + let mut buffer = MutableBuffer::new(0); + + iter.try_for_each(|item| -> Result<(), ArrowError> { let slice = item.as_ref(); if let Some(size) = size { if size != slice.len() { @@ -237,8 +421,11 @@ impl FixedSizeBinaryArray { ))); } } else { - size = Some(slice.len()); + let len = slice.len(); + size = Some(len); + buffer.reserve(iter_size_hint * len); } + buffer.extend_from_slice(slice); len += 1; @@ -252,17 +439,19 @@ impl FixedSizeBinaryArray { )); } - let size = size.unwrap_or(0); - let array_data = ArrayData::builder(DataType::FixedSizeBinary(size as i32)) - .len(len) - .add_buffer(buffer.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Ok(FixedSizeBinaryArray::from(array_data)) + let size = size.unwrap_or(0).try_into().unwrap(); + Ok(Self { + data_type: DataType::FixedSizeBinary(size), + value_data: buffer.into(), + nulls: None, + value_length: size, + len, + }) } #[inline] fn value_offset_at(&self, i: usize) -> i32 { - self.length * i as i32 + self.value_length * i as i32 } /// constructs a new iterator @@ -278,35 +467,48 @@ impl From for FixedSizeBinaryArray { 1, "FixedSizeBinaryArray data should contain 1 buffer only (values)" ); - let value_data = data.buffers()[0].as_ptr(); - let length = match data.data_type() { + let value_length = match data.data_type() { DataType::FixedSizeBinary(len) => *len, _ => panic!("Expected data type to be FixedSizeBinary"), }; + + let size = value_length as usize; + let value_data = + data.buffers()[0].slice_with_length(data.offset() * size, data.len() * size); + Self { - data, - value_data: unsafe { RawPtrBox::new(value_data) }, - length, + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + len: data.len(), + value_data, + value_length, } } } impl From for ArrayData { fn from(array: FixedSizeBinaryArray) -> Self { - array.data + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.len) + .buffers(vec![array.value_data]) + .nulls(array.nulls); + + unsafe { builder.build_unchecked() } } } /// Creates a `FixedSizeBinaryArray` from `FixedSizeList` array impl From for FixedSizeBinaryArray { fn from(v: FixedSizeListArray) -> Self { + let value_len = v.value_length(); + let v = v.into_data(); assert_eq!( - v.data_ref().child_data().len(), + v.child_data().len(), 1, "FixedSizeBinaryArray can only be created from list array of u8 values \ (i.e. FixedSizeList>)." ); - let child_data = &v.data_ref().child_data()[0]; + let child_data = &v.child_data()[0]; assert_eq!( child_data.child_data().len(), @@ -325,11 +527,11 @@ impl From for FixedSizeBinaryArray { "The child array cannot contain null values." ); - let builder = ArrayData::builder(DataType::FixedSizeBinary(v.value_length())) + let builder = ArrayData::builder(DataType::FixedSizeBinary(value_len)) .len(v.len()) .offset(v.offset()) .add_buffer(child_data.buffers()[0].slice(child_data.offset())) - .null_bit_buffer(v.data_ref().null_buffer().cloned()); + .nulls(v.nulls().cloned()); let data = unsafe { builder.build_unchecked() }; Self::from(data) @@ -338,6 +540,7 @@ impl From for FixedSizeBinaryArray { impl From>> for FixedSizeBinaryArray { fn from(v: Vec>) -> Self { + #[allow(deprecated)] Self::try_from_sparse_iter(v.into_iter()).unwrap() } } @@ -348,11 +551,11 @@ impl From> for FixedSizeBinaryArray { } } -impl fmt::Debug for FixedSizeBinaryArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl std::fmt::Debug for FixedSizeBinaryArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "FixedSizeBinaryArray<{}>\n[\n", self.value_length())?; print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) + std::fmt::Debug::fmt(&array.value(index), f) })?; write!(f, "]") } @@ -363,13 +566,49 @@ impl Array for FixedSizeBinaryArray { self } - fn data(&self) -> &ArrayData { - &self.data + fn to_data(&self) -> ArrayData { + self.clone().into() } fn into_data(self) -> ArrayData { self.into() } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.value_data.capacity(); + if let Some(n) = &self.nulls { + sum += n.buffer().capacity(); + } + sum + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } } impl<'a> ArrayAccessor for &'a FixedSizeBinaryArray { @@ -395,13 +634,10 @@ impl<'a> IntoIterator for &'a FixedSizeBinaryArray { #[cfg(test)] mod tests { + use crate::RecordBatch; + use arrow_schema::{Field, Schema}; use std::sync::Arc; - use crate::{ - datatypes::{Field, Schema}, - record_batch::RecordBatch, - }; - use super::*; #[test] @@ -452,9 +688,9 @@ mod tests { fixed_size_binary_array.value(1) ); assert_eq!(2, fixed_size_binary_array.len()); - assert_eq!(5, fixed_size_binary_array.value_offset(0)); + assert_eq!(0, fixed_size_binary_array.value_offset(0)); assert_eq!(5, fixed_size_binary_array.value_length()); - assert_eq!(10, fixed_size_binary_array.value_offset(1)); + assert_eq!(5, fixed_size_binary_array.value_offset(1)); } #[test] @@ -463,19 +699,19 @@ mod tests { let values_data = ArrayData::builder(DataType::UInt8) .len(12) .offset(2) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); // [null, [10, 11, 12, 13]] let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::UInt8, false)), + Arc::new(Field::new("item", DataType::UInt8, false)), 4, )) .len(2) .offset(1) .add_child_data(values_data) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b101]))) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b101]))) .build_unchecked() }; let list_array = FixedSizeListArray::from(array_data); @@ -499,13 +735,13 @@ mod tests { let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt32) .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Binary, false)), + Arc::new(Field::new("item", DataType::Binary, false)), 4, )) .len(3) @@ -522,14 +758,14 @@ mod tests { let values = [0_u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt8) .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b101010101010]))) + .add_buffer(Buffer::from_slice_ref(values)) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b101010101010]))) .build() .unwrap(); let array_data = unsafe { ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::UInt8, false)), + Arc::new(Field::new("item", DataType::UInt8, false)), 4, )) .len(3) @@ -552,7 +788,7 @@ mod tests { let arr = FixedSizeBinaryArray::from(array_data); assert_eq!( "FixedSizeBinaryArray<5>\n[\n [104, 101, 108, 108, 111],\n [116, 104, 101, 114, 101],\n [97, 114, 114, 111, 119],\n]", - format!("{:?}", arr) + format!("{arr:?}") ); } @@ -569,8 +805,8 @@ mod tests { fn test_all_none_fixed_size_binary_array_from_sparse_iter() { let none_option: Option<[u8; 32]> = None; let input_arg = vec![none_option, none_option, none_option]; - let arr = - FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + #[allow(deprecated)] + let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); assert_eq!(0, arr.value_length()); assert_eq!(3, arr.len()) } @@ -584,9 +820,24 @@ mod tests { None, Some(vec![13, 14]), ]; + #[allow(deprecated)] + let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.iter().cloned()).unwrap(); + assert_eq!(2, arr.value_length()); + assert_eq!(5, arr.len()); + let arr = - FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 2).unwrap(); assert_eq!(2, arr.value_length()); + assert_eq!(5, arr.len()); + } + + #[test] + fn test_fixed_size_binary_array_from_sparse_iter_with_size_all_none() { + let input_arg = vec![None, None, None, None, None] as Vec>>; + + let arr = FixedSizeBinaryArray::try_from_sparse_iter_with_size(input_arg.into_iter(), 16) + .unwrap(); + assert_eq!(16, arr.value_length()); assert_eq!(5, arr.len()) } @@ -651,25 +902,23 @@ mod tests { #[test] fn fixed_size_binary_array_all_null() { let data = vec![None] as Vec>; - let array = FixedSizeBinaryArray::try_from_sparse_iter(data.into_iter()).unwrap(); + let array = + FixedSizeBinaryArray::try_from_sparse_iter_with_size(data.into_iter(), 0).unwrap(); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } #[test] // Test for https://github.com/apache/arrow-rs/issues/1390 - #[should_panic( - expected = "column types must match schema types, expected FixedSizeBinary(2) but found FixedSizeBinary(0) at column index 0" - )] fn fixed_size_binary_array_all_null_in_batch_with_schema() { - let schema = - Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); + let schema = Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); let none_option: Option<[u8; 2]> = None; - let item = FixedSizeBinaryArray::try_from_sparse_iter( + let item = FixedSizeBinaryArray::try_from_sparse_iter_with_size( vec![none_option, none_option, none_option].into_iter(), + 2, ) .unwrap(); @@ -687,4 +936,31 @@ mod tests { array.value(4); } + + #[test] + fn test_constructors() { + let buffer = Buffer::from_vec(vec![0_u8; 10]); + let a = FixedSizeBinaryArray::new(2, buffer.clone(), None); + assert_eq!(a.len(), 5); + + let nulls = NullBuffer::new_null(5); + FixedSizeBinaryArray::new(2, buffer.clone(), Some(nulls)); + + let a = FixedSizeBinaryArray::new(3, buffer.clone(), None); + assert_eq!(a.len(), 3); + + let nulls = NullBuffer::new_null(3); + FixedSizeBinaryArray::new(3, buffer.clone(), Some(nulls)); + + let err = FixedSizeBinaryArray::try_new(-1, buffer.clone(), None).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Size cannot be negative, got -1" + ); + + let nulls = NullBuffer::new_null(3); + let err = FixedSizeBinaryArray::try_new(2, buffer, Some(nulls)).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeBinaryArray, expected 5 got 3"); + } } diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs new file mode 100644 index 000000000000..f8f01516e3d4 --- /dev/null +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -0,0 +1,677 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::print_long_array; +use crate::builder::{FixedSizeListBuilder, PrimitiveBuilder}; +use crate::iterator::FixedSizeListIter; +use crate::{make_array, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::ArrowNativeType; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// An array of [fixed length lists], similar to JSON arrays +/// (e.g. `["A", "B"]`). +/// +/// Lists are represented using a `values` child +/// array where each list has a fixed size of `value_length`. +/// +/// Use [`FixedSizeListBuilder`] to construct a [`FixedSizeListArray`]. +/// +/// # Representation +/// +/// A [`FixedSizeListArray`] can represent a list of values of any other +/// supported Arrow type. Each element of the `FixedSizeListArray` itself is +/// a list which may contain NULL and non-null values, +/// or may itself be NULL. +/// +/// For example, this `FixedSizeListArray` stores lists of strings: +/// +/// ```text +/// ┌─────────────┐ +/// │ [A,B] │ +/// ├─────────────┤ +/// │ NULL │ +/// ├─────────────┤ +/// │ [C,NULL] │ +/// └─────────────┘ +/// ``` +/// +/// The `values` of this `FixedSizeListArray`s are stored in a child +/// [`StringArray`] where logical null values take up `values_length` slots in the array +/// as shown in the following diagram. The logical values +/// are shown on the left, and the actual `FixedSizeListArray` encoding on the right +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─┐ +/// ┌─────────────┐ │ ┌───┐ ┌───┐ ┌──────┐ │ +/// │ [A,B] │ │ 1 │ │ │ 1 │ │ A │ │ 0 +/// ├─────────────┤ │ ├───┤ ├───┤ ├──────┤ │ +/// │ NULL │ │ 0 │ │ │ 1 │ │ B │ │ 1 +/// ├─────────────┤ │ ├───┤ ├───┤ ├──────┤ │ +/// │ [C,NULL] │ │ 1 │ │ │ 0 │ │ ???? │ │ 2 +/// └─────────────┘ │ └───┘ ├───┤ ├──────┤ │ +/// | │ 0 │ │ ???? │ │ 3 +/// Logical Values │ Validity ├───┤ ├──────┤ │ +/// (nulls) │ │ 1 │ │ C │ │ 4 +/// │ ├───┤ ├──────┤ │ +/// │ │ 0 │ │ ???? │ │ 5 +/// │ └───┘ └──────┘ │ +/// │ Values │ +/// │ FixedSizeListArray (Array) │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─┘ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ``` +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{Array, FixedSizeListArray, Int32Array}; +/// # use arrow_data::ArrayData; +/// # use arrow_schema::{DataType, Field}; +/// # use arrow_buffer::Buffer; +/// // Construct a value array +/// let value_data = ArrayData::builder(DataType::Int32) +/// .len(9) +/// .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) +/// .build() +/// .unwrap(); +/// let list_data_type = DataType::FixedSizeList( +/// Arc::new(Field::new("item", DataType::Int32, false)), +/// 3, +/// ); +/// let list_data = ArrayData::builder(list_data_type.clone()) +/// .len(3) +/// .add_child_data(value_data.clone()) +/// .build() +/// .unwrap(); +/// let list_array = FixedSizeListArray::from(list_data); +/// let list0 = list_array.value(0); +/// let list1 = list_array.value(1); +/// let list2 = list_array.value(2); +/// +/// assert_eq!( &[0, 1, 2], list0.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[3, 4, 5], list1.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[6, 7, 8], list2.as_any().downcast_ref::().unwrap().values()); +/// ``` +/// +/// [`StringArray`]: crate::array::StringArray +/// [fixed size arrays](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-list-layout) +#[derive(Clone)] +pub struct FixedSizeListArray { + data_type: DataType, // Must be DataType::FixedSizeList(value_length) + values: ArrayRef, + nulls: Option, + value_length: i32, + len: usize, +} + +impl FixedSizeListArray { + /// Create a new [`FixedSizeListArray`] with `size` element size, panicking on failure + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(field: FieldRef, size: i32, values: ArrayRef, nulls: Option) -> Self { + Self::try_new(field, size, values, nulls).unwrap() + } + + /// Create a new [`FixedSizeListArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// * `size < 0` + /// * `values.len() / size != nulls.len()` + /// * `values.data_type() != field.data_type()` + /// * `!field.is_nullable() && !nulls.expand(size).contains(values.logical_nulls())` + pub fn try_new( + field: FieldRef, + size: i32, + values: ArrayRef, + nulls: Option, + ) -> Result { + let s = size.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) + })?; + + let len = values.len() / s.max(1); + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for FixedSizeListArray, expected {} got {}", + len, + n.len(), + ))); + } + } + + if field.data_type() != values.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedSizeListArray expected data type {} got {} for {:?}", + field.data_type(), + values.data_type(), + field.name() + ))); + } + + if let Some(a) = values.logical_nulls() { + let nulls_valid = field.is_nullable() + || nulls + .as_ref() + .map(|n| n.expand(size as _).contains(&a)) + .unwrap_or_default(); + + if !nulls_valid { + return Err(ArrowError::InvalidArgumentError(format!( + "Found unmasked nulls for non-nullable FixedSizeListArray field {:?}", + field.name() + ))); + } + } + + let data_type = DataType::FixedSizeList(field, size); + Ok(Self { + data_type, + values, + value_length: size, + nulls, + len, + }) + } + + /// Create a new [`FixedSizeListArray`] of length `len` where all values are null + /// + /// # Panics + /// + /// Panics if + /// + /// * `size < 0` + /// * `size * len` would overflow `usize` + pub fn new_null(field: FieldRef, size: i32, len: usize) -> Self { + let capacity = size.to_usize().unwrap().checked_mul(len).unwrap(); + Self { + values: make_array(ArrayData::new_null(field.data_type(), capacity)), + data_type: DataType::FixedSizeList(field, size), + nulls: Some(NullBuffer::new_null(len)), + value_length: size, + len, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (FieldRef, i32, ArrayRef, Option) { + let f = match self.data_type { + DataType::FixedSizeList(f, _) => f, + _ => unreachable!(), + }; + (f, self.value_length, self.values, self.nulls) + } + + /// Returns a reference to the values of this list. + pub fn values(&self) -> &ArrayRef { + &self.values + } + + /// Returns a clone of the value type of this list. + pub fn value_type(&self) -> DataType { + self.values.data_type().clone() + } + + /// Returns ith value of this list array. + pub fn value(&self, i: usize) -> ArrayRef { + self.values + .slice(self.value_offset(i) as usize, self.value_length() as usize) + } + + /// Returns the offset for value at index `i`. + /// + /// Note this doesn't do any bound checking, for performance reason. + #[inline] + pub fn value_offset(&self, i: usize) -> i32 { + self.value_offset_at(i) + } + + /// Returns the length for an element. + /// + /// All elements have the same length as the array is a fixed size. + #[inline] + pub const fn value_length(&self) -> i32 { + self.value_length + } + + #[inline] + const fn value_offset_at(&self, i: usize) -> i32 { + i as i32 * self.value_length + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced FixedSizeListArray cannot exceed the existing length" + ); + let size = self.value_length as usize; + + Self { + data_type: self.data_type.clone(), + values: self.values.slice(offset * size, len * size), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)), + value_length: self.value_length, + len, + } + } + + /// Creates a [`FixedSizeListArray`] from an iterator of primitive values + /// # Example + /// ``` + /// # use arrow_array::FixedSizeListArray; + /// # use arrow_array::types::Int32Type; + /// + /// let data = vec![ + /// Some(vec![Some(0), Some(1), Some(2)]), + /// None, + /// Some(vec![Some(3), None, Some(5)]), + /// Some(vec![Some(6), Some(7), Some(45)]), + /// ]; + /// let list_array = FixedSizeListArray::from_iter_primitive::(data, 3); + /// println!("{:?}", list_array); + /// ``` + pub fn from_iter_primitive(iter: I, length: i32) -> Self + where + T: ArrowPrimitiveType, + P: IntoIterator::Native>>, + I: IntoIterator>, + { + let l = length as usize; + let iter = iter.into_iter(); + let size_hint = iter.size_hint().0; + let mut builder = FixedSizeListBuilder::with_capacity( + PrimitiveBuilder::::with_capacity(size_hint * l), + length, + size_hint, + ); + + for i in iter { + match i { + Some(p) => { + for t in p { + builder.values().append_option(t); + } + builder.append(true); + } + None => { + builder.values().append_nulls(l); + builder.append(false) + } + } + } + builder.finish() + } + + /// constructs a new iterator + pub fn iter(&self) -> FixedSizeListIter<'_> { + FixedSizeListIter::new(self) + } +} + +impl From for FixedSizeListArray { + fn from(data: ArrayData) -> Self { + let value_length = match data.data_type() { + DataType::FixedSizeList(_, len) => *len, + _ => { + panic!("FixedSizeListArray data should contain a FixedSizeList data type") + } + }; + + let size = value_length as usize; + let values = + make_array(data.child_data()[0].slice(data.offset() * size, data.len() * size)); + Self { + data_type: data.data_type().clone(), + values, + nulls: data.nulls().cloned(), + value_length, + len: data.len(), + } + } +} + +impl From for ArrayData { + fn from(array: FixedSizeListArray) -> Self { + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.len) + .nulls(array.nulls) + .child_data(vec![array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl Array for FixedSizeListArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.values.get_buffer_memory_size(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = std::mem::size_of::() + self.values.get_array_memory_size(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } +} + +impl ArrayAccessor for FixedSizeListArray { + type Item = ArrayRef; + + fn value(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } +} + +impl std::fmt::Debug for FixedSizeListArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "FixedSizeListArray<{}>\n[\n", self.value_length())?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl<'a> ArrayAccessor for &'a FixedSizeListArray { + type Item = ArrayRef; + + fn value(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } +} + +#[cfg(test)] +mod tests { + use arrow_buffer::{bit_util, BooleanBuffer, Buffer}; + use arrow_schema::Field; + + use crate::cast::AsArray; + use crate::types::Int32Type; + use crate::Int32Array; + + use super::*; + + #[test] + fn test_fixed_size_list_array() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8])) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + assert_eq!(value_data, list_array.values().to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + assert_eq!(0, list_array.value(0).as_primitive::().value(0)); + for i in 0..3 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + + // Now test with a non-zero offset + let list_data = ArrayData::builder(list_data_type) + .len(2) + .offset(1) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + assert_eq!(value_data.slice(3, 6), list_array.values().to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(2, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(3, list_array.value(0).as_primitive::().value(0)); + assert_eq!(3, list_array.value_offset(1)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + #[should_panic(expected = "assertion failed: (offset + length) <= self.len()")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_fixed_size_list_array_unequal_children() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build_unchecked() + }; + drop(FixedSizeListArray::from(list_data)); + } + + #[test] + fn test_fixed_size_list_array_slice() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Set null buts for the nested array: + // [[0, 1], null, null, [6, 7], [8, 9]] + // 01011001 00000001 + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + + // Construct a fixed size list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data.clone()) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + assert_eq!(value_data, list_array.values().to_data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(5, list_array.len()); + assert_eq!(2, list_array.null_count()); + assert_eq!(6, list_array.value_offset(3)); + assert_eq!(2, list_array.value_length()); + + let sliced_array = list_array.slice(1, 4); + assert_eq!(4, sliced_array.len()); + assert_eq!(2, sliced_array.null_count()); + + for i in 0..sliced_array.len() { + if bit_util::get_bit(&null_bits, 1 + i) { + assert!(sliced_array.is_valid(i)); + } else { + assert!(sliced_array.is_null(i)); + } + } + + // Check offset and length for each non-null value. + let sliced_list_array = sliced_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(2, sliced_list_array.value_length()); + assert_eq!(4, sliced_list_array.value_offset(2)); + assert_eq!(6, sliced_list_array.value_offset(3)); + } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn test_fixed_size_list_array_index_out_of_bound() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Set null buts for the nested array: + // [[0, 1], null, null, [6, 7], [8, 9]] + // 01011001 00000001 + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + + // Construct a fixed size list array from the above two + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 2); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + list_array.value(10); + } + + #[test] + fn test_fixed_size_list_constructors() { + let values = Arc::new(Int32Array::from_iter([ + Some(1), + Some(2), + None, + None, + Some(3), + Some(4), + ])); + + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let list = FixedSizeListArray::new(field.clone(), 2, values.clone(), None); + assert_eq!(list.len(), 3); + + let nulls = NullBuffer::new_null(3); + let list = FixedSizeListArray::new(field.clone(), 2, values.clone(), Some(nulls)); + assert_eq!(list.len(), 3); + + let list = FixedSizeListArray::new(field.clone(), 4, values.clone(), None); + assert_eq!(list.len(), 1); + + let err = FixedSizeListArray::try_new(field.clone(), -1, values.clone(), None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Size cannot be negative, got -1" + ); + + let list = FixedSizeListArray::new(field.clone(), 0, values.clone(), None); + assert_eq!(list.len(), 6); + + let nulls = NullBuffer::new_null(2); + let err = FixedSizeListArray::try_new(field, 2, values.clone(), Some(nulls)).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeListArray, expected 3 got 2"); + + let field = Arc::new(Field::new("item", DataType::Int32, false)); + let err = FixedSizeListArray::try_new(field.clone(), 2, values.clone(), None).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Found unmasked nulls for non-nullable FixedSizeListArray field \"item\""); + + // Valid as nulls in child masked by parent + let nulls = NullBuffer::new(BooleanBuffer::new(vec![0b0000101].into(), 0, 3)); + FixedSizeListArray::new(field, 2, values.clone(), Some(nulls)); + + let field = Arc::new(Field::new("item", DataType::Int64, true)); + let err = FixedSizeListArray::try_new(field, 2, values, None).unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: FixedSizeListArray expected data type Int64 got Int32 for \"item\""); + } +} diff --git a/arrow/src/array/array_list.rs b/arrow-array/src/array/list_array.rs similarity index 50% rename from arrow/src/array/array_list.rs rename to arrow-array/src/array/list_array.rs index b9c05014c3f7..9758c112a1ef 100644 --- a/arrow/src/array/array_list.rs +++ b/arrow-array/src/array/list_array.rs @@ -15,25 +15,32 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::fmt; - -use num::Integer; - -use super::{ - array::print_long_array, make_array, raw_pointer::RawPtrBox, Array, ArrayData, - ArrayRef, BooleanBufferBuilder, GenericListArrayIter, PrimitiveArray, -}; -use crate::array::array::ArrayAccessor; +use crate::array::{get_offsets, make_array, print_long_array}; +use crate::builder::{GenericListBuilder, PrimitiveBuilder}; use crate::{ - buffer::MutableBuffer, - datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Field}, - error::ArrowError, + iterator::GenericListArrayIter, new_empty_array, Array, ArrayAccessor, ArrayRef, + ArrowPrimitiveType, FixedSizeListArray, }; +use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, FieldRef}; +use num::Integer; +use std::any::Any; +use std::sync::Arc; -/// trait declaring an offset size, relevant for i32 vs i64 array types. +/// A type that can be used within a variable-size array to encode offset information +/// +/// See [`ListArray`], [`LargeListArray`], [`BinaryArray`], [`LargeBinaryArray`], +/// [`StringArray`] and [`LargeStringArray`] +/// +/// [`BinaryArray`]: crate::array::BinaryArray +/// [`LargeBinaryArray`]: crate::array::LargeBinaryArray +/// [`StringArray`]: crate::array::StringArray +/// [`LargeStringArray`]: crate::array::LargeStringArray pub trait OffsetSizeTrait: ArrowNativeType + std::ops::AddAssign + Integer { + /// True for 64 bit offset size and false for 32 bit offset size const IS_LARGE: bool; + /// Prefix for the offset size const PREFIX: &'static str; } @@ -47,69 +54,247 @@ impl OffsetSizeTrait for i64 { const PREFIX: &'static str = "Large"; } -/// Generic struct for a variable-size list array. +/// An array of [variable length lists], similar to JSON arrays +/// (e.g. `["A", "B", "C"]`). +/// +/// Lists are represented using `offsets` into a `values` child +/// array. Offsets are stored in two adjacent entries of an +/// [`OffsetBuffer`]. +/// +/// Arrow defines [`ListArray`] with `i32` offsets and +/// [`LargeListArray`] with `i64` offsets. +/// +/// Use [`GenericListBuilder`] to construct a [`GenericListArray`]. /// -/// Columnar format in Apache Arrow: -/// +/// # Representation /// -/// For non generic lists, you may wish to consider using [`ListArray`] or [`LargeListArray`]` -pub struct GenericListArray { - data: ArrayData, +/// A [`ListArray`] can represent a list of values of any other +/// supported Arrow type. Each element of the `ListArray` itself is +/// a list which may be empty, may contain NULL and non-null values, +/// or may itself be NULL. +/// +/// For example, the `ListArray` shown in the following diagram stores +/// lists of strings. Note that `[]` represents an empty (length +/// 0), but non NULL list. +/// +/// ```text +/// ┌─────────────┐ +/// │ [A,B,C] │ +/// ├─────────────┤ +/// │ [] │ +/// ├─────────────┤ +/// │ NULL │ +/// ├─────────────┤ +/// │ [D] │ +/// ├─────────────┤ +/// │ [NULL, F] │ +/// └─────────────┘ +/// ``` +/// +/// The `values` are stored in a child [`StringArray`] and the offsets +/// are stored in an [`OffsetBuffer`] as shown in the following +/// diagram. The logical values and offsets are shown on the left, and +/// the actual `ListArray` encoding on the right. +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌ ─ ─ ─ ─ ─ ─ ┐ │ +/// ┌─────────────┐ ┌───────┐ │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ +/// │ [A,B,C] │ │ (0,3) │ │ 1 │ │ 0 │ │ │ 1 │ │ A │ │ 0 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [] │ │ (3,3) │ │ 1 │ │ 3 │ │ │ 1 │ │ B │ │ 1 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ NULL │ │ (3,4) │ │ 0 │ │ 3 │ │ │ 1 │ │ C │ │ 2 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [D] │ │ (4,5) │ │ 1 │ │ 4 │ │ │ ? │ │ ? │ │ 3 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ +/// │ [NULL, F] │ │ (5,7) │ │ 1 │ │ 5 │ │ │ 1 │ │ D │ │ 4 │ +/// └─────────────┘ └───────┘ │ └───┘ ├───┤ ├───┤ ├───┤ +/// │ 7 │ │ │ 0 │ │ ? │ │ 5 │ +/// │ Validity └───┘ ├───┤ ├───┤ +/// Logical Logical (nulls) Offsets │ │ 1 │ │ F │ │ 6 │ +/// Values Offsets │ └───┘ └───┘ +/// │ Values │ │ +/// (offsets[i], │ ListArray (Array) +/// offsets[i+1]) └ ─ ─ ─ ─ ─ ─ ┘ │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// +/// +/// ``` +/// +/// [`StringArray`]: crate::array::StringArray +/// [variable length lists]: https://arrow.apache.org/docs/format/Columnar.html#variable-size-list-layout +pub struct GenericListArray { + data_type: DataType, + nulls: Option, values: ArrayRef, - value_offsets: RawPtrBox, + value_offsets: OffsetBuffer, +} + +impl Clone for GenericListArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.clone(), + values: self.values.clone(), + value_offsets: self.value_offsets.clone(), + } + } } impl GenericListArray { /// The data type constructor of list array. /// The input is the schema of the child array and /// the output is the [`DataType`], List or LargeList. - pub const DATA_TYPE_CONSTRUCTOR: fn(Box) -> DataType = if OffsetSize::IS_LARGE - { + pub const DATA_TYPE_CONSTRUCTOR: fn(FieldRef) -> DataType = if OffsetSize::IS_LARGE { DataType::LargeList } else { DataType::List }; - /// Returns a reference to the values of this list. - pub fn values(&self) -> ArrayRef { - self.values.clone() + /// Create a new [`GenericListArray`] from the provided parts + /// + /// # Errors + /// + /// Errors if + /// + /// * `offsets.len() - 1 != nulls.len()` + /// * `offsets.last() > values.len()` + /// * `!field.is_nullable() && values.is_nullable()` + /// * `field.data_type() != values.data_type()` + pub fn try_new( + field: FieldRef, + offsets: OffsetBuffer, + values: ArrayRef, + nulls: Option, + ) -> Result { + let len = offsets.len() - 1; // Offsets guaranteed to not be empty + let end_offset = offsets.last().unwrap().as_usize(); + // don't need to check other values of `offsets` because they are checked + // during construction of `OffsetBuffer` + if end_offset > values.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Max offset of {end_offset} exceeds length of values {}", + values.len() + ))); + } + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for {}ListArray, expected {len} got {}", + OffsetSize::PREFIX, + n.len(), + ))); + } + } + if !field.is_nullable() && values.is_nullable() { + return Err(ArrowError::InvalidArgumentError(format!( + "Non-nullable field of {}ListArray {:?} cannot contain nulls", + OffsetSize::PREFIX, + field.name() + ))); + } + + if field.data_type() != values.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "{}ListArray expected data type {} got {} for {:?}", + OffsetSize::PREFIX, + field.data_type(), + values.data_type(), + field.name() + ))); + } + + Ok(Self { + data_type: Self::DATA_TYPE_CONSTRUCTOR(field), + nulls, + values, + value_offsets: offsets, + }) + } + + /// Create a new [`GenericListArray`] from the provided parts + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new( + field: FieldRef, + offsets: OffsetBuffer, + values: ArrayRef, + nulls: Option, + ) -> Self { + Self::try_new(field, offsets, values, nulls).unwrap() + } + + /// Create a new [`GenericListArray`] of length `len` where all values are null + pub fn new_null(field: FieldRef, len: usize) -> Self { + let values = new_empty_array(field.data_type()); + Self { + data_type: Self::DATA_TYPE_CONSTRUCTOR(field), + nulls: Some(NullBuffer::new_null(len)), + value_offsets: OffsetBuffer::new_zeroed(len), + values, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts( + self, + ) -> ( + FieldRef, + OffsetBuffer, + ArrayRef, + Option, + ) { + let f = match self.data_type { + DataType::List(f) | DataType::LargeList(f) => f, + _ => unreachable!(), + }; + (f, self.value_offsets, self.values, self.nulls) + } + + /// Returns a reference to the offsets of this list + /// + /// Unlike [`Self::value_offsets`] this returns the [`OffsetBuffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn offsets(&self) -> &OffsetBuffer { + &self.value_offsets + } + + /// Returns a reference to the values of this list + #[inline] + pub fn values(&self) -> &ArrayRef { + &self.values } /// Returns a clone of the value type of this list. pub fn value_type(&self) -> DataType { - self.values.data_ref().data_type().clone() + self.values.data_type().clone() } /// Returns ith value of this list array. /// # Safety /// Caller must ensure that the index is within the array bounds pub unsafe fn value_unchecked(&self, i: usize) -> ArrayRef { - let end = *self.value_offsets().get_unchecked(i + 1); - let start = *self.value_offsets().get_unchecked(i); - self.values - .slice(start.to_usize().unwrap(), (end - start).to_usize().unwrap()) + let end = self.value_offsets().get_unchecked(i + 1).as_usize(); + let start = self.value_offsets().get_unchecked(i).as_usize(); + self.values.slice(start, end - start) } /// Returns ith value of this list array. pub fn value(&self, i: usize) -> ArrayRef { - let end = self.value_offsets()[i + 1]; - let start = self.value_offsets()[i]; - self.values - .slice(start.to_usize().unwrap(), (end - start).to_usize().unwrap()) + let end = self.value_offsets()[i + 1].as_usize(); + let start = self.value_offsets()[i].as_usize(); + self.values.slice(start, end - start) } /// Returns the offset values in the offsets buffer #[inline] pub fn value_offsets(&self) -> &[OffsetSize] { - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the ArrayData instance. - unsafe { - std::slice::from_raw_parts( - self.value_offsets.as_ptr().add(self.data.offset()), - self.len() + 1, - ) - } + &self.value_offsets } /// Returns the length for value at index `i`. @@ -134,11 +319,22 @@ impl GenericListArray { } } + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + values: self.values.clone(), + value_offsets: self.value_offsets.slice(offset, length), + } + } + /// Creates a [`GenericListArray`] from an iterator of primitive values /// # Example /// ``` - /// # use arrow::array::ListArray; - /// # use arrow::datatypes::Int32Type; + /// # use arrow_array::ListArray; + /// # use arrow_array::types::Int32Type; + /// /// let data = vec![ /// Some(vec![Some(0), Some(1), Some(2)]), /// None, @@ -151,72 +347,74 @@ impl GenericListArray { pub fn from_iter_primitive(iter: I) -> Self where T: ArrowPrimitiveType, - P: AsRef<[Option<::Native>]> - + IntoIterator::Native>>, + P: IntoIterator::Native>>, I: IntoIterator>, { - let iterator = iter.into_iter(); - let (lower, _) = iterator.size_hint(); - - let mut offsets = - MutableBuffer::new((lower + 1) * std::mem::size_of::()); - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - let mut null_buf = BooleanBufferBuilder::new(lower); - - let values: PrimitiveArray = iterator - .filter_map(|maybe_slice| { - // regardless of whether the item is Some, the offsets and null buffers must be updated. - match &maybe_slice { - Some(x) => { - length_so_far += - OffsetSize::from_usize(x.as_ref().len()).unwrap(); - null_buf.append(true); + let iter = iter.into_iter(); + let size_hint = iter.size_hint().0; + let mut builder = + GenericListBuilder::with_capacity(PrimitiveBuilder::::new(), size_hint); + + for i in iter { + match i { + Some(p) => { + for t in p { + builder.values().append_option(t); } - None => null_buf.append(false), - }; - offsets.push(length_so_far); - maybe_slice - }) - .flatten() - .collect(); - - let field = Box::new(Field::new("item", T::DATA_TYPE, true)); - let data_type = Self::DATA_TYPE_CONSTRUCTOR(field); - let array_data = ArrayData::builder(data_type) - .len(null_buf.len()) - .add_buffer(offsets.into()) - .add_child_data(values.into_data()) - .null_bit_buffer(Some(null_buf.into())); - let array_data = unsafe { array_data.build_unchecked() }; - - Self::from(array_data) + builder.append(true); + } + None => builder.append(false), + } + } + builder.finish() } } impl From for GenericListArray { fn from(data: ArrayData) -> Self { - Self::try_new_from_array_data(data).expect( - "Expected infallable creation of GenericListArray from ArrayDataRef failed", - ) + Self::try_new_from_array_data(data) + .expect("Expected infallible creation of GenericListArray from ArrayDataRef failed") } } -impl From> - for ArrayData -{ +impl From> for ArrayData { fn from(array: GenericListArray) -> Self { - array.data + let len = array.len(); + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .nulls(array.nulls) + .buffers(vec![array.value_offsets.into_inner().into_inner()]) + .child_data(vec![array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl From for GenericListArray { + fn from(value: FixedSizeListArray) -> Self { + let (field, size) = match value.data_type() { + DataType::FixedSizeList(f, size) => (f, *size as usize), + _ => unreachable!(), + }; + + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(size).take(value.len())); + + Self { + data_type: Self::DATA_TYPE_CONSTRUCTOR(field.clone()), + nulls: value.nulls().cloned(), + values: value.values().clone(), + value_offsets: offsets, + } } } impl GenericListArray { fn try_new_from_array_data(data: ArrayData) -> Result { if data.buffers().len() != 1 { - return Err(ArrowError::InvalidArgumentError( - format!("ListArray data should contain a single buffer only (value offsets), had {}", - data.len()))); + return Err(ArrowError::InvalidArgumentError(format!( + "ListArray data should contain a single buffer only (value offsets), had {}", + data.buffers().len() + ))); } if data.child_data().len() != 1 { @@ -245,10 +443,13 @@ impl GenericListArray { } let values = make_array(values); - let value_offsets = data.buffers()[0].as_ptr(); - let value_offsets = unsafe { RawPtrBox::::new(value_offsets) }; + // SAFETY: + // ArrayData is valid, and verified type above + let value_offsets = unsafe { get_offsets(&data) }; + Ok(Self { - data, + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), values, value_offsets, }) @@ -260,13 +461,55 @@ impl Array for GenericListArray { self } - fn data(&self) -> &ArrayData { - &self.data + fn to_data(&self) -> ArrayData { + self.clone().into() } fn into_data(self) -> ArrayData { self.into() } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.value_offsets.len() - 1 + } + + fn is_empty(&self) -> bool { + self.value_offsets.len() <= 1 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.values.get_buffer_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = std::mem::size_of::() + self.values.get_array_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } } impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor for &'a GenericListArray { @@ -281,109 +524,44 @@ impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor for &'a GenericListArray fmt::Debug for GenericListArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl std::fmt::Debug for GenericListArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let prefix = OffsetSize::PREFIX; - write!(f, "{}ListArray\n[\n", prefix)?; + write!(f, "{prefix}ListArray\n[\n")?; print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) + std::fmt::Debug::fmt(&array.value(index), f) })?; write!(f, "]") } } -/// A list array where each element is a variable-sized sequence of values with the same -/// type whose memory offsets between elements are represented by a i32. -/// -/// # Example -/// -/// ``` -/// # use arrow::array::{Array, ListArray, Int32Array}; -/// # use arrow::datatypes::{DataType, Int32Type}; -/// let data = vec![ -/// Some(vec![]), -/// None, -/// Some(vec![Some(3), None, Some(5), Some(19)]), -/// Some(vec![Some(6), Some(7)]), -/// ]; -/// let list_array = ListArray::from_iter_primitive::(data); +/// A [`GenericListArray`] of variable size lists, storing offsets as `i32`. /// -/// assert_eq!(false, list_array.is_valid(1)); -/// -/// let list0 = list_array.value(0); -/// let list2 = list_array.value(2); -/// let list3 = list_array.value(3); -/// -/// assert_eq!(&[] as &[i32], list0.as_any().downcast_ref::().unwrap().values()); -/// assert_eq!(false, list2.as_any().downcast_ref::().unwrap().is_valid(1)); -/// assert_eq!(&[6, 7], list3.as_any().downcast_ref::().unwrap().values()); -/// ``` +// See [`ListBuilder`](crate::builder::ListBuilder) for how to construct a [`ListArray`] pub type ListArray = GenericListArray; -/// A list array where each element is a variable-sized sequence of values with the same -/// type whose memory offsets between elements are represented by a i64. -/// # Example -/// -/// ``` -/// # use arrow::array::{Array, LargeListArray, Int32Array}; -/// # use arrow::datatypes::{DataType, Int32Type}; -/// let data = vec![ -/// Some(vec![]), -/// None, -/// Some(vec![Some(3), None, Some(5), Some(19)]), -/// Some(vec![Some(6), Some(7)]), -/// ]; -/// let list_array = LargeListArray::from_iter_primitive::(data); -/// -/// assert_eq!(false, list_array.is_valid(1)); +/// A [`GenericListArray`] of variable size lists, storing offsets as `i64`. /// -/// let list0 = list_array.value(0); -/// let list2 = list_array.value(2); -/// let list3 = list_array.value(3); -/// -/// assert_eq!(&[] as &[i32], list0.as_any().downcast_ref::().unwrap().values()); -/// assert_eq!(false, list2.as_any().downcast_ref::().unwrap().is_valid(1)); -/// assert_eq!(&[6, 7], list3.as_any().downcast_ref::().unwrap().values()); -/// ``` +// See [`LargeListBuilder`](crate::builder::LargeListBuilder) for how to construct a [`LargeListArray`] pub type LargeListArray = GenericListArray; #[cfg(test)] mod tests { - use crate::{ - alloc, - array::ArrayData, - array::Int32Array, - buffer::Buffer, - datatypes::Field, - datatypes::{Int32Type, ToByteSlice}, - util::bit_util, - }; - use super::*; + use crate::builder::{FixedSizeListBuilder, Int32Builder, ListBuilder}; + use crate::cast::AsArray; + use crate::types::Int32Type; + use crate::{Int32Array, Int64Array}; + use arrow_buffer::{bit_util, Buffer, ScalarBuffer}; + use arrow_schema::Field; fn create_from_buffers() -> ListArray { - // Construct a value array - let value_data = ArrayData::builder(DataType::Int32) - .len(8) - .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) - .build() - .unwrap(); - - // Construct a buffer for value offsets, for the nested array: // [[0, 1, 2], [3, 4, 5], [6, 7]] - let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); - - // Construct a list array from the above two - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, true))); - let list_data = ArrayData::builder(list_data_type) - .len(3) - .add_buffer(value_offsets) - .add_child_data(value_data) - .build() - .unwrap(); - ListArray::from(list_data) + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 6, 8])); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + ListArray::new(field, offsets, Arc::new(values), None) } #[test] @@ -412,8 +590,7 @@ mod tests { let value_offsets = Buffer::from([]); // Construct a list array from the above two - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(0) .add_buffer(value_offsets) @@ -430,17 +607,16 @@ mod tests { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) .len(8) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) .build() .unwrap(); // Construct a buffer for value offsets, for the nested array: // [[0, 1, 2], [3, 4, 5], [6, 7]] - let value_offsets = Buffer::from_slice_ref(&[0, 3, 6, 8]); + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); // Construct a list array from the above two - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type.clone()) .len(3) .add_buffer(value_offsets.clone()) @@ -450,7 +626,7 @@ mod tests { let list_array = ListArray::from(list_data); let values = list_array.values(); - assert_eq!(&value_data, values.data()); + assert_eq!(value_data, values.to_data()); assert_eq!(DataType::Int32, list_array.value_type()); assert_eq!(3, list_array.len()); assert_eq!(0, list_array.null_count()); @@ -490,7 +666,7 @@ mod tests { let list_array = ListArray::from(list_data); let values = list_array.values(); - assert_eq!(&value_data, values.data()); + assert_eq!(value_data, values.to_data()); assert_eq!(DataType::Int32, list_array.value_type()); assert_eq!(2, list_array.len()); assert_eq!(0, list_array.null_count()); @@ -520,17 +696,17 @@ mod tests { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) .len(8) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) .build() .unwrap(); // Construct a buffer for value offsets, for the nested array: // [[0, 1, 2], [3, 4, 5], [6, 7]] - let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 6, 8]); + let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8]); // Construct a list array from the above two let list_data_type = - DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false))); + DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type.clone()) .len(3) .add_buffer(value_offsets.clone()) @@ -540,7 +716,7 @@ mod tests { let list_array = LargeListArray::from(list_data); let values = list_array.values(); - assert_eq!(&value_data, values.data()); + assert_eq!(value_data, values.to_data()); assert_eq!(DataType::Int32, list_array.value_type()); assert_eq!(3, list_array.len()); assert_eq!(0, list_array.null_count()); @@ -580,7 +756,7 @@ mod tests { let list_array = LargeListArray::from(list_data); let values = list_array.values(); - assert_eq!(&value_data, values.data()); + assert_eq!(value_data, values.to_data()); assert_eq!(DataType::Int32, list_array.value_type()); assert_eq!(2, list_array.len()); assert_eq!(0, list_array.null_count()); @@ -610,13 +786,13 @@ mod tests { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) .len(10) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) .build() .unwrap(); // Construct a buffer for value offsets, for the nested array: // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] - let value_offsets = Buffer::from_slice_ref(&[0, 2, 2, 2, 4, 6, 6, 9, 9, 10]); + let value_offsets = Buffer::from_slice_ref([0, 2, 2, 2, 4, 6, 6, 9, 9, 10]); // 01011001 00000001 let mut null_bits: [u8; 2] = [0; 2]; bit_util::set_bit(&mut null_bits, 0); @@ -626,8 +802,7 @@ mod tests { bit_util::set_bit(&mut null_bits, 8); // Construct a list array from the above two - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(9) .add_buffer(value_offsets) @@ -638,7 +813,7 @@ mod tests { let list_array = ListArray::from(list_data); let values = list_array.values(); - assert_eq!(&value_data, values.data()); + assert_eq!(value_data, values.to_data()); assert_eq!(DataType::Int32, list_array.value_type()); assert_eq!(9, list_array.len()); assert_eq!(4, list_array.null_count()); @@ -647,11 +822,10 @@ mod tests { let sliced_array = list_array.slice(1, 6); assert_eq!(6, sliced_array.len()); - assert_eq!(1, sliced_array.offset()); assert_eq!(3, sliced_array.null_count()); for i in 0..sliced_array.len() { - if bit_util::get_bit(&null_bits, sliced_array.offset() + i) { + if bit_util::get_bit(&null_bits, 1 + i) { assert!(sliced_array.is_valid(i)); } else { assert!(sliced_array.is_null(i)); @@ -659,8 +833,7 @@ mod tests { } // Check offset and length for each non-null value. - let sliced_list_array = - sliced_array.as_any().downcast_ref::().unwrap(); + let sliced_list_array = sliced_array.as_any().downcast_ref::().unwrap(); assert_eq!(2, sliced_list_array.value_offsets()[2]); assert_eq!(2, sliced_list_array.value_length(2)); assert_eq!(4, sliced_list_array.value_offsets()[3]); @@ -674,13 +847,13 @@ mod tests { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) .len(10) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) .build() .unwrap(); // Construct a buffer for value offsets, for the nested array: // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] - let value_offsets = Buffer::from_slice_ref(&[0i64, 2, 2, 2, 4, 6, 6, 9, 9, 10]); + let value_offsets = Buffer::from_slice_ref([0i64, 2, 2, 2, 4, 6, 6, 9, 9, 10]); // 01011001 00000001 let mut null_bits: [u8; 2] = [0; 2]; bit_util::set_bit(&mut null_bits, 0); @@ -691,7 +864,7 @@ mod tests { // Construct a list array from the above two let list_data_type = - DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false))); + DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(9) .add_buffer(value_offsets) @@ -702,7 +875,7 @@ mod tests { let list_array = LargeListArray::from(list_data); let values = list_array.values(); - assert_eq!(&value_data, values.data()); + assert_eq!(value_data, values.to_data()); assert_eq!(DataType::Int32, list_array.value_type()); assert_eq!(9, list_array.len()); assert_eq!(4, list_array.null_count()); @@ -711,11 +884,10 @@ mod tests { let sliced_array = list_array.slice(1, 6); assert_eq!(6, sliced_array.len()); - assert_eq!(1, sliced_array.offset()); assert_eq!(3, sliced_array.null_count()); for i in 0..sliced_array.len() { - if bit_util::get_bit(&null_bits, sliced_array.offset() + i) { + if bit_util::get_bit(&null_bits, 1 + i) { assert!(sliced_array.is_valid(i)); } else { assert!(sliced_array.is_null(i)); @@ -741,13 +913,13 @@ mod tests { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) .len(10) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) .build() .unwrap(); // Construct a buffer for value offsets, for the nested array: // [[0, 1], null, null, [2, 3], [4, 5], null, [6, 7, 8], null, [9]] - let value_offsets = Buffer::from_slice_ref(&[0i64, 2, 2, 2, 4, 6, 6, 9, 9, 10]); + let value_offsets = Buffer::from_slice_ref([0i64, 2, 2, 2, 4, 6, 6, 9, 9, 10]); // 01011001 00000001 let mut null_bits: [u8; 2] = [0; 2]; bit_util::set_bit(&mut null_bits, 0); @@ -758,7 +930,7 @@ mod tests { // Construct a list array from the above two let list_data_type = - DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false))); + DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(9) .add_buffer(value_offsets) @@ -772,9 +944,7 @@ mod tests { list_array.value(10); } #[test] - #[should_panic( - expected = "ListArray data should contain a single buffer only (value offsets)" - )] + #[should_panic(expected = "ListArray data should contain a single buffer only (value offsets)")] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] @@ -782,11 +952,10 @@ mod tests { let value_data = unsafe { ArrayData::builder(DataType::Int32) .len(8) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) .build_unchecked() }; - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -797,16 +966,13 @@ mod tests { } #[test] - #[should_panic( - expected = "ListArray should contain a single child array (values array)" - )] + #[should_panic(expected = "ListArray should contain a single child array (values array)")] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] fn test_list_array_invalid_child_array_len() { - let value_offsets = Buffer::from_slice_ref(&[0, 2, 5, 7]); - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let value_offsets = Buffer::from_slice_ref([0, 2, 5, 7]); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .len(3) @@ -816,18 +982,27 @@ mod tests { drop(ListArray::from(list_data)); } + #[test] + #[should_panic(expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List")] + fn test_from_array_data_validation() { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.values().append_value(1); + builder.append(true); + let array = builder.finish(); + let _ = LargeListArray::from(array.into_data()); + } + #[test] fn test_list_array_offsets_need_not_start_at_zero() { let value_data = ArrayData::builder(DataType::Int32) .len(8) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) .build() .unwrap(); - let value_offsets = Buffer::from_slice_ref(&[2, 2, 5, 7]); + let value_offsets = Buffer::from_slice_ref([2, 2, 5, 7]); - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -842,37 +1017,38 @@ mod tests { } #[test] - #[should_panic(expected = "memory is not aligned")] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_primitive_array_alignment() { - let ptr = alloc::allocate_aligned::(8); - let buf = unsafe { Buffer::from_raw_parts(ptr, 8, 8) }; + let buf = Buffer::from_slice_ref([0_u64]); let buf2 = buf.slice(1); - let array_data = ArrayData::builder(DataType::Int32) - .add_buffer(buf2) - .build() - .unwrap(); + let array_data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(buf2) + .build_unchecked() + }; drop(Int32Array::from(array_data)); } #[test] - #[should_panic(expected = "memory is not aligned")] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] fn test_list_array_alignment() { - let ptr = alloc::allocate_aligned::(8); - let buf = unsafe { Buffer::from_raw_parts(ptr, 8, 8) }; + let buf = Buffer::from_slice_ref([0_u64]); let buf2 = buf.slice(1); let values: [i32; 8] = [0; 8]; let value_data = unsafe { ArrayData::builder(DataType::Int32) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(values)) .build_unchecked() }; - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let list_data = unsafe { ArrayData::builder(list_data_type) .add_buffer(buf2) @@ -953,4 +1129,99 @@ mod tests { false, ); } + + #[test] + fn test_empty_offsets() { + let f = Arc::new(Field::new("element", DataType::Int32, true)); + let string = ListArray::from( + ArrayData::builder(DataType::List(f.clone())) + .buffers(vec![Buffer::from(&[])]) + .add_child_data(ArrayData::new_empty(&DataType::Int32)) + .build() + .unwrap(), + ); + assert_eq!(string.value_offsets(), &[0]); + let string = LargeListArray::from( + ArrayData::builder(DataType::LargeList(f)) + .buffers(vec![Buffer::from(&[])]) + .add_child_data(ArrayData::new_empty(&DataType::Int32)) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + } + + #[test] + fn test_try_new() { + let offsets = OffsetBuffer::new(vec![0, 1, 4, 5].into()); + let values = Int32Array::new(vec![1, 2, 3, 4, 5].into(), None); + let values = Arc::new(values) as ArrayRef; + + let field = Arc::new(Field::new("element", DataType::Int32, false)); + ListArray::new(field.clone(), offsets.clone(), values.clone(), None); + + let nulls = NullBuffer::new_null(3); + ListArray::new(field.clone(), offsets, values.clone(), Some(nulls)); + + let nulls = NullBuffer::new_null(3); + let offsets = OffsetBuffer::new(vec![0, 1, 2, 4, 5].into()); + let err = LargeListArray::try_new(field, offsets.clone(), values.clone(), Some(nulls)) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for LargeListArray, expected 4 got 3" + ); + + let field = Arc::new(Field::new("element", DataType::Int64, false)); + let err = LargeListArray::try_new(field.clone(), offsets.clone(), values.clone(), None) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: LargeListArray expected data type Int64 got Int32 for \"element\"" + ); + + let nulls = NullBuffer::new_null(7); + let values = Int64Array::new(vec![0; 7].into(), Some(nulls)); + let values = Arc::new(values); + + let err = + LargeListArray::try_new(field, offsets.clone(), values.clone(), None).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Non-nullable field of LargeListArray \"element\" cannot contain nulls" + ); + + let field = Arc::new(Field::new("element", DataType::Int64, true)); + LargeListArray::new(field.clone(), offsets.clone(), values, None); + + let values = Int64Array::new(vec![0; 2].into(), None); + let err = LargeListArray::try_new(field, offsets, Arc::new(values), None).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Max offset of 5 exceeds length of values 2" + ); + } + + #[test] + fn test_from_fixed_size_list() { + let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 3); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[0, 0, 0]); + builder.append(false); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + let list: ListArray = builder.finish().into(); + + let values: Vec<_> = list + .iter() + .map(|x| x.map(|x| x.as_primitive::().values().to_vec())) + .collect(); + assert_eq!(values, vec![Some(vec![1, 2, 3]), None, Some(vec![4, 5, 6])]) + } } diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs new file mode 100644 index 000000000000..bde7fdd5a953 --- /dev/null +++ b/arrow-array/src/array/map_array.rs @@ -0,0 +1,802 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::{get_offsets, print_long_array}; +use crate::iterator::MapArrayIter; +use crate::{make_array, Array, ArrayAccessor, ArrayRef, ListArray, StringArray, StructArray}; +use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer, OffsetBuffer, ToByteSlice}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef}; +use std::any::Any; +use std::sync::Arc; + +/// An array of key-value maps +/// +/// Keys should always be non-null, but values can be null. +/// +/// [`MapArray`] is physically a [`ListArray`] of key values pairs stored as an `entries` +/// [`StructArray`] with 2 child fields. +/// +/// See [`MapBuilder`](crate::builder::MapBuilder) for how to construct a [`MapArray`] +#[derive(Clone)] +pub struct MapArray { + data_type: DataType, + nulls: Option, + /// The [`StructArray`] that is the direct child of this array + entries: StructArray, + /// The start and end offsets of each entry + value_offsets: OffsetBuffer, +} + +impl MapArray { + /// Create a new [`MapArray`] from the provided parts + /// + /// See [`MapBuilder`](crate::builder::MapBuilder) for a higher-level interface + /// to construct a [`MapArray`] + /// + /// # Errors + /// + /// Errors if + /// + /// * `offsets.len() - 1 != nulls.len()` + /// * `offsets.last() > entries.len()` + /// * `field.is_nullable()` + /// * `entries.null_count() != 0` + /// * `entries.columns().len() != 2` + /// * `field.data_type() != entries.data_type()` + pub fn try_new( + field: FieldRef, + offsets: OffsetBuffer, + entries: StructArray, + nulls: Option, + ordered: bool, + ) -> Result { + let len = offsets.len() - 1; // Offsets guaranteed to not be empty + let end_offset = offsets.last().unwrap().as_usize(); + // don't need to check other values of `offsets` because they are checked + // during construction of `OffsetBuffer` + if end_offset > entries.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Max offset of {end_offset} exceeds length of entries {}", + entries.len() + ))); + } + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for MapArray, expected {len} got {}", + n.len(), + ))); + } + } + if field.is_nullable() || entries.null_count() != 0 { + return Err(ArrowError::InvalidArgumentError( + "MapArray entries cannot contain nulls".to_string(), + )); + } + + if field.data_type() != entries.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray expected data type {} got {} for {:?}", + field.data_type(), + entries.data_type(), + field.name() + ))); + } + + if entries.columns().len() != 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray entries must contain two children, got {}", + entries.columns().len() + ))); + } + + Ok(Self { + data_type: DataType::Map(field, ordered), + nulls, + entries, + value_offsets: offsets, + }) + } + + /// Create a new [`MapArray`] from the provided parts + /// + /// See [`MapBuilder`](crate::builder::MapBuilder) for a higher-level interface + /// to construct a [`MapArray`] + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new( + field: FieldRef, + offsets: OffsetBuffer, + entries: StructArray, + nulls: Option, + ordered: bool, + ) -> Self { + Self::try_new(field, offsets, entries, nulls, ordered).unwrap() + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts( + self, + ) -> ( + FieldRef, + OffsetBuffer, + StructArray, + Option, + bool, + ) { + let (f, ordered) = match self.data_type { + DataType::Map(f, ordered) => (f, ordered), + _ => unreachable!(), + }; + (f, self.value_offsets, self.entries, self.nulls, ordered) + } + + /// Returns a reference to the offsets of this map + /// + /// Unlike [`Self::value_offsets`] this returns the [`OffsetBuffer`] + /// allowing for zero-copy cloning + #[inline] + pub fn offsets(&self) -> &OffsetBuffer { + &self.value_offsets + } + + /// Returns a reference to the keys of this map + pub fn keys(&self) -> &ArrayRef { + self.entries.column(0) + } + + /// Returns a reference to the values of this map + pub fn values(&self) -> &ArrayRef { + self.entries.column(1) + } + + /// Returns a reference to the [`StructArray`] entries of this map + pub fn entries(&self) -> &StructArray { + &self.entries + } + + /// Returns the data type of the map's keys. + pub fn key_type(&self) -> &DataType { + self.keys().data_type() + } + + /// Returns the data type of the map's values. + pub fn value_type(&self) -> &DataType { + self.values().data_type() + } + + /// Returns ith value of this map array. + /// + /// # Safety + /// Caller must ensure that the index is within the array bounds + pub unsafe fn value_unchecked(&self, i: usize) -> StructArray { + let end = *self.value_offsets().get_unchecked(i + 1); + let start = *self.value_offsets().get_unchecked(i); + self.entries + .slice(start.to_usize().unwrap(), (end - start).to_usize().unwrap()) + } + + /// Returns ith value of this map array. + /// + /// This is a [`StructArray`] containing two fields + pub fn value(&self, i: usize) -> StructArray { + let end = self.value_offsets()[i + 1] as usize; + let start = self.value_offsets()[i] as usize; + self.entries.slice(start, end - start) + } + + /// Returns the offset values in the offsets buffer + #[inline] + pub fn value_offsets(&self) -> &[i32] { + &self.value_offsets + } + + /// Returns the length for value at index `i`. + #[inline] + pub fn value_length(&self, i: usize) -> i32 { + let offsets = self.value_offsets(); + offsets[i + 1] - offsets[i] + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + entries: self.entries.clone(), + value_offsets: self.value_offsets.slice(offset, length), + } + } + + /// constructs a new iterator + pub fn iter(&self) -> MapArrayIter<'_> { + MapArrayIter::new(self) + } +} + +impl From for MapArray { + fn from(data: ArrayData) -> Self { + Self::try_new_from_array_data(data) + .expect("Expected infallible creation of MapArray from ArrayData failed") + } +} + +impl From for ArrayData { + fn from(array: MapArray) -> Self { + let len = array.len(); + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .nulls(array.nulls) + .buffers(vec![array.value_offsets.into_inner().into_inner()]) + .child_data(vec![array.entries.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl MapArray { + fn try_new_from_array_data(data: ArrayData) -> Result { + if !matches!(data.data_type(), DataType::Map(_, _)) { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray expected ArrayData with DataType::Map got {}", + data.data_type() + ))); + } + + if data.buffers().len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray data should contain a single buffer only (value offsets), had {}", + data.len() + ))); + } + + if data.child_data().len() != 1 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a single child array (values array), had {}", + data.child_data().len() + ))); + } + + let entries = data.child_data()[0].clone(); + + if let DataType::Struct(fields) = entries.data_type() { + if fields.len() != 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a struct array with 2 fields, have {} fields", + fields.len() + ))); + } + } else { + return Err(ArrowError::InvalidArgumentError(format!( + "MapArray should contain a struct array child, found {:?}", + entries.data_type() + ))); + } + let entries = entries.into(); + + // SAFETY: + // ArrayData is valid, and verified type above + let value_offsets = unsafe { get_offsets(&data) }; + + Ok(Self { + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + entries, + value_offsets, + }) + } + + /// Creates map array from provided keys, values and entry_offsets. + pub fn new_from_strings<'a>( + keys: impl Iterator, + values: &dyn Array, + entry_offsets: &[u32], + ) -> Result { + let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice()); + let keys_data = StringArray::from_iter_values(keys); + + let keys_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let values_field = Arc::new(Field::new( + "values", + values.data_type().clone(), + values.null_count() > 0, + )); + + let entry_struct = StructArray::from(vec![ + (keys_field, Arc::new(keys_data) as ArrayRef), + (values_field, make_array(values.to_data())), + ]); + + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(entry_offsets.len() - 1) + .add_buffer(entry_offsets_buffer) + .add_child_data(entry_struct.into_data()) + .build()?; + + Ok(MapArray::from(map_data)) + } +} + +impl Array for MapArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into_data() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.value_offsets.len() - 1 + } + + fn is_empty(&self) -> bool { + self.value_offsets.len() <= 1 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.entries.get_buffer_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = std::mem::size_of::() + self.entries.get_array_memory_size(); + size += self.value_offsets.inner().inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } +} + +impl<'a> ArrayAccessor for &'a MapArray { + type Item = StructArray; + + fn value(&self, index: usize) -> Self::Item { + MapArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + MapArray::value(self, index) + } +} + +impl std::fmt::Debug for MapArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "MapArray\n[\n")?; + print_long_array(self, f, |array, index, f| { + std::fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl From for ListArray { + fn from(value: MapArray) -> Self { + let field = match value.data_type() { + DataType::Map(field, _) => field, + _ => unreachable!("This should be a map type."), + }; + let data_type = DataType::List(field.clone()); + let builder = value.into_data().into_builder().data_type(data_type); + let array_data = unsafe { builder.build_unchecked() }; + + ListArray::from(array_data) + } +} + +#[cfg(test)] +mod tests { + use crate::cast::AsArray; + use crate::types::UInt32Type; + use crate::{Int32Array, UInt32Array}; + use arrow_schema::Fields; + use std::sync::Arc; + + use super::*; + + fn create_from_buffers() -> MapArray { + // Construct key and values + let keys_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build() + .unwrap(); + let values_data = ArrayData::builder(DataType::UInt32) + .len(8) + .add_buffer(Buffer::from( + &[0u32, 10, 20, 30, 40, 50, 60, 70].to_byte_slice(), + )) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + + let keys = Arc::new(Field::new("keys", DataType::Int32, false)); + let values = Arc::new(Field::new("values", DataType::UInt32, false)); + let entry_struct = StructArray::from(vec![ + (keys, make_array(keys_data)), + (values, make_array(values_data)), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.into_data()) + .build() + .unwrap(); + MapArray::from(map_data) + } + + #[test] + fn test_map_array() { + // Construct key and values + let key_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) + .build() + .unwrap(); + let value_data = ArrayData::builder(DataType::UInt32) + .len(8) + .add_buffer(Buffer::from( + &[0u32, 10, 20, 0, 40, 0, 60, 70].to_byte_slice(), + )) + .null_bit_buffer(Some(Buffer::from(&[0b11010110]))) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + + let keys_field = Arc::new(Field::new("keys", DataType::Int32, false)); + let values_field = Arc::new(Field::new("values", DataType::UInt32, true)); + let entry_struct = StructArray::from(vec![ + (keys_field.clone(), make_array(key_data)), + (values_field.clone(), make_array(value_data.clone())), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.into_data()) + .build() + .unwrap(); + let map_array = MapArray::from(map_data); + + assert_eq!(value_data, map_array.values().to_data()); + assert_eq!(&DataType::UInt32, map_array.value_type()); + assert_eq!(3, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[2]); + assert_eq!(2, map_array.value_length(2)); + + let key_array = Arc::new(Int32Array::from(vec![0, 1, 2])) as ArrayRef; + let value_array = + Arc::new(UInt32Array::from(vec![None, Some(10u32), Some(20)])) as ArrayRef; + let struct_array = StructArray::from(vec![ + (keys_field.clone(), key_array), + (values_field.clone(), value_array), + ]); + assert_eq!( + struct_array, + StructArray::from(map_array.value(0).into_data()) + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + for i in 0..3 { + assert!(map_array.is_valid(i)); + assert!(!map_array.is_null(i)); + } + + // Now test with a non-zero offset + let map_array = map_array.slice(1, 2); + + assert_eq!(value_data, map_array.values().to_data()); + assert_eq!(&DataType::UInt32, map_array.value_type()); + assert_eq!(2, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[1]); + assert_eq!(2, map_array.value_length(1)); + + let key_array = Arc::new(Int32Array::from(vec![3, 4, 5])) as ArrayRef; + let value_array = Arc::new(UInt32Array::from(vec![None, Some(40), None])) as ArrayRef; + let struct_array = + StructArray::from(vec![(keys_field, key_array), (values_field, value_array)]); + assert_eq!( + &struct_array, + map_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + } + + #[test] + #[ignore = "Test fails because slice of > is still buggy"] + fn test_map_array_slice() { + let map_array = create_from_buffers(); + + let sliced_array = map_array.slice(1, 2); + assert_eq!(2, sliced_array.len()); + assert_eq!(1, sliced_array.offset()); + let sliced_array_data = sliced_array.to_data(); + for array_data in sliced_array_data.child_data() { + assert_eq!(array_data.offset(), 1); + } + + // Check offset and length for each non-null value. + let sliced_map_array = sliced_array.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_map_array.value_offsets()[0]); + assert_eq!(3, sliced_map_array.value_length(0)); + assert_eq!(6, sliced_map_array.value_offsets()[1]); + assert_eq!(2, sliced_map_array.value_length(1)); + + // Construct key and values + let keys_data = ArrayData::builder(DataType::Int32) + .len(5) + .add_buffer(Buffer::from(&[3, 4, 5, 6, 7].to_byte_slice())) + .build() + .unwrap(); + let values_data = ArrayData::builder(DataType::UInt32) + .len(5) + .add_buffer(Buffer::from(&[30u32, 40, 50, 60, 70].to_byte_slice())) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[3, 4, 5], [6, 7]] + let entry_offsets = Buffer::from(&[0, 3, 5].to_byte_slice()); + + let keys = Arc::new(Field::new("keys", DataType::Int32, false)); + let values = Arc::new(Field::new("values", DataType::UInt32, false)); + let entry_struct = StructArray::from(vec![ + (keys, make_array(keys_data)), + (values, make_array(values_data)), + ]); + + // Construct a map array from the above two + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let expected_map_data = ArrayData::builder(map_data_type) + .len(2) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.into_data()) + .build() + .unwrap(); + let expected_map_array = MapArray::from(expected_map_data); + + assert_eq!(&expected_map_array, sliced_map_array) + } + + #[test] + #[should_panic(expected = "index out of bounds: the len is ")] + fn test_map_array_index_out_of_bound() { + let map_array = create_from_buffers(); + + map_array.value(map_array.len()); + } + + #[test] + #[should_panic(expected = "MapArray expected ArrayData with DataType::Map got Dictionary")] + fn test_from_array_data_validation() { + // A DictionaryArray has similar buffer layout to a MapArray + // but the meaning of the values differs + let struct_t = DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, true), + Field::new("values", DataType::UInt32, true), + ])); + let dict_t = DataType::Dictionary(Box::new(DataType::Int32), Box::new(struct_t)); + let _ = MapArray::from(ArrayData::new_empty(&dict_t)); + } + + #[test] + fn test_new_from_strings() { + let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; + let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]); + + // Construct a buffer for value offsets, for the nested array: + // [[a, b, c], [d, e, f], [g, h]] + let entry_offsets = [0, 3, 6, 8]; + + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + + assert_eq!( + &values_data, + map_array.values().as_primitive::() + ); + assert_eq!(&DataType::UInt32, map_array.value_type()); + assert_eq!(3, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[2]); + assert_eq!(2, map_array.value_length(2)); + + let key_array = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let value_array = Arc::new(UInt32Array::from(vec![0u32, 10, 20])) as ArrayRef; + let keys_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let values_field = Arc::new(Field::new("values", DataType::UInt32, false)); + let struct_array = + StructArray::from(vec![(keys_field, key_array), (values_field, value_array)]); + assert_eq!( + struct_array, + StructArray::from(map_array.value(0).into_data()) + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + for i in 0..3 { + assert!(map_array.is_valid(i)); + assert!(!map_array.is_null(i)); + } + } + + #[test] + fn test_try_new() { + let offsets = OffsetBuffer::new(vec![0, 1, 4, 5].into()); + let fields = Fields::from(vec![ + Field::new("key", DataType::Int32, false), + Field::new("values", DataType::Int32, false), + ]); + let columns = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + ]; + + let entries = StructArray::new(fields.clone(), columns, None); + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + + MapArray::new(field.clone(), offsets.clone(), entries.clone(), None, false); + + let nulls = NullBuffer::new_null(3); + MapArray::new(field.clone(), offsets, entries.clone(), Some(nulls), false); + + let nulls = NullBuffer::new_null(3); + let offsets = OffsetBuffer::new(vec![0, 1, 2, 4, 5].into()); + let err = MapArray::try_new( + field.clone(), + offsets.clone(), + entries.clone(), + Some(nulls), + false, + ) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for MapArray, expected 4 got 3" + ); + + let err = MapArray::try_new(field, offsets.clone(), entries.slice(0, 2), None, false) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Max offset of 5 exceeds length of entries 2" + ); + + let field = Arc::new(Field::new("element", DataType::Int64, false)); + let err = MapArray::try_new(field, offsets.clone(), entries, None, false) + .unwrap_err() + .to_string(); + + assert!( + err.starts_with("Invalid argument error: MapArray expected data type Int64 got Struct"), + "{err}" + ); + + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let columns = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _, + ]; + + let s = StructArray::new(fields.clone(), columns, None); + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + let err = MapArray::try_new(field, offsets, s, None, false).unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: MapArray entries must contain two children, got 3" + ); + } +} diff --git a/arrow/src/array/array.rs b/arrow-array/src/array/mod.rs similarity index 56% rename from arrow/src/array/array.rs rename to arrow-array/src/array/mod.rs index 38ba2025a2e3..f19406c1610b 100644 --- a/arrow/src/array/array.rs +++ b/arrow-array/src/array/mod.rs @@ -15,118 +15,146 @@ // specific language governing permissions and limitations // under the License. +//! The concrete array definitions + +mod binary_array; + +use crate::types::*; +use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow_data::ArrayData; +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; use std::any::Any; -use std::convert::From; -use std::fmt; use std::sync::Arc; -use super::*; -use crate::buffer::{Buffer, MutableBuffer}; +pub use binary_array::*; + +mod boolean_array; +pub use boolean_array::*; + +mod byte_array; +pub use byte_array::*; + +mod dictionary_array; +pub use dictionary_array::*; + +mod fixed_size_binary_array; +pub use fixed_size_binary_array::*; + +mod fixed_size_list_array; +pub use fixed_size_list_array::*; + +mod list_array; +pub use list_array::*; + +mod map_array; +pub use map_array::*; + +mod null_array; +pub use null_array::*; + +mod primitive_array; +pub use primitive_array::*; -/// Trait for dealing with different types of array at runtime when the type of the -/// array is not known in advance. -pub trait Array: fmt::Debug + Send + Sync { - /// Returns the array as [`Any`](std::any::Any) so that it can be +mod string_array; +pub use string_array::*; + +mod struct_array; +pub use struct_array::*; + +mod union_array; +pub use union_array::*; + +mod run_array; +pub use run_array::*; + +/// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html) +pub trait Array: std::fmt::Debug + Send + Sync { + /// Returns the array as [`Any`] so that it can be /// downcasted to a specific implementation. /// /// # Example: /// /// ``` - /// use std::sync::Arc; - /// use arrow::array::Int32Array; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{Schema, Field, DataType, ArrowError}; /// - /// # fn main() -> arrow::error::Result<()> { /// let id = Int32Array::from(vec![1, 2, 3, 4, 5]); /// let batch = RecordBatch::try_new( /// Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), /// vec![Arc::new(id)] - /// )?; + /// ).unwrap(); /// /// let int32array = batch /// .column(0) /// .as_any() /// .downcast_ref::() /// .expect("Failed to downcast"); - /// # Ok(()) - /// # } /// ``` fn as_any(&self) -> &dyn Any; - /// Returns a reference to the underlying data of this array. - fn data(&self) -> &ArrayData; + /// Returns the underlying data of this array + fn to_data(&self) -> ArrayData; - /// Returns the underlying data of this array. + /// Returns the underlying data of this array + /// + /// Unlike [`Array::to_data`] this consumes self, allowing it avoid unnecessary clones fn into_data(self) -> ArrayData; - /// Returns a reference-counted pointer to the underlying data of this array. - fn data_ref(&self) -> &ArrayData { - self.data() - } - - /// Returns a reference to the [`DataType`](crate::datatypes::DataType) of this array. + /// Returns a reference to the [`DataType`] of this array. /// /// # Example: /// /// ``` - /// use arrow::datatypes::DataType; - /// use arrow::array::{Array, Int32Array}; + /// use arrow_schema::DataType; + /// use arrow_array::{Array, Int32Array}; /// /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// /// assert_eq!(*array.data_type(), DataType::Int32); /// ``` - fn data_type(&self) -> &DataType { - self.data_ref().data_type() - } + fn data_type(&self) -> &DataType; /// Returns a zero-copy slice of this array with the indicated offset and length. /// /// # Example: /// /// ``` - /// use arrow::array::{Array, Int32Array}; + /// use arrow_array::{Array, Int32Array}; /// /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// // Make slice over the values [2, 3, 4] /// let array_slice = array.slice(1, 3); /// - /// assert_eq!(array_slice.as_ref(), &Int32Array::from(vec![2, 3, 4])); + /// assert_eq!(&array_slice, &Int32Array::from(vec![2, 3, 4])); /// ``` - fn slice(&self, offset: usize, length: usize) -> ArrayRef { - make_array(self.data_ref().slice(offset, length)) - } + fn slice(&self, offset: usize, length: usize) -> ArrayRef; /// Returns the length (i.e., number of elements) of this array. /// /// # Example: /// /// ``` - /// use arrow::array::{Array, Int32Array}; + /// use arrow_array::{Array, Int32Array}; /// /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// /// assert_eq!(array.len(), 5); /// ``` - fn len(&self) -> usize { - self.data_ref().len() - } + fn len(&self) -> usize; /// Returns whether this array is empty. /// /// # Example: /// /// ``` - /// use arrow::array::{Array, Int32Array}; + /// use arrow_array::{Array, Int32Array}; /// /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// /// assert_eq!(array.is_empty(), false); /// ``` - fn is_empty(&self) -> bool { - self.data_ref().is_empty() - } + fn is_empty(&self) -> bool; /// Returns the offset into the underlying data used by this array(-slice). /// Note that the underlying data can be shared by many arrays. @@ -135,43 +163,71 @@ pub trait Array: fmt::Debug + Send + Sync { /// # Example: /// /// ``` - /// use arrow::array::{Array, Int32Array}; + /// use arrow_array::{Array, BooleanArray}; /// - /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]); - /// // Make slice over the values [2, 3, 4] + /// let array = BooleanArray::from(vec![false, false, true, true]); /// let array_slice = array.slice(1, 3); /// /// assert_eq!(array.offset(), 0); /// assert_eq!(array_slice.offset(), 1); /// ``` - fn offset(&self) -> usize { - self.data_ref().offset() + fn offset(&self) -> usize; + + /// Returns the null buffer of this array if any. + /// + /// The null buffer encodes the "physical" nulls of an array. + /// However, some arrays can also encode nullability in their children, for example, + /// [`DictionaryArray::values`] values or [`RunArray::values`], or without a null buffer, + /// such as [`NullArray`]. To determine if each element of such an array is logically null, + /// you can use the slower [`Array::logical_nulls`] to obtain a computed mask . + fn nulls(&self) -> Option<&NullBuffer>; + + /// Returns a potentially computed [`NullBuffer`] that represent the logical null values of this array, if any. + /// + /// In most cases this will be the same as [`Array::nulls`], except for: + /// + /// * [`DictionaryArray`] where [`DictionaryArray::values`] contains nulls + /// * [`RunArray`] where [`RunArray::values`] contains nulls + /// * [`NullArray`] where all indices are nulls + /// + /// In these cases a logical [`NullBuffer`] will be computed, encoding the logical nullability + /// of these arrays, beyond what is encoded in [`Array::nulls`] + fn logical_nulls(&self) -> Option { + self.nulls().cloned() } - /// Returns whether the element at `index` is null. - /// When using this function on a slice, the index is relative to the slice. + /// Returns whether the element at `index` is null according to [`Array::nulls`] + /// + /// Note: For performance reasons, this method returns nullability solely as determined by the + /// null buffer. This difference can lead to surprising results, for example, [`NullArray::is_null`] always + /// returns `false` as the array lacks a null buffer. Similarly [`DictionaryArray`] and [`RunArray`] may + /// encode nullability in their children. See [`Self::logical_nulls`] for more information. /// /// # Example: /// /// ``` - /// use arrow::array::{Array, Int32Array}; + /// use arrow_array::{Array, Int32Array, NullArray}; /// /// let array = Int32Array::from(vec![Some(1), None]); - /// /// assert_eq!(array.is_null(0), false); /// assert_eq!(array.is_null(1), true); + /// + /// // NullArrays do not have a null buffer, and therefore always + /// // return false for is_null. + /// let array = NullArray::new(1); + /// assert_eq!(array.is_null(0), false); /// ``` fn is_null(&self, index: usize) -> bool { - self.data_ref().is_null(index) + self.nulls().map(|n| n.is_null(index)).unwrap_or_default() } - /// Returns whether the element at `index` is not null. - /// When using this function on a slice, the index is relative to the slice. + /// Returns whether the element at `index` is *not* null, the + /// opposite of [`Self::is_null`]. /// /// # Example: /// /// ``` - /// use arrow::array::{Array, Int32Array}; + /// use arrow_array::{Array, Int32Array}; /// /// let array = Int32Array::from(vec![Some(1), None]); /// @@ -179,15 +235,18 @@ pub trait Array: fmt::Debug + Send + Sync { /// assert_eq!(array.is_valid(1), false); /// ``` fn is_valid(&self, index: usize) -> bool { - self.data_ref().is_valid(index) + !self.is_null(index) } - /// Returns the total number of null values in this array. + /// Returns the total number of physical null values in this array. + /// + /// Note: this method returns the physical null count, i.e. that encoded in [`Array::nulls`], + /// see [`Array::logical_nulls`] for logical nullability /// /// # Example: /// /// ``` - /// use arrow::array::{Array, Int32Array}; + /// use arrow_array::{Array, Int32Array}; /// /// // Construct an array with values [1, NULL, NULL] /// let array = Int32Array::from(vec![Some(1), None, None]); @@ -195,27 +254,33 @@ pub trait Array: fmt::Debug + Send + Sync { /// assert_eq!(array.null_count(), 2); /// ``` fn null_count(&self) -> usize { - self.data_ref().null_count() + self.nulls().map(|n| n.null_count()).unwrap_or_default() + } + + /// Returns `false` if the array is guaranteed to not contain any logical nulls + /// + /// In general this will be equivalent to `Array::null_count() != 0` but may differ in the + /// presence of logical nullability, see [`Array::logical_nulls`]. + /// + /// Implementations will return `true` unless they can cheaply prove no logical nulls + /// are present. For example a [`DictionaryArray`] with nullable values will still return true, + /// even if the nulls present in [`DictionaryArray::values`] are not referenced by any key, + /// and therefore would not appear in [`Array::logical_nulls`]. + fn is_nullable(&self) -> bool { + self.null_count() != 0 } /// Returns the total number of bytes of memory pointed to by this array. /// The buffers store bytes in the Arrow memory format, and include the data as well as the validity map. - fn get_buffer_memory_size(&self) -> usize { - self.data_ref().get_buffer_memory_size() - } + fn get_buffer_memory_size(&self) -> usize; /// Returns the total number of bytes of memory occupied physically by this array. /// This value will always be greater than returned by `get_buffer_memory_size()` and /// includes the overhead of the data structures that contain the pointers to the various buffers. - fn get_array_memory_size(&self) -> usize { - // both data.get_array_memory_size and size_of_val(self) include ArrayData fields, - // to only count additional fields of this array substract size_of(ArrayData) - self.data_ref().get_array_memory_size() + std::mem::size_of_val(self) - - std::mem::size_of::() - } + fn get_array_memory_size(&self) -> usize; } -/// A reference-counted reference to a generic `Array`. +/// A reference-counted reference to a generic `Array` pub type ArrayRef = Arc; /// Ergonomics: Allow use of an ArrayRef as an `&dyn Array` @@ -224,16 +289,12 @@ impl Array for ArrayRef { self.as_ref().as_any() } - fn data(&self) -> &ArrayData { - self.as_ref().data() + fn to_data(&self) -> ArrayData { + self.as_ref().to_data() } fn into_data(self) -> ArrayData { - self.into() - } - - fn data_ref(&self) -> &ArrayData { - self.as_ref().data_ref() + self.to_data() } fn data_type(&self) -> &DataType { @@ -256,6 +317,14 @@ impl Array for ArrayRef { self.as_ref().offset() } + fn nulls(&self) -> Option<&NullBuffer> { + self.as_ref().nulls() + } + + fn logical_nulls(&self) -> Option { + self.as_ref().logical_nulls() + } + fn is_null(&self, index: usize) -> bool { self.as_ref().is_null(index) } @@ -268,6 +337,10 @@ impl Array for ArrayRef { self.as_ref().null_count() } + fn is_nullable(&self) -> bool { + self.as_ref().is_nullable() + } + fn get_buffer_memory_size(&self) -> usize { self.as_ref().get_buffer_memory_size() } @@ -282,16 +355,12 @@ impl<'a, T: Array> Array for &'a T { T::as_any(self) } - fn data(&self) -> &ArrayData { - T::data(self) + fn to_data(&self) -> ArrayData { + T::to_data(self) } fn into_data(self) -> ArrayData { - self.data().clone() - } - - fn data_ref(&self) -> &ArrayData { - T::data_ref(self) + self.to_data() } fn data_type(&self) -> &DataType { @@ -314,6 +383,14 @@ impl<'a, T: Array> Array for &'a T { T::offset(self) } + fn nulls(&self) -> Option<&NullBuffer> { + T::nulls(self) + } + + fn logical_nulls(&self) -> Option { + T::logical_nulls(self) + } + fn is_null(&self, index: usize) -> bool { T::is_null(self, index) } @@ -326,6 +403,10 @@ impl<'a, T: Array> Array for &'a T { T::null_count(self) } + fn is_nullable(&self) -> bool { + T::is_nullable(self) + } + fn get_buffer_memory_size(&self) -> usize { T::get_buffer_memory_size(self) } @@ -345,6 +426,7 @@ impl<'a, T: Array> Array for &'a T { /// The value at null indexes is unspecified, and implementations must not rely on a specific /// value such as [`Default::default`] being returned, however, it must not be undefined pub trait ArrayAccessor: Array { + /// The Arrow type of the element being accessed. type Item: Send + Sync; /// Returns the element at index `i` @@ -358,6 +440,84 @@ pub trait ArrayAccessor: Array { unsafe fn value_unchecked(&self, index: usize) -> Self::Item; } +impl PartialEq for dyn Array + '_ { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for dyn Array + '_ { + fn eq(&self, other: &T) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for NullArray { + fn eq(&self, other: &NullArray) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for PrimitiveArray { + fn eq(&self, other: &PrimitiveArray) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for DictionaryArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for BooleanArray { + fn eq(&self, other: &BooleanArray) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for GenericStringArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for GenericBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for FixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for GenericListArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for MapArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for FixedSizeListArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + +impl PartialEq for StructArray { + fn eq(&self, other: &Self) -> bool { + self.to_data().eq(&other.to_data()) + } +} + /// Constructs an array using the input `data`. /// Returns a reference-counted `Array` instance. pub fn make_array(data: ArrayData) -> ArrayRef { @@ -376,9 +536,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef, DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef, DataType::Date64 => Arc::new(Date64Array::from(data)) as ArrayRef, - DataType::Time32(TimeUnit::Second) => { - Arc::new(Time32SecondArray::from(data)) as ArrayRef - } + DataType::Time32(TimeUnit::Second) => Arc::new(Time32SecondArray::from(data)) as ArrayRef, DataType::Time32(TimeUnit::Millisecond) => { Arc::new(Time32MillisecondArray::from(data)) as ArrayRef } @@ -423,62 +581,36 @@ pub fn make_array(data: ArrayData) -> ArrayRef { } DataType::Binary => Arc::new(BinaryArray::from(data)) as ArrayRef, DataType::LargeBinary => Arc::new(LargeBinaryArray::from(data)) as ArrayRef, - DataType::FixedSizeBinary(_) => { - Arc::new(FixedSizeBinaryArray::from(data)) as ArrayRef - } + DataType::FixedSizeBinary(_) => Arc::new(FixedSizeBinaryArray::from(data)) as ArrayRef, DataType::Utf8 => Arc::new(StringArray::from(data)) as ArrayRef, DataType::LargeUtf8 => Arc::new(LargeStringArray::from(data)) as ArrayRef, DataType::List(_) => Arc::new(ListArray::from(data)) as ArrayRef, DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef, DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef, DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef, - DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as ArrayRef, - DataType::FixedSizeList(_, _) => { - Arc::new(FixedSizeListArray::from(data)) as ArrayRef - } + DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef, + DataType::FixedSizeList(_, _) => Arc::new(FixedSizeListArray::from(data)) as ArrayRef, DataType::Dictionary(ref key_type, _) => match key_type.as_ref() { - DataType::Int8 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::Int16 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::Int32 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::Int64 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt8 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt16 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt32 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - DataType::UInt64 => { - Arc::new(DictionaryArray::::from(data)) as ArrayRef - } - dt => panic!("Unexpected dictionary key type {:?}", dt), + DataType::Int8 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt8 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + DataType::UInt64 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, + dt => panic!("Unexpected dictionary key type {dt:?}"), + }, + DataType::RunEndEncoded(ref run_ends_type, _) => match run_ends_type.data_type() { + DataType::Int16 => Arc::new(RunArray::::from(data)) as ArrayRef, + DataType::Int32 => Arc::new(RunArray::::from(data)) as ArrayRef, + DataType::Int64 => Arc::new(RunArray::::from(data)) as ArrayRef, + dt => panic!("Unexpected data type for run_ends array {dt:?}"), }, DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, - dt => panic!("Unexpected data type {:?}", dt), - } -} - -impl From for ArrayRef { - fn from(data: ArrayData) -> Self { - make_array(data) - } -} - -impl From for ArrayData { - fn from(array: ArrayRef) -> Self { - array.data().clone() + dt => panic!("Unexpected data type {dt:?}"), } } @@ -486,8 +618,8 @@ impl From for ArrayData { /// /// ``` /// use std::sync::Arc; -/// use arrow::datatypes::DataType; -/// use arrow::array::{ArrayRef, Int32Array, new_empty_array}; +/// use arrow_schema::DataType; +/// use arrow_array::{ArrayRef, Int32Array, new_empty_array}; /// /// let empty_array = new_empty_array(&DataType::Int32); /// let array: ArrayRef = Arc::new(Int32Array::from(vec![] as Vec)); @@ -504,8 +636,8 @@ pub fn new_empty_array(data_type: &DataType) -> ArrayRef { /// /// ``` /// use std::sync::Arc; -/// use arrow::datatypes::DataType; -/// use arrow::array::{ArrayRef, Int32Array, new_null_array}; +/// use arrow_schema::DataType; +/// use arrow_array::{ArrayRef, Int32Array, new_null_array}; /// /// let null_array = new_null_array(&DataType::Int32, 3); /// let array: ArrayRef = Arc::new(Int32Array::from(vec![None, None, None])); @@ -513,217 +645,32 @@ pub fn new_empty_array(data_type: &DataType) -> ArrayRef { /// assert_eq!(&array, &null_array); /// ``` pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { - // context: https://github.com/apache/arrow/pull/9469#discussion_r574761687 - match data_type { - DataType::Null => Arc::new(NullArray::new(length)), - DataType::Boolean => { - let null_buf: Buffer = MutableBuffer::new_null(length).into(); - make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - Some(null_buf.clone()), - 0, - vec![null_buf], - vec![], - ) - }) - } - DataType::Int8 => new_null_sized_array::(data_type, length), - DataType::UInt8 => new_null_sized_array::(data_type, length), - DataType::Int16 => new_null_sized_array::(data_type, length), - DataType::UInt16 => new_null_sized_array::(data_type, length), - DataType::Float16 => new_null_sized_array::(data_type, length), - DataType::Int32 => new_null_sized_array::(data_type, length), - DataType::UInt32 => new_null_sized_array::(data_type, length), - DataType::Float32 => new_null_sized_array::(data_type, length), - DataType::Date32 => new_null_sized_array::(data_type, length), - // expanding this into Date23{unit}Type results in needless branching - DataType::Time32(_) => new_null_sized_array::(data_type, length), - DataType::Int64 => new_null_sized_array::(data_type, length), - DataType::UInt64 => new_null_sized_array::(data_type, length), - DataType::Float64 => new_null_sized_array::(data_type, length), - DataType::Date64 => new_null_sized_array::(data_type, length), - // expanding this into Timestamp{unit}Type results in needless branching - DataType::Timestamp(_, _) => new_null_sized_array::(data_type, length), - DataType::Time64(_) => new_null_sized_array::(data_type, length), - DataType::Duration(_) => new_null_sized_array::(data_type, length), - DataType::Interval(unit) => match unit { - IntervalUnit::YearMonth => { - new_null_sized_array::(data_type, length) - } - IntervalUnit::DayTime => { - new_null_sized_array::(data_type, length) - } - IntervalUnit::MonthDayNano => { - new_null_sized_array::(data_type, length) - } - }, - DataType::FixedSizeBinary(value_len) => make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - Some(MutableBuffer::new_null(length).into()), - 0, - vec![Buffer::from(vec![0u8; *value_len as usize * length])], - vec![], - ) - }), - DataType::Binary | DataType::Utf8 => { - new_null_binary_array::(data_type, length) - } - DataType::LargeBinary | DataType::LargeUtf8 => { - new_null_binary_array::(data_type, length) - } - DataType::List(field) => { - new_null_list_array::(data_type, field.data_type(), length) - } - DataType::LargeList(field) => { - new_null_list_array::(data_type, field.data_type(), length) - } - DataType::FixedSizeList(field, value_len) => make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - Some(MutableBuffer::new_null(length).into()), - 0, - vec![], - vec![ - new_null_array(field.data_type(), *value_len as usize * length) - .data() - .clone(), - ], - ) - }), - DataType::Struct(fields) => { - let fields: Vec<_> = fields - .iter() - .map(|field| (field.clone(), new_null_array(field.data_type(), length))) - .collect(); - - let null_buffer = MutableBuffer::new_null(length); - Arc::new(StructArray::from((fields, null_buffer.into()))) - } - DataType::Map(field, _keys_sorted) => { - new_null_list_array::(data_type, field.data_type(), length) - } - DataType::Union(_, _, _) => { - unimplemented!("Creating null Union array not yet supported") - } - DataType::Dictionary(key, value) => { - let keys = new_null_array(key, length); - let keys = keys.data(); - - make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - keys.null_buffer().cloned(), - 0, - keys.buffers().into(), - vec![new_empty_array(value.as_ref()).into_data()], - ) - }) - } - DataType::Decimal128(_, _) => { - new_null_sized_decimal(data_type, length, std::mem::size_of::()) - } - DataType::Decimal256(_, _) => new_null_sized_decimal(data_type, length, 32), - } -} - -#[inline] -fn new_null_list_array( - data_type: &DataType, - child_data_type: &DataType, - length: usize, -) -> ArrayRef { - make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - Some(MutableBuffer::new_null(length).into()), - 0, - vec![Buffer::from( - vec![OffsetSize::zero(); length + 1].to_byte_slice(), - )], - vec![ArrayData::new_empty(child_data_type)], - ) - }) + make_array(ArrayData::new_null(data_type, length)) } -#[inline] -fn new_null_binary_array( - data_type: &DataType, - length: usize, -) -> ArrayRef { - make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - Some(MutableBuffer::new_null(length).into()), - 0, - vec![ - Buffer::from(vec![OffsetSize::zero(); length + 1].to_byte_slice()), - MutableBuffer::new(0).into(), - ], - vec![], - ) - }) -} - -#[inline] -fn new_null_sized_array( - data_type: &DataType, - length: usize, -) -> ArrayRef { - make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - Some(MutableBuffer::new_null(length).into()), - 0, - vec![Buffer::from(vec![0u8; length * T::get_byte_width()])], - vec![], - ) - }) -} - -#[inline] -fn new_null_sized_decimal( - data_type: &DataType, - length: usize, - byte_width: usize, -) -> ArrayRef { - make_array(unsafe { - ArrayData::new_unchecked( - data_type.clone(), - length, - Some(length), - Some(MutableBuffer::new_null(length).into()), - 0, - vec![Buffer::from(vec![0u8; length * byte_width])], - vec![], - ) - }) +/// Helper function that gets offset from an [`ArrayData`] +/// +/// # Safety +/// +/// - ArrayData must contain a valid [`OffsetBuffer`] as its first buffer +unsafe fn get_offsets(data: &ArrayData) -> OffsetBuffer { + match data.is_empty() && data.buffers()[0].is_empty() { + true => OffsetBuffer::new_empty(), + false => { + let buffer = + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len() + 1); + // Safety: + // ArrayData is valid + unsafe { OffsetBuffer::new_unchecked(buffer) } + } + } } -// Helper function for printing potentially long arrays. -pub(super) fn print_long_array( - array: &A, - f: &mut fmt::Formatter, - print_item: F, -) -> fmt::Result +/// Helper function for printing potentially long arrays. +fn print_long_array(array: &A, f: &mut std::fmt::Formatter, print_item: F) -> std::fmt::Result where A: Array, - F: Fn(&A, usize, &mut fmt::Formatter) -> fmt::Result, + F: Fn(&A, usize, &mut std::fmt::Formatter) -> std::fmt::Result, { let head = std::cmp::min(10, array.len()); @@ -759,6 +706,10 @@ where #[cfg(test)] mod tests { use super::*; + use crate::cast::{as_union_array, downcast_array}; + use crate::downcast_run_array; + use arrow_buffer::MutableBuffer; + use arrow_schema::{Field, Fields, UnionFields, UnionMode}; #[test] fn test_empty_primitive() { @@ -779,8 +730,7 @@ mod tests { #[test] fn test_empty_list_primitive() { - let data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, false))); let array = new_empty_array(&data_type); let a = array.as_any().downcast_ref::().unwrap(); assert_eq!(a.len(), 0); @@ -809,8 +759,9 @@ mod tests { #[test] fn test_null_struct() { - let struct_type = - DataType::Struct(vec![Field::new("data", DataType::Int64, false)]); + // It is possible to create a null struct containing a non-nullable child + // see https://github.com/apache/arrow-rs/pull/3244 for details + let struct_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); let array = new_null_array(&struct_type, 9); let a = array.as_any().downcast_ref::().unwrap(); @@ -837,8 +788,7 @@ mod tests { #[test] fn test_null_list_primitive() { - let data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let array = new_null_array(&data_type, 9); let a = array.as_any().downcast_ref::().unwrap(); assert_eq!(a.len(), 9); @@ -851,12 +801,12 @@ mod tests { #[test] fn test_null_map() { let data_type = DataType::Map( - Box::new(Field::new( + Arc::new(Field::new( "entry", - DataType::Struct(vec![ + DataType::Struct(Fields::from(vec![ Field::new("key", DataType::Utf8, false), Field::new("value", DataType::Int32, true), - ]), + ])), false, )), false, @@ -872,8 +822,8 @@ mod tests { #[test] fn test_null_dictionary() { - let values = vec![None, None, None, None, None, None, None, None, None] - as Vec>; + let values = + vec![None, None, None, None, None, None, None, None, None] as Vec>; let array: DictionaryArray = values.into_iter().collect(); let array = Arc::new(array) as ArrayRef; @@ -881,33 +831,103 @@ mod tests { let null_array = new_null_array(array.data_type(), 9); assert_eq!(&array, &null_array); assert_eq!( - array.data().buffers()[0].len(), - null_array.data().buffers()[0].len() + array.to_data().buffers()[0].len(), + null_array.to_data().buffers()[0].len() ); } + #[test] + fn test_null_union() { + for mode in [UnionMode::Sparse, UnionMode::Dense] { + let data_type = DataType::Union( + UnionFields::new( + vec![2, 1], + vec![ + Field::new("foo", DataType::Int32, true), + Field::new("bar", DataType::Int64, true), + ], + ), + mode, + ); + let array = new_null_array(&data_type, 4); + + let array = as_union_array(array.as_ref()); + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + + for i in 0..4 { + let a = array.value(i); + assert_eq!(a.len(), 1); + assert_eq!(a.null_count(), 1); + assert!(a.is_null(0)) + } + + array.to_data().validate_full().unwrap(); + } + } + + #[test] + #[allow(unused_parens)] + fn test_null_runs() { + for r in [DataType::Int16, DataType::Int32, DataType::Int64] { + let data_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", r, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + + let array = new_null_array(&data_type, 4); + let array = array.as_ref(); + + downcast_run_array! { + array => { + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + assert_eq!(array.values().len(), 1); + assert_eq!(array.values().null_count(), 1); + assert_eq!(array.run_ends().len(), 4); + assert_eq!(array.run_ends().values(), &[4]); + + let idx = array.get_physical_indices(&[0, 1, 2, 3]).unwrap(); + assert_eq!(idx, &[0,0,0,0]); + } + d => unreachable!("{d}") + } + } + } + + #[test] + fn test_null_fixed_size_binary() { + for size in [1, 2, 7] { + let array = new_null_array(&DataType::FixedSizeBinary(size), 6); + let array = array + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(array.len(), 6); + assert_eq!(array.null_count(), 6); + array.iter().for_each(|x| assert!(x.is_none())); + } + } + #[test] fn test_memory_size_null() { let null_arr = NullArray::new(32); assert_eq!(0, null_arr.get_buffer_memory_size()); assert_eq!( - std::mem::size_of::(), + std::mem::size_of::(), null_arr.get_array_memory_size() ); - assert_eq!( - std::mem::size_of::(), - std::mem::size_of::(), - ); } #[test] fn test_memory_size_primitive() { let arr = PrimitiveArray::::from_iter_values(0..128); - let empty = - PrimitiveArray::::from(ArrayData::new_empty(arr.data_type())); + let empty = PrimitiveArray::::from(ArrayData::new_empty(arr.data_type())); - // substract empty array to avoid magic numbers for the size of additional fields + // subtract empty array to avoid magic numbers for the size of additional fields assert_eq!( arr.get_array_memory_size() - empty.get_array_memory_size(), 128 * std::mem::size_of::() @@ -931,12 +951,11 @@ mod tests { // which includes the optional validity buffer // plus one buffer on the heap assert_eq!( - std::mem::size_of::>() - + std::mem::size_of::(), + std::mem::size_of::>(), empty_with_bitmap.get_array_memory_size() ); - // substract empty array to avoid magic numbers for the size of additional fields + // subtract empty array to avoid magic numbers for the size of additional fields // the size of the validity bitmap is rounded up to 64 bytes assert_eq!( arr.get_array_memory_size() - empty_with_bitmap.get_array_memory_size(), @@ -951,19 +970,17 @@ mod tests { (0..256).map(|i| (i % values.len()) as i16), ); - let dict_data = ArrayData::builder(DataType::Dictionary( + let dict_data_type = DataType::Dictionary( Box::new(keys.data_type().clone()), Box::new(values.data_type().clone()), - )) - .len(keys.len()) - .buffers(keys.data_ref().buffers().to_vec()) - .child_data(vec![ArrayData::builder(DataType::Int64) - .len(values.len()) - .buffers(values.data_ref().buffers().to_vec()) + ); + let dict_data = keys + .into_data() + .into_builder() + .data_type(dict_data_type) + .child_data(vec![values.into_data()]) .build() - .unwrap()]) - .build() - .unwrap(); + .unwrap(); let empty_data = ArrayData::new_empty(&DataType::Dictionary( Box::new(DataType::Int16), @@ -1009,4 +1026,15 @@ mod tests { assert!(compute_my_thing(&arr)); assert!(compute_my_thing(arr.as_ref())); } + + #[test] + fn test_downcast_array() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let boxed: ArrayRef = Arc::new(array); + let array: Int32Array = downcast_array(&boxed); + + let expected: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert_eq!(array, expected); + } } diff --git a/arrow/src/array/null.rs b/arrow-array/src/array/null_array.rs similarity index 50% rename from arrow/src/array/null.rs rename to arrow-array/src/array/null_array.rs index 467121f6ccfa..af3ec0b57d27 100644 --- a/arrow/src/array/null.rs +++ b/arrow-array/src/array/null_array.rs @@ -17,32 +17,33 @@ //! Contains the `NullArray` type. +use crate::builder::NullBuilder; +use crate::{Array, ArrayRef}; +use arrow_buffer::buffer::NullBuffer; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::DataType; use std::any::Any; -use std::fmt; +use std::sync::Arc; -use crate::array::{Array, ArrayData}; -use crate::datatypes::*; - -/// An Array where all elements are nulls +/// An array of [null values](https://arrow.apache.org/docs/format/Columnar.html#null-layout) /// /// A `NullArray` is a simplified array where all values are null. /// /// # Example: Create an array /// /// ``` -/// use arrow::array::{Array, NullArray}; +/// use arrow_array::{Array, NullArray}; /// -/// # fn main() -> arrow::error::Result<()> { /// let array = NullArray::new(10); /// +/// assert!(array.is_nullable()); /// assert_eq!(array.len(), 10); -/// assert_eq!(array.null_count(), 10); -/// -/// # Ok(()) -/// # } +/// assert_eq!(array.null_count(), 0); +/// assert_eq!(array.logical_nulls().unwrap().null_count(), 10); /// ``` +#[derive(Clone)] pub struct NullArray { - data: ArrayData, + len: usize, } impl NullArray { @@ -52,9 +53,22 @@ impl NullArray { /// other [`DataType`]. /// pub fn new(length: usize) -> Self { - let array_data = ArrayData::builder(DataType::Null).len(length); - let array_data = unsafe { array_data.build_unchecked() }; - NullArray::from(array_data) + Self { len: length } + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced BooleanBuffer cannot exceed the existing length" + ); + + Self { len } + } + + /// Returns a new null array builder + pub fn builder(capacity: usize) -> NullBuilder { + NullBuilder::with_capacity(capacity) } } @@ -63,30 +77,52 @@ impl Array for NullArray { self } - fn data(&self) -> &ArrayData { - &self.data + fn to_data(&self) -> ArrayData { + self.clone().into() } fn into_data(self) -> ArrayData { self.into() } - /// Returns whether the element at `index` is null. - /// All elements of a `NullArray` are always null. - fn is_null(&self, _index: usize) -> bool { - true + fn data_type(&self) -> &DataType { + &DataType::Null + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + None + } + + fn logical_nulls(&self) -> Option { + (self.len != 0).then(|| NullBuffer::new_null(self.len)) + } + + fn is_nullable(&self) -> bool { + !self.is_empty() } - /// Returns whether the element at `index` is valid. - /// All elements of a `NullArray` are always invalid. - fn is_valid(&self, _index: usize) -> bool { - false + fn get_buffer_memory_size(&self) -> usize { + 0 } - /// Returns the total number of null values in this array. - /// The null count of a `NullArray` always equals its length. - fn null_count(&self) -> usize { - self.data_ref().len() + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() } } @@ -103,21 +139,22 @@ impl From for NullArray { "NullArray data should contain 0 buffers" ); assert!( - data.null_buffer().is_none(), + data.nulls().is_none(), "NullArray data should not contain a null buffer, as no buffers are required" ); - Self { data } + Self { len: data.len() } } } impl From for ArrayData { fn from(array: NullArray) -> Self { - array.data + let builder = ArrayDataBuilder::new(DataType::Null).len(array.len); + unsafe { builder.build_unchecked() } } } -impl fmt::Debug for NullArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl std::fmt::Debug for NullArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "NullArray({})", self.len()) } } @@ -131,8 +168,10 @@ mod tests { let null_arr = NullArray::new(32); assert_eq!(null_arr.len(), 32); - assert_eq!(null_arr.null_count(), 32); - assert!(!null_arr.is_valid(0)); + assert_eq!(null_arr.null_count(), 0); + assert_eq!(null_arr.logical_nulls().unwrap().null_count(), 32); + assert!(null_arr.is_valid(0)); + assert!(null_arr.is_nullable()); } #[test] @@ -141,13 +180,15 @@ mod tests { let array2 = array1.slice(8, 16); assert_eq!(array2.len(), 16); - assert_eq!(array2.null_count(), 16); - assert_eq!(array2.offset(), 8); + assert_eq!(array2.null_count(), 0); + assert_eq!(array2.logical_nulls().unwrap().null_count(), 16); + assert!(array2.is_valid(0)); + assert!(array2.is_nullable()); } #[test] fn test_debug_null_array() { let array = NullArray::new(1024 * 1024); - assert_eq!(format!("{:?}", array), "NullArray(1048576)"); + assert_eq!(format!("{array:?}"), "NullArray(1048576)"); } } diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs new file mode 100644 index 000000000000..2296cebd4681 --- /dev/null +++ b/arrow-array/src/array/primitive_array.rs @@ -0,0 +1,2485 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::print_long_array; +use crate::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder}; +use crate::iterator::PrimitiveIter; +use crate::temporal_conversions::{ + as_date, as_datetime, as_datetime_with_timezone, as_duration, as_time, +}; +use crate::timezone::Tz; +use crate::trusted_len::trusted_len_unzip; +use crate::types::*; +use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; +use arrow_buffer::{i256, ArrowNativeType, Buffer, NullBuffer, ScalarBuffer}; +use arrow_data::bit_iterator::try_for_each_valid_idx; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType}; +use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime}; +use half::f16; +use std::any::Any; +use std::sync::Arc; + +/// A [`PrimitiveArray`] of `i8` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int8Array; +/// // Create from Vec> +/// let arr = Int8Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int8Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int8Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int8Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `i16` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int16Array; +/// // Create from Vec> +/// let arr = Int16Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int16Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int16Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int16Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `i32` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int32Array; +/// // Create from Vec> +/// let arr = Int32Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int32Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int32Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `i64` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Int64Array; +/// // Create from Vec> +/// let arr = Int64Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Int64Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Int64Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Int64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u8` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt8Array; +/// // Create from Vec> +/// let arr = UInt8Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt8Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt8Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt8Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u16` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt16Array; +/// // Create from Vec> +/// let arr = UInt16Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt16Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt16Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt16Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u32` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt32Array; +/// // Create from Vec> +/// let arr = UInt32Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt32Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt32Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `u64` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::UInt64Array; +/// // Create from Vec> +/// let arr = UInt64Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = UInt64Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: UInt64Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type UInt64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `f16` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Float16Array; +/// use half::f16; +/// // Create from Vec> +/// let arr = Float16Array::from(vec![Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))]); +/// // Create from Vec +/// let arr = Float16Array::from(vec![f16::from_f64(1.0), f16::from_f64(2.0), f16::from_f64(3.0)]); +/// // Create iter/collect +/// let arr: Float16Array = std::iter::repeat(f16::from_f64(1.0)).take(10).collect(); +/// ``` +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::Float16Array; +/// use half::f16; +/// let arr : Float16Array = [Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))].into_iter().collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Float16Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `f32` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Float32Array; +/// // Create from Vec> +/// let arr = Float32Array::from(vec![Some(1.0), None, Some(2.0)]); +/// // Create from Vec +/// let arr = Float32Array::from(vec![1.0, 2.0, 3.0]); +/// // Create iter/collect +/// let arr: Float32Array = std::iter::repeat(42.0).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Float32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of `f64` +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Float32Array; +/// // Create from Vec> +/// let arr = Float32Array::from(vec![Some(1.0), None, Some(2.0)]); +/// // Create from Vec +/// let arr = Float32Array::from(vec![1.0, 2.0, 3.0]); +/// // Create iter/collect +/// let arr: Float32Array = std::iter::repeat(42.0).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Float64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of seconds since UNIX epoch stored as `i64` +/// +/// This type is similar to the [`chrono::DateTime`] type and can hold +/// values such as `1970-05-09 14:25:11 +01:00` +/// +/// See also [`Timestamp`](arrow_schema::DataType::Timestamp). +/// +/// # Example: UTC timestamps post epoch +/// ``` +/// # use arrow_array::TimestampSecondArray; +/// use arrow_array::timezone::Tz; +/// // Corresponds to single element array with entry 1970-05-09T14:25:11+0:00 +/// let arr = TimestampSecondArray::from(vec![11111111]); +/// // OR +/// let arr = TimestampSecondArray::from(vec![Some(11111111)]); +/// let utc_tz: Tz = "+00:00".parse().unwrap(); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, utc_tz).map(|v| v.to_string()).unwrap(), "1970-05-09 14:25:11 +00:00") +/// ``` +/// +/// # Example: UTC timestamps pre epoch +/// ``` +/// # use arrow_array::TimestampSecondArray; +/// use arrow_array::timezone::Tz; +/// // Corresponds to single element array with entry 1969-08-25T09:34:49+0:00 +/// let arr = TimestampSecondArray::from(vec![-11111111]); +/// // OR +/// let arr = TimestampSecondArray::from(vec![Some(-11111111)]); +/// let utc_tz: Tz = "+00:00".parse().unwrap(); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, utc_tz).map(|v| v.to_string()).unwrap(), "1969-08-25 09:34:49 +00:00") +/// ``` +/// +/// # Example: With timezone specified +/// ``` +/// # use arrow_array::TimestampSecondArray; +/// use arrow_array::timezone::Tz; +/// // Corresponds to single element array with entry 1970-05-10T00:25:11+10:00 +/// let arr = TimestampSecondArray::from(vec![11111111]).with_timezone("+10:00".to_string()); +/// // OR +/// let arr = TimestampSecondArray::from(vec![Some(11111111)]).with_timezone("+10:00".to_string()); +/// let sydney_tz: Tz = "+10:00".parse().unwrap(); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, sydney_tz).map(|v| v.to_string()).unwrap(), "1970-05-10 00:25:11 +10:00") +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type TimestampSecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of milliseconds since UNIX epoch stored as `i64` +/// +/// See examples for [`TimestampSecondArray`] +pub type TimestampMillisecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of microseconds since UNIX epoch stored as `i64` +/// +/// See examples for [`TimestampSecondArray`] +pub type TimestampMicrosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of nanoseconds since UNIX epoch stored as `i64` +/// +/// See examples for [`TimestampSecondArray`] +pub type TimestampNanosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of days since UNIX epoch stored as `i32` +/// +/// This type is similar to the [`chrono::NaiveDate`] type and can hold +/// values such as `2018-11-13` +pub type Date32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of milliseconds since UNIX epoch stored as `i64` +/// +/// This type is similar to the [`chrono::NaiveDate`] type and can hold +/// values such as `2018-11-13` +pub type Date64Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of seconds since midnight stored as `i32` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00` +pub type Time32SecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of milliseconds since midnight stored as `i32` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00.123` +pub type Time32MillisecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of microseconds since midnight stored as `i64` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00.123456` +pub type Time64MicrosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of nanoseconds since midnight stored as `i64` +/// +/// This type is similar to the [`chrono::NaiveTime`] type and can +/// hold values such as `00:02:00.123456789` +pub type Time64NanosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of “calendar” intervals in months +/// +/// See [`IntervalYearMonthType`] for details on representation and caveats. +pub type IntervalYearMonthArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of “calendar” intervals in days and milliseconds +/// +/// See [`IntervalDayTimeType`] for details on representation and caveats. +pub type IntervalDayTimeArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of “calendar” intervals in months, days, and nanoseconds. +/// +/// See [`IntervalMonthDayNanoType`] for details on representation and caveats. +pub type IntervalMonthDayNanoArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in seconds +pub type DurationSecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in milliseconds +pub type DurationMillisecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in microseconds +pub type DurationMicrosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of elapsed durations in nanoseconds +pub type DurationNanosecondArray = PrimitiveArray; + +/// A [`PrimitiveArray`] of 128-bit fixed point decimals +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Decimal128Array; +/// // Create from Vec> +/// let arr = Decimal128Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Decimal128Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Decimal128Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Decimal128Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of 256-bit fixed point decimals +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Decimal256Array; +/// use arrow_buffer::i256; +/// // Create from Vec> +/// let arr = Decimal256Array::from(vec![Some(i256::from(1)), None, Some(i256::from(2))]); +/// // Create from Vec +/// let arr = Decimal256Array::from(vec![i256::from(1), i256::from(2), i256::from(3)]); +/// // Create iter/collect +/// let arr: Decimal256Array = std::iter::repeat(i256::from(42)).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Decimal256Array = PrimitiveArray; + +pub use crate::types::ArrowPrimitiveType; + +/// An array of [primitive values](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout) +/// +/// # Example: From a Vec +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = vec![1, 2, 3, 4].into(); +/// assert_eq!(4, arr.len()); +/// assert_eq!(0, arr.null_count()); +/// assert_eq!(arr.values(), &[1, 2, 3, 4]) +/// ``` +/// +/// # Example: From an optional Vec +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = vec![Some(1), None, Some(3), None].into(); +/// assert_eq!(4, arr.len()); +/// assert_eq!(2, arr.null_count()); +/// // Note: values for null indexes are arbitrary +/// assert_eq!(arr.values(), &[1, 0, 3, 0]) +/// ``` +/// +/// # Example: From an iterator of values +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = (0..10).map(|x| x + 1).collect(); +/// assert_eq!(10, arr.len()); +/// assert_eq!(0, arr.null_count()); +/// for i in 0..10i32 { +/// assert_eq!(i + 1, arr.value(i as usize)); +/// } +/// ``` +/// +/// # Example: From an iterator of option +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr: PrimitiveArray = (0..10).map(|x| (x % 2 == 0).then_some(x)).collect(); +/// assert_eq!(10, arr.len()); +/// assert_eq!(5, arr.null_count()); +/// // Note: values for null indexes are arbitrary +/// assert_eq!(arr.values(), &[0, 0, 2, 0, 4, 0, 6, 0, 8, 0]) +/// ``` +/// +/// # Example: Using Builder +/// +/// ``` +/// # use arrow_array::Array; +/// # use arrow_array::builder::PrimitiveBuilder; +/// # use arrow_array::types::Int32Type; +/// let mut builder = PrimitiveBuilder::::new(); +/// builder.append_value(1); +/// builder.append_null(); +/// builder.append_value(2); +/// let array = builder.finish(); +/// // Note: values for null indexes are arbitrary +/// assert_eq!(array.values(), &[1, 0, 2]); +/// assert!(array.is_null(1)); +/// ``` +pub struct PrimitiveArray { + data_type: DataType, + /// Values data + values: ScalarBuffer, + nulls: Option, +} + +impl Clone for PrimitiveArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + values: self.values.clone(), + nulls: self.nulls.clone(), + } + } +} + +impl PrimitiveArray { + /// Create a new [`PrimitiveArray`] from the provided values and nulls + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + /// + /// # Example + /// + /// Creating a [`PrimitiveArray`] directly from a [`ScalarBuffer`] and [`NullBuffer`] using + /// this constructor is the most performant approach, avoiding any additional allocations + /// + /// ``` + /// # use arrow_array::Int32Array; + /// # use arrow_array::types::Int32Type; + /// # use arrow_buffer::NullBuffer; + /// // [1, 2, 3, 4] + /// let array = Int32Array::new(vec![1, 2, 3, 4].into(), None); + /// // [1, null, 3, 4] + /// let nulls = NullBuffer::from(vec![true, false, true, true]); + /// let array = Int32Array::new(vec![1, 2, 3, 4].into(), Some(nulls)); + /// ``` + pub fn new(values: ScalarBuffer, nulls: Option) -> Self { + Self::try_new(values, nulls).unwrap() + } + + /// Create a new [`PrimitiveArray`] of the given length where all values are null + pub fn new_null(length: usize) -> Self { + Self { + data_type: T::DATA_TYPE, + values: vec![T::Native::usize_as(0); length].into(), + nulls: Some(NullBuffer::new_null(length)), + } + } + + /// Create a new [`PrimitiveArray`] from the provided values and nulls + /// + /// # Errors + /// + /// Errors if: + /// - `values.len() != nulls.len()` + pub fn try_new( + values: ScalarBuffer, + nulls: Option, + ) -> Result { + if let Some(n) = nulls.as_ref() { + if n.len() != values.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of null buffer for PrimitiveArray, expected {} got {}", + values.len(), + n.len(), + ))); + } + } + + Ok(Self { + data_type: T::DATA_TYPE, + values, + nulls, + }) + } + + /// Create a new [`Scalar`] from `value` + pub fn new_scalar(value: T::Native) -> Scalar { + Scalar::new(Self { + data_type: T::DATA_TYPE, + values: vec![value].into(), + nulls: None, + }) + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (DataType, ScalarBuffer, Option) { + (self.data_type, self.values, self.nulls) + } + + /// Overrides the [`DataType`] of this [`PrimitiveArray`] + /// + /// Prefer using [`Self::with_timezone`] or [`Self::with_precision_and_scale`] where + /// the primitive type is suitably constrained, as these cannot panic + /// + /// # Panics + /// + /// Panics if ![Self::is_compatible] + pub fn with_data_type(self, data_type: DataType) -> Self { + Self::assert_compatible(&data_type); + Self { data_type, ..self } + } + + /// Asserts that `data_type` is compatible with `Self` + fn assert_compatible(data_type: &DataType) { + assert!( + Self::is_compatible(data_type), + "PrimitiveArray expected data type {} got {}", + T::DATA_TYPE, + data_type + ); + } + + /// Returns the length of this array. + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// Returns whether this array is empty. + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + /// Returns the values of this array + #[inline] + pub fn values(&self) -> &ScalarBuffer { + &self.values + } + + /// Returns a new primitive array builder + pub fn builder(capacity: usize) -> PrimitiveBuilder { + PrimitiveBuilder::::with_capacity(capacity) + } + + /// Returns if this [`PrimitiveArray`] is compatible with the provided [`DataType`] + /// + /// This is equivalent to `data_type == T::DATA_TYPE`, however ignores timestamp + /// timezones and decimal precision and scale + pub fn is_compatible(data_type: &DataType) -> bool { + match T::DATA_TYPE { + DataType::Timestamp(t1, _) => { + matches!(data_type, DataType::Timestamp(t2, _) if &t1 == t2) + } + DataType::Decimal128(_, _) => matches!(data_type, DataType::Decimal128(_, _)), + DataType::Decimal256(_, _) => matches!(data_type, DataType::Decimal256(_, _)), + _ => T::DATA_TYPE.eq(data_type), + } + } + + /// Returns the primitive value at index `i`. + /// + /// # Safety + /// + /// caller must ensure that the passed in offset is less than the array len() + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> T::Native { + *self.values.get_unchecked(i) + } + + /// Returns the primitive value at index `i`. + /// # Panics + /// Panics if index `i` is out of bounds + #[inline] + pub fn value(&self, i: usize) -> T::Native { + assert!( + i < self.len(), + "Trying to access an element at index {} from a PrimitiveArray of length {}", + i, + self.len() + ); + unsafe { self.value_unchecked(i) } + } + + /// Creates a PrimitiveArray based on an iterator of values without nulls + pub fn from_iter_values>(iter: I) -> Self { + let val_buf: Buffer = iter.into_iter().collect(); + let len = val_buf.len() / std::mem::size_of::(); + Self { + data_type: T::DATA_TYPE, + values: ScalarBuffer::new(val_buf, 0, len), + nulls: None, + } + } + + /// Creates a PrimitiveArray based on a constant value with `count` elements + pub fn from_value(value: T::Native, count: usize) -> Self { + unsafe { + let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| value)); + Self::new(val_buf.into(), None) + } + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + values: self.values.slice(offset, length), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, length)), + } + } + + /// Reinterprets this array's contents as a different data type without copying + /// + /// This can be used to efficiently convert between primitive arrays with the + /// same underlying representation + /// + /// Note: this will not modify the underlying values, and therefore may change + /// the semantic values of the array, e.g. 100 milliseconds in a [`TimestampNanosecondArray`] + /// will become 100 seconds in a [`TimestampSecondArray`]. + /// + /// For casts that preserve the semantic value, check out the [compute kernels] + /// + /// [compute kernels](https://docs.rs/arrow/latest/arrow/compute/kernels/cast/index.html) + /// + /// ``` + /// # use arrow_array::{Int64Array, TimestampNanosecondArray}; + /// let a = Int64Array::from_iter_values([1, 2, 3, 4]); + /// let b: TimestampNanosecondArray = a.reinterpret_cast(); + /// ``` + pub fn reinterpret_cast(&self) -> PrimitiveArray + where + K: ArrowPrimitiveType, + { + let d = self.to_data().into_builder().data_type(K::DATA_TYPE); + + // SAFETY: + // Native type is the same + PrimitiveArray::from(unsafe { d.build_unchecked() }) + } + + /// Applies an unary and infallible function to a primitive array. + /// This is the fastest way to perform an operation on a primitive array when + /// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls. + /// + /// # Implementation + /// + /// This will apply the function for all values, including those on null slots. + /// This implies that the operation must be infallible for any value of the corresponding type + /// or this function may panic. + /// # Example + /// ```rust + /// # use arrow_array::{Int32Array, types::Int32Type}; + /// # fn main() { + /// let array = Int32Array::from(vec![Some(5), Some(7), None]); + /// let c = array.unary(|x| x * 2 + 1); + /// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); + /// # } + /// ``` + pub fn unary(&self, op: F) -> PrimitiveArray + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> O::Native, + { + let nulls = self.nulls().cloned(); + let values = self.values().iter().map(|v| op(*v)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size because arrays are sized. + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + PrimitiveArray::new(buffer.into(), nulls) + } + + /// Applies an unary and infallible function to a mutable primitive array. + /// Mutable primitive array means that the buffer is not shared with other arrays. + /// As a result, this mutates the buffer directly without allocating new buffer. + /// + /// # Implementation + /// + /// This will apply the function for all values, including those on null slots. + /// This implies that the operation must be infallible for any value of the corresponding type + /// or this function may panic. + /// # Example + /// ```rust + /// # use arrow_array::{Int32Array, types::Int32Type}; + /// # fn main() { + /// let array = Int32Array::from(vec![Some(5), Some(7), None]); + /// let c = array.unary_mut(|x| x * 2 + 1).unwrap(); + /// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); + /// # } + /// ``` + pub fn unary_mut(self, op: F) -> Result, PrimitiveArray> + where + F: Fn(T::Native) -> T::Native, + { + let mut builder = self.into_builder()?; + builder + .values_slice_mut() + .iter_mut() + .for_each(|v| *v = op(*v)); + Ok(builder.finish()) + } + + /// Applies a unary and fallible function to all valid values in a primitive array + /// + /// This is unlike [`Self::unary`] which will apply an infallible function to all rows + /// regardless of validity, in many cases this will be significantly faster and should + /// be preferred if `op` is infallible. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn try_unary(&self, op: F) -> Result, E> + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, + { + let len = self.len(); + + let nulls = self.nulls().cloned(); + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + let f = |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(self.value_unchecked(idx))? }; + Ok::<_, E>(()) + }; + + match &nulls { + Some(nulls) => nulls.try_for_each_valid_idx(f)?, + None => (0..len).try_for_each(f)?, + } + + let values = buffer.finish().into(); + Ok(PrimitiveArray::new(values, nulls)) + } + + /// Applies an unary and fallible function to all valid values in a mutable primitive array. + /// Mutable primitive array means that the buffer is not shared with other arrays. + /// As a result, this mutates the buffer directly without allocating new buffer. + /// + /// This is unlike [`Self::unary_mut`] which will apply an infallible function to all rows + /// regardless of validity, in many cases this will be significantly faster and should + /// be preferred if `op` is infallible. + /// + /// This returns an `Err` when the input array is shared buffer with other + /// array. In the case, returned `Err` wraps input array. If the function + /// encounters an error during applying on values. In the case, this returns an `Err` within + /// an `Ok` which wraps the actual error. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn try_unary_mut( + self, + op: F, + ) -> Result, E>, PrimitiveArray> + where + F: Fn(T::Native) -> Result, + { + let len = self.len(); + let null_count = self.null_count(); + let mut builder = self.into_builder()?; + + let (slice, null_buffer) = builder.slices_mut(); + + match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(*slice.get_unchecked(idx))? }; + Ok::<_, E>(()) + }) { + Ok(_) => {} + Err(err) => return Ok(Err(err)), + }; + + Ok(Ok(builder.finish())) + } + + /// Applies a unary and nullable function to all valid values in a primitive array + /// + /// This is unlike [`Self::unary`] which will apply an infallible function to all rows + /// regardless of validity, in many cases this will be significantly faster and should + /// be preferred if `op` is infallible. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn unary_opt(&self, op: F) -> PrimitiveArray + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> Option, + { + let len = self.len(); + let (nulls, null_count, offset) = match self.nulls() { + Some(n) => (Some(n.validity()), n.null_count(), n.offset()), + None => (None, 0, 0), + }; + + let mut null_builder = BooleanBufferBuilder::new(len); + match nulls { + Some(b) => null_builder.append_packed_range(offset..offset + len, b), + None => null_builder.append_n(len, true), + } + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + let mut out_null_count = null_count; + + let _ = try_for_each_valid_idx(len, offset, null_count, nulls, |idx| { + match op(unsafe { self.value_unchecked(idx) }) { + Some(v) => unsafe { *slice.get_unchecked_mut(idx) = v }, + None => { + out_null_count += 1; + null_builder.set_bit(idx, false); + } + } + Ok::<_, ()>(()) + }); + + let nulls = null_builder.finish(); + let values = buffer.finish().into(); + let nulls = unsafe { NullBuffer::new_unchecked(nulls, out_null_count) }; + PrimitiveArray::new(values, Some(nulls)) + } + + /// Returns `PrimitiveBuilder` of this primitive array for mutating its values if the underlying + /// data buffer is not shared by others. + pub fn into_builder(self) -> Result, Self> { + let len = self.len(); + let data = self.into_data(); + let null_bit_buffer = data.nulls().map(|b| b.inner().sliced()); + + let element_len = std::mem::size_of::(); + let buffer = + data.buffers()[0].slice_with_length(data.offset() * element_len, len * element_len); + + drop(data); + + let try_mutable_null_buffer = match null_bit_buffer { + None => Ok(None), + Some(null_buffer) => { + // Null buffer exists, tries to make it mutable + null_buffer.into_mutable().map(Some) + } + }; + + let try_mutable_buffers = match try_mutable_null_buffer { + Ok(mutable_null_buffer) => { + // Got mutable null buffer, tries to get mutable value buffer + let try_mutable_buffer = buffer.into_mutable(); + + // try_mutable_buffer.map(...).map_err(...) doesn't work as the compiler complains + // mutable_null_buffer is moved into map closure. + match try_mutable_buffer { + Ok(mutable_buffer) => Ok(PrimitiveBuilder::::new_from_buffer( + mutable_buffer, + mutable_null_buffer, + )), + Err(buffer) => Err((buffer, mutable_null_buffer.map(|b| b.into()))), + } + } + Err(mutable_null_buffer) => { + // Unable to get mutable null buffer + Err((buffer, Some(mutable_null_buffer))) + } + }; + + match try_mutable_buffers { + Ok(builder) => Ok(builder), + Err((buffer, null_bit_buffer)) => { + let builder = ArrayData::builder(T::DATA_TYPE) + .len(len) + .add_buffer(buffer) + .null_bit_buffer(null_bit_buffer); + + let array_data = unsafe { builder.build_unchecked() }; + let array = PrimitiveArray::::from(array_data); + + Err(array) + } + } + } +} + +impl From> for ArrayData { + fn from(array: PrimitiveArray) -> Self { + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.values.len()) + .nulls(array.nulls) + .buffers(vec![array.values.into_inner()]); + + unsafe { builder.build_unchecked() } + } +} + +impl Array for PrimitiveArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.values.inner().capacity(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + self.get_buffer_memory_size() + } +} + +impl<'a, T: ArrowPrimitiveType> ArrayAccessor for &'a PrimitiveArray { + type Item = T::Native; + + fn value(&self, index: usize) -> Self::Item { + PrimitiveArray::value(self, index) + } + + #[inline] + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + PrimitiveArray::value_unchecked(self, index) + } +} + +impl PrimitiveArray +where + i64: From, +{ + /// Returns value as a chrono `NaiveDateTime`, handling time resolution + /// + /// If a data type cannot be converted to `NaiveDateTime`, a `None` is returned. + /// A valid value is expected, thus the user should first check for validity. + pub fn value_as_datetime(&self, i: usize) -> Option { + as_datetime::(i64::from(self.value(i))) + } + + /// Returns value as a chrono `NaiveDateTime`, handling time resolution with the provided tz + /// + /// functionally it is same as `value_as_datetime`, however it adds + /// the passed tz to the to-be-returned NaiveDateTime + pub fn value_as_datetime_with_tz(&self, i: usize, tz: Tz) -> Option> { + as_datetime_with_timezone::(i64::from(self.value(i)), tz) + } + + /// Returns value as a chrono `NaiveDate` by using `Self::datetime()` + /// + /// If a data type cannot be converted to `NaiveDate`, a `None` is returned + pub fn value_as_date(&self, i: usize) -> Option { + self.value_as_datetime(i).map(|datetime| datetime.date()) + } + + /// Returns a value as a chrono `NaiveTime` + /// + /// `Date32` and `Date64` return UTC midnight as they do not have time resolution + pub fn value_as_time(&self, i: usize) -> Option { + as_time::(i64::from(self.value(i))) + } + + /// Returns a value as a chrono `Duration` + /// + /// If a data type cannot be converted to `Duration`, a `None` is returned + pub fn value_as_duration(&self, i: usize) -> Option { + as_duration::(i64::from(self.value(i))) + } +} + +impl std::fmt::Debug for PrimitiveArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let data_type = self.data_type(); + write!(f, "PrimitiveArray<{data_type:?}>\n[\n")?; + print_long_array(self, f, |array, index, f| match data_type { + DataType::Date32 | DataType::Date64 => { + let v = self.value(index).to_isize().unwrap() as i64; + match as_date::(v) { + Some(date) => write!(f, "{date:?}"), + None => write!(f, "null"), + } + } + DataType::Time32(_) | DataType::Time64(_) => { + let v = self.value(index).to_isize().unwrap() as i64; + match as_time::(v) { + Some(time) => write!(f, "{time:?}"), + None => write!(f, "null"), + } + } + DataType::Timestamp(_, tz_string_opt) => { + let v = self.value(index).to_isize().unwrap() as i64; + match tz_string_opt { + // for Timestamp with TimeZone + Some(tz_string) => { + match tz_string.parse::() { + // if the time zone is valid, construct a DateTime and format it as rfc3339 + Ok(tz) => match as_datetime_with_timezone::(v, tz) { + Some(datetime) => write!(f, "{}", datetime.to_rfc3339()), + None => write!(f, "null"), + }, + // if the time zone is invalid, shows NaiveDateTime with an error message + Err(_) => match as_datetime::(v) { + Some(datetime) => { + write!(f, "{datetime:?} (Unknown Time Zone '{tz_string}')") + } + None => write!(f, "null"), + }, + } + } + // for Timestamp without TimeZone + None => match as_datetime::(v) { + Some(datetime) => write!(f, "{datetime:?}"), + None => write!(f, "null"), + }, + } + } + _ => std::fmt::Debug::fmt(&array.value(index), f), + })?; + write!(f, "]") + } +} + +impl<'a, T: ArrowPrimitiveType> IntoIterator for &'a PrimitiveArray { + type Item = Option<::Native>; + type IntoIter = PrimitiveIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + PrimitiveIter::<'a, T>::new(self) + } +} + +impl<'a, T: ArrowPrimitiveType> PrimitiveArray { + /// constructs a new iterator + pub fn iter(&'a self) -> PrimitiveIter<'a, T> { + PrimitiveIter::<'a, T>::new(self) + } +} + +/// An optional primitive value +/// +/// This struct is used as an adapter when creating `PrimitiveArray` from an iterator. +/// `FromIterator` for `PrimitiveArray` takes an iterator where the elements can be `into` +/// this struct. So once implementing `From` or `Into` trait for a type, an iterator of +/// the type can be collected to `PrimitiveArray`. +#[derive(Debug)] +pub struct NativeAdapter { + /// Corresponding Rust native type if available + pub native: Option, +} + +macro_rules! def_from_for_primitive { + ( $ty:ident, $tt:tt) => { + impl From<$tt> for NativeAdapter<$ty> { + fn from(value: $tt) -> Self { + NativeAdapter { + native: Some(value), + } + } + } + }; +} + +def_from_for_primitive!(Int8Type, i8); +def_from_for_primitive!(Int16Type, i16); +def_from_for_primitive!(Int32Type, i32); +def_from_for_primitive!(Int64Type, i64); +def_from_for_primitive!(UInt8Type, u8); +def_from_for_primitive!(UInt16Type, u16); +def_from_for_primitive!(UInt32Type, u32); +def_from_for_primitive!(UInt64Type, u64); +def_from_for_primitive!(Float16Type, f16); +def_from_for_primitive!(Float32Type, f32); +def_from_for_primitive!(Float64Type, f64); +def_from_for_primitive!(Decimal128Type, i128); +def_from_for_primitive!(Decimal256Type, i256); + +impl From::Native>> for NativeAdapter { + fn from(value: Option<::Native>) -> Self { + NativeAdapter { native: value } + } +} + +impl From<&Option<::Native>> for NativeAdapter { + fn from(value: &Option<::Native>) -> Self { + NativeAdapter { native: *value } + } +} + +impl>> FromIterator for PrimitiveArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut null_builder = BooleanBufferBuilder::new(lower); + + let buffer: Buffer = iter + .map(|item| { + if let Some(a) = item.into().native { + null_builder.append(true); + a + } else { + null_builder.append(false); + // this ensures that null items on the buffer are not arbitrary. + // This is important because fallible operations can use null values (e.g. a vectorized "add") + // which may panic (e.g. overflow if the number on the slots happen to be very large). + T::Native::default() + } + }) + .collect(); + + let len = null_builder.len(); + + let data = unsafe { + ArrayData::new_unchecked( + T::DATA_TYPE, + len, + None, + Some(null_builder.into()), + 0, + vec![buffer], + vec![], + ) + }; + PrimitiveArray::from(data) + } +} + +impl PrimitiveArray { + /// Creates a [`PrimitiveArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter(iter: I) -> Self + where + P: std::borrow::Borrow::Native>>, + I: IntoIterator, + { + let iterator = iter.into_iter(); + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let (null, buffer) = trusted_len_unzip(iterator); + + let data = + ArrayData::new_unchecked(T::DATA_TYPE, len, None, Some(null), 0, vec![buffer], vec![]); + PrimitiveArray::from(data) + } +} + +// TODO: the macro is needed here because we'd get "conflicting implementations" error +// otherwise with both `From>` and `From>>`. +// We should revisit this in future. +macro_rules! def_numeric_from_vec { + ( $ty:ident ) => { + impl From::Native>> for PrimitiveArray<$ty> { + fn from(data: Vec<<$ty as ArrowPrimitiveType>::Native>) -> Self { + let array_data = ArrayData::builder($ty::DATA_TYPE) + .len(data.len()) + .add_buffer(Buffer::from_vec(data)); + let array_data = unsafe { array_data.build_unchecked() }; + PrimitiveArray::from(array_data) + } + } + + // Constructs a primitive array from a vector. Should only be used for testing. + impl From::Native>>> for PrimitiveArray<$ty> { + fn from(data: Vec::Native>>) -> Self { + PrimitiveArray::from_iter(data.iter()) + } + } + }; +} + +def_numeric_from_vec!(Int8Type); +def_numeric_from_vec!(Int16Type); +def_numeric_from_vec!(Int32Type); +def_numeric_from_vec!(Int64Type); +def_numeric_from_vec!(UInt8Type); +def_numeric_from_vec!(UInt16Type); +def_numeric_from_vec!(UInt32Type); +def_numeric_from_vec!(UInt64Type); +def_numeric_from_vec!(Float16Type); +def_numeric_from_vec!(Float32Type); +def_numeric_from_vec!(Float64Type); +def_numeric_from_vec!(Decimal128Type); +def_numeric_from_vec!(Decimal256Type); + +def_numeric_from_vec!(Date32Type); +def_numeric_from_vec!(Date64Type); +def_numeric_from_vec!(Time32SecondType); +def_numeric_from_vec!(Time32MillisecondType); +def_numeric_from_vec!(Time64MicrosecondType); +def_numeric_from_vec!(Time64NanosecondType); +def_numeric_from_vec!(IntervalYearMonthType); +def_numeric_from_vec!(IntervalDayTimeType); +def_numeric_from_vec!(IntervalMonthDayNanoType); +def_numeric_from_vec!(DurationSecondType); +def_numeric_from_vec!(DurationMillisecondType); +def_numeric_from_vec!(DurationMicrosecondType); +def_numeric_from_vec!(DurationNanosecondType); +def_numeric_from_vec!(TimestampSecondType); +def_numeric_from_vec!(TimestampMillisecondType); +def_numeric_from_vec!(TimestampMicrosecondType); +def_numeric_from_vec!(TimestampNanosecondType); + +impl PrimitiveArray { + /// Construct a timestamp array from a vec of i64 values and an optional timezone + #[deprecated(note = "Use with_timezone_opt instead")] + pub fn from_vec(data: Vec, timezone: Option) -> Self + where + Self: From>, + { + Self::from(data).with_timezone_opt(timezone) + } + + /// Construct a timestamp array from a vec of `Option` values and an optional timezone + #[deprecated(note = "Use with_timezone_opt instead")] + pub fn from_opt_vec(data: Vec>, timezone: Option) -> Self + where + Self: From>>, + { + Self::from(data).with_timezone_opt(timezone) + } + + /// Returns the timezone of this array if any + pub fn timezone(&self) -> Option<&str> { + match self.data_type() { + DataType::Timestamp(_, tz) => tz.as_deref(), + _ => unreachable!(), + } + } + + /// Construct a timestamp array with new timezone + pub fn with_timezone(self, timezone: impl Into>) -> Self { + self.with_timezone_opt(Some(timezone.into())) + } + + /// Construct a timestamp array with UTC + pub fn with_timezone_utc(self) -> Self { + self.with_timezone("+00:00") + } + + /// Construct a timestamp array with an optional timezone + pub fn with_timezone_opt>>(self, timezone: Option) -> Self { + Self { + data_type: DataType::Timestamp(T::UNIT, timezone.map(Into::into)), + ..self + } + } +} + +/// Constructs a `PrimitiveArray` from an array data reference. +impl From for PrimitiveArray { + fn from(data: ArrayData) -> Self { + Self::assert_compatible(data.data_type()); + assert_eq!( + data.buffers().len(), + 1, + "PrimitiveArray data should contain a single buffer only (values buffer)" + ); + + let values = ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); + Self { + data_type: data.data_type().clone(), + values, + nulls: data.nulls().cloned(), + } + } +} + +impl PrimitiveArray { + /// Returns a Decimal array with the same data as self, with the + /// specified precision and scale. + /// + /// See [`validate_decimal_precision_and_scale`] + pub fn with_precision_and_scale(self, precision: u8, scale: i8) -> Result { + validate_decimal_precision_and_scale::(precision, scale)?; + Ok(Self { + data_type: T::TYPE_CONSTRUCTOR(precision, scale), + ..self + }) + } + + /// Validates values in this array can be properly interpreted + /// with the specified precision. + pub fn validate_decimal_precision(&self, precision: u8) -> Result<(), ArrowError> { + (0..self.len()).try_for_each(|idx| { + if self.is_valid(idx) { + let decimal = unsafe { self.value_unchecked(idx) }; + T::validate_decimal_precision(decimal, precision) + } else { + Ok(()) + } + }) + } + + /// Validates the Decimal Array, if the value of slot is overflow for the specified precision, and + /// will be casted to Null + pub fn null_if_overflow_precision(&self, precision: u8) -> Self { + self.unary_opt::<_, T>(|v| { + (T::validate_decimal_precision(v, precision).is_ok()).then_some(v) + }) + } + + /// Returns [`Self::value`] formatted as a string + pub fn value_as_string(&self, row: usize) -> String { + T::format_decimal(self.value(row), self.precision(), self.scale()) + } + + /// Returns the decimal precision of this array + pub fn precision(&self) -> u8 { + match T::BYTE_LENGTH { + 16 => { + if let DataType::Decimal128(p, _) = self.data_type() { + *p + } else { + unreachable!( + "Decimal128Array datatype is not DataType::Decimal128 but {}", + self.data_type() + ) + } + } + 32 => { + if let DataType::Decimal256(p, _) = self.data_type() { + *p + } else { + unreachable!( + "Decimal256Array datatype is not DataType::Decimal256 but {}", + self.data_type() + ) + } + } + other => unreachable!("Unsupported byte length for decimal array {}", other), + } + } + + /// Returns the decimal scale of this array + pub fn scale(&self) -> i8 { + match T::BYTE_LENGTH { + 16 => { + if let DataType::Decimal128(_, s) = self.data_type() { + *s + } else { + unreachable!( + "Decimal128Array datatype is not DataType::Decimal128 but {}", + self.data_type() + ) + } + } + 32 => { + if let DataType::Decimal256(_, s) = self.data_type() { + *s + } else { + unreachable!( + "Decimal256Array datatype is not DataType::Decimal256 but {}", + self.data_type() + ) + } + } + other => unreachable!("Unsupported byte length for decimal array {}", other), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::{Decimal128Builder, Decimal256Builder}; + use crate::cast::downcast_array; + use crate::{ArrayRef, BooleanArray}; + use arrow_schema::TimeUnit; + use std::sync::Arc; + + #[test] + fn test_primitive_array_from_vec() { + let buf = Buffer::from_slice_ref([0, 1, 2, 3, 4]); + let arr = Int32Array::from(vec![0, 1, 2, 3, 4]); + assert_eq!(&buf, arr.values.inner()); + assert_eq!(5, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + for i in 0..5 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i as i32, arr.value(i)); + } + } + + #[test] + fn test_primitive_array_from_vec_option() { + // Test building a primitive array with null values + let arr = Int32Array::from(vec![Some(0), None, Some(2), None, Some(4)]); + assert_eq!(5, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(2, arr.null_count()); + for i in 0..5 { + if i % 2 == 0 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i as i32, arr.value(i)); + } else { + assert!(arr.is_null(i)); + assert!(!arr.is_valid(i)); + } + } + } + + #[test] + fn test_date64_array_from_vec_option() { + // Test building a primitive array with null values + // we use Int32 and Int64 as a backing array, so all Int32 and Int64 conventions + // work + let arr: PrimitiveArray = + vec![Some(1550902545147), None, Some(1550902545147)].into(); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + for i in 0..3 { + if i % 2 == 0 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(1550902545147, arr.value(i)); + // roundtrip to and from datetime + assert_eq!( + 1550902545147, + arr.value_as_datetime(i).unwrap().timestamp_millis() + ); + } else { + assert!(arr.is_null(i)); + assert!(!arr.is_valid(i)); + } + } + } + + #[test] + fn test_time32_millisecond_array_from_vec() { + // 1: 00:00:00.001 + // 37800005: 10:30:00.005 + // 86399210: 23:59:59.210 + let arr: PrimitiveArray = vec![1, 37_800_005, 86_399_210].into(); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + let formatted = ["00:00:00.001", "10:30:00.005", "23:59:59.210"]; + for (i, formatted) in formatted.iter().enumerate().take(3) { + // check that we can't create dates or datetimes from time instances + assert_eq!(None, arr.value_as_datetime(i)); + assert_eq!(None, arr.value_as_date(i)); + let time = arr.value_as_time(i).unwrap(); + assert_eq!(*formatted, time.format("%H:%M:%S%.3f").to_string()); + } + } + + #[test] + fn test_time64_nanosecond_array_from_vec() { + // Test building a primitive array with null values + // we use Int32 and Int64 as a backing array, so all Int32 and Int64 conventions + // work + + // 1e6: 00:00:00.001 + // 37800005e6: 10:30:00.005 + // 86399210e6: 23:59:59.210 + let arr: PrimitiveArray = + vec![1_000_000, 37_800_005_000_000, 86_399_210_000_000].into(); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + let formatted = ["00:00:00.001", "10:30:00.005", "23:59:59.210"]; + for (i, item) in formatted.iter().enumerate().take(3) { + // check that we can't create dates or datetimes from time instances + assert_eq!(None, arr.value_as_datetime(i)); + assert_eq!(None, arr.value_as_date(i)); + let time = arr.value_as_time(i).unwrap(); + assert_eq!(*item, time.format("%H:%M:%S%.3f").to_string()); + } + } + + #[test] + fn test_interval_array_from_vec() { + // intervals are currently not treated specially, but are Int32 and Int64 arrays + let arr = IntervalYearMonthArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + // a day_time interval contains days and milliseconds, but we do not yet have accessors for the values + let arr = IntervalDayTimeArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + // a month_day_nano interval contains months, days and nanoseconds, + // but we do not yet have accessors for the values. + // TODO: implement month, day, and nanos access method for month_day_nano. + let arr = IntervalMonthDayNanoArray::from(vec![ + Some(100000000000000000000), + None, + Some(-500000000000000000000), + ]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(100000000000000000000, arr.value(0)); + assert_eq!(100000000000000000000, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-500000000000000000000, arr.value(2)); + assert_eq!(-500000000000000000000, arr.values()[2]); + } + + #[test] + fn test_duration_array_from_vec() { + let arr = DurationSecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + let arr = DurationMillisecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + let arr = DurationMicrosecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + + let arr = DurationNanosecondArray::from(vec![Some(1), None, Some(-5)]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(1, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-5, arr.value(2)); + assert_eq!(-5, arr.values()[2]); + } + + #[test] + fn test_timestamp_array_from_vec() { + let arr = TimestampSecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + + let arr = TimestampMillisecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + + let arr = TimestampMicrosecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + + let arr = TimestampNanosecondArray::from(vec![1, -5]); + assert_eq!(2, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert_eq!(1, arr.value(0)); + assert_eq!(-5, arr.value(1)); + assert_eq!(&[1, -5], arr.values()); + } + + #[test] + fn test_primitive_array_slice() { + let arr = Int32Array::from(vec![ + Some(0), + None, + Some(2), + None, + Some(4), + Some(5), + Some(6), + None, + None, + ]); + assert_eq!(9, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(4, arr.null_count()); + + let arr2 = arr.slice(2, 5); + assert_eq!(5, arr2.len()); + assert_eq!(1, arr2.null_count()); + + for i in 0..arr2.len() { + assert_eq!(i == 1, arr2.is_null(i)); + assert_eq!(i != 1, arr2.is_valid(i)); + } + let int_arr2 = arr2.as_any().downcast_ref::().unwrap(); + assert_eq!(2, int_arr2.values()[0]); + assert_eq!(&[4, 5, 6], &int_arr2.values()[2..5]); + + let arr3 = arr2.slice(2, 3); + assert_eq!(3, arr3.len()); + assert_eq!(0, arr3.null_count()); + + let int_arr3 = arr3.as_any().downcast_ref::().unwrap(); + assert_eq!(&[4, 5, 6], int_arr3.values()); + assert_eq!(4, int_arr3.value(0)); + assert_eq!(5, int_arr3.value(1)); + assert_eq!(6, int_arr3.value(2)); + } + + #[test] + fn test_boolean_array_slice() { + let arr = BooleanArray::from(vec![ + Some(true), + None, + Some(false), + None, + Some(true), + Some(false), + Some(true), + Some(false), + None, + Some(true), + ]); + + assert_eq!(10, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(3, arr.null_count()); + + let arr2 = arr.slice(3, 5); + assert_eq!(5, arr2.len()); + assert_eq!(3, arr2.offset()); + assert_eq!(1, arr2.null_count()); + + let bool_arr = arr2.as_any().downcast_ref::().unwrap(); + + assert!(!bool_arr.is_valid(0)); + + assert!(bool_arr.is_valid(1)); + assert!(bool_arr.value(1)); + + assert!(bool_arr.is_valid(2)); + assert!(!bool_arr.value(2)); + + assert!(bool_arr.is_valid(3)); + assert!(bool_arr.value(3)); + + assert!(bool_arr.is_valid(4)); + assert!(!bool_arr.value(4)); + } + + #[test] + fn test_int32_fmt_debug() { + let arr = Int32Array::from(vec![0, 1, 2, 3, 4]); + assert_eq!( + "PrimitiveArray\n[\n 0,\n 1,\n 2,\n 3,\n 4,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_fmt_debug_up_to_20_elements() { + (1..=20).for_each(|i| { + let values = (0..i).collect::>(); + let array_expected = format!( + "PrimitiveArray\n[\n{}\n]", + values + .iter() + .map(|v| { format!(" {v},") }) + .collect::>() + .join("\n") + ); + let array = Int16Array::from(values); + + assert_eq!(array_expected, format!("{array:?}")); + }) + } + + #[test] + fn test_int32_with_null_fmt_debug() { + let mut builder = Int32Array::builder(3); + builder.append_slice(&[0, 1]); + builder.append_null(); + builder.append_slice(&[3, 4]); + let arr = builder.finish(); + assert_eq!( + "PrimitiveArray\n[\n 0,\n 1,\n null,\n 3,\n 4,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00,\n 2018-12-31T00:00:00,\n 1921-01-02T00:00:00,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_utc_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone_utc(); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00+00:00,\n 2018-12-31T00:00:00+00:00,\n 1921-01-02T00:00:00+00:00,\n]", + format!("{arr:?}") + ); + } + + #[test] + #[cfg(feature = "chrono-tz")] + fn test_timestamp_with_named_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("Asia/Taipei".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", + format!("{:?}", arr) + ); + } + + #[test] + #[cfg(not(feature = "chrono-tz"))] + fn test_timestamp_with_named_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("Asia/Taipei".to_string()); + + println!("{arr:?}"); + + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_with_fixed_offset_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("+08:00".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_timestamp_with_incorrect_tz_fmt_debug() { + let arr: PrimitiveArray = + TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) + .with_timezone("xxx".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'xxx'),\n]", + format!("{arr:?}") + ); + } + + #[test] + #[cfg(feature = "chrono-tz")] + fn test_timestamp_with_tz_with_daylight_saving_fmt_debug() { + let arr: PrimitiveArray = TimestampMillisecondArray::from(vec![ + 1647161999000, + 1647162000000, + 1667717999000, + 1667718000000, + ]) + .with_timezone("America/Denver".to_string()); + assert_eq!( + "PrimitiveArray\n[\n 2022-03-13T01:59:59-07:00,\n 2022-03-13T03:00:00-06:00,\n 2022-11-06T00:59:59-06:00,\n 2022-11-06T01:00:00-06:00,\n]", + format!("{:?}", arr) + ); + } + + #[test] + fn test_date32_fmt_debug() { + let arr: PrimitiveArray = vec![12356, 13548, -365].into(); + assert_eq!( + "PrimitiveArray\n[\n 2003-10-31,\n 2007-02-04,\n 1969-01-01,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_time32second_fmt_debug() { + let arr: PrimitiveArray = vec![7201, 60054].into(); + assert_eq!( + "PrimitiveArray\n[\n 02:00:01,\n 16:40:54,\n]", + format!("{arr:?}") + ); + } + + #[test] + fn test_time32second_invalid_neg() { + // chrono::NaiveDatetime::from_timestamp_opt returns None while input is invalid + let arr: PrimitiveArray = vec![-7201, -60054].into(); + assert_eq!( + "PrimitiveArray\n[\n null,\n null,\n]", + format!("{arr:?}") + ) + } + + #[test] + fn test_timestamp_micros_out_of_range() { + // replicate the issue from https://github.com/apache/arrow-datafusion/issues/3832 + let arr: PrimitiveArray = vec![9065525203050843594].into(); + assert_eq!( + "PrimitiveArray\n[\n null,\n]", + format!("{arr:?}") + ) + } + + #[test] + fn test_primitive_array_builder() { + // Test building a primitive array with ArrayData builder and offset + let buf = Buffer::from_slice_ref([0i32, 1, 2, 3, 4, 5, 6]); + let buf2 = buf.slice_with_length(8, 20); + let data = ArrayData::builder(DataType::Int32) + .len(5) + .offset(2) + .add_buffer(buf) + .build() + .unwrap(); + let arr = Int32Array::from(data); + assert_eq!(&buf2, arr.values.inner()); + assert_eq!(5, arr.len()); + assert_eq!(0, arr.null_count()); + for i in 0..3 { + assert_eq!((i + 2) as i32, arr.value(i)); + } + } + + #[test] + fn test_primitive_from_iter_values() { + // Test building a primitive array with from_iter_values + let arr: PrimitiveArray = PrimitiveArray::from_iter_values(0..10); + assert_eq!(10, arr.len()); + assert_eq!(0, arr.null_count()); + for i in 0..10i32 { + assert_eq!(i, arr.value(i as usize)); + } + } + + #[test] + fn test_primitive_array_from_unbound_iter() { + // iterator that doesn't declare (upper) size bound + let value_iter = (0..) + .scan(0usize, |pos, i| { + if *pos < 10 { + *pos += 1; + Some(Some(i)) + } else { + // actually returns up to 10 values + None + } + }) + // limited using take() + .take(100); + + let (_, upper_size_bound) = value_iter.size_hint(); + // the upper bound, defined by take above, is 100 + assert_eq!(upper_size_bound, Some(100)); + let primitive_array: PrimitiveArray = value_iter.collect(); + // but the actual number of items in the array should be 10 + assert_eq!(primitive_array.len(), 10); + } + + #[test] + fn test_primitive_array_from_non_null_iter() { + let iter = (0..10_i32).map(Some); + let primitive_array = PrimitiveArray::::from_iter(iter); + assert_eq!(primitive_array.len(), 10); + assert_eq!(primitive_array.null_count(), 0); + assert!(primitive_array.nulls().is_none()); + assert_eq!(primitive_array.values(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + } + + #[test] + #[should_panic(expected = "PrimitiveArray data should contain a single buffer only \ + (values buffer)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_primitive_array_invalid_buffer_len() { + let buffer = Buffer::from_slice_ref([0i32, 1, 2, 3, 4]); + let data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(buffer.clone()) + .add_buffer(buffer) + .len(5) + .build_unchecked() + }; + + drop(Int32Array::from(data)); + } + + #[test] + fn test_access_array_concurrently() { + let a = Int32Array::from(vec![5, 6, 7, 8, 9]); + let ret = std::thread::spawn(move || a.value(3)).join(); + + assert!(ret.is_ok()); + assert_eq!(8, ret.ok().unwrap()); + } + + #[test] + fn test_primitive_array_creation() { + let array1: Int8Array = [10_i8, 11, 12, 13, 14].into_iter().collect(); + let array2: Int8Array = [10_i8, 11, 12, 13, 14].into_iter().map(Some).collect(); + + assert_eq!(array1, array2); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3" + )] + fn test_string_array_get_value_index_out_of_bound() { + let array: Int8Array = [10_i8, 11, 12].into_iter().collect(); + + array.value(4); + } + + #[test] + #[should_panic(expected = "PrimitiveArray expected data type Int64 got Int32")] + fn test_from_array_data_validation() { + let foo = PrimitiveArray::::from_iter([1, 2, 3]); + let _ = PrimitiveArray::::from(foo.into_data()); + } + + #[test] + fn test_decimal128() { + let values: Vec<_> = vec![0, 1, -1, i128::MIN, i128::MAX]; + let array: PrimitiveArray = + PrimitiveArray::from_iter(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array: PrimitiveArray = + PrimitiveArray::from_iter_values(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(values.clone()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(array.to_data()); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_decimal256() { + let values: Vec<_> = vec![i256::ZERO, i256::ONE, i256::MINUS_ONE, i256::MIN, i256::MAX]; + + let array: PrimitiveArray = + PrimitiveArray::from_iter(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array: PrimitiveArray = + PrimitiveArray::from_iter_values(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(values.clone()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(array.to_data()); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_decimal_array() { + // let val_8887: [u8; 16] = [192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + // let val_neg_8887: [u8; 16] = [64, 36, 75, 238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255]; + let values: [u8; 32] = [ + 192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 36, 75, 238, 253, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]; + let array_data = ArrayData::builder(DataType::Decimal128(38, 6)) + .len(2) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + let decimal_array = Decimal128Array::from(array_data); + assert_eq!(8_887_000_000_i128, decimal_array.value(0)); + assert_eq!(-8_887_000_000_i128, decimal_array.value(1)); + } + + #[test] + fn test_decimal_append_error_value() { + let mut decimal_builder = Decimal128Builder::with_capacity(10); + decimal_builder.append_value(123456); + decimal_builder.append_value(12345); + let result = decimal_builder.finish().with_precision_and_scale(5, 3); + assert!(result.is_ok()); + let arr = result.unwrap(); + assert_eq!("12.345", arr.value_as_string(1)); + + // Validate it explicitly + let result = arr.validate_decimal_precision(5); + let error = result.unwrap_err(); + assert_eq!( + "Invalid argument error: 123456 is too large to store in a Decimal128 of precision 5. Max is 99999", + error.to_string() + ); + + decimal_builder = Decimal128Builder::new(); + decimal_builder.append_value(100); + decimal_builder.append_value(99); + decimal_builder.append_value(-100); + decimal_builder.append_value(-99); + let result = decimal_builder.finish().with_precision_and_scale(2, 1); + assert!(result.is_ok()); + let arr = result.unwrap(); + assert_eq!("9.9", arr.value_as_string(1)); + assert_eq!("-9.9", arr.value_as_string(3)); + + // Validate it explicitly + let result = arr.validate_decimal_precision(2); + let error = result.unwrap_err(); + assert_eq!( + "Invalid argument error: 100 is too large to store in a Decimal128 of precision 2. Max is 99", + error.to_string() + ); + } + + #[test] + fn test_decimal_from_iter_values() { + let array = Decimal128Array::from_iter_values(vec![-100, 0, 101]); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); + assert_eq!(-100_i128, array.value(0)); + assert!(!array.is_null(0)); + assert_eq!(0_i128, array.value(1)); + assert!(!array.is_null(1)); + assert_eq!(101_i128, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_decimal_from_iter() { + let array: Decimal128Array = vec![Some(-100), None, Some(101)].into_iter().collect(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); + assert_eq!(-100_i128, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(101_i128, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_decimal_iter_sized() { + let data = vec![Some(-100), None, Some(101)]; + let array: Decimal128Array = data.into_iter().collect(); + let mut iter = array.into_iter(); + + // is exact sized + assert_eq!(array.len(), 3); + + // size_hint is reported correctly + assert_eq!(iter.size_hint(), (3, Some(3))); + iter.next().unwrap(); + assert_eq!(iter.size_hint(), (2, Some(2))); + iter.next().unwrap(); + iter.next().unwrap(); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert!(iter.next().is_none()); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + #[test] + fn test_decimal_array_value_as_string() { + let arr = [123450, -123450, 100, -100, 10, -10, 0] + .into_iter() + .map(Some) + .collect::() + .with_precision_and_scale(6, 3) + .unwrap(); + + assert_eq!("123.450", arr.value_as_string(0)); + assert_eq!("-123.450", arr.value_as_string(1)); + assert_eq!("0.100", arr.value_as_string(2)); + assert_eq!("-0.100", arr.value_as_string(3)); + assert_eq!("0.010", arr.value_as_string(4)); + assert_eq!("-0.010", arr.value_as_string(5)); + assert_eq!("0.000", arr.value_as_string(6)); + } + + #[test] + fn test_decimal_array_with_precision_and_scale() { + let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) + .with_precision_and_scale(20, 2) + .unwrap(); + + assert_eq!(arr.data_type(), &DataType::Decimal128(20, 2)); + assert_eq!(arr.precision(), 20); + assert_eq!(arr.scale(), 2); + + let actual: Vec<_> = (0..arr.len()).map(|i| arr.value_as_string(i)).collect(); + let expected = vec!["123.45", "4.56", "78.90", "-1232234234324.32"]; + + assert_eq!(actual, expected); + } + + #[test] + #[should_panic( + expected = "-123223423432432 is too small to store in a Decimal128 of precision 5. Min is -99999" + )] + fn test_decimal_array_with_precision_and_scale_out_of_range() { + let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) + // precision is too small to hold value + .with_precision_and_scale(5, 2) + .unwrap(); + arr.validate_decimal_precision(5).unwrap(); + } + + #[test] + #[should_panic(expected = "precision cannot be 0, has to be between [1, 38]")] + fn test_decimal_array_with_precision_zero() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(0, 2) + .unwrap(); + } + + #[test] + #[should_panic(expected = "precision 40 is greater than max 38")] + fn test_decimal_array_with_precision_and_scale_invalid_precision() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(40, 2) + .unwrap(); + } + + #[test] + #[should_panic(expected = "scale 40 is greater than max 38")] + fn test_decimal_array_with_precision_and_scale_invalid_scale() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(20, 40) + .unwrap(); + } + + #[test] + #[should_panic(expected = "scale 10 is greater than precision 4")] + fn test_decimal_array_with_precision_and_scale_invalid_precision_and_scale() { + Decimal128Array::from_iter_values([12345, 456]) + .with_precision_and_scale(4, 10) + .unwrap(); + } + + #[test] + fn test_decimal_array_set_null_if_overflow_with_precision() { + let array = Decimal128Array::from(vec![Some(123456), Some(123), None, Some(123456)]); + let result = array.null_if_overflow_precision(5); + let expected = Decimal128Array::from(vec![None, Some(123), None, None]); + assert_eq!(result, expected); + } + + #[test] + fn test_decimal256_iter() { + let mut builder = Decimal256Builder::with_capacity(30); + let decimal1 = i256::from_i128(12345); + builder.append_value(decimal1); + + builder.append_null(); + + let decimal2 = i256::from_i128(56789); + builder.append_value(decimal2); + + let array: Decimal256Array = builder.finish().with_precision_and_scale(76, 6).unwrap(); + + let collected: Vec<_> = array.iter().collect(); + assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected); + } + + #[test] + fn test_from_iter_decimal256array() { + let value1 = i256::from_i128(12345); + let value2 = i256::from_i128(56789); + + let mut array: Decimal256Array = + vec![Some(value1), None, Some(value2)].into_iter().collect(); + array = array.with_precision_and_scale(76, 10).unwrap(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal256(76, 10)); + assert_eq!(value1, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(value2, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_from_iter_decimal128array() { + let mut array: Decimal128Array = vec![Some(-100), None, Some(101)].into_iter().collect(); + array = array.with_precision_and_scale(38, 10).unwrap(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); + assert_eq!(-100_i128, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(101_i128, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_unary_opt() { + let array = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7]); + let r = array.unary_opt::<_, Int32Type>(|x| (x % 2 != 0).then_some(x)); + + let expected = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5), None, Some(7)]); + assert_eq!(r, expected); + + let r = expected.unary_opt::<_, Int32Type>(|x| (x % 3 != 0).then_some(x)); + let expected = Int32Array::from(vec![Some(1), None, None, None, Some(5), None, Some(7)]); + assert_eq!(r, expected); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3" + )] + fn test_fixed_size_binary_array_get_value_index_out_of_bound() { + let array = Decimal128Array::from(vec![-100, 0, 101]); + array.value(4); + } + + #[test] + fn test_into_builder() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let boxed: ArrayRef = Arc::new(array); + let col: Int32Array = downcast_array(&boxed); + drop(boxed); + + let mut builder = col.into_builder().unwrap(); + + let slice = builder.values_slice_mut(); + assert_eq!(slice, &[1, 2, 3]); + + slice[0] = 4; + slice[1] = 2; + slice[2] = 1; + + let expected: Int32Array = vec![Some(4), Some(2), Some(1)].into_iter().collect(); + + let new_array = builder.finish(); + assert_eq!(expected, new_array); + } + + #[test] + fn test_into_builder_cloned_array() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let boxed: ArrayRef = Arc::new(array); + + let col: Int32Array = PrimitiveArray::::from(boxed.to_data()); + let err = col.into_builder(); + + match err { + Ok(_) => panic!("Should not get builder from cloned array"), + Err(returned) => { + let expected: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert_eq!(expected, returned) + } + } + } + + #[test] + fn test_into_builder_on_sliced_array() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + let slice = array.slice(1, 2); + let col: Int32Array = downcast_array(&slice); + + drop(slice); + + col.into_builder() + .expect_err("Should not build builder from sliced array"); + } + + #[test] + fn test_unary_mut() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + let c = array.unary_mut(|x| x * 2 + 1).unwrap(); + let expected: Int32Array = vec![3, 5, 7].into_iter().map(Some).collect(); + + assert_eq!(expected, c); + + let array: Int32Array = Int32Array::from(vec![Some(5), Some(7), None]); + let c = array.unary_mut(|x| x * 2 + 1).unwrap(); + assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); + } + + #[test] + #[should_panic( + expected = "PrimitiveArray expected data type Interval(MonthDayNano) got Interval(DayTime)" + )] + fn test_invalid_interval_type() { + let array = IntervalDayTimeArray::from(vec![1, 2, 3]); + let _ = IntervalMonthDayNanoArray::from(array.into_data()); + } + + #[test] + fn test_timezone() { + let array = TimestampNanosecondArray::from_iter_values([1, 2]); + assert_eq!(array.timezone(), None); + + let array = array.with_timezone("+02:00"); + assert_eq!(array.timezone(), Some("+02:00")); + } + + #[test] + fn test_try_new() { + Int32Array::new(vec![1, 2, 3, 4].into(), None); + Int32Array::new(vec![1, 2, 3, 4].into(), Some(NullBuffer::new_null(4))); + + let err = Int32Array::try_new(vec![1, 2, 3, 4].into(), Some(NullBuffer::new_null(3))) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for PrimitiveArray, expected 4 got 3" + ); + + TimestampNanosecondArray::new(vec![1, 2, 3, 4].into(), None).with_data_type( + DataType::Timestamp(TimeUnit::Nanosecond, Some("03:00".into())), + ); + } + + #[test] + #[should_panic(expected = "PrimitiveArray expected data type Int32 got Date32")] + fn test_with_data_type() { + Int32Array::new(vec![1, 2, 3, 4].into(), None).with_data_type(DataType::Date32); + } +} diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs new file mode 100644 index 000000000000..4877f9f850a3 --- /dev/null +++ b/arrow-array/src/array/run_array.rs @@ -0,0 +1,1095 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, RunEndBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, Field}; + +use crate::{ + builder::StringRunBuilder, + make_array, + run_iterator::RunArrayIter, + types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, + Array, ArrayAccessor, ArrayRef, PrimitiveArray, +}; + +/// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout) +/// +/// This encoding is variation on [run-length encoding (RLE)](https://en.wikipedia.org/wiki/Run-length_encoding) +/// and is good for representing data containing same values repeated consecutively. +/// +/// [`RunArray`] contains `run_ends` array and `values` array of same length. +/// The `run_ends` array stores the indexes at which the run ends. The `values` array +/// stores the value of each run. Below example illustrates how a logical array is represented in +/// [`RunArray`] +/// +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ +/// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ +/// │ │ A │ │ 2 │ │ │ A │ +/// ├─────────────────┤ ├─────────┤ ├─────────────────┤ +/// │ │ D │ │ 3 │ │ │ A │ run length of 'A' = runs_ends[0] - 0 = 2 +/// ├─────────────────┤ ├─────────┤ ├─────────────────┤ +/// │ │ B │ │ 6 │ │ │ D │ run length of 'D' = run_ends[1] - run_ends[0] = 1 +/// └─────────────────┘ └─────────┘ ├─────────────────┤ +/// │ values run_ends │ │ B │ +/// ├─────────────────┤ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ │ B │ +/// ├─────────────────┤ +/// RunArray │ B │ run length of 'B' = run_ends[2] - run_ends[1] = 3 +/// length = 3 └─────────────────┘ +/// +/// Logical array +/// Contents +/// ``` + +pub struct RunArray { + data_type: DataType, + run_ends: RunEndBuffer, + values: ArrayRef, +} + +impl Clone for RunArray { + fn clone(&self) -> Self { + Self { + data_type: self.data_type.clone(), + run_ends: self.run_ends.clone(), + values: self.values.clone(), + } + } +} + +impl RunArray { + /// Calculates the logical length of the array encoded + /// by the given run_ends array. + pub fn logical_len(run_ends: &PrimitiveArray) -> usize { + let len = run_ends.len(); + if len == 0 { + return 0; + } + run_ends.value(len - 1).as_usize() + } + + /// Attempts to create RunArray using given run_ends (index where a run ends) + /// and the values (value of the run). Returns an error if the given data is not compatible + /// with RunEndEncoded specification. + pub fn try_new(run_ends: &PrimitiveArray, values: &dyn Array) -> Result { + let run_ends_type = run_ends.data_type().clone(); + let values_type = values.data_type().clone(); + let ree_array_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", run_ends_type, false)), + Arc::new(Field::new("values", values_type, true)), + ); + let len = RunArray::logical_len(run_ends); + let builder = ArrayDataBuilder::new(ree_array_type) + .len(len) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + + // `build_unchecked` is used to avoid recursive validation of child arrays. + let array_data = unsafe { builder.build_unchecked() }; + + // Safety: `validate_data` checks below + // 1. The given array data has exactly two child arrays. + // 2. The first child array (run_ends) has valid data type. + // 3. run_ends array does not have null values + // 4. run_ends array has non-zero and strictly increasing values. + // 5. The length of run_ends array and values array are the same. + array_data.validate_data()?; + + Ok(array_data.into()) + } + + /// Returns a reference to [`RunEndBuffer`] + pub fn run_ends(&self) -> &RunEndBuffer { + &self.run_ends + } + + /// Returns a reference to values array + /// + /// Note: any slicing of this [`RunArray`] array is not applied to the returned array + /// and must be handled separately + pub fn values(&self) -> &ArrayRef { + &self.values + } + + /// Returns the physical index at which the array slice starts. + pub fn get_start_physical_index(&self) -> usize { + self.run_ends.get_start_physical_index() + } + + /// Returns the physical index at which the array slice ends. + pub fn get_end_physical_index(&self) -> usize { + self.run_ends.get_end_physical_index() + } + + /// Downcast this [`RunArray`] to a [`TypedRunArray`] + /// + /// ``` + /// use arrow_array::{Array, ArrayAccessor, RunArray, StringArray, types::Int32Type}; + /// + /// let orig = [Some("a"), Some("b"), None]; + /// let run_array = RunArray::::from_iter(orig); + /// let typed = run_array.downcast::().unwrap(); + /// assert_eq!(typed.value(0), "a"); + /// assert_eq!(typed.value(1), "b"); + /// assert!(typed.values().is_null(2)); + /// ``` + /// + pub fn downcast(&self) -> Option> { + let values = self.values.as_any().downcast_ref()?; + Some(TypedRunArray { + run_array: self, + values, + }) + } + + /// Returns index to the physical array for the given index to the logical array. + /// This function adjusts the input logical index based on `ArrayData::offset` + /// Performs a binary search on the run_ends array for the input index. + /// + /// The result is arbitrary if `logical_index >= self.len()` + pub fn get_physical_index(&self, logical_index: usize) -> usize { + self.run_ends.get_physical_index(logical_index) + } + + /// Returns the physical indices of the input logical indices. Returns error if any of the logical + /// index cannot be converted to physical index. The logical indices are sorted and iterated along + /// with run_ends array to find matching physical index. The approach used here was chosen over + /// finding physical index for each logical index using binary search using the function + /// `get_physical_index`. Running benchmarks on both approaches showed that the approach used here + /// scaled well for larger inputs. + /// See for more details. + #[inline] + pub fn get_physical_indices(&self, logical_indices: &[I]) -> Result, ArrowError> + where + I: ArrowNativeType, + { + let len = self.run_ends().len(); + let offset = self.run_ends().offset(); + + let indices_len = logical_indices.len(); + + if indices_len == 0 { + return Ok(vec![]); + } + + // `ordered_indices` store index into `logical_indices` and can be used + // to iterate `logical_indices` in sorted order. + let mut ordered_indices: Vec = (0..indices_len).collect(); + + // Instead of sorting `logical_indices` directly, sort the `ordered_indices` + // whose values are index of `logical_indices` + ordered_indices.sort_unstable_by(|lhs, rhs| { + logical_indices[*lhs] + .partial_cmp(&logical_indices[*rhs]) + .unwrap() + }); + + // Return early if all the logical indices cannot be converted to physical indices. + let largest_logical_index = logical_indices[*ordered_indices.last().unwrap()].as_usize(); + if largest_logical_index >= len { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {largest_logical_index}.", + ))); + } + + // Skip some physical indices based on offset. + let skip_value = self.get_start_physical_index(); + + let mut physical_indices = vec![0; indices_len]; + + let mut ordered_index = 0_usize; + for (physical_index, run_end) in self.run_ends.values().iter().enumerate().skip(skip_value) + { + // Get the run end index (relative to offset) of current physical index + let run_end_value = run_end.as_usize() - offset; + + // All the `logical_indices` that are less than current run end index + // belongs to current physical index. + while ordered_index < indices_len + && logical_indices[ordered_indices[ordered_index]].as_usize() < run_end_value + { + physical_indices[ordered_indices[ordered_index]] = physical_index; + ordered_index += 1; + } + } + + // If there are input values >= run_ends.last_value then we'll not be able to convert + // all logical indices to physical indices. + if ordered_index < logical_indices.len() { + let logical_index = logical_indices[ordered_indices[ordered_index]].as_usize(); + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {logical_index}.", + ))); + } + Ok(physical_indices) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + data_type: self.data_type.clone(), + run_ends: self.run_ends.slice(offset, length), + values: self.values.clone(), + } + } +} + +impl From for RunArray { + // The method assumes the caller already validated the data using `ArrayData::validate_data()` + fn from(data: ArrayData) -> Self { + match data.data_type() { + DataType::RunEndEncoded(_, _) => {} + _ => { + panic!("Invalid data type for RunArray. The data type should be DataType::RunEndEncoded"); + } + } + + // Safety + // ArrayData is valid + let child = &data.child_data()[0]; + assert_eq!(child.data_type(), &R::DATA_TYPE, "Incorrect run ends type"); + let run_ends = unsafe { + let scalar = child.buffers()[0].clone().into(); + RunEndBuffer::new_unchecked(scalar, data.offset(), data.len()) + }; + + let values = make_array(data.child_data()[1].clone()); + Self { + data_type: data.data_type().clone(), + run_ends, + values, + } + } +} + +impl From> for ArrayData { + fn from(array: RunArray) -> Self { + let len = array.run_ends.len(); + let offset = array.run_ends.offset(); + + let run_ends = ArrayDataBuilder::new(R::DATA_TYPE) + .len(array.run_ends.values().len()) + .buffers(vec![array.run_ends.into_inner().into_inner()]); + + let run_ends = unsafe { run_ends.build_unchecked() }; + + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .offset(offset) + .child_data(vec![run_ends, array.values.to_data()]); + + unsafe { builder.build_unchecked() } + } +} + +impl Array for RunArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.run_ends.len() + } + + fn is_empty(&self) -> bool { + self.run_ends.is_empty() + } + + fn offset(&self) -> usize { + self.run_ends.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + None + } + + fn logical_nulls(&self) -> Option { + let len = self.len(); + let nulls = self.values.logical_nulls()?; + let mut out = BooleanBufferBuilder::new(len); + let offset = self.run_ends.offset(); + let mut valid_start = 0; + let mut last_end = 0; + for (idx, end) in self.run_ends.values().iter().enumerate() { + let end = end.as_usize(); + if end < offset { + continue; + } + let end = (end - offset).min(len); + if nulls.is_null(idx) { + if valid_start < last_end { + out.append_n(last_end - valid_start, true); + } + out.append_n(end - last_end, false); + valid_start = end; + } + last_end = end; + if end == len { + break; + } + } + if valid_start < len { + out.append_n(len - valid_start, true) + } + // Sanity check + assert_eq!(out.len(), len); + Some(out.finish().into()) + } + + fn is_nullable(&self) -> bool { + !self.is_empty() && self.values.is_nullable() + } + + fn get_buffer_memory_size(&self) -> usize { + self.run_ends.inner().inner().capacity() + self.values.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + std::mem::size_of::() + + self.run_ends.inner().inner().capacity() + + self.values.get_array_memory_size() + } +} + +impl std::fmt::Debug for RunArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!( + f, + "RunArray {{run_ends: {:?}, values: {:?}}}", + self.run_ends.values(), + self.values + ) + } +} + +/// Constructs a `RunArray` from an iterator of optional strings. +/// +/// # Example: +/// ``` +/// use arrow_array::{RunArray, PrimitiveArray, StringArray, types::Int16Type}; +/// +/// let test = vec!["a", "a", "b", "c", "c"]; +/// let array: RunArray = test +/// .iter() +/// .map(|&x| if x == "b" { None } else { Some(x) }) +/// .collect(); +/// assert_eq!( +/// "RunArray {run_ends: [2, 3, 5], values: StringArray\n[\n \"a\",\n null,\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: RunEndIndexType> FromIterator> for RunArray { + fn from_iter>>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringRunBuilder::with_capacity(lower, 256); + it.for_each(|i| { + builder.append_option(i); + }); + + builder.finish() + } +} + +/// Constructs a `RunArray` from an iterator of strings. +/// +/// # Example: +/// +/// ``` +/// use arrow_array::{RunArray, PrimitiveArray, StringArray, types::Int16Type}; +/// +/// let test = vec!["a", "a", "b", "c"]; +/// let array: RunArray = test.into_iter().collect(); +/// assert_eq!( +/// "RunArray {run_ends: [2, 3, 4], values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", +/// format!("{:?}", array) +/// ); +/// ``` +impl<'a, T: RunEndIndexType> FromIterator<&'a str> for RunArray { + fn from_iter>(iter: I) -> Self { + let it = iter.into_iter(); + let (lower, _) = it.size_hint(); + let mut builder = StringRunBuilder::with_capacity(lower, 256); + it.for_each(|i| { + builder.append_value(i); + }); + + builder.finish() + } +} + +/// +/// A [`RunArray`] with `i16` run ends +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int16RunArray, Int16Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int16RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.run_ends().values(), &[2, 3, 5]); +/// assert_eq!(array.values(), &values); +/// ``` +pub type Int16RunArray = RunArray; + +/// +/// A [`RunArray`] with `i32` run ends +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int32RunArray, Int32Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int32RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.run_ends().values(), &[2, 3, 5]); +/// assert_eq!(array.values(), &values); +/// ``` +pub type Int32RunArray = RunArray; + +/// +/// A [`RunArray`] with `i64` run ends +/// +/// # Example: Using `collect` +/// ``` +/// # use arrow_array::{Array, Int64RunArray, Int64Array, StringArray}; +/// # use std::sync::Arc; +/// +/// let array: Int64RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); +/// let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); +/// assert_eq!(array.run_ends().values(), &[2, 3, 5]); +/// assert_eq!(array.values(), &values); +/// ``` +pub type Int64RunArray = RunArray; + +/// A [`RunArray`] typed typed on its child values array +/// +/// Implements [`ArrayAccessor`] and [`IntoIterator`] allowing fast access to its elements +/// +/// ``` +/// use arrow_array::{RunArray, StringArray, types::Int32Type}; +/// +/// let orig = ["a", "b", "a", "b"]; +/// let ree_array = RunArray::::from_iter(orig); +/// +/// // `TypedRunArray` allows you to access the values directly +/// let typed = ree_array.downcast::().unwrap(); +/// +/// for (maybe_val, orig) in typed.into_iter().zip(orig) { +/// assert_eq!(maybe_val.unwrap(), orig) +/// } +/// ``` +pub struct TypedRunArray<'a, R: RunEndIndexType, V> { + /// The run array + run_array: &'a RunArray, + + /// The values of the run_array + values: &'a V, +} + +// Manually implement `Clone` to avoid `V: Clone` type constraint +impl<'a, R: RunEndIndexType, V> Clone for TypedRunArray<'a, R, V> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, R: RunEndIndexType, V> Copy for TypedRunArray<'a, R, V> {} + +impl<'a, R: RunEndIndexType, V> std::fmt::Debug for TypedRunArray<'a, R, V> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f, "TypedRunArray({:?})", self.run_array) + } +} + +impl<'a, R: RunEndIndexType, V> TypedRunArray<'a, R, V> { + /// Returns the run_ends of this [`TypedRunArray`] + pub fn run_ends(&self) -> &'a RunEndBuffer { + self.run_array.run_ends() + } + + /// Returns the values of this [`TypedRunArray`] + pub fn values(&self) -> &'a V { + self.values + } + + /// Returns the run array of this [`TypedRunArray`] + pub fn run_array(&self) -> &'a RunArray { + self.run_array + } +} + +impl<'a, R: RunEndIndexType, V: Sync> Array for TypedRunArray<'a, R, V> { + fn as_any(&self) -> &dyn Any { + self.run_array + } + + fn to_data(&self) -> ArrayData { + self.run_array.to_data() + } + + fn into_data(self) -> ArrayData { + self.run_array.into_data() + } + + fn data_type(&self) -> &DataType { + self.run_array.data_type() + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.run_array.slice(offset, length)) + } + + fn len(&self) -> usize { + self.run_array.len() + } + + fn is_empty(&self) -> bool { + self.run_array.is_empty() + } + + fn offset(&self) -> usize { + self.run_array.offset() + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.run_array.nulls() + } + + fn logical_nulls(&self) -> Option { + self.run_array.logical_nulls() + } + + fn is_nullable(&self) -> bool { + self.run_array.is_nullable() + } + + fn get_buffer_memory_size(&self) -> usize { + self.run_array.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.run_array.get_array_memory_size() + } +} + +// Array accessor converts the index of logical array to the index of the physical array +// using binary search. The time complexity is O(log N) where N is number of runs. +impl<'a, R, V> ArrayAccessor for TypedRunArray<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = <&'a V as ArrayAccessor>::Item; + + fn value(&self, logical_index: usize) -> Self::Item { + assert!( + logical_index < self.len(), + "Trying to access an element at index {} from a TypedRunArray of length {}", + logical_index, + self.len() + ); + unsafe { self.value_unchecked(logical_index) } + } + + unsafe fn value_unchecked(&self, logical_index: usize) -> Self::Item { + let physical_index = self.run_array.get_physical_index(logical_index); + self.values().value_unchecked(physical_index) + } +} + +impl<'a, R, V> IntoIterator for TypedRunArray<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = Option<<&'a V as ArrayAccessor>::Item>; + type IntoIter = RunArrayIter<'a, R, V>; + + fn into_iter(self) -> Self::IntoIter { + RunArrayIter::new(self) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rand::seq::SliceRandom; + use rand::thread_rng; + use rand::Rng; + + use super::*; + use crate::builder::PrimitiveRunBuilder; + use crate::cast::AsArray; + use crate::types::{Int16Type, Int32Type, Int8Type, UInt32Type}; + use crate::{Array, Int32Array, StringArray}; + + fn build_input_array(size: usize) -> Vec> { + // The input array is created by shuffling and repeating + // the seed values random number of times. + let mut seed: Vec> = vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ]; + let mut result: Vec> = Vec::with_capacity(size); + let mut ix = 0; + let mut rng = thread_rng(); + // run length can go up to 8. Cap the max run length for smaller arrays to size / 2. + let max_run_length = 8_usize.min(1_usize.max(size / 2)); + while result.len() < size { + // shuffle the seed array if all the values are iterated. + if ix == 0 { + seed.shuffle(&mut rng); + } + // repeat the items between 1 and 8 times. Cap the length for smaller sized arrays + let num = max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); + for _ in 0..num { + result.push(seed[ix]); + } + ix += 1; + if ix == seed.len() { + ix = 0 + } + } + result.resize(size, None); + result + } + + // Asserts that `logical_array[logical_indices[*]] == physical_array[physical_indices[*]]` + fn compare_logical_and_physical_indices( + logical_indices: &[u32], + logical_array: &[Option], + physical_indices: &[usize], + physical_array: &PrimitiveArray, + ) { + assert_eq!(logical_indices.len(), physical_indices.len()); + + // check value in logical index in the logical_array matches physical index in physical_array + logical_indices + .iter() + .map(|f| f.as_usize()) + .zip(physical_indices.iter()) + .for_each(|(logical_ix, physical_ix)| { + let expected = logical_array[logical_ix]; + match expected { + Some(val) => { + assert!(physical_array.is_valid(*physical_ix)); + let actual = physical_array.value(*physical_ix); + assert_eq!(val, actual); + } + None => { + assert!(physical_array.is_null(*physical_ix)) + } + }; + }); + } + #[test] + fn test_run_array() { + // Construct a value array + let value_data = + PrimitiveArray::::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + + // Construct a run_ends array: + let run_ends_values = [4_i16, 6, 7, 9, 13, 18, 20, 22]; + let run_ends_data = + PrimitiveArray::::from_iter_values(run_ends_values.iter().copied()); + + // Construct a run ends encoded array from the above two + let ree_array = RunArray::::try_new(&run_ends_data, &value_data).unwrap(); + + assert_eq!(ree_array.len(), 22); + assert_eq!(ree_array.null_count(), 0); + + let values = ree_array.values(); + assert_eq!(value_data.into_data(), values.to_data()); + assert_eq!(&DataType::Int8, values.data_type()); + + let run_ends = ree_array.run_ends(); + assert_eq!(run_ends.values(), &run_ends_values); + } + + #[test] + fn test_run_array_fmt_debug() { + let mut builder = PrimitiveRunBuilder::::with_capacity(3); + builder.append_value(12345678); + builder.append_null(); + builder.append_value(22345678); + let array = builder.finish(); + assert_eq!( + "RunArray {run_ends: [1, 2, 3], values: PrimitiveArray\n[\n 12345678,\n null,\n 22345678,\n]}\n", + format!("{array:?}") + ); + + let mut builder = PrimitiveRunBuilder::::with_capacity(20); + for _ in 0..20 { + builder.append_value(1); + } + let array = builder.finish(); + + assert_eq!(array.len(), 20); + assert_eq!(array.null_count(), 0); + + assert_eq!( + "RunArray {run_ends: [20], values: PrimitiveArray\n[\n 1,\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_run_array_from_iter() { + let test = vec!["a", "a", "b", "c"]; + let array: RunArray = test + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + assert_eq!( + "RunArray {run_ends: [2, 3, 4], values: StringArray\n[\n \"a\",\n null,\n \"c\",\n]}\n", + format!("{array:?}") + ); + + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + + let array: RunArray = test.into_iter().collect(); + assert_eq!( + "RunArray {run_ends: [2, 3, 4], values: StringArray\n[\n \"a\",\n \"b\",\n \"c\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_run_array_run_ends_as_primitive_array() { + let test = vec!["a", "b", "c", "a"]; + let array: RunArray = test.into_iter().collect(); + + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + + let run_ends = array.run_ends(); + assert_eq!(&[1, 2, 3, 4], run_ends.values()); + } + + #[test] + fn test_run_array_as_primitive_array_with_null() { + let test = vec![Some("a"), None, Some("b"), None, None, Some("a")]; + let array: RunArray = test.into_iter().collect(); + + assert_eq!(array.len(), 6); + assert_eq!(array.null_count(), 0); + + let run_ends = array.run_ends(); + assert_eq!(&[1, 2, 3, 5, 6], run_ends.values()); + + let values_data = array.values(); + assert_eq!(2, values_data.null_count()); + assert_eq!(5, values_data.len()); + } + + #[test] + fn test_run_array_all_nulls() { + let test = vec![None, None, None]; + let array: RunArray = test.into_iter().collect(); + + assert_eq!(array.len(), 3); + assert_eq!(array.null_count(), 0); + + let run_ends = array.run_ends(); + assert_eq!(3, run_ends.len()); + assert_eq!(&[3], run_ends.values()); + + let values_data = array.values(); + assert_eq!(1, values_data.null_count()); + } + + #[test] + fn test_run_array_try_new() { + let values: StringArray = [Some("foo"), Some("bar"), None, Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), Some(2), Some(3), Some(4)].into_iter().collect(); + + let array = RunArray::::try_new(&run_ends, &values).unwrap(); + assert_eq!(array.values().data_type(), &DataType::Utf8); + + assert_eq!(array.null_count(), 0); + assert_eq!(array.len(), 4); + assert_eq!(array.values().null_count(), 1); + + assert_eq!( + "RunArray {run_ends: [1, 2, 3, 4], values: StringArray\n[\n \"foo\",\n \"bar\",\n null,\n \"baz\",\n]}\n", + format!("{array:?}") + ); + } + + #[test] + fn test_run_array_int16_type_definition() { + let array: Int16RunArray = vec!["a", "a", "b", "c", "c"].into_iter().collect(); + let values: Arc = Arc::new(StringArray::from(vec!["a", "b", "c"])); + assert_eq!(array.run_ends().values(), &[2, 3, 5]); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_run_array_empty_string() { + let array: Int16RunArray = vec!["a", "a", "", "", "c"].into_iter().collect(); + let values: Arc = Arc::new(StringArray::from(vec!["a", "", "c"])); + assert_eq!(array.run_ends().values(), &[2, 4, 5]); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_run_array_length_mismatch() { + let values: StringArray = [Some("foo"), Some("bar"), None, Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), Some(2), Some(3)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError("The run_ends array length should be the same as values array length. Run_ends array length is 3, values array length is 4".to_string()); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + fn test_run_array_run_ends_with_null() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), None, Some(3)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError( + "Found null values in run_ends array. The run_ends array should not have null values." + .to_string(), + ); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + fn test_run_array_run_ends_with_zeroes() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(0), Some(1), Some(3)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError("The values in run_ends array should be strictly positive. Found value 0 at index 0 that does not match the criteria.".to_string()); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + fn test_run_array_run_ends_non_increasing() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let run_ends: Int32Array = [Some(1), Some(4), Some(4)].into_iter().collect(); + + let actual = RunArray::::try_new(&run_ends, &values); + let expected = ArrowError::InvalidArgumentError("The values in run_ends array should be strictly increasing. Found value 4 at index 2 with previous value 4 that does not match the criteria.".to_string()); + assert_eq!(expected.to_string(), actual.err().unwrap().to_string()); + } + + #[test] + #[should_panic(expected = "Incorrect run ends type")] + fn test_run_array_run_ends_data_type_mismatch() { + let a = RunArray::::from_iter(["32"]); + let _ = RunArray::::from(a.into_data()); + } + + #[test] + fn test_ree_array_accessor() { + let input_array = build_input_array(256); + + // Encode the input_array to ree_array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + let typed = run_array.downcast::>().unwrap(); + + // Access every index and check if the value in the input array matches returned value. + for (i, inp_val) in input_array.iter().enumerate() { + if let Some(val) = inp_val { + let actual = typed.value(i); + assert_eq!(*val, actual) + } else { + let physical_ix = run_array.get_physical_index(i); + assert!(typed.values().is_null(physical_ix)); + }; + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_get_physical_indices() { + // Test for logical lengths starting from 10 to 250 increasing by 10 + for logical_len in (0..250).step_by(10) { + let input_array = build_input_array(logical_len); + + // create run array using input_array + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend(input_array.clone().into_iter()); + + let run_array = builder.finish(); + let physical_values_array = run_array.values().as_primitive::(); + + // create an array consisting of all the indices repeated twice and shuffled. + let mut logical_indices: Vec = (0_u32..(logical_len as u32)).collect(); + // add same indices once more + logical_indices.append(&mut logical_indices.clone()); + let mut rng = thread_rng(); + logical_indices.shuffle(&mut rng); + + let physical_indices = run_array.get_physical_indices(&logical_indices).unwrap(); + + assert_eq!(logical_indices.len(), physical_indices.len()); + + // check value in logical index in the input_array matches physical index in typed_run_array + compare_logical_and_physical_indices( + &logical_indices, + &input_array, + &physical_indices, + physical_values_array, + ); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_get_physical_indices_sliced() { + let total_len = 80; + let input_array = build_input_array(total_len); + + // Encode the input_array to run array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + let physical_values_array = run_array.values().as_primitive::(); + + // test for all slice lengths. + for slice_len in 1..=total_len { + // create an array consisting of all the indices repeated twice and shuffled. + let mut logical_indices: Vec = (0_u32..(slice_len as u32)).collect(); + // add same indices once more + logical_indices.append(&mut logical_indices.clone()); + let mut rng = thread_rng(); + logical_indices.shuffle(&mut rng); + + // test for offset = 0 and slice length = slice_len + // slice the input array using which the run array was built. + let sliced_input_array = &input_array[0..slice_len]; + + // slice the run array + let sliced_run_array: RunArray = + run_array.slice(0, slice_len).into_data().into(); + + // Get physical indices. + let physical_indices = sliced_run_array + .get_physical_indices(&logical_indices) + .unwrap(); + + compare_logical_and_physical_indices( + &logical_indices, + sliced_input_array, + &physical_indices, + physical_values_array, + ); + + // test for offset = total_len - slice_len and slice length = slice_len + // slice the input array using which the run array was built. + let sliced_input_array = &input_array[total_len - slice_len..total_len]; + + // slice the run array + let sliced_run_array: RunArray = run_array + .slice(total_len - slice_len, slice_len) + .into_data() + .into(); + + // Get physical indices + let physical_indices = sliced_run_array + .get_physical_indices(&logical_indices) + .unwrap(); + + compare_logical_and_physical_indices( + &logical_indices, + sliced_input_array, + &physical_indices, + physical_values_array, + ); + } + } + + #[test] + fn test_logical_nulls() { + let run = Int32Array::from(vec![3, 6, 9, 12]); + let values = Int32Array::from(vec![Some(0), None, Some(1), None]); + let array = RunArray::try_new(&run, &values).unwrap(); + + let expected = [ + true, true, true, false, false, false, true, true, true, false, false, false, + ]; + + let n = array.logical_nulls().unwrap(); + assert_eq!(n.null_count(), 6); + + let slices = [(0, 12), (0, 2), (2, 5), (3, 0), (3, 3), (3, 4), (4, 8)]; + for (offset, length) in slices { + let a = array.slice(offset, length); + let n = a.logical_nulls().unwrap(); + let n = n.into_iter().collect::>(); + assert_eq!(&n, &expected[offset..offset + length], "{offset} {length}"); + } + } +} diff --git a/arrow/src/array/array_string.rs b/arrow-array/src/array/string_array.rs similarity index 55% rename from arrow/src/array/array_string.rs rename to arrow-array/src/array/string_array.rs index 62743a20a119..9d266e0ca4b8 100644 --- a/arrow/src/array/array_string.rs +++ b/arrow-array/src/array/string_array.rs @@ -15,69 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::convert::From; -use std::fmt; -use std::{any::Any, iter::FromIterator}; - -use super::{ - array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, - GenericBinaryArray, GenericListArray, GenericStringIter, OffsetSizeTrait, -}; -use crate::array::array::ArrayAccessor; -use crate::buffer::Buffer; -use crate::util::bit_util; -use crate::{buffer::MutableBuffer, datatypes::DataType}; - -/// Generic struct for \[Large\]StringArray -/// -/// See [`StringArray`] and [`LargeStringArray`] for storing -/// specific string data. -pub struct GenericStringArray { - data: ArrayData, - value_offsets: RawPtrBox, - value_data: RawPtrBox, -} +use crate::types::GenericStringType; +use crate::{GenericBinaryArray, GenericByteArray, GenericListArray, OffsetSizeTrait}; +use arrow_schema::{ArrowError, DataType}; -impl GenericStringArray { - /// Data type of the array. - pub const DATA_TYPE: DataType = if OffsetSize::IS_LARGE { - DataType::LargeUtf8 - } else { - DataType::Utf8 - }; +/// A [`GenericByteArray`] for storing `str` +pub type GenericStringArray = GenericByteArray>; +impl GenericStringArray { /// Get the data type of the array. #[deprecated(note = "please use `Self::DATA_TYPE` instead")] pub const fn get_data_type() -> DataType { Self::DATA_TYPE } - /// Returns the length for the element at index `i`. - #[inline] - pub fn value_length(&self, i: usize) -> OffsetSize { - let offsets = self.value_offsets(); - offsets[i + 1] - offsets[i] - } - - /// Returns the offset values in the offsets buffer - #[inline] - pub fn value_offsets(&self) -> &[OffsetSize] { - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the ArrayData instance. - unsafe { - std::slice::from_raw_parts( - self.value_offsets.as_ptr().add(self.data.offset()), - self.len() + 1, - ) - } - } - - /// Returns a clone of the value data buffer - pub fn value_data(&self) -> Buffer { - self.data.buffers()[1].clone() - } - /// Returns the number of `Unicode Scalar Value` in the string at index `i`. /// # Performance /// This function has `O(n)` time complexity where `n` is the string length. @@ -87,123 +38,6 @@ impl GenericStringArray { self.value(i).chars().count() } - /// Returns the element at index - /// # Safety - /// caller is responsible for ensuring that index is within the array bounds - #[inline] - pub unsafe fn value_unchecked(&self, i: usize) -> &str { - let end = self.value_offsets().get_unchecked(i + 1); - let start = self.value_offsets().get_unchecked(i); - - // Soundness - // pointer alignment & location is ensured by RawPtrBox - // buffer bounds/offset is ensured by the value_offset invariants - // ISSUE: utf-8 well formedness is not checked - - // Safety of `to_isize().unwrap()` - // `start` and `end` are &OffsetSize, which is a generic type that implements the - // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, - // both of which should cleanly cast to isize on an architecture that supports - // 32/64-bit offsets - let slice = std::slice::from_raw_parts( - self.value_data.as_ptr().offset(start.to_isize().unwrap()), - (*end - *start).to_usize().unwrap(), - ); - std::str::from_utf8_unchecked(slice) - } - - /// Returns the element at index `i` as &str - /// # Panics - /// Panics if index `i` is out of bounds. - #[inline] - pub fn value(&self, i: usize) -> &str { - assert!( - i < self.data.len(), - "Trying to access an element at index {} from a StringArray of length {}", - i, - self.len() - ); - // Safety: - // `i < self.data.len() - unsafe { self.value_unchecked(i) } - } - - /// Convert a list array to a string array. - /// This method is unsound because it does - /// not check the utf-8 validation for each element. - fn from_list(v: GenericListArray) -> Self { - assert_eq!( - v.data_ref().child_data().len(), - 1, - "StringArray can only be created from list array of u8 values \ - (i.e. List>)." - ); - let child_data = &v.data_ref().child_data()[0]; - - assert_eq!( - child_data.child_data().len(), - 0, - "StringArray can only be created from list array of u8 values \ - (i.e. List>)." - ); - assert_eq!( - child_data.data_type(), - &DataType::UInt8, - "StringArray can only be created from List arrays, mismatched data types." - ); - assert_eq!( - child_data.null_count(), - 0, - "The child array cannot contain null values." - ); - - let builder = ArrayData::builder(Self::DATA_TYPE) - .len(v.len()) - .offset(v.offset()) - .add_buffer(v.data().buffers()[0].clone()) - .add_buffer(child_data.buffers()[0].slice(child_data.offset())) - .null_bit_buffer(v.data().null_buffer().cloned()); - - let array_data = unsafe { builder.build_unchecked() }; - Self::from(array_data) - } - - /// Creates a [`GenericStringArray`] based on an iterator of values without nulls - pub fn from_iter_values(iter: I) -> Self - where - Ptr: AsRef, - I: IntoIterator, - { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let mut offsets = - MutableBuffer::new((data_len + 1) * std::mem::size_of::()); - let mut values = MutableBuffer::new(0); - - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - for i in iter { - let s = i.as_ref(); - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - offsets.push(length_so_far); - values.extend_from_slice(s.as_bytes()); - } - - // iterator size hint may not be correct so compute the actual number of offsets - assert!(!offsets.is_empty()); // wrote at least one - let actual_len = (offsets.len() / std::mem::size_of::()) - 1; - - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(actual_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) - } - /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` pub fn take_iter<'a>( &'a self, @@ -222,120 +56,12 @@ impl GenericStringArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) } -} - -impl<'a, Ptr, OffsetSize: OffsetSizeTrait> FromIterator<&'a Option> - for GenericStringArray -where - Ptr: AsRef + 'a, -{ - /// Creates a [`GenericStringArray`] based on an iterator of `Option` references. - fn from_iter>>(iter: I) -> Self { - // Convert each owned Ptr into &str and wrap in an owned `Option` - let iter = iter.into_iter().map(|o| o.as_ref().map(|p| p.as_ref())); - // Build a `GenericStringArray` with the resulting iterator - iter.collect::>() - } -} - -impl FromIterator> - for GenericStringArray -where - Ptr: AsRef, -{ - /// Creates a [`GenericStringArray`] based on an iterator of [`Option`]s - fn from_iter>>(iter: I) -> Self { - let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. - - let offset_size = std::mem::size_of::(); - let mut offsets = MutableBuffer::new((data_len + 1) * offset_size); - let mut values = MutableBuffer::new(0); - let mut null_buf = MutableBuffer::new_null(data_len); - let null_slice = null_buf.as_slice_mut(); - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - for (i, s) in iter.enumerate() { - let value_bytes = if let Some(ref s) = s { - // set null bit - bit_util::set_bit(null_slice, i); - let s_bytes = s.as_ref().as_bytes(); - length_so_far += OffsetSize::from_usize(s_bytes.len()).unwrap(); - s_bytes - } else { - b"" - }; - values.extend_from_slice(value_bytes); - offsets.push(length_so_far); - } - - // calculate actual data_len, which may be different from the iterator's upper bound - let data_len = (offsets.len() / offset_size) - 1; - let array_data = ArrayData::builder(Self::DATA_TYPE) - .len(data_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()) - .null_bit_buffer(Some(null_buf.into())); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) - } -} - -impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericStringArray { - type Item = Option<&'a str>; - type IntoIter = GenericStringIter<'a, T>; - - fn into_iter(self) -> Self::IntoIter { - GenericStringIter::<'a, T>::new(self) - } -} - -impl<'a, T: OffsetSizeTrait> GenericStringArray { - /// constructs a new iterator - pub fn iter(&'a self) -> GenericStringIter<'a, T> { - GenericStringIter::<'a, T>::new(self) - } -} - -impl fmt::Debug for GenericStringArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = OffsetSize::PREFIX; - - write!(f, "{}StringArray\n[\n", prefix)?; - print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) - })?; - write!(f, "]") - } -} - -impl Array for GenericStringArray { - fn as_any(&self) -> &dyn Any { - self - } - - fn data(&self) -> &ArrayData { - &self.data - } - - fn into_data(self) -> ArrayData { - self.into() - } -} - -impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor - for &'a GenericStringArray -{ - type Item = &'a str; - - fn value(&self, index: usize) -> Self::Item { - GenericStringArray::value(self, index) - } - unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - GenericStringArray::value_unchecked(self, index) + /// Fallibly creates a [`GenericStringArray`] from a [`GenericBinaryArray`] returning + /// an error if [`GenericBinaryArray`] contains invalid UTF-8 data + pub fn try_from_binary(v: GenericBinaryArray) -> Result { + let (offsets, values, nulls) = v.into_parts(); + Self::try_new(offsets, values, nulls) } } @@ -343,7 +69,7 @@ impl From> for GenericStringArray { fn from(v: GenericListArray) -> Self { - GenericStringArray::::from_list(v) + GenericBinaryArray::::from(v).into() } } @@ -351,37 +77,11 @@ impl From> for GenericStringArray { fn from(v: GenericBinaryArray) -> Self { - let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE); - let data = unsafe { builder.build_unchecked() }; - Self::from(data) + Self::try_from_binary(v).unwrap() } } -impl From for GenericStringArray { - fn from(data: ArrayData) -> Self { - assert_eq!( - data.data_type(), - &Self::DATA_TYPE, - "[Large]StringArray expects Datatype::[Large]Utf8" - ); - assert_eq!( - data.buffers().len(), - 2, - "StringArray data should contain 2 buffers only (offsets and values)" - ); - let offsets = data.buffers()[0].as_ptr(); - let values = data.buffers()[1].as_ptr(); - Self { - data, - value_offsets: unsafe { RawPtrBox::new(offsets) }, - value_data: unsafe { RawPtrBox::new(values) }, - } - } -} - -impl From>> - for GenericStringArray -{ +impl From>> for GenericStringArray { fn from(v: Vec>) -> Self { v.into_iter().collect() } @@ -393,51 +93,82 @@ impl From> for GenericStringArray From> for GenericStringArray { - fn from(v: Vec) -> Self { - Self::from_iter_values(v) +impl From>> for GenericStringArray { + fn from(v: Vec>) -> Self { + v.into_iter().collect() } } -impl From> for ArrayData { - fn from(array: GenericStringArray) -> Self { - array.data +impl From> for GenericStringArray { + fn from(v: Vec) -> Self { + Self::from_iter_values(v) } } -/// An array where each element is a variable-sized sequence of bytes representing a string -/// whose maximum length (in bytes) is represented by a i32. +/// A [`GenericStringArray`] of `str` using `i32` offsets /// -/// Example +/// # Examples /// +/// Construction +/// +/// ``` +/// # use arrow_array::StringArray; +/// // Create from Vec> +/// let arr = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); +/// // Create from Vec<&str> +/// let arr = StringArray::from(vec!["foo", "bar", "baz"]); +/// // Create from iter/collect (requires Option<&str>) +/// let arr: StringArray = std::iter::repeat(Some("foo")).take(10).collect(); /// ``` -/// use arrow::array::StringArray; +/// +/// Construction and Access +/// +/// ``` +/// # use arrow_array::StringArray; /// let array = StringArray::from(vec![Some("foo"), None, Some("bar")]); /// assert_eq!(array.value(0), "foo"); /// ``` +/// +/// See [`GenericByteArray`] for more information and examples pub type StringArray = GenericStringArray; -/// An array where each element is a variable-sized sequence of bytes representing a string -/// whose maximum length (in bytes) is represented by a i64. +/// A [`GenericStringArray`] of `str` using `i64` offsets +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::LargeStringArray; +/// // Create from Vec> +/// let arr = LargeStringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); +/// // Create from Vec<&str> +/// let arr = LargeStringArray::from(vec!["foo", "bar", "baz"]); +/// // Create from iter/collect (requires Option<&str>) +/// let arr: LargeStringArray = std::iter::repeat(Some("foo")).take(10).collect(); +/// ``` /// -/// Example +/// Construction and Access /// /// ``` -/// use arrow::array::LargeStringArray; +/// use arrow_array::LargeStringArray; /// let array = LargeStringArray::from(vec![Some("foo"), None, Some("bar")]); /// assert_eq!(array.value(2), "bar"); /// ``` +/// +/// See [`GenericByteArray`] for more information and examples pub type LargeStringArray = GenericStringArray; #[cfg(test)] mod tests { - - use crate::{ - array::{ListBuilder, StringBuilder}, - datatypes::Field, - }; - use super::*; + use crate::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; + use crate::types::UInt8Type; + use crate::Array; + use arrow_buffer::Buffer; + use arrow_data::ArrayData; + use arrow_schema::Field; + use std::sync::Arc; #[test] fn test_string_array_from_u8_slice() { @@ -465,7 +196,7 @@ mod tests { } #[test] - #[should_panic(expected = "[Large]StringArray expects Datatype::[Large]Utf8")] + #[should_panic(expected = "StringArray expects DataType::Utf8")] fn test_string_array_from_int() { let array = LargeStringArray::from(vec!["a", "b"]); drop(StringArray::from(array.into_data())); @@ -538,8 +269,8 @@ mod tests { let offsets: [i32; 4] = [0, 5, 5, 12]; let array_data = ArrayData::builder(DataType::Utf8) .len(3) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_slice_ref(values)) .build() .unwrap(); let string_array = StringArray::from(array_data); @@ -551,7 +282,7 @@ mod tests { let arr: StringArray = vec!["hello", "arrow"].into(); assert_eq!( "StringArray\n[\n \"hello\",\n \"arrow\",\n]", - format!("{:?}", arr) + format!("{arr:?}") ); } @@ -560,7 +291,7 @@ mod tests { let arr: LargeStringArray = vec!["hello", "arrow"].into(); assert_eq!( "LargeStringArray\n[\n \"hello\",\n \"arrow\",\n]", - format!("{:?}", arr) + format!("{arr:?}") ); } @@ -587,11 +318,18 @@ mod tests { #[test] fn test_string_array_from_iter_values() { - let data = vec!["hello", "hello2"]; + let data = ["hello", "hello2"]; let array1 = StringArray::from_iter_values(data.iter()); assert_eq!(array1.value(0), "hello"); assert_eq!(array1.value(1), "hello2"); + + // Also works with String types. + let data2 = ["goodbye".to_string(), "goodbye2".to_string()]; + let array2 = StringArray::from_iter_values(data2.iter()); + + assert_eq!(array2.value(0), "goodbye"); + assert_eq!(array2.value(1), "goodbye2"); } #[test] @@ -601,7 +339,7 @@ mod tests { .scan(0usize, |pos, i| { if *pos < 10 { *pos += 1; - Some(Some(format!("value {}", i))) + Some(Some(format!("value {i}"))) } else { // actually returns up to 10 values None @@ -620,20 +358,20 @@ mod tests { #[test] fn test_string_array_all_null() { - let data = vec![None]; + let data: Vec> = vec![None]; let array = StringArray::from(data); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } #[test] fn test_large_string_array_all_null() { - let data = vec![None]; + let data: Vec> = vec![None]; let array = LargeStringArray::from(data); array - .data() + .into_data() .validate_full() .expect("All null array has valid array data"); } @@ -694,13 +432,11 @@ mod tests { let expected: LargeStringArray = data.clone().into_iter().map(Some).collect(); // Iterator reports too many items - let arr = - LargeStringArray::from_iter_values(BadIterator::new(3, 10, data.clone())); + let arr = LargeStringArray::from_iter_values(BadIterator::new(3, 10, data.clone())); assert_eq!(expected, arr); // Iterator reports too few items - let arr = - LargeStringArray::from_iter_values(BadIterator::new(3, 1, data.clone())); + let arr = LargeStringArray::from_iter_values(BadIterator::new(3, 1, data.clone())); assert_eq!(expected, arr); } @@ -715,16 +451,18 @@ mod tests { .unwrap(); let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); - let null_buffer = Buffer::from_slice_ref(&[0b101]); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + let null_buffer = Buffer::from_slice_ref([0b101]); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + false, + ))); // [None, Some("Parquet")] let array_data = ArrayData::builder(data_type) .len(2) .offset(1) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .null_bit_buffer(Some(null_buffer)) .add_child_data(child_data) .build() @@ -749,26 +487,29 @@ mod tests { _test_generic_string_array_from_list_array::(); } - fn _test_generic_string_array_from_list_array_with_child_nulls_failed< - O: OffsetSizeTrait, - >() { + fn _test_generic_string_array_from_list_array_with_child_nulls_failed() { let values = b"HelloArrow"; let child_data = ArrayData::builder(DataType::UInt8) .len(10) .add_buffer(Buffer::from(&values[..])) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b1010101010]))) + .null_bit_buffer(Some(Buffer::from_slice_ref([0b1010101010]))) .build() .unwrap(); let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), - )); + + // It is possible to create a null struct containing a non-nullable child + // see https://github.com/apache/arrow-rs/pull/3244 for details + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt8, + true, + ))); // [None, Some(b"Parquet")] let array_data = ArrayData::builder(data_type) .len(2) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(child_data) .build() .unwrap(); @@ -778,7 +519,7 @@ mod tests { #[test] #[should_panic(expected = "The child array cannot contain null values.")] - fn test_stirng_array_from_list_array_with_child_nulls_failed() { + fn test_string_array_from_list_array_with_child_nulls_failed() { _test_generic_string_array_from_list_array_with_child_nulls_failed::(); } @@ -797,13 +538,15 @@ mod tests { .unwrap(); let offsets = [0, 2, 3].map(|n| O::from_usize(n).unwrap()); - let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt16, false), - )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Arc::new(Field::new( + "item", + DataType::UInt16, + false, + ))); let array_data = ArrayData::builder(data_type) .len(2) - .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_buffer(Buffer::from_slice_ref(offsets)) .add_child_data(child_data) .build() .unwrap(); @@ -813,7 +556,7 @@ mod tests { #[test] #[should_panic( - expected = "StringArray can only be created from List arrays, mismatched data types." + expected = "BinaryArray can only be created from List arrays, mismatched data types." )] fn test_string_array_from_list_array_wrong_type() { _test_generic_string_array_from_list_array_wrong_type::(); @@ -821,9 +564,67 @@ mod tests { #[test] #[should_panic( - expected = "StringArray can only be created from List arrays, mismatched data types." + expected = "BinaryArray can only be created from List arrays, mismatched data types." )] fn test_large_string_array_from_list_array_wrong_type() { - _test_generic_string_array_from_list_array_wrong_type::(); + _test_generic_string_array_from_list_array_wrong_type::(); + } + + #[test] + #[should_panic( + expected = "Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 0" + )] + fn test_list_array_utf8_validation() { + let mut builder = ListBuilder::new(PrimitiveBuilder::::new()); + builder.values().append_value(0xFF); + builder.append(true); + let list = builder.finish(); + let _ = StringArray::from(list); + } + + #[test] + fn test_empty_offsets() { + let string = StringArray::from( + ArrayData::builder(DataType::Utf8) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + + let string = LargeStringArray::from( + ArrayData::builder(DataType::LargeUtf8) + .buffers(vec![Buffer::from(&[]), Buffer::from(&[])]) + .build() + .unwrap(), + ); + assert_eq!(string.len(), 0); + assert_eq!(string.value_offsets(), &[0]); + } + + #[test] + fn test_into_builder() { + let array: StringArray = vec!["hello", "arrow"].into(); + + // Append values + let mut builder = array.into_builder().unwrap(); + + builder.append_value("rust"); + + let expected: StringArray = vec!["hello", "arrow", "rust"].into(); + let array = builder.finish(); + assert_eq!(expected, array); + } + + #[test] + fn test_into_builder_err() { + let array: StringArray = vec!["hello", "arrow"].into(); + + // Clone it, so we cannot get a mutable builder back + let shared_array = array.clone(); + + let err_return = array.into_builder().unwrap_err(); + assert_eq!(&err_return, &shared_array); } } diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs new file mode 100644 index 000000000000..699da28cf7a3 --- /dev/null +++ b/arrow-array/src/array/struct_array.rs @@ -0,0 +1,735 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, SchemaBuilder}; +use std::sync::Arc; +use std::{any::Any, ops::Index}; + +/// An array of [structs](https://arrow.apache.org/docs/format/Columnar.html#struct-layout) +/// +/// Each child (called *field*) is represented by a separate array. +/// +/// # Comparison with [RecordBatch] +/// +/// Both [`RecordBatch`] and [`StructArray`] represent a collection of columns / arrays with the +/// same length. +/// +/// However, there are a couple of key differences: +/// +/// * [`StructArray`] can be nested within other [`Array`], including itself +/// * [`RecordBatch`] can contain top-level metadata on its associated [`Schema`][arrow_schema::Schema] +/// * [`StructArray`] can contain top-level nulls, i.e. `null` +/// * [`RecordBatch`] can only represent nulls in its child columns, i.e. `{"field": null}` +/// +/// [`StructArray`] is therefore a more general data container than [`RecordBatch`], and as such +/// code that needs to handle both will typically share an implementation in terms of +/// [`StructArray`] and convert to/from [`RecordBatch`] as necessary. +/// +/// [`From`] implementations are provided to facilitate this conversion, however, converting +/// from a [`StructArray`] containing top-level nulls to a [`RecordBatch`] will panic, as there +/// is no way to preserve them. +/// +/// # Example: Create an array from a vector of fields +/// +/// ``` +/// use std::sync::Arc; +/// use arrow_array::{Array, ArrayRef, BooleanArray, Int32Array, StructArray}; +/// use arrow_schema::{DataType, Field}; +/// +/// let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); +/// let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); +/// +/// let struct_array = StructArray::from(vec![ +/// ( +/// Arc::new(Field::new("b", DataType::Boolean, false)), +/// boolean.clone() as ArrayRef, +/// ), +/// ( +/// Arc::new(Field::new("c", DataType::Int32, false)), +/// int.clone() as ArrayRef, +/// ), +/// ]); +/// assert_eq!(struct_array.column(0).as_ref(), boolean.as_ref()); +/// assert_eq!(struct_array.column(1).as_ref(), int.as_ref()); +/// assert_eq!(4, struct_array.len()); +/// assert_eq!(0, struct_array.null_count()); +/// assert_eq!(0, struct_array.offset()); +/// ``` +#[derive(Clone)] +pub struct StructArray { + len: usize, + data_type: DataType, + nulls: Option, + fields: Vec, +} + +impl StructArray { + /// Create a new [`StructArray`] from the provided parts, panicking on failure + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(fields: Fields, arrays: Vec, nulls: Option) -> Self { + Self::try_new(fields, arrays, nulls).unwrap() + } + + /// Create a new [`StructArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// Errors if + /// + /// * `fields.len() != arrays.len()` + /// * `fields[i].data_type() != arrays[i].data_type()` + /// * `arrays[i].len() != arrays[j].len()` + /// * `arrays[i].len() != nulls.len()` + /// * `!fields[i].is_nullable() && !nulls.contains(arrays[i].nulls())` + pub fn try_new( + fields: Fields, + arrays: Vec, + nulls: Option, + ) -> Result { + if fields.len() != arrays.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect number of arrays for StructArray fields, expected {} got {}", + fields.len(), + arrays.len() + ))); + } + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect number of nulls for StructArray, expected {len} got {}", + n.len(), + ))); + } + } + + for (f, a) in fields.iter().zip(&arrays) { + if f.data_type() != a.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect datatype for StructArray field {:?}, expected {} got {}", + f.name(), + f.data_type(), + a.data_type() + ))); + } + + if a.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect array length for StructArray field {:?}, expected {} got {}", + f.name(), + len, + a.len() + ))); + } + + if !f.is_nullable() { + if let Some(a) = a.logical_nulls() { + if !nulls.as_ref().map(|n| n.contains(&a)).unwrap_or_default() { + return Err(ArrowError::InvalidArgumentError(format!( + "Found unmasked nulls for non-nullable StructArray field {:?}", + f.name() + ))); + } + } + } + } + + Ok(Self { + len, + data_type: DataType::Struct(fields), + nulls: nulls.filter(|n| n.null_count() > 0), + fields: arrays, + }) + } + + /// Create a new [`StructArray`] of length `len` where all values are null + pub fn new_null(fields: Fields, len: usize) -> Self { + let arrays = fields + .iter() + .map(|f| new_null_array(f.data_type(), len)) + .collect(); + + Self { + len, + data_type: DataType::Struct(fields), + nulls: Some(NullBuffer::new_null(len)), + fields: arrays, + } + } + + /// Create a new [`StructArray`] from the provided parts without validation + /// + /// # Safety + /// + /// Safe if [`Self::new`] would not panic with the given arguments + pub unsafe fn new_unchecked( + fields: Fields, + arrays: Vec, + nulls: Option, + ) -> Self { + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + Self { + len, + data_type: DataType::Struct(fields), + nulls, + fields: arrays, + } + } + + /// Create a new [`StructArray`] containing no fields + /// + /// # Panics + /// + /// If `len != nulls.len()` + pub fn new_empty_fields(len: usize, nulls: Option) -> Self { + if let Some(n) = &nulls { + assert_eq!(len, n.len()) + } + Self { + len, + data_type: DataType::Struct(Fields::empty()), + fields: vec![], + nulls, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (Fields, Vec, Option) { + let f = match self.data_type { + DataType::Struct(f) => f, + _ => unreachable!(), + }; + (f, self.fields, self.nulls) + } + + /// Returns the field at `pos`. + pub fn column(&self, pos: usize) -> &ArrayRef { + &self.fields[pos] + } + + /// Return the number of fields in this struct array + pub fn num_columns(&self) -> usize { + self.fields.len() + } + + /// Returns the fields of the struct array + pub fn columns(&self) -> &[ArrayRef] { + &self.fields + } + + /// Returns child array refs of the struct array + #[deprecated(note = "Use columns().to_vec()")] + pub fn columns_ref(&self) -> Vec { + self.columns().to_vec() + } + + /// Return field names in this struct array + pub fn column_names(&self) -> Vec<&str> { + match self.data_type() { + DataType::Struct(fields) => fields + .iter() + .map(|f| f.name().as_str()) + .collect::>(), + _ => unreachable!("Struct array's data type is not struct!"), + } + } + + /// Returns the [`Fields`] of this [`StructArray`] + pub fn fields(&self) -> &Fields { + match self.data_type() { + DataType::Struct(f) => f, + _ => unreachable!(), + } + } + + /// Return child array whose field name equals to column_name + /// + /// Note: A schema can currently have duplicate field names, in which case + /// the first field will always be selected. + /// This issue will be addressed in [ARROW-11178](https://issues.apache.org/jira/browse/ARROW-11178) + pub fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> { + self.column_names() + .iter() + .position(|c| c == &column_name) + .map(|pos| self.column(pos)) + } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced StructArray cannot exceed the existing length" + ); + + let fields = self.fields.iter().map(|a| a.slice(offset, len)).collect(); + + Self { + len, + data_type: self.data_type.clone(), + nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)), + fields, + } + } +} + +impl From for StructArray { + fn from(data: ArrayData) -> Self { + let fields = data + .child_data() + .iter() + .map(|cd| make_array(cd.clone())) + .collect(); + + Self { + len: data.len(), + data_type: data.data_type().clone(), + nulls: data.nulls().cloned(), + fields, + } + } +} + +impl From for ArrayData { + fn from(array: StructArray) -> Self { + let builder = ArrayDataBuilder::new(array.data_type) + .len(array.len) + .nulls(array.nulls) + .child_data(array.fields.iter().map(|x| x.to_data()).collect()); + + unsafe { builder.build_unchecked() } + } +} + +impl TryFrom> for StructArray { + type Error = ArrowError; + + /// builds a StructArray from a vector of names and arrays. + fn try_from(values: Vec<(&str, ArrayRef)>) -> Result { + let (schema, arrays): (SchemaBuilder, _) = values + .into_iter() + .map(|(name, array)| { + ( + Field::new(name, array.data_type().clone(), array.is_nullable()), + array, + ) + }) + .unzip(); + + StructArray::try_new(schema.finish().fields, arrays, None) + } +} + +impl Array for StructArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + self.clone().into() + } + + fn into_data(self) -> ArrayData { + self.into() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + fn get_buffer_memory_size(&self) -> usize { + let mut size = self.fields.iter().map(|a| a.get_buffer_memory_size()).sum(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } + + fn get_array_memory_size(&self) -> usize { + let mut size = self.fields.iter().map(|a| a.get_array_memory_size()).sum(); + size += std::mem::size_of::(); + if let Some(n) = self.nulls.as_ref() { + size += n.buffer().capacity(); + } + size + } +} + +impl From> for StructArray { + fn from(v: Vec<(FieldRef, ArrayRef)>) -> Self { + let (schema, arrays): (SchemaBuilder, _) = v.into_iter().unzip(); + StructArray::new(schema.finish().fields, arrays, None) + } +} + +impl std::fmt::Debug for StructArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "StructArray\n[\n")?; + for (child_index, name) in self.column_names().iter().enumerate() { + let column = self.column(child_index); + writeln!( + f, + "-- child {}: \"{}\" ({:?})", + child_index, + name, + column.data_type() + )?; + std::fmt::Debug::fmt(column, f)?; + writeln!(f)?; + } + write!(f, "]") + } +} + +impl From<(Vec<(FieldRef, ArrayRef)>, Buffer)> for StructArray { + fn from(pair: (Vec<(FieldRef, ArrayRef)>, Buffer)) -> Self { + let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default(); + let (fields, arrays): (SchemaBuilder, Vec<_>) = pair.0.into_iter().unzip(); + let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len)); + Self::new(fields.finish().fields, arrays, Some(nulls)) + } +} + +impl From for StructArray { + fn from(value: RecordBatch) -> Self { + Self { + len: value.num_rows(), + data_type: DataType::Struct(value.schema().fields().clone()), + nulls: None, + fields: value.columns().to_vec(), + } + } +} + +impl Index<&str> for StructArray { + type Output = ArrayRef; + + /// Get a reference to a column's array by name. + /// + /// Note: A schema can currently have duplicate field names, in which case + /// the first field will always be selected. + /// This issue will be addressed in [ARROW-11178](https://issues.apache.org/jira/browse/ARROW-11178) + /// + /// # Panics + /// + /// Panics if the name is not in the schema. + fn index(&self, name: &str) -> &Self::Output { + self.column_by_name(name).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray}; + use arrow_buffer::ToByteSlice; + use std::sync::Arc; + + #[test] + fn test_struct_array_builder() { + let boolean_array = BooleanArray::from(vec![false, false, true, true]); + let int_array = Int64Array::from(vec![42, 28, 19, 31]); + + let fields = vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int64, false), + ]; + let struct_array_data = ArrayData::builder(DataType::Struct(fields.into())) + .len(4) + .add_child_data(boolean_array.to_data()) + .add_child_data(int_array.to_data()) + .build() + .unwrap(); + let struct_array = StructArray::from(struct_array_data); + + assert_eq!(struct_array.column(0).as_ref(), &boolean_array); + assert_eq!(struct_array.column(1).as_ref(), &int_array); + } + + #[test] + fn test_struct_array_from() { + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + assert_eq!(struct_array.column(0).as_ref(), boolean.as_ref()); + assert_eq!(struct_array.column(1).as_ref(), int.as_ref()); + assert_eq!(4, struct_array.len()); + assert_eq!(0, struct_array.null_count()); + assert_eq!(0, struct_array.offset()); + } + + /// validates that struct can be accessed using `column_name` as index i.e. `struct_array["column_name"]`. + #[test] + fn test_struct_array_index_access() { + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ]); + assert_eq!(struct_array["b"].as_ref(), boolean.as_ref()); + assert_eq!(struct_array["c"].as_ref(), int.as_ref()); + } + + /// validates that the in-memory representation follows [the spec](https://arrow.apache.org/docs/format/Columnar.html#struct-layout) + #[test] + fn test_struct_array_from_vec() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + Some("mark"), + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); + + let arr = + StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]).unwrap(); + + let struct_data = arr.into_data(); + assert_eq!(4, struct_data.len()); + assert_eq!(0, struct_data.null_count()); + + let expected_string_data = ArrayData::builder(DataType::Utf8) + .len(4) + .null_bit_buffer(Some(Buffer::from(&[9_u8]))) + .add_buffer(Buffer::from(&[0, 3, 3, 3, 7].to_byte_slice())) + .add_buffer(Buffer::from(b"joemark")) + .build() + .unwrap(); + + let expected_int_data = ArrayData::builder(DataType::Int32) + .len(4) + .null_bit_buffer(Some(Buffer::from(&[11_u8]))) + .add_buffer(Buffer::from(&[1, 2, 0, 4].to_byte_slice())) + .build() + .unwrap(); + + assert_eq!(expected_string_data, struct_data.child_data()[0]); + assert_eq!(expected_int_data, struct_data.child_data()[1]); + } + + #[test] + fn test_struct_array_from_vec_error() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![ + Some("joe"), + None, + None, + // 3 elements, not 4 + ])); + let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); + + let err = StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) + .unwrap_err() + .to_string(); + + assert_eq!( + err, + "Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4" + ) + } + + #[test] + #[should_panic( + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" + )] + fn test_struct_array_from_mismatched_types_single() { + drop(StructArray::from(vec![( + Arc::new(Field::new("b", DataType::Int16, false)), + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, + )])); + } + + #[test] + #[should_panic( + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" + )] + fn test_struct_array_from_mismatched_types_multiple() { + drop(StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Int16, false)), + Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc, + ), + ( + Arc::new(Field::new("c", DataType::Utf8, false)), + Arc::new(Int32Array::from(vec![42, 28, 19, 31])), + ), + ])); + } + + #[test] + fn test_struct_array_slice() { + let boolean_data = ArrayData::builder(DataType::Boolean) + .len(5) + .add_buffer(Buffer::from([0b00010000])) + .null_bit_buffer(Some(Buffer::from([0b00010001]))) + .build() + .unwrap(); + let int_data = ArrayData::builder(DataType::Int32) + .len(5) + .add_buffer(Buffer::from([0, 28, 42, 0, 0].to_byte_slice())) + .null_bit_buffer(Some(Buffer::from([0b00000110]))) + .build() + .unwrap(); + + let field_types = vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int32, true), + ]; + let struct_array_data = ArrayData::builder(DataType::Struct(field_types.into())) + .len(5) + .add_child_data(boolean_data.clone()) + .add_child_data(int_data.clone()) + .null_bit_buffer(Some(Buffer::from([0b00010111]))) + .build() + .unwrap(); + let struct_array = StructArray::from(struct_array_data); + + assert_eq!(5, struct_array.len()); + assert_eq!(1, struct_array.null_count()); + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_valid(1)); + assert!(struct_array.is_valid(2)); + assert!(struct_array.is_null(3)); + assert!(struct_array.is_valid(4)); + assert_eq!(boolean_data, struct_array.column(0).to_data()); + assert_eq!(int_data, struct_array.column(1).to_data()); + + let c0 = struct_array.column(0); + let c0 = c0.as_any().downcast_ref::().unwrap(); + assert_eq!(5, c0.len()); + assert_eq!(3, c0.null_count()); + assert!(c0.is_valid(0)); + assert!(!c0.value(0)); + assert!(c0.is_null(1)); + assert!(c0.is_null(2)); + assert!(c0.is_null(3)); + assert!(c0.is_valid(4)); + assert!(c0.value(4)); + + let c1 = struct_array.column(1); + let c1 = c1.as_any().downcast_ref::().unwrap(); + assert_eq!(5, c1.len()); + assert_eq!(3, c1.null_count()); + assert!(c1.is_null(0)); + assert!(c1.is_valid(1)); + assert_eq!(28, c1.value(1)); + assert!(c1.is_valid(2)); + assert_eq!(42, c1.value(2)); + assert!(c1.is_null(3)); + assert!(c1.is_null(4)); + + let sliced_array = struct_array.slice(2, 3); + let sliced_array = sliced_array.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_array.len()); + assert_eq!(1, sliced_array.null_count()); + assert!(sliced_array.is_valid(0)); + assert!(sliced_array.is_null(1)); + assert!(sliced_array.is_valid(2)); + + let sliced_c0 = sliced_array.column(0); + let sliced_c0 = sliced_c0.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_c0.len()); + assert!(sliced_c0.is_null(0)); + assert!(sliced_c0.is_null(1)); + assert!(sliced_c0.is_valid(2)); + assert!(sliced_c0.value(2)); + + let sliced_c1 = sliced_array.column(1); + let sliced_c1 = sliced_c1.as_any().downcast_ref::().unwrap(); + assert_eq!(3, sliced_c1.len()); + assert!(sliced_c1.is_valid(0)); + assert_eq!(42, sliced_c1.value(0)); + assert!(sliced_c1.is_null(1)); + assert!(sliced_c1.is_null(2)); + } + + #[test] + #[should_panic( + expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2" + )] + fn test_invalid_struct_child_array_lengths() { + drop(StructArray::from(vec![ + ( + Arc::new(Field::new("b", DataType::Float32, false)), + Arc::new(Float32Array::from(vec![1.1])) as Arc, + ), + ( + Arc::new(Field::new("c", DataType::Float64, false)), + Arc::new(Float64Array::from(vec![2.2, 3.3])), + ), + ])); + } + + #[test] + fn test_struct_array_from_empty() { + let sa = StructArray::from(vec![]); + assert!(sa.is_empty()) + } + + #[test] + #[should_panic(expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"")] + fn test_struct_array_from_mismatched_nullability() { + drop(StructArray::from(vec![( + Arc::new(Field::new("c", DataType::Int32, false)), + Arc::new(Int32Array::from(vec![Some(42), None, Some(19)])) as ArrayRef, + )])); + } +} diff --git a/arrow/src/array/array_union.rs b/arrow-array/src/array/union_array.rs similarity index 71% rename from arrow/src/array/array_union.rs rename to arrow-array/src/array/union_array.rs index b221239b2dbe..94ac0bc879e4 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow-array/src/array/union_array.rs @@ -15,26 +15,26 @@ // specific language governing permissions and limitations // under the License. +use crate::{make_array, Array, ArrayRef}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{Buffer, ScalarBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode}; /// Contains the `UnionArray` type. /// -use crate::array::{make_array, Array, ArrayData, ArrayRef}; -use crate::buffer::Buffer; -use crate::datatypes::*; -use crate::error::{ArrowError, Result}; - -use core::fmt; use std::any::Any; +use std::sync::Arc; -/// An Array that can represent slots of varying types. +/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout) /// /// Each slot in a [UnionArray] can have a value chosen from a number /// of types. Each of the possible types are named like the fields of -/// a [`StructArray`](crate::array::StructArray). A `UnionArray` can +/// a [`StructArray`](crate::StructArray). A `UnionArray` can /// have two possible memory layouts, "dense" or "sparse". For more /// information on please see the /// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout). /// -/// [UnionBuilder](crate::array::UnionBuilder) can be used to +/// [UnionBuilder](crate::builder::UnionBuilder) can be used to /// create [UnionArray]'s of primitive types. `UnionArray`'s of nested /// types are also supported but not via `UnionBuilder`, see the tests /// for examples. @@ -42,10 +42,10 @@ use std::any::Any; /// # Examples /// ## Create a dense UnionArray `[1, 3.2, 34]` /// ``` -/// use arrow::buffer::Buffer; -/// use arrow::datatypes::*; +/// use arrow_buffer::Buffer; +/// use arrow_schema::*; /// use std::sync::Arc; -/// use arrow::array::{Array, Int32Array, Float64Array, UnionArray}; +/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray}; /// /// let int_array = Int32Array::from(vec![1, 34]); /// let float_array = Float64Array::from(vec![3.2]); @@ -76,10 +76,10 @@ use std::any::Any; /// /// ## Create a sparse UnionArray `[1, 3.2, 34]` /// ``` -/// use arrow::buffer::Buffer; -/// use arrow::datatypes::*; +/// use arrow_buffer::Buffer; +/// use arrow_schema::*; /// use std::sync::Arc; -/// use arrow::array::{Array, Int32Array, Float64Array, UnionArray}; +/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray}; /// /// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]); /// let float_array = Float64Array::from(vec![None, Some(3.2), None]); @@ -106,9 +106,12 @@ use std::any::Any; /// let value = array.value(2).as_any().downcast_ref::().unwrap().value(0); /// assert_eq!(34, value); /// ``` +#[derive(Clone)] pub struct UnionArray { - data: ArrayData, - boxed_fields: Vec, + data_type: DataType, + type_ids: ScalarBuffer, + offsets: Option>, + fields: Vec>, } impl UnionArray { @@ -142,8 +145,7 @@ impl UnionArray { value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, ) -> Self { - let (field_types, field_values): (Vec<_>, Vec<_>) = - child_arrays.into_iter().unzip(); + let (fields, field_values): (Vec<_>, Vec<_>) = child_arrays.into_iter().unzip(); let len = type_ids.len(); let mode = if value_offsets.is_some() { @@ -153,8 +155,7 @@ impl UnionArray { }; let builder = ArrayData::builder(DataType::Union( - field_types, - Vec::from(field_type_ids), + UnionFields::new(field_type_ids.iter().copied(), fields), mode, )) .add_buffer(type_ids) @@ -174,12 +175,11 @@ impl UnionArray { type_ids: Buffer, value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, - ) -> Result { + ) -> Result { if let Some(b) = &value_offsets { if ((type_ids.len()) * 4) != b.len() { return Err(ArrowError::InvalidArgumentError( - "Type Ids and Offsets represent a different number of array slots." - .to_string(), + "Type Ids and Offsets represent a different number of array slots.".to_string(), )); } } @@ -193,8 +193,7 @@ impl UnionArray { if !invalid_type_ids.is_empty() { return Err(ArrowError::InvalidArgumentError(format!( "Type Ids must be positive and cannot be greater than the number of \ - child arrays, found:\n{:?}", - invalid_type_ids + child arrays, found:\n{invalid_type_ids:?}" ))); } @@ -209,18 +208,16 @@ impl UnionArray { if !invalid_offsets.is_empty() { return Err(ArrowError::InvalidArgumentError(format!( "Offsets must be positive and within the length of the Array, \ - found:\n{:?}", - invalid_offsets + found:\n{invalid_offsets:?}" ))); } } // Unsafe Justification: arguments were validated above (and // re-revalidated as part of data().validate() below) - let new_self = unsafe { - Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays) - }; - new_self.data().validate()?; + let new_self = + unsafe { Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays) }; + new_self.to_data().validate()?; Ok(new_self) } @@ -232,9 +229,8 @@ impl UnionArray { /// Panics if the `type_id` provided is less than zero or greater than the number of types /// in the `Union`. pub fn child(&self, type_id: i8) -> &ArrayRef { - assert!(0 <= type_id); - assert!((type_id as usize) < self.boxed_fields.len()); - &self.boxed_fields[type_id as usize] + let boxed = &self.fields[type_id as usize]; + boxed.as_ref().expect("invalid type id") } /// Returns the `type_id` for the array slot at `index`. @@ -243,8 +239,17 @@ impl UnionArray { /// /// Panics if `index` is greater than the length of the array. pub fn type_id(&self, index: usize) -> i8 { - assert!(index < self.len()); - self.data().buffers()[0].as_slice()[self.offset() + index] as i8 + self.type_ids[index] + } + + /// Returns the `type_ids` buffer for this array + pub fn type_ids(&self) -> &ScalarBuffer { + &self.type_ids + } + + /// Returns the `offsets` buffer if this is a dense array + pub fn offsets(&self) -> Option<&ScalarBuffer> { + self.offsets.as_ref() } /// Returns the offset into the underlying values array for the array slot at `index`. @@ -252,12 +257,11 @@ impl UnionArray { /// # Panics /// /// Panics if `index` is greater than the length of the array. - pub fn value_offset(&self, index: usize) -> i32 { + pub fn value_offset(&self, index: usize) -> usize { assert!(index < self.len()); - if self.is_dense() { - self.data().buffers()[1].typed_data::()[self.offset() + index] - } else { - (self.offset() + index) as i32 + match &self.offsets { + Some(offsets) => offsets[index] as usize, + None => self.offset() + index, } } @@ -266,17 +270,17 @@ impl UnionArray { /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { let type_id = self.type_id(i); - let value_offset = self.value_offset(i) as usize; - let child_data = self.boxed_fields[type_id as usize].clone(); - child_data.slice(value_offset, 1) + let value_offset = self.value_offset(i); + let child = self.child(type_id); + child.slice(value_offset, 1) } /// Returns the names of the types in the union. pub fn type_names(&self) -> Vec<&str> { - match self.data.data_type() { - DataType::Union(fields, _, _) => fields + match self.data_type() { + DataType::Union(fields, _) => fields .iter() - .map(|f| f.name().as_str()) + .map(|(_, f)| f.name().as_str()) .collect::>(), _ => unreachable!("Union array's data type is not a union!"), } @@ -284,26 +288,94 @@ impl UnionArray { /// Returns whether the `UnionArray` is dense (or sparse if `false`). fn is_dense(&self) -> bool { - match self.data.data_type() { - DataType::Union(_, _, mode) => mode == &UnionMode::Dense, + match self.data_type() { + DataType::Union(_, mode) => mode == &UnionMode::Dense, _ => unreachable!("Union array's data type is not a union!"), } } + + /// Returns a zero-copy slice of this array with the indicated offset and length. + pub fn slice(&self, offset: usize, length: usize) -> Self { + let (offsets, fields) = match self.offsets.as_ref() { + // If dense union, slice offsets + Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()), + // Otherwise need to slice sparse children + None => { + let fields = self + .fields + .iter() + .map(|x| x.as_ref().map(|x| x.slice(offset, length))) + .collect(); + (None, fields) + } + }; + + Self { + data_type: self.data_type.clone(), + type_ids: self.type_ids.slice(offset, length), + offsets, + fields, + } + } } impl From for UnionArray { fn from(data: ArrayData) -> Self { - let mut boxed_fields = vec![]; - for cd in data.child_data() { - boxed_fields.push(make_array(cd.clone())); + let (fields, mode) = match data.data_type() { + DataType::Union(fields, mode) => (fields, *mode), + d => panic!("UnionArray expected ArrayData with type Union got {d}"), + }; + let (type_ids, offsets) = match mode { + UnionMode::Sparse => ( + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), + None, + ), + UnionMode::Dense => ( + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), + Some(ScalarBuffer::new( + data.buffers()[1].clone(), + data.offset(), + data.len(), + )), + ), + }; + + let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize; + let mut boxed_fields = vec![None; max_id + 1]; + for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) { + boxed_fields[field_id as usize] = Some(make_array(cd.clone())); + } + Self { + data_type: data.data_type().clone(), + type_ids, + offsets, + fields: boxed_fields, } - Self { data, boxed_fields } } } impl From for ArrayData { fn from(array: UnionArray) -> Self { - array.data + let len = array.len(); + let f = match &array.data_type { + DataType::Union(f, _) => f, + _ => unreachable!(), + }; + let buffers = match array.offsets { + Some(o) => vec![array.type_ids.into_inner(), o.into_inner()], + None => vec![array.type_ids.into_inner()], + }; + + let child = f + .iter() + .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data()) + .collect(); + + let builder = ArrayDataBuilder::new(array.data_type) + .len(len) + .buffers(buffers) + .child_data(child); + unsafe { builder.build_unchecked() } } } @@ -312,14 +384,38 @@ impl Array for UnionArray { self } - fn data(&self) -> &ArrayData { - &self.data + fn to_data(&self) -> ArrayData { + self.clone().into() } fn into_data(self) -> ArrayData { self.into() } + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(self.slice(offset, length)) + } + + fn len(&self) -> usize { + self.type_ids.len() + } + + fn is_empty(&self) -> bool { + self.type_ids.is_empty() + } + + fn offset(&self) -> usize { + 0 + } + + fn nulls(&self) -> Option<&NullBuffer> { + None + } + /// Union types always return non null as there is no validity buffer. /// To check validity correctly you must check the underlying vector. fn is_null(&self, _index: usize) -> bool { @@ -337,35 +433,66 @@ impl Array for UnionArray { fn null_count(&self) -> usize { 0 } + + fn get_buffer_memory_size(&self) -> usize { + let mut sum = self.type_ids.inner().capacity(); + if let Some(o) = self.offsets.as_ref() { + sum += o.inner().capacity() + } + self.fields + .iter() + .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size())) + .sum::() + + sum + } + + fn get_array_memory_size(&self) -> usize { + let mut sum = self.type_ids.inner().capacity(); + if let Some(o) = self.offsets.as_ref() { + sum += o.inner().capacity() + } + std::mem::size_of::() + + self + .fields + .iter() + .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size())) + .sum::() + + sum + } } -impl fmt::Debug for UnionArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl std::fmt::Debug for UnionArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let header = if self.is_dense() { "UnionArray(Dense)\n[" } else { "UnionArray(Sparse)\n[" }; - writeln!(f, "{}", header)?; + writeln!(f, "{header}")?; writeln!(f, "-- type id buffer:")?; - writeln!(f, "{:?}", self.data().buffers()[0])?; + writeln!(f, "{:?}", self.type_ids)?; - if self.is_dense() { + if let Some(offsets) = &self.offsets { writeln!(f, "-- offsets buffer:")?; - writeln!(f, "{:?}", self.data().buffers()[1])?; + writeln!(f, "{:?}", offsets)?; } - for (child_index, name) in self.type_names().iter().enumerate() { - let column = &self.boxed_fields[child_index]; + let fields = match self.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!(), + }; + + for (type_id, field) in fields.iter() { + let child = self.child(type_id); writeln!( f, "-- child {}: \"{}\" ({:?})", - child_index, - *name, - column.data_type() + type_id, + field.name(), + field.data_type() )?; - fmt::Debug::fmt(column, f)?; + std::fmt::Debug::fmt(child, f)?; writeln!(f)?; } writeln!(f, "]") @@ -376,13 +503,14 @@ impl fmt::Debug for UnionArray { mod tests { use super::*; + use crate::builder::UnionBuilder; + use crate::cast::AsArray; + use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type}; + use crate::RecordBatch; + use crate::{Float64Array, Int32Array, Int64Array, StringArray}; + use arrow_schema::Schema; use std::sync::Arc; - use crate::array::*; - use crate::buffer::Buffer; - use crate::datatypes::{DataType, Field}; - use crate::record_batch::RecordBatch; - #[test] fn test_dense_i32() { let mut builder = UnionBuilder::new_dense(); @@ -396,39 +524,33 @@ mod tests { let union = builder.build().unwrap(); let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1]; - let expected_value_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1]; + let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1]; let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; // Check type ids - assert_eq!( - union.data().buffers()[0], - Buffer::from_slice_ref(&expected_type_ids) - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets - assert_eq!( - union.data().buffers()[1], - Buffer::from_slice_ref(&expected_value_offsets) - ); - for (i, id) in expected_value_offsets.iter().enumerate() { - assert_eq!(&union.value_offset(i), id); + assert_eq!(*union.offsets().unwrap(), expected_offsets); + for (i, id) in expected_offsets.iter().enumerate() { + assert_eq!(union.value_offset(i), *id as usize); } // Check data assert_eq!( - union.data().child_data()[0].buffers()[0], - Buffer::from_slice_ref(&[1_i32, 4, 6]) + *union.child(0).as_primitive::().values(), + [1_i32, 4, 6] ); assert_eq!( - union.data().child_data()[1].buffers()[0], - Buffer::from_slice_ref(&[2_i32, 7]) + *union.child(1).as_primitive::().values(), + [2_i32, 7] ); assert_eq!( - union.data().child_data()[2].buffers()[0], - Buffer::from_slice_ref(&[3_i32, 5]), + *union.child(2).as_primitive::().values(), + [3_i32, 5] ); assert_eq!(expected_array_values.len(), union.len()); @@ -448,7 +570,7 @@ mod tests { let mut builder = UnionBuilder::new_dense(); let expected_type_ids = vec![0_i8; 1024]; - let expected_value_offsets: Vec<_> = (0..1024).collect(); + let expected_offsets: Vec<_> = (0..1024).collect(); let expected_array_values: Vec<_> = (1..=1024).collect(); expected_array_values @@ -458,27 +580,21 @@ mod tests { let union = builder.build().unwrap(); // Check type ids - assert_eq!( - union.data().buffers()[0], - Buffer::from_slice_ref(&expected_type_ids) - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets - assert_eq!( - union.data().buffers()[1], - Buffer::from_slice_ref(&expected_value_offsets) - ); - for (i, id) in expected_value_offsets.iter().enumerate() { - assert_eq!(&union.value_offset(i), id); + assert_eq!(*union.offsets().unwrap(), expected_offsets); + for (i, id) in expected_offsets.iter().enumerate() { + assert_eq!(union.value_offset(i), *id as usize); } for (i, expected_value) in expected_array_values.iter().enumerate() { assert!(!union.is_null(i)); let slot = union.value(i); - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(expected_value, &value); @@ -627,10 +743,10 @@ mod tests { let float_array = Float64Array::from(vec![10.0]); let type_ids = [1_i8, 0, 0, 2, 0, 1]; - let value_offsets = [0_i32, 0, 1, 0, 2, 1]; + let offsets = [0_i32, 0, 1, 0, 2, 1]; - let type_id_buffer = Buffer::from_slice_ref(&type_ids); - let value_offsets_buffer = Buffer::from_slice_ref(&value_offsets); + let type_id_buffer = Buffer::from_slice_ref(type_ids); + let value_offsets_buffer = Buffer::from_slice_ref(offsets); let children: Vec<(Field, Arc)> = vec![ ( @@ -652,18 +768,15 @@ mod tests { .unwrap(); // Check type ids - assert_eq!(Buffer::from_slice_ref(&type_ids), array.data().buffers()[0]); + assert_eq!(*array.type_ids(), type_ids); for (i, id) in type_ids.iter().enumerate() { assert_eq!(id, &array.type_id(i)); } // Check offsets - assert_eq!( - Buffer::from_slice_ref(&value_offsets), - array.data().buffers()[1] - ); - for (i, id) in value_offsets.iter().enumerate() { - assert_eq!(id, &array.value_offset(i)); + assert_eq!(*array.offsets().unwrap(), offsets); + for (i, id) in offsets.iter().enumerate() { + assert_eq!(*id as usize, array.value_offset(i)); } // Check values @@ -726,29 +839,26 @@ mod tests { let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); // Check data assert_eq!( - union.data().child_data()[0].buffers()[0], - Buffer::from_slice_ref(&[1_i32, 0, 0, 4, 0, 6, 0]), + *union.child(0).as_primitive::().values(), + [1_i32, 0, 0, 4, 0, 6, 0], ); assert_eq!( - Buffer::from_slice_ref(&[0_i32, 2_i32, 0, 0, 0, 0, 7]), - union.data().child_data()[1].buffers()[0] + *union.child(1).as_primitive::().values(), + [0_i32, 2_i32, 0, 0, 0, 0, 7] ); assert_eq!( - Buffer::from_slice_ref(&[0_i32, 0, 3_i32, 0, 5, 0, 0]), - union.data().child_data()[2].buffers()[0] + *union.child(2).as_primitive::().values(), + [0_i32, 0, 3_i32, 0, 5, 0, 0] ); assert_eq!(expected_array_values.len(), union.len()); @@ -775,16 +885,13 @@ mod tests { let expected_type_ids = vec![0_i8, 1, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer, i.e. no offsets - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); for i in 0..union.len() { let slot = union.value(i); @@ -837,16 +944,13 @@ mod tests { let expected_type_ids = vec![0_i8, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer, i.e. no offsets - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); for i in 0..union.len() { let slot = union.value(i); @@ -897,7 +1001,7 @@ mod tests { match i { 0 => assert!(slot.is_null(0)), 1 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -905,7 +1009,7 @@ mod tests { } 2 => assert!(slot.is_null(0)), 3 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -926,7 +1030,7 @@ mod tests { } #[test] - fn test_union_array_validaty() { + fn test_union_array_validity() { let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); @@ -953,7 +1057,13 @@ mod tests { let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1.0).unwrap(); let err = builder.append::("a", 1).unwrap_err().to_string(); - assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err); + assert!( + err.contains( + "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32" + ), + "{}", + err + ); } #[test] @@ -990,18 +1100,18 @@ mod tests { assert_eq!(union_slice.type_id(2), 1); let slot = union_slice.value(0); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_null(0)); let slot = union_slice.value(1); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_valid(0)); assert_eq!(array.value(0), 3.0); let slot = union_slice.value(2); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_null(0)); } @@ -1020,4 +1130,74 @@ mod tests { let record_batch_slice = record_batch.slice(1, 3); test_slice_union(record_batch_slice); } + + #[test] + fn test_custom_type_ids() { + let data_type = DataType::Union( + UnionFields::new( + vec![8, 4, 9], + vec![ + Field::new("strings", DataType::Utf8, false), + Field::new("integers", DataType::Int32, false), + Field::new("floats", DataType::Float64, false), + ], + ), + UnionMode::Dense, + ); + + let string_array = StringArray::from(vec!["foo", "bar", "baz"]); + let int_array = Int32Array::from(vec![5, 6, 4]); + let float_array = Float64Array::from(vec![10.0]); + + let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]); + let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]); + + let data = ArrayData::builder(data_type) + .len(7) + .buffers(vec![type_ids, value_offsets]) + .child_data(vec![ + string_array.into_data(), + int_array.into_data(), + float_array.into_data(), + ]) + .build() + .unwrap(); + + let array = UnionArray::from(data); + + let v = array.value(0); + assert_eq!(v.data_type(), &DataType::Int32); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 5); + + let v = array.value(1); + assert_eq!(v.data_type(), &DataType::Utf8); + assert_eq!(v.len(), 1); + assert_eq!(v.as_string::().value(0), "foo"); + + let v = array.value(2); + assert_eq!(v.data_type(), &DataType::Int32); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 6); + + let v = array.value(3); + assert_eq!(v.data_type(), &DataType::Utf8); + assert_eq!(v.len(), 1); + assert_eq!(v.as_string::().value(0), "bar"); + + let v = array.value(4); + assert_eq!(v.data_type(), &DataType::Float64); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 10.0); + + let v = array.value(5); + assert_eq!(v.data_type(), &DataType::Int32); + assert_eq!(v.len(), 1); + assert_eq!(v.as_primitive::().value(0), 4); + + let v = array.value(6); + assert_eq!(v.data_type(), &DataType::Utf8); + assert_eq!(v.len(), 1); + assert_eq!(v.as_string::().value(0), "baz"); + } } diff --git a/arrow/src/array/builder/boolean_builder.rs b/arrow-array/src/builder/boolean_builder.rs similarity index 64% rename from arrow/src/array/builder/boolean_builder.rs rename to arrow-array/src/builder/boolean_builder.rs index eed14a55fd91..7e59d940a50e 100644 --- a/arrow/src/array/builder/boolean_builder.rs +++ b/arrow-array/src/builder/boolean_builder.rs @@ -15,50 +15,45 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::{ArrayBuilder, BooleanBufferBuilder}; +use crate::{ArrayRef, BooleanArray}; +use arrow_buffer::Buffer; +use arrow_buffer::NullBufferBuilder; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; -use crate::array::ArrayBuilder; -use crate::array::ArrayData; -use crate::array::ArrayRef; -use crate::array::BooleanArray; -use crate::datatypes::DataType; - -use crate::error::ArrowError; -use crate::error::Result; - -use super::BooleanBufferBuilder; -use super::NullBufferBuilder; - -/// Array builder for fixed-width primitive types +/// Builder for [`BooleanArray`] /// /// # Example /// /// Create a `BooleanArray` from a `BooleanBuilder` /// /// ``` -/// use arrow::array::{Array, BooleanArray, BooleanBuilder}; /// -/// let mut b = BooleanBuilder::new(); -/// b.append_value(true); -/// b.append_null(); -/// b.append_value(false); -/// b.append_value(true); -/// let arr = b.finish(); +/// # use arrow_array::{Array, BooleanArray, builder::BooleanBuilder}; /// -/// assert_eq!(4, arr.len()); -/// assert_eq!(1, arr.null_count()); -/// assert_eq!(true, arr.value(0)); -/// assert!(arr.is_valid(0)); -/// assert!(!arr.is_null(0)); -/// assert!(!arr.is_valid(1)); -/// assert!(arr.is_null(1)); -/// assert_eq!(false, arr.value(2)); -/// assert!(arr.is_valid(2)); -/// assert!(!arr.is_null(2)); -/// assert_eq!(true, arr.value(3)); -/// assert!(arr.is_valid(3)); -/// assert!(!arr.is_null(3)); +/// let mut b = BooleanBuilder::new(); +/// b.append_value(true); +/// b.append_null(); +/// b.append_value(false); +/// b.append_value(true); +/// let arr = b.finish(); +/// +/// assert_eq!(4, arr.len()); +/// assert_eq!(1, arr.null_count()); +/// assert_eq!(true, arr.value(0)); +/// assert!(arr.is_valid(0)); +/// assert!(!arr.is_null(0)); +/// assert!(!arr.is_valid(1)); +/// assert!(arr.is_null(1)); +/// assert_eq!(false, arr.value(2)); +/// assert!(arr.is_valid(2)); +/// assert!(!arr.is_null(2)); +/// assert_eq!(true, arr.value(3)); +/// assert!(arr.is_valid(3)); +/// assert!(!arr.is_null(3)); /// ``` #[derive(Debug)] pub struct BooleanBuilder { @@ -132,7 +127,7 @@ impl BooleanBuilder { /// /// Returns an error if the slices are of different lengths #[inline] - pub fn append_values(&mut self, values: &[bool], is_valid: &[bool]) -> Result<()> { + pub fn append_values(&mut self, values: &[bool], is_valid: &[bool]) -> Result<(), ArrowError> { if values.len() != is_valid.len() { Err(ArrowError::InvalidArgumentError( "Value and validity lengths must be equal".to_string(), @@ -150,12 +145,31 @@ impl BooleanBuilder { let null_bit_buffer = self.null_buffer_builder.finish(); let builder = ArrayData::builder(DataType::Boolean) .len(len) - .add_buffer(self.values_builder.finish()) - .null_bit_buffer(null_bit_buffer); + .add_buffer(self.values_builder.finish().into_inner()) + .nulls(null_bit_buffer); + + let array_data = unsafe { builder.build_unchecked() }; + BooleanArray::from(array_data) + } + + /// Builds the [BooleanArray] without resetting the builder. + pub fn finish_cloned(&self) -> BooleanArray { + let len = self.len(); + let nulls = self.null_buffer_builder.finish_cloned(); + let value_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); + let builder = ArrayData::builder(DataType::Boolean) + .len(len) + .add_buffer(value_buffer) + .nulls(nulls); let array_data = unsafe { builder.build_unchecked() }; BooleanArray::from(array_data) } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } } impl ArrayBuilder for BooleanBuilder { @@ -179,21 +193,31 @@ impl ArrayBuilder for BooleanBuilder { self.values_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.values_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl Extend> for BooleanBuilder { + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + self.append_option(v) + } + } } #[cfg(test)] mod tests { use super::*; - use crate::{array::Array, buffer::Buffer}; + use crate::Array; + use arrow_buffer::Buffer; #[test] fn test_boolean_array_builder() { @@ -209,21 +233,20 @@ mod tests { } let arr = builder.finish(); - assert_eq!(&buf, arr.values()); + assert_eq!(&buf, arr.values().inner()); assert_eq!(10, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); for i in 0..10 { assert!(!arr.is_null(i)); assert!(arr.is_valid(i)); - assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {}", i) + assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {i}") } } #[test] fn test_boolean_array_builder_append_slice() { - let arr1 = - BooleanArray::from(vec![Some(true), Some(false), None, None, Some(false)]); + let arr1 = BooleanArray::from(vec![Some(true), Some(false), None, None, Some(false)]); let mut builder = BooleanArray::builder(0); builder.append_slice(&[true, false]); @@ -258,6 +281,41 @@ mod tests { let array = builder.finish(); assert_eq!(0, array.null_count()); - assert!(array.data().null_buffer().is_none()); + assert!(array.nulls().is_none()); + } + + #[test] + fn test_boolean_array_builder_finish_cloned() { + let mut builder = BooleanArray::builder(16); + builder.append_option(Some(true)); + builder.append_value(false); + builder.append_slice(&[true, false, true]); + let mut array = builder.finish_cloned(); + assert_eq!(3, array.true_count()); + assert_eq!(2, array.false_count()); + + builder + .append_values(&[false, false, true], &[true, true, true]) + .unwrap(); + + array = builder.finish(); + assert_eq!(4, array.true_count()); + assert_eq!(4, array.false_count()); + + assert_eq!(0, array.null_count()); + assert!(array.nulls().is_none()); + } + + #[test] + fn test_extend() { + let mut builder = BooleanBuilder::new(); + builder.extend([false, false, true, false, false].into_iter().map(Some)); + builder.extend([true, true, false].into_iter().map(Some)); + let array = builder.finish(); + let values = array.iter().map(|x| x.unwrap()).collect::>(); + assert_eq!( + &values, + &[false, false, true, false, false, true, true, false] + ) } } diff --git a/arrow-array/src/builder/buffer_builder.rs b/arrow-array/src/builder/buffer_builder.rs new file mode 100644 index 000000000000..2b66a8187fa9 --- /dev/null +++ b/arrow-array/src/builder/buffer_builder.rs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrowPrimitiveType; +pub use arrow_buffer::BufferBuilder; +use half::f16; + +use crate::types::*; + +/// Buffer builder for signed 8-bit integer type. +pub type Int8BufferBuilder = BufferBuilder; +/// Buffer builder for signed 16-bit integer type. +pub type Int16BufferBuilder = BufferBuilder; +/// Buffer builder for signed 32-bit integer type. +pub type Int32BufferBuilder = BufferBuilder; +/// Buffer builder for signed 64-bit integer type. +pub type Int64BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 8-bit integer type. +pub type UInt8BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 16-bit integer type. +pub type UInt16BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 32-bit integer type. +pub type UInt32BufferBuilder = BufferBuilder; +/// Buffer builder for usigned 64-bit integer type. +pub type UInt64BufferBuilder = BufferBuilder; +/// Buffer builder for 16-bit floating point type. +pub type Float16BufferBuilder = BufferBuilder; +/// Buffer builder for 32-bit floating point type. +pub type Float32BufferBuilder = BufferBuilder; +/// Buffer builder for 64-bit floating point type. +pub type Float64BufferBuilder = BufferBuilder; + +/// Buffer builder for 128-bit decimal type. +pub type Decimal128BufferBuilder = BufferBuilder<::Native>; +/// Buffer builder for 256-bit decimal type. +pub type Decimal256BufferBuilder = BufferBuilder<::Native>; + +/// Buffer builder for timestamp type of second unit. +pub type TimestampSecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for timestamp type of millisecond unit. +pub type TimestampMillisecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for timestamp type of microsecond unit. +pub type TimestampMicrosecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for timestamp type of nanosecond unit. +pub type TimestampNanosecondBufferBuilder = + BufferBuilder<::Native>; + +/// Buffer builder for 32-bit date type. +pub type Date32BufferBuilder = BufferBuilder<::Native>; +/// Buffer builder for 64-bit date type. +pub type Date64BufferBuilder = BufferBuilder<::Native>; + +/// Buffer builder for 32-bit elaspsed time since midnight of second unit. +pub type Time32SecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for 32-bit elaspsed time since midnight of millisecond unit. +pub type Time32MillisecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for 64-bit elaspsed time since midnight of microsecond unit. +pub type Time64MicrosecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for 64-bit elaspsed time since midnight of nanosecond unit. +pub type Time64NanosecondBufferBuilder = + BufferBuilder<::Native>; + +/// Buffer builder for “calendar” interval in months. +pub type IntervalYearMonthBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for “calendar” interval in days and milliseconds. +pub type IntervalDayTimeBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder “calendar” interval in months, days, and nanoseconds. +pub type IntervalMonthDayNanoBufferBuilder = + BufferBuilder<::Native>; + +/// Buffer builder for elaspsed time of second unit. +pub type DurationSecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for elaspsed time of milliseconds unit. +pub type DurationMillisecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for elaspsed time of microseconds unit. +pub type DurationMicrosecondBufferBuilder = + BufferBuilder<::Native>; +/// Buffer builder for elaspsed time of nanoseconds unit. +pub type DurationNanosecondBufferBuilder = + BufferBuilder<::Native>; + +#[cfg(test)] +mod tests { + use crate::builder::{ArrayBuilder, Int32BufferBuilder, Int8Builder, UInt8BufferBuilder}; + use crate::Array; + + #[test] + fn test_builder_i32_empty() { + let mut b = Int32BufferBuilder::new(5); + assert_eq!(0, b.len()); + assert_eq!(16, b.capacity()); + let a = b.finish(); + assert_eq!(0, a.len()); + } + + #[test] + fn test_builder_i32_alloc_zero_bytes() { + let mut b = Int32BufferBuilder::new(0); + b.append(123); + let a = b.finish(); + assert_eq!(4, a.len()); + } + + #[test] + fn test_builder_i32() { + let mut b = Int32BufferBuilder::new(5); + for i in 0..5 { + b.append(i); + } + assert_eq!(16, b.capacity()); + let a = b.finish(); + assert_eq!(20, a.len()); + } + + #[test] + fn test_builder_i32_grow_buffer() { + let mut b = Int32BufferBuilder::new(2); + assert_eq!(16, b.capacity()); + for i in 0..20 { + b.append(i); + } + assert_eq!(32, b.capacity()); + let a = b.finish(); + assert_eq!(80, a.len()); + } + + #[test] + fn test_builder_finish() { + let mut b = Int32BufferBuilder::new(5); + assert_eq!(16, b.capacity()); + for i in 0..10 { + b.append(i); + } + let mut a = b.finish(); + assert_eq!(40, a.len()); + assert_eq!(0, b.len()); + assert_eq!(0, b.capacity()); + + // Try build another buffer after cleaning up. + for i in 0..20 { + b.append(i) + } + assert_eq!(32, b.capacity()); + a = b.finish(); + assert_eq!(80, a.len()); + } + + #[test] + fn test_reserve() { + let mut b = UInt8BufferBuilder::new(2); + assert_eq!(64, b.capacity()); + b.reserve(64); + assert_eq!(64, b.capacity()); + b.reserve(65); + assert_eq!(128, b.capacity()); + + let mut b = Int32BufferBuilder::new(2); + assert_eq!(16, b.capacity()); + b.reserve(16); + assert_eq!(16, b.capacity()); + b.reserve(17); + assert_eq!(32, b.capacity()); + } + + #[test] + fn test_append_slice() { + let mut b = UInt8BufferBuilder::new(0); + b.append_slice(b"Hello, "); + b.append_slice(b"World!"); + let buffer = b.finish(); + assert_eq!(13, buffer.len()); + + let mut b = Int32BufferBuilder::new(0); + b.append_slice(&[32, 54]); + let buffer = b.finish(); + assert_eq!(8, buffer.len()); + } + + #[test] + fn test_append_values() { + let mut a = Int8Builder::new(); + a.append_value(1); + a.append_null(); + a.append_value(-2); + assert_eq!(a.len(), 3); + + // append values + let values = &[1, 2, 3, 4]; + let is_valid = &[true, true, false, true]; + a.append_values(values, is_valid); + + assert_eq!(a.len(), 7); + let array = a.finish(); + assert_eq!(array.value(0), 1); + assert!(array.is_null(1)); + assert_eq!(array.value(2), -2); + assert_eq!(array.value(3), 1); + assert_eq!(array.value(4), 2); + assert!(array.is_null(5)); + assert_eq!(array.value(6), 4); + } +} diff --git a/arrow/src/array/builder/fixed_size_binary_builder.rs b/arrow-array/src/builder/fixed_size_binary_builder.rs similarity index 65% rename from arrow/src/array/builder/fixed_size_binary_builder.rs rename to arrow-array/src/builder/fixed_size_binary_builder.rs index 30c25e0a62b9..0a50eb8a50e9 100644 --- a/arrow/src/array/builder/fixed_size_binary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_builder.rs @@ -15,16 +15,31 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{ - ArrayBuilder, ArrayData, ArrayRef, FixedSizeBinaryArray, UInt8BufferBuilder, -}; -use crate::datatypes::DataType; -use crate::error::{ArrowError, Result}; +use crate::builder::{ArrayBuilder, UInt8BufferBuilder}; +use crate::{ArrayRef, FixedSizeBinaryArray}; +use arrow_buffer::Buffer; +use arrow_buffer::NullBufferBuilder; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; -use super::NullBufferBuilder; - +/// Builder for [`FixedSizeBinaryArray`] +/// ``` +/// # use arrow_array::builder::FixedSizeBinaryBuilder; +/// # use arrow_array::Array; +/// # +/// let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 5); +/// // [b"hello", null, b"arrow"] +/// builder.append_value(b"hello").unwrap(); +/// builder.append_null(); +/// builder.append_value(b"arrow").unwrap(); +/// +/// let array = builder.finish(); +/// assert_eq!(array.value(0), b"hello"); +/// assert!(array.is_null(1)); +/// assert_eq!(array.value(2), b"arrow"); +/// ``` #[derive(Debug)] pub struct FixedSizeBinaryBuilder { values_builder: UInt8BufferBuilder, @@ -43,8 +58,7 @@ impl FixedSizeBinaryBuilder { pub fn with_capacity(capacity: usize, byte_width: i32) -> Self { assert!( byte_width >= 0, - "value length ({}) of the array must >= 0", - byte_width + "value length ({byte_width}) of the array must >= 0" ); Self { values_builder: UInt8BufferBuilder::new(capacity * byte_width as usize), @@ -58,10 +72,11 @@ impl FixedSizeBinaryBuilder { /// Automatically update the null buffer to delimit the slice appended in as a /// distinct value element. #[inline] - pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<()> { + pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<(), ArrowError> { if self.value_length != value.as_ref().len() as i32 { Err(ArrowError::InvalidArgumentError( - "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths".to_string() + "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths" + .to_string(), )) } else { self.values_builder.append_slice(value.as_ref()); @@ -81,11 +96,22 @@ impl FixedSizeBinaryBuilder { /// Builds the [`FixedSizeBinaryArray`] and reset this builder. pub fn finish(&mut self) -> FixedSizeBinaryArray { let array_length = self.len(); - let array_data_builder = - ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) - .add_buffer(self.values_builder.finish()) - .null_bit_buffer(self.null_buffer_builder.finish()) - .len(array_length); + let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) + .add_buffer(self.values_builder.finish()) + .nulls(self.null_buffer_builder.finish()) + .len(array_length); + let array_data = unsafe { array_data_builder.build_unchecked() }; + FixedSizeBinaryArray::from(array_data) + } + + /// Builds the [`FixedSizeBinaryArray`] without resetting the builder. + pub fn finish_cloned(&self) -> FixedSizeBinaryArray { + let array_length = self.len(); + let values_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); + let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) + .add_buffer(values_buffer) + .nulls(self.null_buffer_builder.finish_cloned()) + .len(array_length); let array_data = unsafe { array_data_builder.build_unchecked() }; FixedSizeBinaryArray::from(array_data) } @@ -112,24 +138,24 @@ impl ArrayBuilder for FixedSizeBinaryBuilder { self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.null_buffer_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } } #[cfg(test)] mod tests { use super::*; - use crate::array::Array; - use crate::array::FixedSizeBinaryArray; - use crate::datatypes::DataType; + use crate::Array; + use crate::FixedSizeBinaryArray; + use arrow_schema::DataType; #[test] fn test_fixed_size_binary_builder() { @@ -148,6 +174,36 @@ mod tests { assert_eq!(5, array.value_length()); } + #[test] + fn test_fixed_size_binary_builder_finish_cloned() { + let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 5); + + // [b"hello", null, "arrow"] + builder.append_value(b"hello").unwrap(); + builder.append_null(); + builder.append_value(b"arrow").unwrap(); + let mut array: FixedSizeBinaryArray = builder.finish_cloned(); + + assert_eq!(&DataType::FixedSizeBinary(5), array.data_type()); + assert_eq!(3, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(10, array.value_offset(2)); + assert_eq!(5, array.value_length()); + + // [b"finis", null, "clone"] + builder.append_value(b"finis").unwrap(); + builder.append_null(); + builder.append_value(b"clone").unwrap(); + + array = builder.finish(); + + assert_eq!(&DataType::FixedSizeBinary(5), array.data_type()); + assert_eq!(6, array.len()); + assert_eq!(2, array.null_count()); + assert_eq!(25, array.value_offset(5)); + assert_eq!(5, array.value_length()); + } + #[test] fn test_fixed_size_binary_builder_with_zero_value_length() { let mut builder = FixedSizeBinaryBuilder::new(0); diff --git a/arrow/src/array/builder/fixed_size_list_builder.rs b/arrow-array/src/builder/fixed_size_list_builder.rs similarity index 58% rename from arrow/src/array/builder/fixed_size_list_builder.rs rename to arrow-array/src/builder/fixed_size_list_builder.rs index da850d156243..0fe779d5c1a2 100644 --- a/arrow/src/array/builder/fixed_size_list_builder.rs +++ b/arrow-array/src/builder/fixed_size_list_builder.rs @@ -15,19 +15,53 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::ArrayBuilder; +use crate::{ArrayRef, FixedSizeListArray}; +use arrow_buffer::NullBufferBuilder; +use arrow_data::ArrayData; +use arrow_schema::{DataType, Field}; use std::any::Any; use std::sync::Arc; -use crate::array::ArrayData; -use crate::array::ArrayRef; -use crate::array::FixedSizeListArray; -use crate::datatypes::DataType; -use crate::datatypes::Field; - -use super::ArrayBuilder; -use super::NullBufferBuilder; - -/// Array builder for [`FixedSizeListArray`] +/// Builder for [`FixedSizeListArray`] +/// ``` +/// use arrow_array::{builder::{Int32Builder, FixedSizeListBuilder}, Array, Int32Array}; +/// let values_builder = Int32Builder::new(); +/// let mut builder = FixedSizeListBuilder::new(values_builder, 3); +/// +/// // [[0, 1, 2], null, [3, null, 5], [6, 7, null]] +/// builder.values().append_value(0); +/// builder.values().append_value(1); +/// builder.values().append_value(2); +/// builder.append(true); +/// builder.values().append_null(); +/// builder.values().append_null(); +/// builder.values().append_null(); +/// builder.append(false); +/// builder.values().append_value(3); +/// builder.values().append_null(); +/// builder.values().append_value(5); +/// builder.append(true); +/// builder.values().append_value(6); +/// builder.values().append_value(7); +/// builder.values().append_null(); +/// builder.append(true); +/// let list_array = builder.finish(); +/// assert_eq!( +/// *list_array.value(0), +/// Int32Array::from(vec![Some(0), Some(1), Some(2)]) +/// ); +/// assert!(list_array.is_null(1)); +/// assert_eq!( +/// *list_array.value(2), +/// Int32Array::from(vec![Some(3), None, Some(5)]) +/// ); +/// assert_eq!( +/// *list_array.value(3), +/// Int32Array::from(vec![Some(6), Some(7), None]) +/// ) +/// ``` +/// #[derive(Debug)] pub struct FixedSizeListBuilder { null_buffer_builder: NullBufferBuilder, @@ -39,7 +73,11 @@ impl FixedSizeListBuilder { /// Creates a new [`FixedSizeListBuilder`] from a given values array builder /// `value_length` is the number of values within each array pub fn new(values_builder: T, value_length: i32) -> Self { - let capacity = values_builder.len(); + let capacity = values_builder + .len() + .checked_div(value_length as _) + .unwrap_or_default(); + Self::with_capacity(values_builder, value_length, capacity) } @@ -79,15 +117,15 @@ where self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.null_buffer_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } } impl FixedSizeListBuilder @@ -102,6 +140,7 @@ where &mut self.values_builder } + /// Returns the length of the list pub fn value_length(&self) -> i32 { self.list_len } @@ -115,30 +154,53 @@ where /// Builds the [`FixedSizeListBuilder`] and reset this builder. pub fn finish(&mut self) -> FixedSizeListArray { let len = self.len(); - let values_arr = self - .values_builder - .as_any_mut() - .downcast_mut::() - .unwrap() - .finish(); - let values_data = values_arr.data(); - - assert!( - values_data.len() == len * self.list_len as usize, + let values_arr = self.values_builder.finish(); + let values_data = values_arr.to_data(); + + assert_eq!( + values_data.len(), len * self.list_len as usize, "Length of the child array ({}) must be the multiple of the value length ({}) and the array length ({}).", values_data.len(), self.list_len, len, ); - let null_bit_buffer = self.null_buffer_builder.finish(); + let nulls = self.null_buffer_builder.finish(); let array_data = ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", values_data.data_type().clone(), true)), + Arc::new(Field::new("item", values_data.data_type().clone(), true)), self.list_len, )) .len(len) - .add_child_data(values_data.clone()) - .null_bit_buffer(null_bit_buffer); + .add_child_data(values_data) + .nulls(nulls); + + let array_data = unsafe { array_data.build_unchecked() }; + + FixedSizeListArray::from(array_data) + } + + /// Builds the [`FixedSizeListBuilder`] without resetting the builder. + pub fn finish_cloned(&self) -> FixedSizeListArray { + let len = self.len(); + let values_arr = self.values_builder.finish_cloned(); + let values_data = values_arr.to_data(); + + assert_eq!( + values_data.len(), len * self.list_len as usize, + "Length of the child array ({}) must be the multiple of the value length ({}) and the array length ({}).", + values_data.len(), + self.list_len, + len, + ); + + let nulls = self.null_buffer_builder.finish_cloned(); + let array_data = ArrayData::builder(DataType::FixedSizeList( + Arc::new(Field::new("item", values_data.data_type().clone(), true)), + self.list_len, + )) + .len(len) + .add_child_data(values_data) + .nulls(nulls); let array_data = unsafe { array_data.build_unchecked() }; @@ -150,9 +212,9 @@ where mod tests { use super::*; - use crate::array::Array; - use crate::array::Int32Array; - use crate::array::Int32Builder; + use crate::builder::Int32Builder; + use crate::Array; + use crate::Int32Array; #[test] fn test_fixed_size_list_array_builder() { @@ -185,6 +247,48 @@ mod tests { assert_eq!(3, list_array.value_length()); } + #[test] + fn test_fixed_size_list_array_builder_finish_cloned() { + let values_builder = Int32Builder::new(); + let mut builder = FixedSizeListBuilder::new(values_builder, 3); + + // [[0, 1, 2], null, [3, null, 5], [6, 7, null]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + let mut list_array = builder.finish_cloned(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(3, list_array.value_length()); + + builder.values().append_value(6); + builder.values().append_value(7); + builder.values().append_null(); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); + list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(5, list_array.len()); + assert_eq!(2, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + } + #[test] fn test_fixed_size_list_array_builder_empty() { let values_builder = Int32Array::builder(5); diff --git a/arrow-array/src/builder/generic_byte_run_builder.rs b/arrow-array/src/builder/generic_byte_run_builder.rs new file mode 100644 index 000000000000..3cde76c4a039 --- /dev/null +++ b/arrow-array/src/builder/generic_byte_run_builder.rs @@ -0,0 +1,514 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::bytes::ByteArrayNativeType; +use std::{any::Any, sync::Arc}; + +use crate::{ + types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, Utf8Type}, + ArrayRef, ArrowPrimitiveType, RunArray, +}; + +use super::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; + +use arrow_buffer::ArrowNativeType; + +/// Builder for [`RunArray`] of [`GenericByteArray`](crate::array::GenericByteArray) +/// +/// # Example: +/// +/// ``` +/// +/// # use arrow_array::builder::GenericByteRunBuilder; +/// # use arrow_array::{GenericByteArray, BinaryArray}; +/// # use arrow_array::types::{BinaryType, Int16Type}; +/// # use arrow_array::{Array, Int16Array}; +/// # use arrow_array::cast::AsArray; +/// +/// let mut builder = +/// GenericByteRunBuilder::::new(); +/// builder.extend([Some(b"abc"), Some(b"abc"), None, Some(b"def")].into_iter()); +/// builder.append_value(b"def"); +/// builder.append_null(); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[2, 3, 5, 6]); +/// +/// let av = array.values(); +/// +/// assert!(!av.is_null(0)); +/// assert!(av.is_null(1)); +/// assert!(!av.is_null(2)); +/// assert!(av.is_null(3)); +/// +/// // Values are polymorphic and so require a downcast. +/// let ava: &BinaryArray = av.as_binary(); +/// +/// assert_eq!(ava.value(0), b"abc"); +/// assert_eq!(ava.value(2), b"def"); +/// ``` +#[derive(Debug)] +pub struct GenericByteRunBuilder +where + R: ArrowPrimitiveType, + V: ByteArrayType, +{ + run_ends_builder: PrimitiveBuilder, + values_builder: GenericByteBuilder, + current_value: Vec, + has_current_value: bool, + current_run_end_index: usize, + prev_run_end_index: usize, +} + +impl Default for GenericByteRunBuilder +where + R: ArrowPrimitiveType, + V: ByteArrayType, +{ + fn default() -> Self { + Self::new() + } +} + +impl GenericByteRunBuilder +where + R: ArrowPrimitiveType, + V: ByteArrayType, +{ + /// Creates a new `GenericByteRunBuilder` + pub fn new() -> Self { + Self { + run_ends_builder: PrimitiveBuilder::new(), + values_builder: GenericByteBuilder::::new(), + current_value: Vec::new(), + has_current_value: false, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } + + /// Creates a new `GenericByteRunBuilder` with the provided capacity + /// + /// `capacity`: the expected number of run-end encoded values. + /// `data_capacity`: the expected number of bytes of run end encoded values + pub fn with_capacity(capacity: usize, data_capacity: usize) -> Self { + Self { + run_ends_builder: PrimitiveBuilder::with_capacity(capacity), + values_builder: GenericByteBuilder::::with_capacity(capacity, data_capacity), + current_value: Vec::new(), + has_current_value: false, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } +} + +impl ArrayBuilder for GenericByteRunBuilder +where + R: RunEndIndexType, + V: ByteArrayType, +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the length of logical array encoded by + /// the eventual runs array. + fn len(&self) -> usize { + self.current_run_end_index + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl GenericByteRunBuilder +where + R: RunEndIndexType, + V: ByteArrayType, +{ + /// Appends optional value to the logical array encoded by the RunArray. + pub fn append_option(&mut self, input_value: Option>) { + match input_value { + Some(value) => self.append_value(value), + None => self.append_null(), + } + } + + /// Appends value to the logical array encoded by the RunArray. + pub fn append_value(&mut self, input_value: impl AsRef) { + let value: &[u8] = input_value.as_ref().as_ref(); + if !self.has_current_value { + self.append_run_end(); + self.current_value.extend_from_slice(value); + self.has_current_value = true; + } else if self.current_value.as_slice() != value { + self.append_run_end(); + self.current_value.clear(); + self.current_value.extend_from_slice(value); + } + self.current_run_end_index += 1; + } + + /// Appends null to the logical array encoded by the RunArray. + pub fn append_null(&mut self) { + if self.has_current_value { + self.append_run_end(); + self.current_value.clear(); + self.has_current_value = false; + } + self.current_run_end_index += 1; + } + + /// Creates the RunArray and resets the builder. + /// Panics if RunArray cannot be built. + pub fn finish(&mut self) -> RunArray { + // write the last run end to the array. + self.append_run_end(); + + // reset the run end index to zero. + self.current_value.clear(); + self.has_current_value = false; + self.current_run_end_index = 0; + self.prev_run_end_index = 0; + + // build the run encoded array by adding run_ends and values array as its children. + let run_ends_array = self.run_ends_builder.finish(); + let values_array = self.values_builder.finish(); + RunArray::::try_new(&run_ends_array, &values_array).unwrap() + } + + /// Creates the RunArray and without resetting the builder. + /// Panics if RunArray cannot be built. + pub fn finish_cloned(&self) -> RunArray { + let mut run_ends_array = self.run_ends_builder.finish_cloned(); + let mut values_array = self.values_builder.finish_cloned(); + + // Add current run if one exists + if self.prev_run_end_index != self.current_run_end_index { + let mut run_end_builder = run_ends_array.into_builder().unwrap(); + let mut values_builder = values_array.into_builder().unwrap(); + self.append_run_end_with_builders(&mut run_end_builder, &mut values_builder); + run_ends_array = run_end_builder.finish(); + values_array = values_builder.finish(); + } + + RunArray::::try_new(&run_ends_array, &values_array).unwrap() + } + + // Appends the current run to the array. + fn append_run_end(&mut self) { + // empty array or the function called without appending any value. + if self.prev_run_end_index == self.current_run_end_index { + return; + } + let run_end_index = self.run_end_index_as_native(); + self.run_ends_builder.append_value(run_end_index); + if self.has_current_value { + let slice = self.current_value.as_slice(); + let native = unsafe { + // Safety: + // As self.current_value is created from V::Native. The value V::Native can be + // built back from the bytes without validations + V::Native::from_bytes_unchecked(slice) + }; + self.values_builder.append_value(native); + } else { + self.values_builder.append_null(); + } + self.prev_run_end_index = self.current_run_end_index; + } + + // Similar to `append_run_end` but on custom builders. + // Used in `finish_cloned` which is not suppose to mutate `self`. + fn append_run_end_with_builders( + &self, + run_ends_builder: &mut PrimitiveBuilder, + values_builder: &mut GenericByteBuilder, + ) { + let run_end_index = self.run_end_index_as_native(); + run_ends_builder.append_value(run_end_index); + if self.has_current_value { + let slice = self.current_value.as_slice(); + let native = unsafe { + // Safety: + // As self.current_value is created from V::Native. The value V::Native can be + // built back from the bytes without validations + V::Native::from_bytes_unchecked(slice) + }; + values_builder.append_value(native); + } else { + values_builder.append_null(); + } + } + + fn run_end_index_as_native(&self) -> R::Native { + R::Native::from_usize(self.current_run_end_index).unwrap_or_else(|| { + panic!( + "Cannot convert the value {} from `usize` to native form of arrow datatype {}", + self.current_run_end_index, + R::DATA_TYPE + ) + }) + } +} + +impl Extend> for GenericByteRunBuilder +where + R: RunEndIndexType, + V: ByteArrayType, + S: AsRef, +{ + fn extend>>(&mut self, iter: T) { + for elem in iter { + self.append_option(elem); + } + } +} + +/// Builder for [`RunArray`] of [`StringArray`](crate::array::StringArray) +/// +/// ``` +/// // Create a run-end encoded array with run-end indexes data type as `i16`. +/// // The encoded values are Strings. +/// +/// # use arrow_array::builder::StringRunBuilder; +/// # use arrow_array::{Int16Array, StringArray}; +/// # use arrow_array::types::Int16Type; +/// # use arrow_array::cast::AsArray; +/// # +/// let mut builder = StringRunBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append_value("abc"); +/// builder.append_null(); +/// builder.extend([Some("def"), Some("def"), Some("abc")]); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[1, 2, 4, 5]); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &StringArray = av.as_string::(); +/// +/// assert_eq!(ava.value(0), "abc"); +/// assert!(av.is_null(1)); +/// assert_eq!(ava.value(2), "def"); +/// assert_eq!(ava.value(3), "abc"); +/// +/// ``` +pub type StringRunBuilder = GenericByteRunBuilder; + +/// Builder for [`RunArray`] of [`LargeStringArray`](crate::array::LargeStringArray) +pub type LargeStringRunBuilder = GenericByteRunBuilder; + +/// Builder for [`RunArray`] of [`BinaryArray`](crate::array::BinaryArray) +/// +/// ``` +/// // Create a run-end encoded array with run-end indexes data type as `i16`. +/// // The encoded data is binary values. +/// +/// # use arrow_array::builder::BinaryRunBuilder; +/// # use arrow_array::{BinaryArray, Int16Array}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::types::Int16Type; +/// +/// let mut builder = BinaryRunBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append_value(b"abc"); +/// builder.append_null(); +/// builder.extend([Some(b"def"), Some(b"def"), Some(b"abc")]); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[1, 2, 4, 5]); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &BinaryArray = av.as_binary(); +/// +/// assert_eq!(ava.value(0), b"abc"); +/// assert!(av.is_null(1)); +/// assert_eq!(ava.value(2), b"def"); +/// assert_eq!(ava.value(3), b"abc"); +/// +/// ``` +pub type BinaryRunBuilder = GenericByteRunBuilder; + +/// Builder for [`RunArray`] of [`LargeBinaryArray`](crate::array::LargeBinaryArray) +pub type LargeBinaryRunBuilder = GenericByteRunBuilder; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::Array; + use crate::cast::AsArray; + use crate::types::{Int16Type, Int32Type}; + use crate::GenericByteArray; + use crate::Int16RunArray; + + fn test_bytes_run_builder(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteRunBuilder::::new(); + builder.append_value(values[0]); + builder.append_value(values[0]); + builder.append_value(values[0]); + builder.append_null(); + builder.append_null(); + builder.append_value(values[1]); + builder.append_value(values[1]); + builder.append_value(values[2]); + builder.append_value(values[2]); + builder.append_value(values[2]); + builder.append_value(values[2]); + let array = builder.finish(); + + assert_eq!(array.len(), 11); + assert_eq!(array.null_count(), 0); + + assert_eq!(array.run_ends().values(), &[3, 5, 7, 11]); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(*ava.value(0), *values[0]); + assert!(ava.is_null(1)); + assert_eq!(*ava.value(2), *values[1]); + assert_eq!(*ava.value(3), *values[2]); + } + + #[test] + fn test_string_run_builder() { + test_bytes_run_builder::(vec!["abc", "def", "ghi"]); + } + + #[test] + fn test_string_run_builder_with_empty_strings() { + test_bytes_run_builder::(vec!["abc", "", "ghi"]); + } + + #[test] + fn test_binary_run_builder() { + test_bytes_run_builder::(vec![b"abc", b"def", b"ghi"]); + } + + fn test_bytes_run_builder_finish_cloned(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteRunBuilder::::new(); + + builder.append_value(values[0]); + builder.append_null(); + builder.append_value(values[1]); + builder.append_value(values[1]); + builder.append_value(values[0]); + let mut array: Int16RunArray = builder.finish_cloned(); + + assert_eq!(array.len(), 5); + assert_eq!(array.null_count(), 0); + + assert_eq!(array.run_ends().values(), &[1, 2, 4, 5]); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava.value(0), values[0]); + assert!(ava.is_null(1)); + assert_eq!(ava.value(2), values[1]); + assert_eq!(ava.value(3), values[0]); + + // Append last value before `finish_cloned` (`value[0]`) again and ensure it has only + // one entry in final output. + builder.append_value(values[0]); + builder.append_value(values[0]); + builder.append_value(values[1]); + array = builder.finish(); + + assert_eq!(array.len(), 8); + assert_eq!(array.null_count(), 0); + + assert_eq!(array.run_ends().values(), &[1, 2, 4, 7, 8]); + + // Values are polymorphic and so require a downcast. + let av2 = array.values(); + let ava2: &GenericByteArray = + av2.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava2.value(0), values[0]); + assert!(ava2.is_null(1)); + assert_eq!(ava2.value(2), values[1]); + // The value appended before and after `finish_cloned` has only one entry. + assert_eq!(ava2.value(3), values[0]); + assert_eq!(ava2.value(4), values[1]); + } + + #[test] + fn test_string_run_builder_finish_cloned() { + test_bytes_run_builder_finish_cloned::(vec!["abc", "def", "ghi"]); + } + + #[test] + fn test_binary_run_builder_finish_cloned() { + test_bytes_run_builder_finish_cloned::(vec![b"abc", b"def", b"ghi"]); + } + + #[test] + fn test_extend() { + let mut builder = StringRunBuilder::::new(); + builder.extend(["a", "a", "a", "", "", "b", "b"].into_iter().map(Some)); + builder.extend(["b", "cupcakes", "cupcakes"].into_iter().map(Some)); + let array = builder.finish(); + + assert_eq!(array.len(), 10); + assert_eq!(array.run_ends().values(), &[3, 5, 8, 10]); + + let str_array = array.values().as_string::(); + assert_eq!(str_array.value(0), "a"); + assert_eq!(str_array.value(1), ""); + assert_eq!(str_array.value(2), "b"); + assert_eq!(str_array.value(3), "cupcakes"); + } +} diff --git a/arrow-array/src/builder/generic_bytes_builder.rs b/arrow-array/src/builder/generic_bytes_builder.rs new file mode 100644 index 000000000000..2c7ee7a3e448 --- /dev/null +++ b/arrow-array/src/builder/generic_bytes_builder.rs @@ -0,0 +1,482 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, BufferBuilder, UInt8BufferBuilder}; +use crate::types::{ByteArrayType, GenericBinaryType, GenericStringType}; +use crate::{ArrayRef, GenericByteArray, OffsetSizeTrait}; +use arrow_buffer::NullBufferBuilder; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_data::ArrayDataBuilder; +use std::any::Any; +use std::fmt::Write; +use std::sync::Arc; + +/// Builder for [`GenericByteArray`] +pub struct GenericByteBuilder { + value_builder: UInt8BufferBuilder, + offsets_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, +} + +impl GenericByteBuilder { + /// Creates a new [`GenericByteBuilder`]. + pub fn new() -> Self { + Self::with_capacity(1024, 1024) + } + + /// Creates a new [`GenericByteBuilder`]. + /// + /// - `item_capacity` is the number of items to pre-allocate. + /// The size of the preallocated buffer of offsets is the number of items plus one. + /// - `data_capacity` is the total number of bytes of data to pre-allocate + /// (for all items, not per item). + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_builder = BufferBuilder::::new(item_capacity + 1); + offsets_builder.append(T::Offset::from_usize(0).unwrap()); + Self { + value_builder: UInt8BufferBuilder::new(data_capacity), + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(item_capacity), + } + } + + /// Creates a new [`GenericByteBuilder`] from buffers. + /// + /// # Safety + /// This doesn't verify buffer contents as it assumes the buffers are from existing and + /// valid [`GenericByteArray`]. + pub unsafe fn new_from_buffer( + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, + null_buffer: Option, + ) -> Self { + let offsets_builder = BufferBuilder::::new_from_buffer(offsets_buffer); + let value_builder = BufferBuilder::::new_from_buffer(value_buffer); + + let null_buffer_builder = null_buffer + .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, offsets_builder.len() - 1)) + .unwrap_or_else(|| NullBufferBuilder::new_with_len(offsets_builder.len() - 1)); + + Self { + offsets_builder, + value_builder, + null_buffer_builder, + } + } + + #[inline] + fn next_offset(&self) -> T::Offset { + T::Offset::from_usize(self.value_builder.len()).expect("byte array offset overflow") + } + + /// Appends a value into the builder. + /// + /// # Panics + /// + /// Panics if the resulting length of [`Self::values_slice`] would exceed `T::Offset::MAX` + #[inline] + pub fn append_value(&mut self, value: impl AsRef) { + self.value_builder.append_slice(value.as_ref().as_ref()); + self.null_buffer_builder.append(true); + self.offsets_builder.append(self.next_offset()); + } + + /// Append an `Option` value into the builder. + #[inline] + pub fn append_option(&mut self, value: Option>) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + }; + } + + /// Append a null value into the builder. + #[inline] + pub fn append_null(&mut self) { + self.null_buffer_builder.append(false); + self.offsets_builder.append(self.next_offset()); + } + + /// Builds the [`GenericByteArray`] and reset this builder. + pub fn finish(&mut self) -> GenericByteArray { + let array_type = T::DATA_TYPE; + let array_builder = ArrayDataBuilder::new(array_type) + .len(self.len()) + .add_buffer(self.offsets_builder.finish()) + .add_buffer(self.value_builder.finish()) + .nulls(self.null_buffer_builder.finish()); + + self.offsets_builder.append(self.next_offset()); + let array_data = unsafe { array_builder.build_unchecked() }; + GenericByteArray::from(array_data) + } + + /// Builds the [`GenericByteArray`] without resetting the builder. + pub fn finish_cloned(&self) -> GenericByteArray { + let array_type = T::DATA_TYPE; + let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + let value_buffer = Buffer::from_slice_ref(self.value_builder.as_slice()); + let array_builder = ArrayDataBuilder::new(array_type) + .len(self.len()) + .add_buffer(offset_buffer) + .add_buffer(value_buffer) + .nulls(self.null_buffer_builder.finish_cloned()); + + let array_data = unsafe { array_builder.build_unchecked() }; + GenericByteArray::from(array_data) + } + + /// Returns the current values buffer as a slice + pub fn values_slice(&self) -> &[u8] { + self.value_builder.as_slice() + } + + /// Returns the current offsets buffer as a slice + pub fn offsets_slice(&self) -> &[T::Offset] { + self.offsets_builder.as_slice() + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } + + /// Returns the current null buffer as a mutable slice + pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> { + self.null_buffer_builder.as_slice_mut() + } +} + +impl std::fmt::Debug for GenericByteBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}{}Builder", T::Offset::PREFIX, T::PREFIX)?; + f.debug_struct("") + .field("value_builder", &self.value_builder) + .field("offsets_builder", &self.offsets_builder) + .field("null_buffer_builder", &self.null_buffer_builder) + .finish() + } +} + +impl Default for GenericByteBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ArrayBuilder for GenericByteBuilder { + /// Returns the number of binary slots in the builder + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } + + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } +} + +impl> Extend> for GenericByteBuilder { + #[inline] + fn extend>>(&mut self, iter: I) { + for v in iter { + self.append_option(v) + } + } +} + +/// Array builder for [`GenericStringArray`][crate::GenericStringArray] +/// +/// Values can be appended using [`GenericByteBuilder::append_value`], and nulls with +/// [`GenericByteBuilder::append_null`] as normal. +/// +/// Additionally implements [`std::fmt::Write`] with any written data included in the next +/// appended value. This allows use with [`std::fmt::Display`] without intermediate allocations +/// +/// ``` +/// # use std::fmt::Write; +/// # use arrow_array::builder::GenericStringBuilder; +/// let mut builder = GenericStringBuilder::::new(); +/// +/// // Write data +/// write!(builder, "foo").unwrap(); +/// write!(builder, "bar").unwrap(); +/// +/// // Finish value +/// builder.append_value("baz"); +/// +/// // Write second value +/// write!(builder, "v2").unwrap(); +/// builder.append_value(""); +/// +/// let array = builder.finish(); +/// assert_eq!(array.value(0), "foobarbaz"); +/// assert_eq!(array.value(1), "v2"); +/// ``` +pub type GenericStringBuilder = GenericByteBuilder>; + +impl Write for GenericStringBuilder { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.value_builder.append_slice(s.as_bytes()); + Ok(()) + } +} + +/// Array builder for [`GenericBinaryArray`][crate::GenericBinaryArray] +pub type GenericBinaryBuilder = GenericByteBuilder>; + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::{Array, OffsetSizeTrait}; + use crate::GenericStringArray; + + fn _test_generic_binary_builder() { + let mut builder = GenericBinaryBuilder::::new(); + + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"rust"); + + let array = builder.finish(); + + assert_eq!(4, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(b"hello", array.value(0)); + assert_eq!([] as [u8; 0], array.value(1)); + assert!(array.is_null(2)); + assert_eq!(b"rust", array.value(3)); + assert_eq!(O::from_usize(5).unwrap(), array.value_offsets()[2]); + assert_eq!(O::from_usize(4).unwrap(), array.value_length(3)); + } + + #[test] + fn test_binary_builder() { + _test_generic_binary_builder::() + } + + #[test] + fn test_large_binary_builder() { + _test_generic_binary_builder::() + } + + fn _test_generic_binary_builder_all_nulls() { + let mut builder = GenericBinaryBuilder::::new(); + builder.append_null(); + builder.append_null(); + builder.append_null(); + assert_eq!(3, builder.len()); + assert!(!builder.is_empty()); + + let array = builder.finish(); + assert_eq!(3, array.null_count()); + assert_eq!(3, array.len()); + assert!(array.is_null(0)); + assert!(array.is_null(1)); + assert!(array.is_null(2)); + } + + #[test] + fn test_binary_builder_all_nulls() { + _test_generic_binary_builder_all_nulls::() + } + + #[test] + fn test_large_binary_builder_all_nulls() { + _test_generic_binary_builder_all_nulls::() + } + + fn _test_generic_binary_builder_reset() { + let mut builder = GenericBinaryBuilder::::new(); + + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"rust"); + builder.finish(); + + assert!(builder.is_empty()); + + builder.append_value(b"parquet"); + builder.append_null(); + builder.append_value(b"arrow"); + builder.append_value(b""); + let array = builder.finish(); + + assert_eq!(4, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(b"parquet", array.value(0)); + assert!(array.is_null(1)); + assert_eq!(b"arrow", array.value(2)); + assert_eq!(b"", array.value(1)); + assert_eq!(O::zero(), array.value_offsets()[0]); + assert_eq!(O::from_usize(7).unwrap(), array.value_offsets()[2]); + assert_eq!(O::from_usize(5).unwrap(), array.value_length(2)); + } + + #[test] + fn test_binary_builder_reset() { + _test_generic_binary_builder_reset::() + } + + #[test] + fn test_large_binary_builder_reset() { + _test_generic_binary_builder_reset::() + } + + fn _test_generic_string_array_builder() { + let mut builder = GenericStringBuilder::::new(); + let owned = "arrow".to_owned(); + + builder.append_value("hello"); + builder.append_value(""); + builder.append_value(&owned); + builder.append_null(); + builder.append_option(Some("rust")); + builder.append_option(None::<&str>); + builder.append_option(None::); + assert_eq!(7, builder.len()); + + assert_eq!( + GenericStringArray::::from(vec![ + Some("hello"), + Some(""), + Some("arrow"), + None, + Some("rust"), + None, + None + ]), + builder.finish() + ); + } + + #[test] + fn test_string_array_builder() { + _test_generic_string_array_builder::() + } + + #[test] + fn test_large_string_array_builder() { + _test_generic_string_array_builder::() + } + + fn _test_generic_string_array_builder_finish() { + let mut builder = GenericStringBuilder::::with_capacity(3, 11); + + builder.append_value("hello"); + builder.append_value("rust"); + builder.append_null(); + + builder.finish(); + assert!(builder.is_empty()); + assert_eq!(&[O::zero()], builder.offsets_slice()); + + builder.append_value("arrow"); + builder.append_value("parquet"); + let arr = builder.finish(); + // array should not have null buffer because there is not `null` value. + assert!(arr.nulls().is_none()); + assert_eq!(GenericStringArray::::from(vec!["arrow", "parquet"]), arr,) + } + + #[test] + fn test_string_array_builder_finish() { + _test_generic_string_array_builder_finish::() + } + + #[test] + fn test_large_string_array_builder_finish() { + _test_generic_string_array_builder_finish::() + } + + fn _test_generic_string_array_builder_finish_cloned() { + let mut builder = GenericStringBuilder::::with_capacity(3, 11); + + builder.append_value("hello"); + builder.append_value("rust"); + builder.append_null(); + + let mut arr = builder.finish_cloned(); + assert!(!builder.is_empty()); + assert_eq!(3, arr.len()); + + builder.append_value("arrow"); + builder.append_value("parquet"); + arr = builder.finish(); + + assert!(arr.nulls().is_some()); + assert_eq!(&[O::zero()], builder.offsets_slice()); + assert_eq!(5, arr.len()); + } + + #[test] + fn test_string_array_builder_finish_cloned() { + _test_generic_string_array_builder_finish_cloned::() + } + + #[test] + fn test_large_string_array_builder_finish_cloned() { + _test_generic_string_array_builder_finish_cloned::() + } + + #[test] + fn test_extend() { + let mut builder = GenericStringBuilder::::new(); + builder.extend(["a", "b", "c", "", "a", "b", "c"].into_iter().map(Some)); + builder.extend(["d", "cupcakes", "hello"].into_iter().map(Some)); + let array = builder.finish(); + assert_eq!(array.value_offsets(), &[0, 1, 2, 3, 3, 4, 5, 6, 7, 15, 20]); + assert_eq!(array.value_data(), b"abcabcdcupcakeshello"); + } + + #[test] + fn test_write() { + let mut builder = GenericStringBuilder::::new(); + write!(builder, "foo").unwrap(); + builder.append_value(""); + writeln!(builder, "bar").unwrap(); + builder.append_value(""); + write!(builder, "fiz").unwrap(); + write!(builder, "buz").unwrap(); + builder.append_value(""); + let a = builder.finish(); + let r: Vec<_> = a.iter().map(|x| x.unwrap()).collect(); + assert_eq!(r, &["foo", "bar\n", "fizbuz"]) + } +} diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs new file mode 100644 index 000000000000..b0c722ae7cda --- /dev/null +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -0,0 +1,626 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; +use crate::types::{ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType}; +use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray}; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{ArrowError, DataType}; +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`DictionaryArray`] of [`GenericByteArray`] +/// +/// For example to map a set of byte indices to String values. Note that +/// the use of a `HashMap` here will not scale to very large arrays or +/// result in an ordered dictionary. +#[derive(Debug)] +pub struct GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + state: ahash::RandomState, + /// Used to provide a lookup from string value to key type + /// + /// Note: usize's hash implementation is not used, instead the raw entry + /// API is used to store keys w.r.t the hash of the strings themselves + /// + dedup: HashMap, + + keys_builder: PrimitiveBuilder, + values_builder: GenericByteBuilder, +} + +impl Default for GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + fn default() -> Self { + Self::new() + } +} + +impl GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + /// Creates a new `GenericByteDictionaryBuilder` + pub fn new() -> Self { + let keys_builder = PrimitiveBuilder::new(); + let values_builder = GenericByteBuilder::::new(); + Self { + state: Default::default(), + dedup: HashMap::with_capacity_and_hasher(keys_builder.capacity(), ()), + keys_builder, + values_builder, + } + } + + /// Creates a new `GenericByteDictionaryBuilder` with the provided capacities + /// + /// `keys_capacity`: the number of keys, i.e. length of array to build + /// `value_capacity`: the number of distinct dictionary values, i.e. size of dictionary + /// `data_capacity`: the total number of bytes of all distinct bytes in the dictionary + pub fn with_capacity( + keys_capacity: usize, + value_capacity: usize, + data_capacity: usize, + ) -> Self { + Self { + state: Default::default(), + dedup: Default::default(), + keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), + values_builder: GenericByteBuilder::::with_capacity(value_capacity, data_capacity), + } + } + + /// Creates a new `GenericByteDictionaryBuilder` from a keys capacity and a dictionary + /// which is initialized with the given values. + /// The indices of those dictionary values are used as keys. + /// + /// # Example + /// + /// ``` + /// # use arrow_array::builder::StringDictionaryBuilder; + /// # use arrow_array::{Int16Array, StringArray}; + /// + /// let dictionary_values = StringArray::from(vec![None, Some("abc"), Some("def")]); + /// + /// let mut builder = StringDictionaryBuilder::new_with_dictionary(3, &dictionary_values).unwrap(); + /// builder.append("def").unwrap(); + /// builder.append_null(); + /// builder.append("abc").unwrap(); + /// + /// let dictionary_array = builder.finish(); + /// + /// let keys = dictionary_array.keys(); + /// + /// assert_eq!(keys, &Int16Array::from(vec![Some(2), None, Some(1)])); + /// ``` + pub fn new_with_dictionary( + keys_capacity: usize, + dictionary_values: &GenericByteArray, + ) -> Result { + let state = ahash::RandomState::default(); + let dict_len = dictionary_values.len(); + + let mut dedup = HashMap::with_capacity_and_hasher(dict_len, ()); + + let values_len = dictionary_values.value_data().len(); + let mut values_builder = GenericByteBuilder::::with_capacity(dict_len, values_len); + + K::Native::from_usize(dictionary_values.len()) + .ok_or(ArrowError::DictionaryKeyOverflowError)?; + + for (idx, maybe_value) in dictionary_values.iter().enumerate() { + match maybe_value { + Some(value) => { + let value_bytes: &[u8] = value.as_ref(); + let hash = state.hash_one(value_bytes); + + let entry = dedup.raw_entry_mut().from_hash(hash, |idx: &usize| { + value_bytes == get_bytes(&values_builder, *idx) + }); + + if let RawEntryMut::Vacant(v) = entry { + v.insert_with_hasher(hash, idx, (), |idx| { + state.hash_one(get_bytes(&values_builder, *idx)) + }); + } + + values_builder.append_value(value); + } + None => values_builder.append_null(), + } + } + + Ok(Self { + state, + dedup, + keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), + values_builder, + }) + } +} + +impl ArrayBuilder for GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + /// Returns the builder as an non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as an mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.keys_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl GenericByteDictionaryBuilder +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + /// Append a value to the array. Return an existing index + /// if already present in the values array or a new index if the + /// value is appended to the values array. + /// + /// Returns an error if the new index would overflow the key type. + pub fn append(&mut self, value: impl AsRef) -> Result { + let value_native: &T::Native = value.as_ref(); + let value_bytes: &[u8] = value_native.as_ref(); + + let state = &self.state; + let storage = &mut self.values_builder; + let hash = state.hash_one(value_bytes); + + let entry = self + .dedup + .raw_entry_mut() + .from_hash(hash, |idx| value_bytes == get_bytes(storage, *idx)); + + let key = match entry { + RawEntryMut::Occupied(entry) => K::Native::usize_as(*entry.into_key()), + RawEntryMut::Vacant(entry) => { + let idx = storage.len(); + storage.append_value(value); + + entry.insert_with_hasher(hash, idx, (), |idx| { + state.hash_one(get_bytes(storage, *idx)) + }); + + K::Native::from_usize(idx).ok_or(ArrowError::DictionaryKeyOverflowError)? + } + }; + self.keys_builder.append_value(key); + + Ok(key) + } + + /// Infallibly append a value to this builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + pub fn append_value(&mut self, value: impl AsRef) { + self.append(value).expect("dictionary key overflow"); + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.keys_builder.append_null() + } + + /// Append an `Option` value into the builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + #[inline] + pub fn append_option(&mut self, value: Option>) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + }; + } + + /// Builds the `DictionaryArray` and reset this builder. + pub fn finish(&mut self) -> DictionaryArray { + self.dedup.clear(); + let values = self.values_builder.finish(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + + /// Builds the `DictionaryArray` without resetting the builder. + pub fn finish_cloned(&self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish_cloned(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } +} + +impl> Extend> + for GenericByteDictionaryBuilder +{ + #[inline] + fn extend>>(&mut self, iter: I) { + for v in iter { + self.append_option(v) + } + } +} + +fn get_bytes(values: &GenericByteBuilder, idx: usize) -> &[u8] { + let offsets = values.offsets_slice(); + let values = values.values_slice(); + + let end_offset = offsets[idx + 1].as_usize(); + let start_offset = offsets[idx].as_usize(); + + &values[start_offset..end_offset] +} + +/// Builder for [`DictionaryArray`] of [`StringArray`](crate::array::StringArray) +/// +/// ``` +/// // Create a dictionary array indexed by bytes whose values are Strings. +/// // It can thus hold up to 256 distinct string values. +/// +/// # use arrow_array::builder::StringDictionaryBuilder; +/// # use arrow_array::{Int8Array, StringArray}; +/// # use arrow_array::types::Int8Type; +/// +/// let mut builder = StringDictionaryBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append("abc").unwrap(); +/// builder.append_null(); +/// builder.append("def").unwrap(); +/// builder.append("def").unwrap(); +/// builder.append("abc").unwrap(); +/// let array = builder.finish(); +/// +/// assert_eq!( +/// array.keys(), +/// &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) +/// ); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &StringArray = av.as_any().downcast_ref::().unwrap(); +/// +/// assert_eq!(ava.value(0), "abc"); +/// assert_eq!(ava.value(1), "def"); +/// +/// ``` +pub type StringDictionaryBuilder = GenericByteDictionaryBuilder>; + +/// Builder for [`DictionaryArray`] of [`LargeStringArray`](crate::array::LargeStringArray) +pub type LargeStringDictionaryBuilder = GenericByteDictionaryBuilder>; + +/// Builder for [`DictionaryArray`] of [`BinaryArray`](crate::array::BinaryArray) +/// +/// ``` +/// // Create a dictionary array indexed by bytes whose values are binary. +/// // It can thus hold up to 256 distinct binary values. +/// +/// # use arrow_array::builder::BinaryDictionaryBuilder; +/// # use arrow_array::{BinaryArray, Int8Array}; +/// # use arrow_array::types::Int8Type; +/// +/// let mut builder = BinaryDictionaryBuilder::::new(); +/// +/// // The builder builds the dictionary value by value +/// builder.append(b"abc").unwrap(); +/// builder.append_null(); +/// builder.append(b"def").unwrap(); +/// builder.append(b"def").unwrap(); +/// builder.append(b"abc").unwrap(); +/// let array = builder.finish(); +/// +/// assert_eq!( +/// array.keys(), +/// &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) +/// ); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &BinaryArray = av.as_any().downcast_ref::().unwrap(); +/// +/// assert_eq!(ava.value(0), b"abc"); +/// assert_eq!(ava.value(1), b"def"); +/// +/// ``` +pub type BinaryDictionaryBuilder = GenericByteDictionaryBuilder>; + +/// Builder for [`DictionaryArray`] of [`LargeBinaryArray`](crate::array::LargeBinaryArray) +pub type LargeBinaryDictionaryBuilder = GenericByteDictionaryBuilder>; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::Array; + use crate::array::Int8Array; + use crate::types::{Int16Type, Int32Type, Int8Type, Utf8Type}; + use crate::{BinaryArray, StringArray}; + + fn test_bytes_dictionary_builder(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + let array = builder.finish(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(*ava.value(0), *values[0]); + assert_eq!(*ava.value(1), *values[1]); + } + + #[test] + fn test_string_dictionary_builder() { + test_bytes_dictionary_builder::>(vec!["abc", "def"]); + } + + #[test] + fn test_binary_dictionary_builder() { + test_bytes_dictionary_builder::>(vec![b"abc", b"def"]); + } + + fn test_bytes_dictionary_builder_finish_cloned(values: Vec<&T::Native>) + where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = GenericByteDictionaryBuilder::::new(); + + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + let mut array = builder.finish_cloned(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava.value(0), values[0]); + assert_eq!(ava.value(1), values[1]); + + builder.append(values[0]).unwrap(); + builder.append(values[2]).unwrap(); + builder.append(values[1]).unwrap(); + + array = builder.finish(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![ + Some(0), + None, + Some(1), + Some(1), + Some(0), + Some(0), + Some(2), + Some(1) + ]) + ); + + // Values are polymorphic and so require a downcast. + let av2 = array.values(); + let ava2: &GenericByteArray = + av2.as_any().downcast_ref::>().unwrap(); + + assert_eq!(ava2.value(0), values[0]); + assert_eq!(ava2.value(1), values[1]); + assert_eq!(ava2.value(2), values[2]); + } + + #[test] + fn test_string_dictionary_builder_finish_cloned() { + test_bytes_dictionary_builder_finish_cloned::>(vec![ + "abc", "def", "ghi", + ]); + } + + #[test] + fn test_binary_dictionary_builder_finish_cloned() { + test_bytes_dictionary_builder_finish_cloned::>(vec![ + b"abc", b"def", b"ghi", + ]); + } + + fn test_bytes_dictionary_builder_with_existing_dictionary( + dictionary: GenericByteArray, + values: Vec<&T::Native>, + ) where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = + GenericByteDictionaryBuilder::::new_with_dictionary(6, &dictionary) + .unwrap(); + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + builder.append(values[2]).unwrap(); + let array = builder.finish(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![Some(2), None, Some(1), Some(1), Some(2), Some(3)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &GenericByteArray = av.as_any().downcast_ref::>().unwrap(); + + assert!(!ava.is_valid(0)); + assert_eq!(ava.value(1), values[1]); + assert_eq!(ava.value(2), values[0]); + assert_eq!(ava.value(3), values[2]); + } + + #[test] + fn test_string_dictionary_builder_with_existing_dictionary() { + test_bytes_dictionary_builder_with_existing_dictionary::>( + StringArray::from(vec![None, Some("def"), Some("abc")]), + vec!["abc", "def", "ghi"], + ); + } + + #[test] + fn test_binary_dictionary_builder_with_existing_dictionary() { + let values: Vec> = vec![None, Some(b"def"), Some(b"abc")]; + test_bytes_dictionary_builder_with_existing_dictionary::>( + BinaryArray::from(values), + vec![b"abc", b"def", b"ghi"], + ); + } + + fn test_bytes_dictionary_builder_with_reserved_null_value( + dictionary: GenericByteArray, + values: Vec<&T::Native>, + ) where + T: ByteArrayType, + ::Native: PartialEq, + ::Native: AsRef<::Native>, + { + let mut builder = + GenericByteDictionaryBuilder::::new_with_dictionary(4, &dictionary) + .unwrap(); + builder.append(values[0]).unwrap(); + builder.append_null(); + builder.append(values[1]).unwrap(); + builder.append(values[0]).unwrap(); + let array = builder.finish(); + + assert!(array.is_null(1)); + assert!(!array.is_valid(1)); + + let keys = array.keys(); + + assert_eq!(keys.value(0), 1); + assert!(keys.is_null(1)); + // zero initialization is currently guaranteed by Buffer allocation and resizing + assert_eq!(keys.value(1), 0); + assert_eq!(keys.value(2), 2); + assert_eq!(keys.value(3), 1); + } + + #[test] + fn test_string_dictionary_builder_with_reserved_null_value() { + let v: Vec> = vec![None]; + test_bytes_dictionary_builder_with_reserved_null_value::>( + StringArray::from(v), + vec!["abc", "def"], + ); + } + + #[test] + fn test_binary_dictionary_builder_with_reserved_null_value() { + let values: Vec> = vec![None]; + test_bytes_dictionary_builder_with_reserved_null_value::>( + BinaryArray::from(values), + vec![b"abc", b"def"], + ); + } + + #[test] + fn test_extend() { + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.extend(["a", "b", "c", "a", "b", "c"].into_iter().map(Some)); + builder.extend(["c", "d", "a"].into_iter().map(Some)); + let dict = builder.finish(); + assert_eq!(dict.keys().values(), &[0, 1, 2, 0, 1, 2, 2, 3, 0]); + assert_eq!(dict.values().len(), 4); + } +} diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs new file mode 100644 index 000000000000..116e2553cfb7 --- /dev/null +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -0,0 +1,768 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow_buffer::Buffer; +use arrow_buffer::NullBufferBuilder; +use arrow_data::ArrayData; +use arrow_schema::Field; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`GenericListArray`] +/// +/// Use [`ListBuilder`] to build [`ListArray`]s and [`LargeListBuilder`] to build [`LargeListArray`]s. +/// +/// # Example +/// +/// Here is code that constructs a ListArray with the contents: +/// `[[A,B,C], [], NULL, [D], [NULL, F]]` +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{builder::ListBuilder, builder::StringBuilder, ArrayRef, StringArray, Array}; +/// # +/// let values_builder = StringBuilder::new(); +/// let mut builder = ListBuilder::new(values_builder); +/// +/// // [A, B, C] +/// builder.values().append_value("A"); +/// builder.values().append_value("B"); +/// builder.values().append_value("C"); +/// builder.append(true); +/// +/// // [ ] (empty list) +/// builder.append(true); +/// +/// // Null +/// builder.values().append_value("?"); // irrelevant +/// builder.append(false); +/// +/// // [D] +/// builder.values().append_value("D"); +/// builder.append(true); +/// +/// // [NULL, F] +/// builder.values().append_null(); +/// builder.values().append_value("F"); +/// builder.append(true); +/// +/// // Build the array +/// let array = builder.finish(); +/// +/// // Values is a string array +/// // "A", "B" "C", "?", "D", NULL, "F" +/// assert_eq!( +/// array.values().as_ref(), +/// &StringArray::from(vec![ +/// Some("A"), Some("B"), Some("C"), +/// Some("?"), Some("D"), None, +/// Some("F") +/// ]) +/// ); +/// +/// // Offsets are indexes into the values array +/// assert_eq!( +/// array.value_offsets(), +/// &[0, 3, 3, 4, 5, 7] +/// ); +/// ``` +/// +/// [`ListBuilder`]: crate::builder::ListBuilder +/// [`ListArray`]: crate::array::ListArray +/// [`LargeListBuilder`]: crate::builder::LargeListBuilder +/// [`LargeListArray`]: crate::array::LargeListArray +#[derive(Debug)] +pub struct GenericListBuilder { + offsets_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, + values_builder: T, +} + +impl Default for GenericListBuilder { + fn default() -> Self { + Self::new(T::default()) + } +} + +impl GenericListBuilder { + /// Creates a new [`GenericListBuilder`] from a given values array builder + pub fn new(values_builder: T) -> Self { + let capacity = values_builder.len(); + Self::with_capacity(values_builder, capacity) + } + + /// Creates a new [`GenericListBuilder`] from a given values array builder + /// `capacity` is the number of items to pre-allocate space for in this builder + pub fn with_capacity(values_builder: T, capacity: usize) -> Self { + let mut offsets_builder = BufferBuilder::::new(capacity + 1); + offsets_builder.append(OffsetSize::zero()); + Self { + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(capacity), + values_builder, + } + } +} + +impl ArrayBuilder + for GenericListBuilder +where + T: 'static, +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl GenericListBuilder +where + T: 'static, +{ + /// Returns the child array builder as a mutable reference. + /// + /// This mutable reference can be used to append values into the child array builder, + /// but you must call [`append`](#method.append) to delimit each distinct list value. + pub fn values(&mut self) -> &mut T { + &mut self.values_builder + } + + /// Returns the child array builder as an immutable reference + pub fn values_ref(&self) -> &T { + &self.values_builder + } + + /// Finish the current variable-length list array slot + /// + /// # Panics + /// + /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` + #[inline] + pub fn append(&mut self, is_valid: bool) { + self.offsets_builder.append(self.next_offset()); + self.null_buffer_builder.append(is_valid); + } + + /// Returns the next offset + /// + /// # Panics + /// + /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` + #[inline] + fn next_offset(&self) -> OffsetSize { + OffsetSize::from_usize(self.values_builder.len()).unwrap() + } + + /// Append a value to this [`GenericListBuilder`] + /// + /// ``` + /// # use arrow_array::builder::{Int32Builder, ListBuilder}; + /// # use arrow_array::cast::AsArray; + /// # use arrow_array::{Array, Int32Array}; + /// # use arrow_array::types::Int32Type; + /// let mut builder = ListBuilder::new(Int32Builder::new()); + /// + /// builder.append_value([Some(1), Some(2), Some(3)]); + /// builder.append_value([]); + /// builder.append_value([None]); + /// + /// let array = builder.finish(); + /// assert_eq!(array.len(), 3); + /// + /// assert_eq!(array.value_offsets(), &[0, 3, 3, 4]); + /// let values = array.values().as_primitive::(); + /// assert_eq!(values, &Int32Array::from(vec![Some(1), Some(2), Some(3), None])); + /// ``` + /// + /// This is an alternative API to appending directly to [`Self::values`] and + /// delimiting the result with [`Self::append`] + /// + /// ``` + /// # use arrow_array::builder::{Int32Builder, ListBuilder}; + /// # use arrow_array::cast::AsArray; + /// # use arrow_array::{Array, Int32Array}; + /// # use arrow_array::types::Int32Type; + /// let mut builder = ListBuilder::new(Int32Builder::new()); + /// + /// builder.values().append_value(1); + /// builder.values().append_value(2); + /// builder.values().append_value(3); + /// builder.append(true); + /// builder.append(true); + /// builder.values().append_null(); + /// builder.append(true); + /// + /// let array = builder.finish(); + /// assert_eq!(array.len(), 3); + /// + /// assert_eq!(array.value_offsets(), &[0, 3, 3, 4]); + /// let values = array.values().as_primitive::(); + /// assert_eq!(values, &Int32Array::from(vec![Some(1), Some(2), Some(3), None])); + /// ``` + #[inline] + pub fn append_value(&mut self, i: I) + where + T: Extend>, + I: IntoIterator>, + { + self.extend(std::iter::once(Some(i))) + } + + /// Append a null to this [`GenericListBuilder`] + /// + /// See [`Self::append_value`] for an example use. + #[inline] + pub fn append_null(&mut self) { + self.offsets_builder.append(self.next_offset()); + self.null_buffer_builder.append_null(); + } + + /// Appends an optional value into this [`GenericListBuilder`] + /// + /// If `Some` calls [`Self::append_value`] otherwise calls [`Self::append_null`] + #[inline] + pub fn append_option(&mut self, i: Option) + where + T: Extend>, + I: IntoIterator>, + { + match i { + Some(i) => self.append_value(i), + None => self.append_null(), + } + } + + /// Builds the [`GenericListArray`] and reset this builder. + pub fn finish(&mut self) -> GenericListArray { + let len = self.len(); + let values_arr = self.values_builder.finish(); + let values_data = values_arr.to_data(); + + let offset_buffer = self.offsets_builder.finish(); + let null_bit_buffer = self.null_buffer_builder.finish(); + self.offsets_builder.append(OffsetSize::zero()); + let field = Arc::new(Field::new( + "item", + values_data.data_type().clone(), + true, // TODO: find a consistent way of getting this + )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(field); + let array_data_builder = ArrayData::builder(data_type) + .len(len) + .add_buffer(offset_buffer) + .add_child_data(values_data) + .nulls(null_bit_buffer); + + let array_data = unsafe { array_data_builder.build_unchecked() }; + + GenericListArray::::from(array_data) + } + + /// Builds the [`GenericListArray`] without resetting the builder. + pub fn finish_cloned(&self) -> GenericListArray { + let len = self.len(); + let values_arr = self.values_builder.finish_cloned(); + let values_data = values_arr.to_data(); + + let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + let nulls = self.null_buffer_builder.finish_cloned(); + let field = Arc::new(Field::new( + "item", + values_data.data_type().clone(), + true, // TODO: find a consistent way of getting this + )); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(field); + let array_data_builder = ArrayData::builder(data_type) + .len(len) + .add_buffer(offset_buffer) + .add_child_data(values_data) + .nulls(nulls); + + let array_data = unsafe { array_data_builder.build_unchecked() }; + + GenericListArray::::from(array_data) + } + + /// Returns the current offsets buffer as a slice + pub fn offsets_slice(&self) -> &[OffsetSize] { + self.offsets_builder.as_slice() + } +} + +impl Extend> for GenericListBuilder +where + O: OffsetSizeTrait, + B: ArrayBuilder + Extend, + V: IntoIterator, +{ + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + match v { + Some(elements) => { + self.values_builder.extend(elements); + self.append(true); + } + None => self.append(false), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::builder::{make_builder, Int32Builder, ListBuilder}; + use crate::cast::AsArray; + use crate::types::Int32Type; + use crate::{Array, Int32Array}; + use arrow_schema::DataType; + + fn _test_generic_list_array_builder() { + let values_builder = Int32Builder::with_capacity(10); + let mut builder = GenericListBuilder::::new(values_builder); + + // [[0, 1, 2], [3, 4, 5], [6, 7]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_value(3); + builder.values().append_value(4); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); + let list_array = builder.finish(); + + let list_values = list_array.values().as_primitive::(); + assert_eq!(list_values.values(), &[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(list_array.value_offsets(), [0, 3, 6, 8].map(O::usize_as)); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(O::from_usize(6).unwrap(), list_array.value_offsets()[2]); + assert_eq!(O::from_usize(2).unwrap(), list_array.value_length(2)); + for i in 0..3 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + } + + #[test] + fn test_list_array_builder() { + _test_generic_list_array_builder::() + } + + #[test] + fn test_large_list_array_builder() { + _test_generic_list_array_builder::() + } + + fn _test_generic_list_array_builder_nulls() { + let values_builder = Int32Builder::with_capacity(10); + let mut builder = GenericListBuilder::::new(values_builder); + + // [[0, 1, 2], null, [3, null, 5], [6, 7]] + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.append(false); + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); + + let list_array = builder.finish(); + + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(4, list_array.len()); + assert_eq!(1, list_array.null_count()); + assert_eq!(O::from_usize(3).unwrap(), list_array.value_offsets()[2]); + assert_eq!(O::from_usize(3).unwrap(), list_array.value_length(2)); + } + + #[test] + fn test_list_array_builder_nulls() { + _test_generic_list_array_builder_nulls::() + } + + #[test] + fn test_large_list_array_builder_nulls() { + _test_generic_list_array_builder_nulls::() + } + + #[test] + fn test_list_array_builder_finish() { + let values_builder = Int32Array::builder(5); + let mut builder = ListBuilder::new(values_builder); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + + let mut arr = builder.finish(); + assert_eq!(2, arr.len()); + assert!(builder.is_empty()); + + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); + arr = builder.finish(); + assert_eq!(1, arr.len()); + assert!(builder.is_empty()); + } + + #[test] + fn test_list_array_builder_finish_cloned() { + let values_builder = Int32Array::builder(5); + let mut builder = ListBuilder::new(values_builder); + + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + + let mut arr = builder.finish_cloned(); + assert_eq!(2, arr.len()); + assert!(!builder.is_empty()); + + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); + arr = builder.finish(); + assert_eq!(3, arr.len()); + assert!(builder.is_empty()); + } + + #[test] + fn test_list_list_array_builder() { + let primitive_builder = Int32Builder::with_capacity(10); + let values_builder = ListBuilder::new(primitive_builder); + let mut builder = ListBuilder::new(values_builder); + + // [[[1, 2], [3, 4]], [[5, 6, 7], null, [8]], null, [[9, 10]]] + builder.values().values().append_value(1); + builder.values().values().append_value(2); + builder.values().append(true); + builder.values().values().append_value(3); + builder.values().values().append_value(4); + builder.values().append(true); + builder.append(true); + + builder.values().values().append_value(5); + builder.values().values().append_value(6); + builder.values().values().append_value(7); + builder.values().append(true); + builder.values().append(false); + builder.values().values().append_value(8); + builder.values().append(true); + builder.append(true); + + builder.append(false); + + builder.values().values().append_value(9); + builder.values().values().append_value(10); + builder.values().append(true); + builder.append(true); + + let l1 = builder.finish(); + + assert_eq!(4, l1.len()); + assert_eq!(1, l1.null_count()); + + assert_eq!(l1.value_offsets(), &[0, 2, 5, 5, 6]); + let l2 = l1.values().as_list::(); + + assert_eq!(6, l2.len()); + assert_eq!(1, l2.null_count()); + assert_eq!(l2.value_offsets(), &[0, 2, 4, 7, 7, 8, 10]); + + let i1 = l2.values().as_primitive::(); + assert_eq!(10, i1.len()); + assert_eq!(0, i1.null_count()); + assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } + + #[test] + fn test_extend() { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.extend([ + Some(vec![Some(1), Some(2), Some(7), None]), + Some(vec![]), + Some(vec![Some(4), Some(5)]), + None, + ]); + + let array = builder.finish(); + assert_eq!(array.value_offsets(), [0, 4, 4, 6, 6]); + assert_eq!(array.null_count(), 1); + assert!(array.is_null(3)); + let elements = array.values().as_primitive::(); + assert_eq!(elements.values(), &[1, 2, 7, 0, 4, 5]); + assert_eq!(elements.null_count(), 1); + assert!(elements.is_null(3)); + } + + #[test] + fn test_boxed_primitive_aray_builder() { + let values_builder = make_builder(&DataType::Int32, 5); + let mut builder = ListBuilder::new(values_builder); + + builder + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_slice(&[1, 2, 3]); + builder.append(true); + + builder + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_slice(&[4, 5, 6]); + builder.append(true); + + let arr = builder.finish(); + assert_eq!(2, arr.len()); + + let elements = arr.values().as_primitive::(); + assert_eq!(elements.values(), &[1, 2, 3, 4, 5, 6]); + } + + #[test] + fn test_boxed_list_list_array_builder() { + // This test is same as `test_list_list_array_builder` but uses boxed builders. + let values_builder = make_builder( + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + 10, + ); + test_boxed_generic_list_generic_list_array_builder::(values_builder); + } + + #[test] + fn test_boxed_large_list_large_list_array_builder() { + // This test is same as `test_list_list_array_builder` but uses boxed builders. + let values_builder = make_builder( + &DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))), + 10, + ); + test_boxed_generic_list_generic_list_array_builder::(values_builder); + } + + fn test_boxed_generic_list_generic_list_array_builder( + values_builder: Box, + ) { + let mut builder: GenericListBuilder> = + GenericListBuilder::>::new(values_builder); + + // [[[1, 2], [3, 4]], [[5, 6, 7], null, [8]], null, [[9, 10]]] + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(1); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(2); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(3); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(4); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder.append(true); + + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(5); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(6); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an (Large)ListBuilder") + .append_value(7); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(false); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(8); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder.append(true); + + builder.append(false); + + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(9); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .values() + .as_any_mut() + .downcast_mut::() + .expect("should be an Int32Builder") + .append_value(10); + builder + .values() + .as_any_mut() + .downcast_mut::>>() + .expect("should be an (Large)ListBuilder") + .append(true); + builder.append(true); + + let l1 = builder.finish(); + + assert_eq!(4, l1.len()); + assert_eq!(1, l1.null_count()); + + assert_eq!(l1.value_offsets(), &[0, 2, 5, 5, 6].map(O::usize_as)); + let l2 = l1.values().as_list::(); + + assert_eq!(6, l2.len()); + assert_eq!(1, l2.null_count()); + assert_eq!(l2.value_offsets(), &[0, 2, 4, 7, 7, 8, 10].map(O::usize_as)); + + let i1 = l2.values().as_primitive::(); + assert_eq!(10, i1.len()); + assert_eq!(0, i1.null_count()); + assert_eq!(i1.values(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + } +} diff --git a/arrow/src/array/builder/map_builder.rs b/arrow-array/src/builder/map_builder.rs similarity index 54% rename from arrow/src/array/builder/map_builder.rs rename to arrow-array/src/builder/map_builder.rs index 766e8a56b387..3a5244ed81a0 100644 --- a/arrow/src/array/builder/map_builder.rs +++ b/arrow-array/src/builder/map_builder.rs @@ -15,20 +15,45 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::{Array, ArrayRef, MapArray, StructArray}; +use arrow_buffer::Buffer; +use arrow_buffer::{NullBuffer, NullBufferBuilder}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType, Field}; use std::any::Any; use std::sync::Arc; -use super::{ArrayBuilder, BufferBuilder, NullBufferBuilder}; -use crate::array::array::Array; -use crate::array::ArrayData; -use crate::array::ArrayRef; -use crate::array::MapArray; -use crate::array::StructArray; -use crate::datatypes::DataType; -use crate::datatypes::Field; -use crate::error::ArrowError; -use crate::error::Result; - +/// Builder for [`MapArray`] +/// +/// ``` +/// # use arrow_array::builder::{Int32Builder, MapBuilder, StringBuilder}; +/// # use arrow_array::{Int32Array, StringArray}; +/// +/// let string_builder = StringBuilder::new(); +/// let int_builder = Int32Builder::with_capacity(4); +/// +/// // Construct `[{"joe": 1}, {"blogs": 2, "foo": 4}, {}, null]` +/// let mut builder = MapBuilder::new(None, string_builder, int_builder); +/// +/// builder.keys().append_value("joe"); +/// builder.values().append_value(1); +/// builder.append(true).unwrap(); +/// +/// builder.keys().append_value("blogs"); +/// builder.values().append_value(2); +/// builder.keys().append_value("foo"); +/// builder.values().append_value(4); +/// builder.append(true).unwrap(); +/// builder.append(true).unwrap(); +/// builder.append(false).unwrap(); +/// +/// let array = builder.finish(); +/// assert_eq!(array.value_offsets(), &[0, 1, 3, 3, 3]); +/// assert_eq!(array.values().as_ref(), &Int32Array::from(vec![1, 2, 4])); +/// assert_eq!(array.keys().as_ref(), &StringArray::from(vec!["joe", "blogs", "foo"])); +/// +/// ``` #[derive(Debug)] pub struct MapBuilder { offsets_builder: BufferBuilder, @@ -38,10 +63,14 @@ pub struct MapBuilder { value_builder: V, } +/// The [`Field`] names for a [`MapArray`] #[derive(Debug, Clone)] pub struct MapFieldNames { + /// [`Field`] name for map entries pub entry: String, + /// [`Field`] name for map key pub key: String, + /// [`Field`] name for map value pub value: String, } @@ -55,17 +84,14 @@ impl Default for MapFieldNames { } } -#[allow(dead_code)] impl MapBuilder { - pub fn new( - field_names: Option, - key_builder: K, - value_builder: V, - ) -> Self { + /// Creates a new `MapBuilder` + pub fn new(field_names: Option, key_builder: K, value_builder: V) -> Self { let capacity = key_builder.len(); Self::with_capacity(field_names, key_builder, value_builder, capacity) } + /// Creates a new `MapBuilder` with capacity pub fn with_capacity( field_names: Option, key_builder: K, @@ -73,8 +99,7 @@ impl MapBuilder { capacity: usize, ) -> Self { let mut offsets_builder = BufferBuilder::::new(capacity + 1); - let len = 0; - offsets_builder.append(len); + offsets_builder.append(0); Self { offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), @@ -84,10 +109,12 @@ impl MapBuilder { } } + /// Returns the key array builder of the map pub fn keys(&mut self) -> &mut K { &mut self.key_builder } + /// Returns the value array builder of the map pub fn values(&mut self) -> &mut V { &mut self.value_builder } @@ -96,7 +123,7 @@ impl MapBuilder { /// /// Returns an error if the key and values builders are in an inconsistent state. #[inline] - pub fn append(&mut self, is_valid: bool) -> Result<()> { + pub fn append(&mut self, is_valid: bool) -> Result<(), ArrowError> { if self.key_builder.len() != self.value_builder.len() { return Err(ArrowError::InvalidArgumentError(format!( "Cannot append to a map builder when its keys and values have unequal lengths of {} and {}", @@ -109,41 +136,59 @@ impl MapBuilder { Ok(()) } + /// Builds the [`MapArray`] pub fn finish(&mut self) -> MapArray { let len = self.len(); + // Build the keys + let keys_arr = self.key_builder.finish(); + let values_arr = self.value_builder.finish(); + let offset_buffer = self.offsets_builder.finish(); + self.offsets_builder.append(0); + let null_bit_buffer = self.null_buffer_builder.finish(); + + self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len) + } + /// Builds the [`MapArray`] without resetting the builder. + pub fn finish_cloned(&self) -> MapArray { + let len = self.len(); // Build the keys - let keys_arr = self - .key_builder - .as_any_mut() - .downcast_mut::() - .unwrap() - .finish(); - let values_arr = self - .value_builder - .as_any_mut() - .downcast_mut::() - .unwrap() - .finish(); - - let keys_field = Field::new( + let keys_arr = self.key_builder.finish_cloned(); + let values_arr = self.value_builder.finish_cloned(); + let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); + let nulls = self.null_buffer_builder.finish_cloned(); + self.finish_helper(keys_arr, values_arr, offset_buffer, nulls, len) + } + + fn finish_helper( + &self, + keys_arr: Arc, + values_arr: Arc, + offset_buffer: Buffer, + nulls: Option, + len: usize, + ) -> MapArray { + assert!( + keys_arr.null_count() == 0, + "Keys array must have no null values, found {} null value(s)", + keys_arr.null_count() + ); + + let keys_field = Arc::new(Field::new( self.field_names.key.as_str(), keys_arr.data_type().clone(), - false, // always nullable - ); - let values_field = Field::new( + false, // always non-nullable + )); + let values_field = Arc::new(Field::new( self.field_names.value.as_str(), values_arr.data_type().clone(), true, - ); + )); let struct_array = StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]); - let offset_buffer = self.offsets_builder.finish(); - let null_bit_buffer = self.null_buffer_builder.finish(); - self.offsets_builder.append(0); - let map_field = Box::new(Field::new( + let map_field = Arc::new(Field::new( self.field_names.entry.as_str(), struct_array.data_type().clone(), false, // always non-nullable @@ -152,7 +197,7 @@ impl MapBuilder { .len(len) .add_buffer(offset_buffer) .add_child_data(struct_array.into_data()) - .null_bit_buffer(null_bit_buffer); + .nulls(nulls); let array_data = unsafe { array_data.build_unchecked() }; @@ -165,14 +210,15 @@ impl ArrayBuilder for MapBuilder { self.null_buffer_builder.len() } - fn is_empty(&self) -> bool { - self.len() == 0 - } - fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } + fn as_any(&self) -> &dyn Any { self } @@ -188,66 +234,18 @@ impl ArrayBuilder for MapBuilder { #[cfg(test)] mod tests { - use super::*; - - use crate::array::builder::StringBuilder; - use crate::array::Int32Builder; - use crate::bitmap::Bitmap; - use crate::buffer::Buffer; + use crate::builder::{Int32Builder, StringBuilder}; - // TODO: add a test that finishes building, after designing a spec-compliant - // way of inserting values to the map. - // A map's values shouldn't be repeated within a slot + use super::*; #[test] - fn test_map_array_builder() { - let string_builder = StringBuilder::new(); - let int_builder = Int32Builder::with_capacity(4); - - let mut builder = MapBuilder::new(None, string_builder, int_builder); - - let string_builder = builder.keys(); - string_builder.append_value("joe"); - string_builder.append_null(); - string_builder.append_null(); - string_builder.append_value("mark"); - - let int_builder = builder.values(); - int_builder.append_value(1); - int_builder.append_value(2); - int_builder.append_null(); - int_builder.append_value(4); - + #[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")] + fn test_map_builder_with_null_keys_panics() { + let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + builder.keys().append_null(); + builder.values().append_value(42); builder.append(true).unwrap(); - builder.append(false).unwrap(); - builder.append(true).unwrap(); - - let arr = builder.finish(); - - let map_data = arr.data(); - assert_eq!(3, map_data.len()); - assert_eq!(1, map_data.null_count()); - assert_eq!( - Some(&Bitmap::from(Buffer::from(&[5_u8]))), - map_data.null_bitmap() - ); - let expected_string_data = ArrayData::builder(DataType::Utf8) - .len(4) - .null_bit_buffer(Some(Buffer::from(&[9_u8]))) - .add_buffer(Buffer::from_slice_ref(&[0, 3, 3, 3, 7])) - .add_buffer(Buffer::from_slice_ref(b"joemark")) - .build() - .unwrap(); - - let expected_int_data = ArrayData::builder(DataType::Int32) - .len(4) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[11_u8]))) - .add_buffer(Buffer::from_slice_ref(&[1, 2, 0, 4])) - .build() - .unwrap(); - - assert_eq!(&expected_string_data, arr.keys().data()); - assert_eq!(&expected_int_data, arr.values().data()); + builder.finish(); } } diff --git a/arrow-array/src/builder/mod.rs b/arrow-array/src/builder/mod.rs new file mode 100644 index 000000000000..8382f7af87b0 --- /dev/null +++ b/arrow-array/src/builder/mod.rs @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines push-based APIs for constructing arrays +//! +//! # Basic Usage +//! +//! Builders can be used to build simple, non-nested arrays +//! +//! ``` +//! # use arrow_array::builder::Int32Builder; +//! # use arrow_array::PrimitiveArray; +//! let mut a = Int32Builder::new(); +//! a.append_value(1); +//! a.append_null(); +//! a.append_value(2); +//! let a = a.finish(); +//! +//! assert_eq!(a, PrimitiveArray::from(vec![Some(1), None, Some(2)])); +//! ``` +//! +//! ``` +//! # use arrow_array::builder::StringBuilder; +//! # use arrow_array::{Array, StringArray}; +//! let mut a = StringBuilder::new(); +//! a.append_value("foo"); +//! a.append_value("bar"); +//! a.append_null(); +//! let a = a.finish(); +//! +//! assert_eq!(a, StringArray::from_iter([Some("foo"), Some("bar"), None])); +//! ``` +//! +//! # Nested Usage +//! +//! Builders can also be used to build more complex nested arrays, such as lists +//! +//! ``` +//! # use arrow_array::builder::{Int32Builder, ListBuilder}; +//! # use arrow_array::ListArray; +//! # use arrow_array::types::Int32Type; +//! let mut a = ListBuilder::new(Int32Builder::new()); +//! // [1, 2] +//! a.values().append_value(1); +//! a.values().append_value(2); +//! a.append(true); +//! // null +//! a.append(false); +//! // [] +//! a.append(true); +//! // [3, null] +//! a.values().append_value(3); +//! a.values().append_null(); +//! a.append(true); +//! +//! // [[1, 2], null, [], [3, null]] +//! let a = a.finish(); +//! +//! assert_eq!(a, ListArray::from_iter_primitive::([ +//! Some(vec![Some(1), Some(2)]), +//! None, +//! Some(vec![]), +//! Some(vec![Some(3), None])] +//! )) +//! ``` +//! +//! # Custom Builders +//! +//! It is common to have a collection of statically defined Rust types that +//! you want to convert to Arrow arrays. +//! +//! An example of doing so is below +//! +//! ``` +//! # use std::any::Any; +//! # use arrow_array::builder::{ArrayBuilder, Int32Builder, ListBuilder, StringBuilder}; +//! # use arrow_array::{ArrayRef, RecordBatch, StructArray}; +//! # use arrow_schema::{DataType, Field}; +//! # use std::sync::Arc; +//! /// A custom row representation +//! struct MyRow { +//! i32: i32, +//! optional_i32: Option, +//! string: Option, +//! i32_list: Option>>, +//! } +//! +//! /// Converts `Vec` into `StructArray` +//! #[derive(Debug, Default)] +//! struct MyRowBuilder { +//! i32: Int32Builder, +//! string: StringBuilder, +//! i32_list: ListBuilder, +//! } +//! +//! impl MyRowBuilder { +//! fn append(&mut self, row: &MyRow) { +//! self.i32.append_value(row.i32); +//! self.string.append_option(row.string.as_ref()); +//! self.i32_list.append_option(row.i32_list.as_ref().map(|x| x.iter().copied())); +//! } +//! +//! /// Note: returns StructArray to allow nesting within another array if desired +//! fn finish(&mut self) -> StructArray { +//! let i32 = Arc::new(self.i32.finish()) as ArrayRef; +//! let i32_field = Arc::new(Field::new("i32", DataType::Int32, false)); +//! +//! let string = Arc::new(self.string.finish()) as ArrayRef; +//! let string_field = Arc::new(Field::new("i32", DataType::Utf8, false)); +//! +//! let i32_list = Arc::new(self.i32_list.finish()) as ArrayRef; +//! let value_field = Arc::new(Field::new("item", DataType::Int32, true)); +//! let i32_list_field = Arc::new(Field::new("i32_list", DataType::List(value_field), true)); +//! +//! StructArray::from(vec![ +//! (i32_field, i32), +//! (string_field, string), +//! (i32_list_field, i32_list), +//! ]) +//! } +//! } +//! +//! impl<'a> Extend<&'a MyRow> for MyRowBuilder { +//! fn extend>(&mut self, iter: T) { +//! iter.into_iter().for_each(|row| self.append(row)); +//! } +//! } +//! +//! /// Converts a slice of [`MyRow`] to a [`RecordBatch`] +//! fn rows_to_batch(rows: &[MyRow]) -> RecordBatch { +//! let mut builder = MyRowBuilder::default(); +//! builder.extend(rows); +//! RecordBatch::from(&builder.finish()) +//! } +//! ``` + +pub use arrow_buffer::BooleanBufferBuilder; + +mod boolean_builder; +pub use boolean_builder::*; +mod buffer_builder; +pub use buffer_builder::*; +mod fixed_size_binary_builder; +pub use fixed_size_binary_builder::*; +mod fixed_size_list_builder; +pub use fixed_size_list_builder::*; +mod generic_bytes_builder; +pub use generic_bytes_builder::*; +mod generic_list_builder; +pub use generic_list_builder::*; +mod map_builder; +pub use map_builder::*; +mod null_builder; +pub use null_builder::*; +mod primitive_builder; +pub use primitive_builder::*; +mod primitive_dictionary_builder; +pub use primitive_dictionary_builder::*; +mod primitive_run_builder; +pub use primitive_run_builder::*; +mod struct_builder; +pub use struct_builder::*; +mod generic_bytes_dictionary_builder; +pub use generic_bytes_dictionary_builder::*; +mod generic_byte_run_builder; +pub use generic_byte_run_builder::*; +mod union_builder; +pub use union_builder::*; + +use crate::ArrayRef; +use std::any::Any; + +/// Trait for dealing with different array builders at runtime +/// +/// # Example +/// +/// ``` +/// // Create +/// # use arrow_array::{ArrayRef, StringArray}; +/// # use arrow_array::builder::{ArrayBuilder, Float64Builder, Int64Builder, StringBuilder}; +/// +/// let mut data_builders: Vec> = vec![ +/// Box::new(Float64Builder::new()), +/// Box::new(Int64Builder::new()), +/// Box::new(StringBuilder::new()), +/// ]; +/// +/// // Fill +/// data_builders[0] +/// .as_any_mut() +/// .downcast_mut::() +/// .unwrap() +/// .append_value(3.14); +/// data_builders[1] +/// .as_any_mut() +/// .downcast_mut::() +/// .unwrap() +/// .append_value(-1); +/// data_builders[2] +/// .as_any_mut() +/// .downcast_mut::() +/// .unwrap() +/// .append_value("🍎"); +/// +/// // Finish +/// let array_refs: Vec = data_builders +/// .iter_mut() +/// .map(|builder| builder.finish()) +/// .collect(); +/// assert_eq!(array_refs[0].len(), 1); +/// assert_eq!(array_refs[1].is_null(0), false); +/// assert_eq!( +/// array_refs[2] +/// .as_any() +/// .downcast_ref::() +/// .unwrap() +/// .value(0), +/// "🍎" +/// ); +/// ``` +pub trait ArrayBuilder: Any + Send { + /// Returns the number of array slots in the builder + fn len(&self) -> usize; + + /// Returns whether number of array slots is zero + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Builds the array + fn finish(&mut self) -> ArrayRef; + + /// Builds the array without resetting the underlying builder. + fn finish_cloned(&self) -> ArrayRef; + + /// Returns the builder as a non-mutable `Any` reference. + /// + /// This is most useful when one wants to call non-mutable APIs on a specific builder + /// type. In this case, one can first cast this into a `Any`, and then use + /// `downcast_ref` to get a reference on the specific builder. + fn as_any(&self) -> &dyn Any; + + /// Returns the builder as a mutable `Any` reference. + /// + /// This is most useful when one wants to call mutable APIs on a specific builder + /// type. In this case, one can first cast this into a `Any`, and then use + /// `downcast_mut` to get a reference on the specific builder. + fn as_any_mut(&mut self) -> &mut dyn Any; + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box; +} + +impl ArrayBuilder for Box { + fn len(&self) -> usize { + (**self).len() + } + + fn is_empty(&self) -> bool { + (**self).is_empty() + } + + fn finish(&mut self) -> ArrayRef { + (**self).finish() + } + + fn finish_cloned(&self) -> ArrayRef { + (**self).finish_cloned() + } + + fn as_any(&self) -> &dyn Any { + (**self).as_any() + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + (**self).as_any_mut() + } + + fn into_box_any(self: Box) -> Box { + self + } +} + +/// Builder for [`ListArray`](crate::array::ListArray) +pub type ListBuilder = GenericListBuilder; + +/// Builder for [`LargeListArray`](crate::array::LargeListArray) +pub type LargeListBuilder = GenericListBuilder; + +/// Builder for [`BinaryArray`](crate::array::BinaryArray) +pub type BinaryBuilder = GenericBinaryBuilder; + +/// Builder for [`LargeBinaryArray`](crate::array::LargeBinaryArray) +pub type LargeBinaryBuilder = GenericBinaryBuilder; + +/// Builder for [`StringArray`](crate::array::StringArray) +pub type StringBuilder = GenericStringBuilder; + +/// Builder for [`LargeStringArray`](crate::array::LargeStringArray) +pub type LargeStringBuilder = GenericStringBuilder; diff --git a/arrow-array/src/builder/null_builder.rs b/arrow-array/src/builder/null_builder.rs new file mode 100644 index 000000000000..53a6b103d541 --- /dev/null +++ b/arrow-array/src/builder/null_builder.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::ArrayBuilder; +use crate::{ArrayRef, NullArray}; +use arrow_data::ArrayData; +use arrow_schema::DataType; +use std::any::Any; +use std::sync::Arc; + +/// Builder for [`NullArray`] +/// +/// # Example +/// +/// Create a `NullArray` from a `NullBuilder` +/// +/// ``` +/// +/// # use arrow_array::{Array, NullArray, builder::NullBuilder}; +/// +/// let mut b = NullBuilder::new(); +/// b.append_empty_value(); +/// b.append_null(); +/// b.append_nulls(3); +/// b.append_empty_values(3); +/// let arr = b.finish(); +/// +/// assert_eq!(8, arr.len()); +/// assert_eq!(0, arr.null_count()); +/// ``` +#[derive(Debug)] +pub struct NullBuilder { + len: usize, +} + +impl Default for NullBuilder { + fn default() -> Self { + Self::new() + } +} + +impl NullBuilder { + /// Creates a new null builder + pub fn new() -> Self { + Self { len: 0 } + } + + /// Creates a new null builder with space for `capacity` elements without re-allocating + pub fn with_capacity(capacity: usize) -> Self { + Self { len: capacity } + } + + /// Returns the capacity of this builder measured in slots of type `T` + pub fn capacity(&self) -> usize { + self.len + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.len += 1; + } + + /// Appends `n` `null`s into the builder. + #[inline] + pub fn append_nulls(&mut self, n: usize) { + self.len += n; + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_empty_value(&mut self) { + self.append_null(); + } + + /// Appends `n` `null`s into the builder. + #[inline] + pub fn append_empty_values(&mut self, n: usize) { + self.append_nulls(n); + } + + /// Builds the [NullArray] and reset this builder. + pub fn finish(&mut self) -> NullArray { + let len = self.len(); + let builder = ArrayData::new_null(&DataType::Null, len).into_builder(); + + let array_data = unsafe { builder.build_unchecked() }; + NullArray::from(array_data) + } + + /// Builds the [NullArray] without resetting the builder. + pub fn finish_cloned(&self) -> NullArray { + let len = self.len(); + let builder = ArrayData::new_null(&DataType::Null, len).into_builder(); + + let array_data = unsafe { builder.build_unchecked() }; + NullArray::from(array_data) + } +} + +impl ArrayBuilder for NullBuilder { + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.len + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Array; + + #[test] + fn test_null_array_builder() { + let mut builder = NullArray::builder(10); + builder.append_null(); + builder.append_nulls(4); + builder.append_empty_value(); + builder.append_empty_values(4); + + let arr = builder.finish(); + assert_eq!(20, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert!(arr.is_nullable()); + } + + #[test] + fn test_null_array_builder_finish_cloned() { + let mut builder = NullArray::builder(16); + builder.append_null(); + builder.append_empty_value(); + builder.append_empty_values(3); + let mut array = builder.finish_cloned(); + assert_eq!(21, array.len()); + + builder.append_empty_values(5); + array = builder.finish(); + assert_eq!(26, array.len()); + } +} diff --git a/arrow/src/array/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs similarity index 52% rename from arrow/src/array/builder/primitive_builder.rs rename to arrow-array/src/builder/primitive_builder.rs index 38c8b4471477..0aad2dbfce0e 100644 --- a/arrow/src/array/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -15,21 +15,89 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::types::*; +use crate::{ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow_buffer::NullBufferBuilder; +use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_data::ArrayData; +use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; -use crate::array::ArrayData; -use crate::array::ArrayRef; -use crate::array::PrimitiveArray; -use crate::datatypes::ArrowPrimitiveType; - -use super::{ArrayBuilder, BufferBuilder, NullBufferBuilder}; - -/// Array builder for fixed-width primitive types +/// A signed 8-bit integer array builder. +pub type Int8Builder = PrimitiveBuilder; +/// A signed 16-bit integer array builder. +pub type Int16Builder = PrimitiveBuilder; +/// A signed 32-bit integer array builder. +pub type Int32Builder = PrimitiveBuilder; +/// A signed 64-bit integer array builder. +pub type Int64Builder = PrimitiveBuilder; +/// An usigned 8-bit integer array builder. +pub type UInt8Builder = PrimitiveBuilder; +/// An usigned 16-bit integer array builder. +pub type UInt16Builder = PrimitiveBuilder; +/// An usigned 32-bit integer array builder. +pub type UInt32Builder = PrimitiveBuilder; +/// An usigned 64-bit integer array builder. +pub type UInt64Builder = PrimitiveBuilder; +/// A 16-bit floating point array builder. +pub type Float16Builder = PrimitiveBuilder; +/// A 32-bit floating point array builder. +pub type Float32Builder = PrimitiveBuilder; +/// A 64-bit floating point array builder. +pub type Float64Builder = PrimitiveBuilder; + +/// A timestamp second array builder. +pub type TimestampSecondBuilder = PrimitiveBuilder; +/// A timestamp millisecond array builder. +pub type TimestampMillisecondBuilder = PrimitiveBuilder; +/// A timestamp microsecond array builder. +pub type TimestampMicrosecondBuilder = PrimitiveBuilder; +/// A timestamp nanosecond array builder. +pub type TimestampNanosecondBuilder = PrimitiveBuilder; + +/// A 32-bit date array builder. +pub type Date32Builder = PrimitiveBuilder; +/// A 64-bit date array builder. +pub type Date64Builder = PrimitiveBuilder; + +/// A 32-bit elaspsed time in seconds array builder. +pub type Time32SecondBuilder = PrimitiveBuilder; +/// A 32-bit elaspsed time in milliseconds array builder. +pub type Time32MillisecondBuilder = PrimitiveBuilder; +/// A 64-bit elaspsed time in microseconds array builder. +pub type Time64MicrosecondBuilder = PrimitiveBuilder; +/// A 64-bit elaspsed time in nanoseconds array builder. +pub type Time64NanosecondBuilder = PrimitiveBuilder; + +/// A “calendar” interval in months array builder. +pub type IntervalYearMonthBuilder = PrimitiveBuilder; +/// A “calendar” interval in days and milliseconds array builder. +pub type IntervalDayTimeBuilder = PrimitiveBuilder; +/// A “calendar” interval in months, days, and nanoseconds array builder. +pub type IntervalMonthDayNanoBuilder = PrimitiveBuilder; + +/// An elapsed time in seconds array builder. +pub type DurationSecondBuilder = PrimitiveBuilder; +/// An elapsed time in milliseconds array builder. +pub type DurationMillisecondBuilder = PrimitiveBuilder; +/// An elapsed time in microseconds array builder. +pub type DurationMicrosecondBuilder = PrimitiveBuilder; +/// An elapsed time in nanoseconds array builder. +pub type DurationNanosecondBuilder = PrimitiveBuilder; + +/// A decimal 128 array builder +pub type Decimal128Builder = PrimitiveBuilder; +/// A decimal 256 array builder +pub type Decimal256Builder = PrimitiveBuilder; + +/// Builder for [`PrimitiveArray`] #[derive(Debug)] pub struct PrimitiveBuilder { values_builder: BufferBuilder, null_buffer_builder: NullBufferBuilder, + data_type: DataType, } impl ArrayBuilder for PrimitiveBuilder { @@ -53,15 +121,15 @@ impl ArrayBuilder for PrimitiveBuilder { self.values_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.values_builder.is_empty() - } - /// Builds the array and reset this builder. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } } impl Default for PrimitiveBuilder { @@ -81,9 +149,47 @@ impl PrimitiveBuilder { Self { values_builder: BufferBuilder::::new(capacity), null_buffer_builder: NullBufferBuilder::new(capacity), + data_type: T::DATA_TYPE, + } + } + + /// Creates a new primitive array builder from buffers + pub fn new_from_buffer( + values_buffer: MutableBuffer, + null_buffer: Option, + ) -> Self { + let values_builder = BufferBuilder::::new_from_buffer(values_buffer); + + let null_buffer_builder = null_buffer + .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, values_builder.len())) + .unwrap_or_else(|| NullBufferBuilder::new_with_len(values_builder.len())); + + Self { + values_builder, + null_buffer_builder, + data_type: T::DATA_TYPE, } } + /// By default [`PrimitiveBuilder`] uses [`ArrowPrimitiveType::DATA_TYPE`] as the + /// data type of the generated array. + /// + /// This method allows overriding the data type, to allow specifying timezones + /// for [`DataType::Timestamp`] or precision and scale for [`DataType::Decimal128`] and [`DataType::Decimal256`] + /// + /// # Panics + /// + /// This method panics if `data_type` is not [PrimitiveArray::is_compatible] + pub fn with_data_type(self, data_type: DataType) -> Self { + assert!( + PrimitiveArray::::is_compatible(&data_type), + "incompatible data type for builder, expected {} got {}", + T::DATA_TYPE, + data_type + ); + Self { data_type, ..self } + } + /// Returns the capacity of this builder measured in slots of type `T` pub fn capacity(&self) -> usize { self.values_builder.capacity() @@ -103,6 +209,7 @@ impl PrimitiveBuilder { self.values_builder.advance(1); } + /// Appends `n` no. of null's into the builder #[inline] pub fn append_nulls(&mut self, n: usize) { self.null_buffer_builder.append_n_nulls(n); @@ -126,6 +233,10 @@ impl PrimitiveBuilder { } /// Appends values from a slice of type `T` and a validity boolean slice + /// + /// # Panics + /// + /// Panics if `values` and `is_valid` have different lengths #[inline] pub fn append_values(&mut self, values: &[T::Native], is_valid: &[bool]) { assert_eq!( @@ -143,10 +254,7 @@ impl PrimitiveBuilder { /// This requires the iterator be a trusted length. This could instead require /// the iterator implement `TrustedLen` once that is stabilized. #[inline] - pub unsafe fn append_trusted_len_iter( - &mut self, - iter: impl IntoIterator, - ) { + pub unsafe fn append_trusted_len_iter(&mut self, iter: impl IntoIterator) { let iter = iter.into_iter(); let len = iter .size_hint() @@ -160,11 +268,25 @@ impl PrimitiveBuilder { /// Builds the [`PrimitiveArray`] and reset this builder. pub fn finish(&mut self) -> PrimitiveArray { let len = self.len(); - let null_bit_buffer = self.null_buffer_builder.finish(); - let builder = ArrayData::builder(T::DATA_TYPE) + let nulls = self.null_buffer_builder.finish(); + let builder = ArrayData::builder(self.data_type.clone()) .len(len) .add_buffer(self.values_builder.finish()) - .null_bit_buffer(null_bit_buffer); + .nulls(nulls); + + let array_data = unsafe { builder.build_unchecked() }; + PrimitiveArray::::from(array_data) + } + + /// Builds the [`PrimitiveArray`] without resetting the builder. + pub fn finish_cloned(&self) -> PrimitiveArray { + let len = self.len(); + let nulls = self.null_buffer_builder.finish_cloned(); + let values_buffer = Buffer::from_slice_ref(self.values_builder.as_slice()); + let builder = ArrayData::builder(self.data_type.clone()) + .len(len) + .add_buffer(values_buffer) + .nulls(nulls); let array_data = unsafe { builder.build_unchecked() }; PrimitiveArray::::from(array_data) @@ -174,19 +296,78 @@ impl PrimitiveBuilder { pub fn values_slice(&self) -> &[T::Native] { self.values_builder.as_slice() } + + /// Returns the current values buffer as a mutable slice + pub fn values_slice_mut(&mut self) -> &mut [T::Native] { + self.values_builder.as_slice_mut() + } + + /// Returns the current null buffer as a slice + pub fn validity_slice(&self) -> Option<&[u8]> { + self.null_buffer_builder.as_slice() + } + + /// Returns the current null buffer as a mutable slice + pub fn validity_slice_mut(&mut self) -> Option<&mut [u8]> { + self.null_buffer_builder.as_slice_mut() + } + + /// Returns the current values buffer and null buffer as a slice + pub fn slices_mut(&mut self) -> (&mut [T::Native], Option<&mut [u8]>) { + ( + self.values_builder.as_slice_mut(), + self.null_buffer_builder.as_slice_mut(), + ) + } +} + +impl PrimitiveBuilder

{ + /// Sets the precision and scale + pub fn with_precision_and_scale(self, precision: u8, scale: i8) -> Result { + validate_decimal_precision_and_scale::

(precision, scale)?; + Ok(Self { + data_type: P::TYPE_CONSTRUCTOR(precision, scale), + ..self + }) + } +} + +impl PrimitiveBuilder

{ + /// Sets the timezone + pub fn with_timezone(self, timezone: impl Into>) -> Self { + self.with_timezone_opt(Some(timezone.into())) + } + + /// Sets an optional timezone + pub fn with_timezone_opt>>(self, timezone: Option) -> Self { + Self { + data_type: DataType::Timestamp(P::UNIT, timezone.map(Into::into)), + ..self + } + } +} + +impl Extend> for PrimitiveBuilder

{ + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + self.append_option(v) + } + } } #[cfg(test)] mod tests { use super::*; + use arrow_buffer::Buffer; + use arrow_schema::TimeUnit; use crate::array::Array; use crate::array::BooleanArray; use crate::array::Date32Array; use crate::array::Int32Array; - use crate::array::Int32Builder; use crate::array::TimestampSecondArray; - use crate::buffer::Buffer; + use crate::builder::Int32Builder; #[test] fn test_primitive_array_builder_i32() { @@ -282,14 +463,14 @@ mod tests { } let arr = builder.finish(); - assert_eq!(&buf, arr.values()); + assert_eq!(&buf, arr.values().inner()); assert_eq!(10, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); for i in 0..10 { assert!(!arr.is_null(i)); assert!(arr.is_valid(i)); - assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {}", i) + assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {i}") } } @@ -377,4 +558,56 @@ mod tests { assert_eq!(5, arr.len()); assert_eq!(0, builder.len()); } + + #[test] + fn test_primitive_array_builder_finish_cloned() { + let mut builder = Int32Builder::new(); + builder.append_value(23); + builder.append_value(45); + let result = builder.finish_cloned(); + assert_eq!(result, Int32Array::from(vec![23, 45])); + builder.append_value(56); + assert_eq!(builder.finish_cloned(), Int32Array::from(vec![23, 45, 56])); + + builder.append_slice(&[2, 4, 6, 8]); + let mut arr = builder.finish(); + assert_eq!(7, arr.len()); + assert_eq!(arr, Int32Array::from(vec![23, 45, 56, 2, 4, 6, 8])); + assert_eq!(0, builder.len()); + + builder.append_slice(&[1, 3, 5, 7, 9]); + arr = builder.finish(); + assert_eq!(5, arr.len()); + assert_eq!(0, builder.len()); + } + + #[test] + fn test_primitive_array_builder_with_data_type() { + let mut builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); + builder.append_value(1); + let array = builder.finish(); + assert_eq!(array.precision(), 1); + assert_eq!(array.scale(), 2); + + let data_type = DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())); + let mut builder = TimestampNanosecondBuilder::new().with_data_type(data_type.clone()); + builder.append_value(1); + let array = builder.finish(); + assert_eq!(array.data_type(), &data_type); + } + + #[test] + #[should_panic(expected = "incompatible data type for builder, expected Int32 got Int64")] + fn test_invalid_with_data_type() { + Int32Builder::new().with_data_type(DataType::Int64); + } + + #[test] + fn test_extend() { + let mut builder = PrimitiveBuilder::::new(); + builder.extend([1, 2, 3, 5, 2, 4, 4].into_iter().map(Some)); + builder.extend([2, 4, 6, 2].into_iter().map(Some)); + let array = builder.finish(); + assert_eq!(array.values(), &[1, 2, 3, 5, 2, 4, 4, 2, 4, 6, 2]); + } } diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs new file mode 100644 index 000000000000..a47b2d30d4f3 --- /dev/null +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -0,0 +1,398 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::builder::{ArrayBuilder, PrimitiveBuilder}; +use crate::types::ArrowDictionaryKeyType; +use crate::{Array, ArrayRef, ArrowPrimitiveType, DictionaryArray}; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; +use arrow_schema::{ArrowError, DataType}; +use std::any::Any; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::Arc; + +/// Wraps a type implementing `ToByteSlice` implementing `Hash` and `Eq` for it +/// +/// This is necessary to handle types such as f32, which don't natively implement these +#[derive(Debug)] +struct Value(T); + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.0.to_byte_slice().eq(other.0.to_byte_slice()) + } +} + +impl Eq for Value {} + +/// Builder for [`DictionaryArray`] of [`PrimitiveArray`](crate::array::PrimitiveArray) +/// +/// # Example: +/// +/// ``` +/// +/// # use arrow_array::builder::PrimitiveDictionaryBuilder; +/// # use arrow_array::types::{UInt32Type, UInt8Type}; +/// # use arrow_array::{Array, UInt32Array, UInt8Array}; +/// +/// let mut builder = PrimitiveDictionaryBuilder::::new(); +/// builder.append(12345678).unwrap(); +/// builder.append_null(); +/// builder.append(22345678).unwrap(); +/// let array = builder.finish(); +/// +/// assert_eq!( +/// array.keys(), +/// &UInt8Array::from(vec![Some(0), None, Some(1)]) +/// ); +/// +/// // Values are polymorphic and so require a downcast. +/// let av = array.values(); +/// let ava: &UInt32Array = av.as_any().downcast_ref::().unwrap(); +/// let avs: &[u32] = ava.values(); +/// +/// assert!(!array.is_null(0)); +/// assert!(array.is_null(1)); +/// assert!(!array.is_null(2)); +/// +/// assert_eq!(avs, &[12345678, 22345678]); +/// ``` +#[derive(Debug)] +pub struct PrimitiveDictionaryBuilder +where + K: ArrowPrimitiveType, + V: ArrowPrimitiveType, +{ + keys_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + map: HashMap, usize>, +} + +impl Default for PrimitiveDictionaryBuilder +where + K: ArrowPrimitiveType, + V: ArrowPrimitiveType, +{ + fn default() -> Self { + Self::new() + } +} + +impl PrimitiveDictionaryBuilder +where + K: ArrowPrimitiveType, + V: ArrowPrimitiveType, +{ + /// Creates a new `PrimitiveDictionaryBuilder`. + pub fn new() -> Self { + Self { + keys_builder: PrimitiveBuilder::new(), + values_builder: PrimitiveBuilder::new(), + map: HashMap::new(), + } + } + + /// Creates a new `PrimitiveDictionaryBuilder` from the provided keys and values builders. + /// + /// # Panics + /// + /// This method panics if `keys_builder` or `values_builder` is not empty. + pub fn new_from_empty_builders( + keys_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + ) -> Self { + assert!( + keys_builder.is_empty() && values_builder.is_empty(), + "keys and values builders must be empty" + ); + Self { + keys_builder, + values_builder, + map: HashMap::new(), + } + } + + /// Creates a new `PrimitiveDictionaryBuilder` from existing `PrimitiveBuilder`s of keys and values. + /// + /// # Safety + /// + /// caller must ensure that the passed in builders are valid for DictionaryArray. + pub unsafe fn new_from_builders( + keys_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + ) -> Self { + let keys = keys_builder.values_slice(); + let values = values_builder.values_slice(); + let mut map = HashMap::with_capacity(values.len()); + + keys.iter().zip(values.iter()).for_each(|(key, value)| { + map.insert(Value(*value), K::Native::to_usize(*key).unwrap()); + }); + + Self { + keys_builder, + values_builder, + map, + } + } + + /// Creates a new `PrimitiveDictionaryBuilder` with the provided capacities + /// + /// `keys_capacity`: the number of keys, i.e. length of array to build + /// `values_capacity`: the number of distinct dictionary values, i.e. size of dictionary + pub fn with_capacity(keys_capacity: usize, values_capacity: usize) -> Self { + Self { + keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), + values_builder: PrimitiveBuilder::with_capacity(values_capacity), + map: HashMap::with_capacity(values_capacity), + } + } +} + +impl ArrayBuilder for PrimitiveDictionaryBuilder +where + K: ArrowDictionaryKeyType, + V: ArrowPrimitiveType, +{ + /// Returns the builder as an non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as an mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.keys_builder.len() + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl PrimitiveDictionaryBuilder +where + K: ArrowDictionaryKeyType, + V: ArrowPrimitiveType, +{ + /// Append a primitive value to the array. Return an existing index + /// if already present in the values array or a new index if the + /// value is appended to the values array. + #[inline] + pub fn append(&mut self, value: V::Native) -> Result { + let key = match self.map.entry(Value(value)) { + Entry::Vacant(vacant) => { + // Append new value. + let key = self.values_builder.len(); + self.values_builder.append_value(value); + vacant.insert(key); + K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)? + } + Entry::Occupied(o) => K::Native::usize_as(*o.get()), + }; + + self.keys_builder.append_value(key); + Ok(key) + } + + /// Infallibly append a value to this builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + #[inline] + pub fn append_value(&mut self, value: V::Native) { + self.append(value).expect("dictionary key overflow"); + } + + /// Appends a null slot into the builder + #[inline] + pub fn append_null(&mut self) { + self.keys_builder.append_null() + } + + /// Append an `Option` value into the builder + /// + /// # Panics + /// + /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` + #[inline] + pub fn append_option(&mut self, value: Option) { + match value { + None => self.append_null(), + Some(v) => self.append_value(v), + }; + } + + /// Builds the `DictionaryArray` and reset this builder. + pub fn finish(&mut self) -> DictionaryArray { + self.map.clear(); + let values = self.values_builder.finish(); + let keys = self.keys_builder.finish(); + + let data_type = + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + + /// Builds the `DictionaryArray` without resetting the builder. + pub fn finish_cloned(&self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish_cloned(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + + /// Returns the current dictionary values buffer as a slice + pub fn values_slice(&self) -> &[V::Native] { + self.values_builder.values_slice() + } + + /// Returns the current dictionary values buffer as a mutable slice + pub fn values_slice_mut(&mut self) -> &mut [V::Native] { + self.values_builder.values_slice_mut() + } +} + +impl Extend> + for PrimitiveDictionaryBuilder +{ + #[inline] + fn extend>>(&mut self, iter: T) { + for v in iter { + self.append_option(v) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::Array; + use crate::array::UInt32Array; + use crate::array::UInt8Array; + use crate::builder::Decimal128Builder; + use crate::types::{Decimal128Type, Int32Type, UInt32Type, UInt8Type}; + + #[test] + fn test_primitive_dictionary_builder() { + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(3, 2); + builder.append(12345678).unwrap(); + builder.append_null(); + builder.append(22345678).unwrap(); + let array = builder.finish(); + + assert_eq!( + array.keys(), + &UInt8Array::from(vec![Some(0), None, Some(1)]) + ); + + // Values are polymorphic and so require a downcast. + let av = array.values(); + let ava: &UInt32Array = av.as_any().downcast_ref::().unwrap(); + let avs: &[u32] = ava.values(); + + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert!(!array.is_null(2)); + + assert_eq!(avs, &[12345678, 22345678]); + } + + #[test] + fn test_extend() { + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some)); + builder.extend([4, 5, 1, 3, 1].into_iter().map(Some)); + let dict = builder.finish(); + assert_eq!( + dict.keys().values(), + &[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 0, 2, 0] + ); + assert_eq!(dict.values().len(), 5); + } + + #[test] + #[should_panic(expected = "DictionaryKeyOverflowError")] + fn test_primitive_dictionary_overflow() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(257, 257); + // 256 unique keys. + for i in 0..256 { + builder.append(i + 1000).unwrap(); + } + // Special error if the key overflows (256th entry) + builder.append(1257).unwrap(); + } + + #[test] + fn test_primitive_dictionary_with_builders() { + let keys_builder = PrimitiveBuilder::::new(); + let values_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); + let mut builder = + PrimitiveDictionaryBuilder::::new_from_empty_builders( + keys_builder, + values_builder, + ); + let dict_array = builder.finish(); + assert_eq!(dict_array.value_type(), DataType::Decimal128(1, 2)); + assert_eq!( + dict_array.data_type(), + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal128(1, 2)), + ) + ); + } +} diff --git a/arrow-array/src/builder/primitive_run_builder.rs b/arrow-array/src/builder/primitive_run_builder.rs new file mode 100644 index 000000000000..01a989199b58 --- /dev/null +++ b/arrow-array/src/builder/primitive_run_builder.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use crate::{types::RunEndIndexType, ArrayRef, ArrowPrimitiveType, RunArray}; + +use super::{ArrayBuilder, PrimitiveBuilder}; + +use arrow_buffer::ArrowNativeType; + +/// Builder for [`RunArray`] of [`PrimitiveArray`](crate::array::PrimitiveArray) +/// +/// # Example: +/// +/// ``` +/// +/// # use arrow_array::builder::PrimitiveRunBuilder; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::types::{UInt32Type, Int16Type}; +/// # use arrow_array::{Array, UInt32Array, Int16Array}; +/// +/// let mut builder = +/// PrimitiveRunBuilder::::new(); +/// builder.append_value(1234); +/// builder.append_value(1234); +/// builder.append_value(1234); +/// builder.append_null(); +/// builder.append_value(5678); +/// builder.append_value(5678); +/// let array = builder.finish(); +/// +/// assert_eq!(array.run_ends().values(), &[3, 4, 6]); +/// +/// let av = array.values(); +/// +/// assert!(!av.is_null(0)); +/// assert!(av.is_null(1)); +/// assert!(!av.is_null(2)); +/// +/// // Values are polymorphic and so require a downcast. +/// let ava: &UInt32Array = av.as_primitive::(); +/// +/// assert_eq!(ava, &UInt32Array::from(vec![Some(1234), None, Some(5678)])); +/// ``` +#[derive(Debug)] +pub struct PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + run_ends_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + current_value: Option, + current_run_end_index: usize, + prev_run_end_index: usize, +} + +impl Default for PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + fn default() -> Self { + Self::new() + } +} + +impl PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + /// Creates a new `PrimitiveRunBuilder` + pub fn new() -> Self { + Self { + run_ends_builder: PrimitiveBuilder::new(), + values_builder: PrimitiveBuilder::new(), + current_value: None, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } + + /// Creates a new `PrimitiveRunBuilder` with the provided capacity + /// + /// `capacity`: the expected number of run-end encoded values. + pub fn with_capacity(capacity: usize) -> Self { + Self { + run_ends_builder: PrimitiveBuilder::with_capacity(capacity), + values_builder: PrimitiveBuilder::with_capacity(capacity), + current_value: None, + current_run_end_index: 0, + prev_run_end_index: 0, + } + } +} + +impl ArrayBuilder for PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + /// Returns the builder as a non-mutable `Any` reference. + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference. + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any`. + fn into_box_any(self: Box) -> Box { + self + } + + /// Returns the length of logical array encoded by + /// the eventual runs array. + fn len(&self) -> usize { + self.current_run_end_index + } + + /// Builds the array and reset this builder. + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } + + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } +} + +impl PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + /// Appends optional value to the logical array encoded by the RunArray. + pub fn append_option(&mut self, value: Option) { + if self.current_run_end_index == 0 { + self.current_run_end_index = 1; + self.current_value = value; + return; + } + if self.current_value != value { + self.append_run_end(); + self.current_value = value; + } + + self.current_run_end_index += 1; + } + + /// Appends value to the logical array encoded by the run-ends array. + pub fn append_value(&mut self, value: V::Native) { + self.append_option(Some(value)) + } + + /// Appends null to the logical array encoded by the run-ends array. + pub fn append_null(&mut self) { + self.append_option(None) + } + + /// Creates the RunArray and resets the builder. + /// Panics if RunArray cannot be built. + pub fn finish(&mut self) -> RunArray { + // write the last run end to the array. + self.append_run_end(); + + // reset the run index to zero. + self.current_value = None; + self.current_run_end_index = 0; + + // build the run encoded array by adding run_ends and values array as its children. + let run_ends_array = self.run_ends_builder.finish(); + let values_array = self.values_builder.finish(); + RunArray::::try_new(&run_ends_array, &values_array).unwrap() + } + + /// Creates the RunArray and without resetting the builder. + /// Panics if RunArray cannot be built. + pub fn finish_cloned(&self) -> RunArray { + let mut run_ends_array = self.run_ends_builder.finish_cloned(); + let mut values_array = self.values_builder.finish_cloned(); + + // Add current run if one exists + if self.prev_run_end_index != self.current_run_end_index { + let mut run_end_builder = run_ends_array.into_builder().unwrap(); + let mut values_builder = values_array.into_builder().unwrap(); + self.append_run_end_with_builders(&mut run_end_builder, &mut values_builder); + run_ends_array = run_end_builder.finish(); + values_array = values_builder.finish(); + } + + RunArray::try_new(&run_ends_array, &values_array).unwrap() + } + + // Appends the current run to the array. + fn append_run_end(&mut self) { + // empty array or the function called without appending any value. + if self.prev_run_end_index == self.current_run_end_index { + return; + } + let run_end_index = self.run_end_index_as_native(); + self.run_ends_builder.append_value(run_end_index); + self.values_builder.append_option(self.current_value); + self.prev_run_end_index = self.current_run_end_index; + } + + // Similar to `append_run_end` but on custom builders. + // Used in `finish_cloned` which is not suppose to mutate `self`. + fn append_run_end_with_builders( + &self, + run_ends_builder: &mut PrimitiveBuilder, + values_builder: &mut PrimitiveBuilder, + ) { + let run_end_index = self.run_end_index_as_native(); + run_ends_builder.append_value(run_end_index); + values_builder.append_option(self.current_value); + } + + fn run_end_index_as_native(&self) -> R::Native { + R::Native::from_usize(self.current_run_end_index) + .unwrap_or_else(|| panic!( + "Cannot convert `current_run_end_index` {} from `usize` to native form of arrow datatype {}", + self.current_run_end_index, + R::DATA_TYPE + )) + } +} + +impl Extend> for PrimitiveRunBuilder +where + R: RunEndIndexType, + V: ArrowPrimitiveType, +{ + fn extend>>(&mut self, iter: T) { + for elem in iter { + self.append_option(elem); + } + } +} + +#[cfg(test)] +mod tests { + use crate::builder::PrimitiveRunBuilder; + use crate::cast::AsArray; + use crate::types::{Int16Type, UInt32Type}; + use crate::{Array, UInt32Array}; + + #[test] + fn test_primitive_ree_array_builder() { + let mut builder = PrimitiveRunBuilder::::new(); + builder.append_value(1234); + builder.append_value(1234); + builder.append_value(1234); + builder.append_null(); + builder.append_value(5678); + builder.append_value(5678); + + let array = builder.finish(); + + assert_eq!(array.null_count(), 0); + assert_eq!(array.len(), 6); + + assert_eq!(array.run_ends().values(), &[3, 4, 6]); + + let av = array.values(); + + assert!(!av.is_null(0)); + assert!(av.is_null(1)); + assert!(!av.is_null(2)); + + // Values are polymorphic and so require a downcast. + let ava: &UInt32Array = av.as_primitive::(); + + assert_eq!(ava, &UInt32Array::from(vec![Some(1234), None, Some(5678)])); + } + + #[test] + fn test_extend() { + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend([1, 2, 2, 5, 5, 4, 4].into_iter().map(Some)); + builder.extend([4, 4, 6, 2].into_iter().map(Some)); + let array = builder.finish(); + + assert_eq!(array.len(), 11); + assert_eq!(array.null_count(), 0); + assert_eq!(array.run_ends().values(), &[1, 3, 5, 9, 10, 11]); + assert_eq!( + array.values().as_primitive::().values(), + &[1, 2, 5, 4, 6, 2] + ); + } +} diff --git a/arrow/src/array/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs similarity index 55% rename from arrow/src/array/builder/struct_builder.rs rename to arrow-array/src/builder/struct_builder.rs index c5db09119e08..960949a2f09f 100644 --- a/arrow/src/array/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -15,29 +15,25 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::*; +use crate::{ArrayRef, StructArray}; +use arrow_buffer::NullBufferBuilder; +use arrow_schema::{DataType, Fields, IntervalUnit, TimeUnit}; use std::any::Any; -use std::fmt; use std::sync::Arc; -use crate::array::builder::decimal_builder::Decimal128Builder; -use crate::array::*; -use crate::datatypes::DataType; -use crate::datatypes::Field; - -use super::NullBufferBuilder; - -/// Array builder for Struct types. +/// Builder for [`StructArray`] /// /// Note that callers should make sure that methods of all the child field builders are /// properly called to maintain the consistency of the data structure. pub struct StructBuilder { - fields: Vec, + fields: Fields, field_builders: Vec>, null_buffer_builder: NullBufferBuilder, } -impl fmt::Debug for StructBuilder { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl std::fmt::Debug for StructBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("StructBuilder") .field("fields", &self.fields) .field("bitmap_builder", &self.null_buffer_builder) @@ -56,16 +52,16 @@ impl ArrayBuilder for StructBuilder { self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero - fn is_empty(&self) -> bool { - self.len() == 0 - } - /// Builds the array. fn finish(&mut self) -> ArrayRef { Arc::new(self.finish()) } + /// Builds the array without resetting the builder. + fn finish_cloned(&self) -> ArrayRef { + Arc::new(self.finish_cloned()) + } + /// Returns the builder as a non-mutable `Any` reference. /// /// This is most useful when one wants to call non-mutable APIs on a specific builder @@ -94,8 +90,9 @@ impl ArrayBuilder for StructBuilder { /// This function is useful to construct arrays from an arbitrary vectors with known/expected /// schema. pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { + use crate::builder::*; match datatype { - DataType::Null => unimplemented!(), + DataType::Null => Box::new(NullBuilder::with_capacity(capacity)), DataType::Boolean => Box::new(BooleanBuilder::with_capacity(capacity)), DataType::Int8 => Box::new(Int8Builder::with_capacity(capacity)), DataType::Int16 => Box::new(Int16Builder::with_capacity(capacity)), @@ -105,16 +102,22 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box Box::new(UInt16Builder::with_capacity(capacity)), DataType::UInt32 => Box::new(UInt32Builder::with_capacity(capacity)), DataType::UInt64 => Box::new(UInt64Builder::with_capacity(capacity)), + DataType::Float16 => Box::new(Float16Builder::with_capacity(capacity)), DataType::Float32 => Box::new(Float32Builder::with_capacity(capacity)), DataType::Float64 => Box::new(Float64Builder::with_capacity(capacity)), - DataType::Binary => Box::new(BinaryBuilder::with_capacity(1024, capacity)), + DataType::Binary => Box::new(BinaryBuilder::with_capacity(capacity, 1024)), + DataType::LargeBinary => Box::new(LargeBinaryBuilder::with_capacity(capacity, 1024)), DataType::FixedSizeBinary(len) => { Box::new(FixedSizeBinaryBuilder::with_capacity(capacity, *len)) } - DataType::Decimal128(precision, scale) => Box::new( - Decimal128Builder::with_capacity(capacity, *precision, *scale), + DataType::Decimal128(p, s) => Box::new( + Decimal128Builder::with_capacity(capacity).with_data_type(DataType::Decimal128(*p, *s)), + ), + DataType::Decimal256(p, s) => Box::new( + Decimal256Builder::with_capacity(capacity).with_data_type(DataType::Decimal256(*p, *s)), ), - DataType::Utf8 => Box::new(StringBuilder::with_capacity(1024, capacity)), + DataType::Utf8 => Box::new(StringBuilder::with_capacity(capacity, 1024)), + DataType::LargeUtf8 => Box::new(LargeStringBuilder::with_capacity(capacity, 1024)), DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)), DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)), DataType::Time32(TimeUnit::Second) => { @@ -129,18 +132,22 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { Box::new(Time64NanosecondBuilder::with_capacity(capacity)) } - DataType::Timestamp(TimeUnit::Second, _) => { - Box::new(TimestampSecondBuilder::with_capacity(capacity)) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - Box::new(TimestampMillisecondBuilder::with_capacity(capacity)) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - Box::new(TimestampMicrosecondBuilder::with_capacity(capacity)) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - Box::new(TimestampNanosecondBuilder::with_capacity(capacity)) - } + DataType::Timestamp(TimeUnit::Second, tz) => Box::new( + TimestampSecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Second, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Millisecond, tz) => Box::new( + TimestampMillisecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => Box::new( + TimestampMicrosecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone())), + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Box::new( + TimestampNanosecondBuilder::with_capacity(capacity) + .with_data_type(DataType::Timestamp(TimeUnit::Nanosecond, tz.clone())), + ), DataType::Interval(IntervalUnit::YearMonth) => { Box::new(IntervalYearMonthBuilder::with_capacity(capacity)) } @@ -162,23 +169,32 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { Box::new(DurationNanosecondBuilder::with_capacity(capacity)) } - DataType::Struct(fields) => { - Box::new(StructBuilder::from_fields(fields.clone(), capacity)) + DataType::List(field) => { + let builder = make_builder(field.data_type(), capacity); + Box::new(ListBuilder::with_capacity(builder, capacity)) + } + DataType::LargeList(field) => { + let builder = make_builder(field.data_type(), capacity); + Box::new(LargeListBuilder::with_capacity(builder, capacity)) } - t => panic!("Data type {:?} is not currently supported", t), + DataType::Struct(fields) => Box::new(StructBuilder::from_fields(fields.clone(), capacity)), + t => panic!("Data type {t:?} is not currently supported"), } } impl StructBuilder { - pub fn new(fields: Vec, field_builders: Vec>) -> Self { + /// Creates a new `StructBuilder` + pub fn new(fields: impl Into, field_builders: Vec>) -> Self { Self { - fields, field_builders, + fields: fields.into(), null_buffer_builder: NullBufferBuilder::new(0), } } - pub fn from_fields(fields: Vec, capacity: usize) -> Self { + /// Creates a new `StructBuilder` from [`Fields`] and `capacity` + pub fn from_fields(fields: impl Into, capacity: usize) -> Self { + let fields = fields.into(); let mut builders = Vec::with_capacity(fields.len()); for field in &fields { builders.push(make_builder(field.data_type(), capacity)); @@ -214,22 +230,35 @@ impl StructBuilder { /// Builds the `StructArray` and reset this builder. pub fn finish(&mut self) -> StructArray { self.validate_content(); + if self.fields.is_empty() { + return StructArray::new_empty_fields(self.len(), self.null_buffer_builder.finish()); + } + + let arrays = self.field_builders.iter_mut().map(|f| f.finish()).collect(); + let nulls = self.null_buffer_builder.finish(); + StructArray::new(self.fields.clone(), arrays, nulls) + } + + /// Builds the `StructArray` without resetting the builder. + pub fn finish_cloned(&self) -> StructArray { + self.validate_content(); - let mut child_data = Vec::with_capacity(self.field_builders.len()); - for f in &mut self.field_builders { - let arr = f.finish(); - child_data.push(arr.into_data()); + if self.fields.is_empty() { + return StructArray::new_empty_fields( + self.len(), + self.null_buffer_builder.finish_cloned(), + ); } - let length = self.len(); - let null_bit_buffer = self.null_buffer_builder.finish(); - let builder = ArrayData::builder(DataType::Struct(self.fields.clone())) - .len(length) - .child_data(child_data) - .null_bit_buffer(null_bit_buffer); + let arrays = self + .field_builders + .iter() + .map(|f| f.finish_cloned()) + .collect(); + + let nulls = self.null_buffer_builder.finish_cloned(); - let array_data = unsafe { builder.build_unchecked() }; - StructArray::from(array_data) + StructArray::new(self.fields.clone(), arrays, nulls) } /// Constructs and validates contents in the builder to ensure that @@ -248,22 +277,25 @@ impl StructBuilder { #[cfg(test)] mod tests { use super::*; + use arrow_buffer::Buffer; + use arrow_data::ArrayData; + use arrow_schema::Field; use crate::array::Array; - use crate::bitmap::Bitmap; - use crate::buffer::Buffer; #[test] fn test_struct_array_builder() { let string_builder = StringBuilder::new(); let int_builder = Int32Builder::new(); - let mut fields = Vec::new(); - let mut field_builders = Vec::new(); - fields.push(Field::new("f1", DataType::Utf8, false)); - field_builders.push(Box::new(string_builder) as Box); - fields.push(Field::new("f2", DataType::Int32, false)); - field_builders.push(Box::new(int_builder) as Box); + let fields = vec![ + Field::new("f1", DataType::Utf8, true), + Field::new("f2", DataType::Int32, true), + ]; + let field_builders = vec![ + Box::new(string_builder) as Box, + Box::new(int_builder) as Box, + ]; let mut builder = StructBuilder::new(fields, field_builders); assert_eq!(2, builder.num_fields()); @@ -289,33 +321,29 @@ mod tests { builder.append_null(); builder.append(true); - let arr = builder.finish(); + let struct_data = builder.finish().into_data(); - let struct_data = arr.data(); assert_eq!(4, struct_data.len()); assert_eq!(1, struct_data.null_count()); - assert_eq!( - Some(&Bitmap::from(Buffer::from(&[11_u8]))), - struct_data.null_bitmap() - ); + assert_eq!(&[11_u8], struct_data.nulls().unwrap().validity()); let expected_string_data = ArrayData::builder(DataType::Utf8) .len(4) .null_bit_buffer(Some(Buffer::from(&[9_u8]))) - .add_buffer(Buffer::from_slice_ref(&[0, 3, 3, 3, 7])) + .add_buffer(Buffer::from_slice_ref([0, 3, 3, 3, 7])) .add_buffer(Buffer::from_slice_ref(b"joemark")) .build() .unwrap(); let expected_int_data = ArrayData::builder(DataType::Int32) .len(4) - .null_bit_buffer(Some(Buffer::from_slice_ref(&[11_u8]))) - .add_buffer(Buffer::from_slice_ref(&[1, 2, 0, 4])) + .null_bit_buffer(Some(Buffer::from_slice_ref([11_u8]))) + .add_buffer(Buffer::from_slice_ref([1, 2, 0, 4])) .build() .unwrap(); - assert_eq!(expected_string_data, *arr.column(0).data()); - assert_eq!(expected_int_data, *arr.column(1).data()); + assert_eq!(expected_string_data, struct_data.child_data()[0]); + assert_eq!(expected_int_data, struct_data.child_data()[1]); } #[test] @@ -323,12 +351,14 @@ mod tests { let int_builder = Int32Builder::new(); let bool_builder = BooleanBuilder::new(); - let mut fields = Vec::new(); - let mut field_builders = Vec::new(); - fields.push(Field::new("f1", DataType::Int32, false)); - field_builders.push(Box::new(int_builder) as Box); - fields.push(Field::new("f2", DataType::Boolean, false)); - field_builders.push(Box::new(bool_builder) as Box); + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![ + Box::new(int_builder) as Box, + Box::new(bool_builder) as Box, + ]; let mut builder = StructBuilder::new(fields, field_builders); builder @@ -376,6 +406,66 @@ mod tests { assert_eq!(0, builder.len()); } + #[test] + fn test_struct_array_builder_finish_cloned() { + let int_builder = Int32Builder::new(); + let bool_builder = BooleanBuilder::new(); + + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![ + Box::new(int_builder) as Box, + Box::new(bool_builder) as Box, + ]; + + let mut builder = StructBuilder::new(fields, field_builders); + builder + .field_builder::(0) + .unwrap() + .append_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + builder + .field_builder::(1) + .unwrap() + .append_slice(&[ + false, true, false, true, false, true, false, true, false, true, + ]); + + // Append slot values - all are valid. + for _ in 0..10 { + builder.append(true); + } + + assert_eq!(10, builder.len()); + + let mut arr = builder.finish_cloned(); + + assert_eq!(10, arr.len()); + assert_eq!(10, builder.len()); + + builder + .field_builder::(0) + .unwrap() + .append_slice(&[1, 3, 5, 7, 9]); + builder + .field_builder::(1) + .unwrap() + .append_slice(&[false, true, false, true, false]); + + // Append slot values - all are valid. + for _ in 0..5 { + builder.append(true); + } + + assert_eq!(15, builder.len()); + + arr = builder.finish(); + + assert_eq!(15, arr.len()); + assert_eq!(0, builder.len()); + } + #[test] fn test_struct_array_builder_from_schema() { let mut fields = vec![ @@ -386,7 +476,7 @@ mod tests { Field::new("g1", DataType::Int32, false), Field::new("g2", DataType::Boolean, false), ]; - let struct_type = DataType::Struct(sub_fields); + let struct_type = DataType::Struct(sub_fields.into()); fields.push(Field::new("f3", struct_type, false)); let mut builder = StructBuilder::from_fields(fields, 5); @@ -396,15 +486,48 @@ mod tests { assert!(builder.field_builder::(2).is_some()); } + #[test] + fn test_datatype_properties() { + let fields = Fields::from(vec![ + Field::new("f1", DataType::Decimal128(1, 2), false), + Field::new( + "f2", + DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), + false, + ), + ]); + let mut builder = StructBuilder::from_fields(fields.clone(), 1); + builder + .field_builder::(0) + .unwrap() + .append_value(1); + builder + .field_builder::(1) + .unwrap() + .append_value(1); + builder.append(true); + let array = builder.finish(); + + assert_eq!(array.data_type(), &DataType::Struct(fields.clone())); + assert_eq!(array.column(0).data_type(), fields[0].data_type()); + assert_eq!(array.column(1).data_type(), fields[1].data_type()); + } + #[test] #[should_panic( - expected = "Data type List(Field { name: \"item\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) is not currently supported" + expected = "Data type Map(Field { name: \"entries\", data_type: Struct([Field { name: \"keys\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"values\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false) is not currently supported" )] fn test_struct_array_builder_from_schema_unsupported_type() { - let mut fields = vec![Field::new("f1", DataType::Int16, false)]; - let list_type = - DataType::List(Box::new(Field::new("item", DataType::Int64, true))); - fields.push(Field::new("f2", list_type, false)); + let keys = Arc::new(Field::new("keys", DataType::Int32, false)); + let values = Arc::new(Field::new("values", DataType::UInt32, false)); + let struct_type = DataType::Struct(Fields::from(vec![keys, values])); + let map_data_type = + DataType::Map(Arc::new(Field::new("entries", struct_type, false)), false); + + let fields = vec![ + Field::new("f1", DataType::Int16, false), + Field::new("f2", map_data_type, false), + ]; let _ = StructBuilder::from_fields(fields, 5); } @@ -413,10 +536,8 @@ mod tests { fn test_struct_array_builder_field_builder_type_mismatch() { let int_builder = Int32Builder::with_capacity(10); - let mut fields = Vec::new(); - let mut field_builders = Vec::new(); - fields.push(Field::new("f1", DataType::Int32, false)); - field_builders.push(Box::new(int_builder) as Box); + let fields = vec![Field::new("f1", DataType::Int32, false)]; + let field_builders = vec![Box::new(int_builder) as Box]; let mut builder = StructBuilder::new(fields, field_builders); assert!(builder.field_builder::(0).is_none()); @@ -432,12 +553,14 @@ mod tests { int_builder.append_value(2); bool_builder.append_value(true); - let mut fields = Vec::new(); - let mut field_builders = Vec::new(); - fields.push(Field::new("f1", DataType::Int32, false)); - field_builders.push(Box::new(int_builder) as Box); - fields.push(Field::new("f2", DataType::Boolean, false)); - field_builders.push(Box::new(bool_builder) as Box); + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![ + Box::new(int_builder) as Box, + Box::new(bool_builder) as Box, + ]; let mut builder = StructBuilder::new(fields, field_builders); builder.append(true); @@ -446,19 +569,50 @@ mod tests { } #[test] - #[should_panic( - expected = "Number of fields is not equal to the number of field_builders." - )] + #[should_panic(expected = "Number of fields is not equal to the number of field_builders.")] fn test_struct_array_builder_unequal_field_field_builders() { let int_builder = Int32Builder::with_capacity(10); - let mut fields = Vec::new(); - let mut field_builders = Vec::new(); - fields.push(Field::new("f1", DataType::Int32, false)); - field_builders.push(Box::new(int_builder) as Box); - fields.push(Field::new("f2", DataType::Boolean, false)); + let fields = vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Boolean, false), + ]; + let field_builders = vec![Box::new(int_builder) as Box]; let mut builder = StructBuilder::new(fields, field_builders); builder.finish(); } + + #[test] + #[should_panic( + expected = "Incorrect datatype for StructArray field \\\"timestamp\\\", expected Timestamp(Nanosecond, Some(\\\"UTC\\\")) got Timestamp(Nanosecond, None)" + )] + fn test_struct_array_mismatch_builder() { + let fields = vec![Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned().into())), + false, + )]; + + let field_builders: Vec> = + vec![Box::new(TimestampNanosecondBuilder::new())]; + + let mut sa = StructBuilder::new(fields, field_builders); + sa.finish(); + } + + #[test] + fn test_empty() { + let mut builder = StructBuilder::new(Fields::empty(), vec![]); + builder.append(true); + builder.append(false); + + let a1 = builder.finish_cloned(); + let a2 = builder.finish(); + assert_eq!(a1, a2); + assert_eq!(a1.len(), 2); + assert_eq!(a1.null_count(), 1); + assert!(a1.is_valid(0)); + assert!(a1.is_null(1)); + } } diff --git a/arrow/src/array/builder/union_builder.rs b/arrow-array/src/builder/union_builder.rs similarity index 83% rename from arrow/src/array/builder/union_builder.rs rename to arrow-array/src/builder/union_builder.rs index c0ae76853dd2..4f88c9d41b9a 100644 --- a/arrow/src/array/builder/union_builder.rs +++ b/arrow-array/src/builder/union_builder.rs @@ -15,24 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::builder::buffer_builder::{Int32BufferBuilder, Int8BufferBuilder}; +use crate::builder::BufferBuilder; +use crate::{make_array, ArrowPrimitiveType, UnionArray}; +use arrow_buffer::NullBufferBuilder; +use arrow_buffer::{ArrowNativeType, Buffer}; +use arrow_data::ArrayDataBuilder; +use arrow_schema::{ArrowError, DataType, Field}; use std::any::Any; use std::collections::HashMap; -use crate::array::ArrayDataBuilder; -use crate::array::Int32BufferBuilder; -use crate::array::Int8BufferBuilder; -use crate::array::UnionArray; -use crate::buffer::Buffer; - -use crate::datatypes::DataType; -use crate::datatypes::Field; -use crate::datatypes::{ArrowNativeType, ArrowPrimitiveType}; -use crate::error::{ArrowError, Result}; - -use super::{BufferBuilder, NullBufferBuilder}; - -use crate::array::make_array; - /// `FieldData` is a helper struct to track the state of the fields in the `UnionBuilder`. #[derive(Debug)] struct FieldData { @@ -73,11 +65,7 @@ impl FieldDataValues for BufferBuilder { impl FieldData { /// Creates a new `FieldData`. - fn new( - type_id: i8, - data_type: DataType, - capacity: usize, - ) -> Self { + fn new(type_id: i8, data_type: DataType, capacity: usize) -> Self { Self { type_id, data_type, @@ -107,13 +95,13 @@ impl FieldData { } } -/// Builder type for creating a new `UnionArray`. +/// Builder for [`UnionArray`] /// /// Example: **Dense Memory Layout** /// /// ``` -/// use arrow::array::UnionBuilder; -/// use arrow::datatypes::{Float64Type, Int32Type}; +/// # use arrow_array::builder::UnionBuilder; +/// # use arrow_array::types::{Float64Type, Int32Type}; /// /// let mut builder = UnionBuilder::new_dense(); /// builder.append::("a", 1).unwrap(); @@ -121,19 +109,19 @@ impl FieldData { /// builder.append::("a", 4).unwrap(); /// let union = builder.build().unwrap(); /// -/// assert_eq!(union.type_id(0), 0_i8); -/// assert_eq!(union.type_id(1), 1_i8); -/// assert_eq!(union.type_id(2), 0_i8); +/// assert_eq!(union.type_id(0), 0); +/// assert_eq!(union.type_id(1), 1); +/// assert_eq!(union.type_id(2), 0); /// -/// assert_eq!(union.value_offset(0), 0_i32); -/// assert_eq!(union.value_offset(1), 0_i32); -/// assert_eq!(union.value_offset(2), 1_i32); +/// assert_eq!(union.value_offset(0), 0); +/// assert_eq!(union.value_offset(1), 0); +/// assert_eq!(union.value_offset(2), 1); /// ``` /// /// Example: **Sparse Memory Layout** /// ``` -/// use arrow::array::UnionBuilder; -/// use arrow::datatypes::{Float64Type, Int32Type}; +/// # use arrow_array::builder::UnionBuilder; +/// # use arrow_array::types::{Float64Type, Int32Type}; /// /// let mut builder = UnionBuilder::new_sparse(); /// builder.append::("a", 1).unwrap(); @@ -141,13 +129,13 @@ impl FieldData { /// builder.append::("a", 4).unwrap(); /// let union = builder.build().unwrap(); /// -/// assert_eq!(union.type_id(0), 0_i8); -/// assert_eq!(union.type_id(1), 1_i8); -/// assert_eq!(union.type_id(2), 0_i8); +/// assert_eq!(union.type_id(0), 0); +/// assert_eq!(union.type_id(1), 1); +/// assert_eq!(union.type_id(2), 0); /// -/// assert_eq!(union.value_offset(0), 0_i32); -/// assert_eq!(union.value_offset(1), 1_i32); -/// assert_eq!(union.value_offset(2), 2_i32); +/// assert_eq!(union.value_offset(0), 0); +/// assert_eq!(union.value_offset(1), 1); +/// assert_eq!(union.value_offset(2), 2); /// ``` #[derive(Debug)] pub struct UnionBuilder { @@ -203,7 +191,10 @@ impl UnionBuilder { /// is part of the final array, appending a NULL requires /// specifying which field (child) to use. #[inline] - pub fn append_null(&mut self, type_name: &str) -> Result<()> { + pub fn append_null( + &mut self, + type_name: &str, + ) -> Result<(), ArrowError> { self.append_option::(type_name, None) } @@ -213,7 +204,7 @@ impl UnionBuilder { &mut self, type_name: &str, v: T::Native, - ) -> Result<()> { + ) -> Result<(), ArrowError> { self.append_option::(type_name, Some(v)) } @@ -221,13 +212,18 @@ impl UnionBuilder { &mut self, type_name: &str, v: Option, - ) -> Result<()> { + ) -> Result<(), ArrowError> { let type_name = type_name.to_string(); let mut field_data = match self.fields.remove(&type_name) { Some(data) => { if data.data_type != T::DATA_TYPE { - return Err(ArrowError::InvalidArgumentError(format!("Attempt to write col \"{}\" with type {} doesn't match existing type {}", type_name, T::DATA_TYPE, data.data_type))); + return Err(ArrowError::InvalidArgumentError(format!( + "Attempt to write col \"{}\" with type {} doesn't match existing type {}", + type_name, + T::DATA_TYPE, + data.data_type + ))); } data } @@ -278,7 +274,7 @@ impl UnionBuilder { } /// Builds this builder creating a new `UnionArray`. - pub fn build(mut self) -> Result { + pub fn build(mut self) -> Result { let type_id_buffer = self.type_id_builder.finish(); let value_offsets_buffer = self.value_offset_builder.map(|mut b| b.finish()); let mut children = Vec::new(); @@ -297,11 +293,11 @@ impl UnionBuilder { let arr_data_builder = ArrayDataBuilder::new(data_type.clone()) .add_buffer(buffer) .len(slots) - .null_bit_buffer(bitmap_builder.finish()); + .nulls(bitmap_builder.finish()); let arr_data_ref = unsafe { arr_data_builder.build_unchecked() }; let array_ref = make_array(arr_data_ref); - children.push((type_id, (Field::new(&name, data_type, false), array_ref))) + children.push((type_id, (Field::new(name, data_type, false), array_ref))) } children.sort_by(|a, b| { diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs new file mode 100644 index 000000000000..2e21f3e7e640 --- /dev/null +++ b/arrow-array/src/cast.rs @@ -0,0 +1,969 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines helper functions for downcasting [`dyn Array`](Array) to concrete types + +use crate::array::*; +use crate::types::*; +use arrow_data::ArrayData; + +/// Repeats the provided pattern based on the number of comma separated identifiers +#[doc(hidden)] +#[macro_export] +macro_rules! repeat_pat { + ($e:pat, $v_:expr) => { + $e + }; + ($e:pat, $v_:expr $(, $tail:expr)+) => { + ($e, $crate::repeat_pat!($e $(, $tail)+)) + } +} + +/// Given one or more expressions evaluating to an integer [`DataType`] invokes the provided macro +/// `m` with the corresponding integer [`ArrowPrimitiveType`], followed by any additional arguments +/// +/// ``` +/// # use arrow_array::{downcast_primitive, ArrowPrimitiveType, downcast_integer}; +/// # use arrow_schema::DataType; +/// +/// macro_rules! dictionary_key_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn dictionary_key_size(t: &DataType) -> u8 { +/// match t { +/// DataType::Dictionary(k, _) => downcast_integer! { +/// k.as_ref() => (dictionary_key_size_helper, u8), +/// _ => unreachable!(), +/// }, +/// _ => u8::MAX, +/// } +/// } +/// +/// assert_eq!(dictionary_key_size(&DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))), 4); +/// assert_eq!(dictionary_key_size(&DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8))), 8); +/// assert_eq!(dictionary_key_size(&DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8))), 2); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_integer { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + match ($($data_type),+) { + $crate::repeat_pat!(arrow_schema::DataType::Int8, $($data_type),+) => { + $m!($crate::types::Int8Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int16, $($data_type),+) => { + $m!($crate::types::Int16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int32, $($data_type),+) => { + $m!($crate::types::Int32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int64, $($data_type),+) => { + $m!($crate::types::Int64Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt8, $($data_type),+) => { + $m!($crate::types::UInt8Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt16, $($data_type),+) => { + $m!($crate::types::UInt16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt32, $($data_type),+) => { + $m!($crate::types::UInt32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::UInt64, $($data_type),+) => { + $m!($crate::types::UInt64Type $(, $args)*) + } + $($p => $fallback,)* + } + }; +} + +/// Given one or more expressions evaluating to an integer [`DataType`] invokes the provided macro +/// `m` with the corresponding integer [`RunEndIndexType`], followed by any additional arguments +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{downcast_primitive, ArrowPrimitiveType, downcast_run_end_index}; +/// # use arrow_schema::{DataType, Field}; +/// +/// macro_rules! run_end_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn run_end_index_size(t: &DataType) -> u8 { +/// match t { +/// DataType::RunEndEncoded(k, _) => downcast_run_end_index! { +/// k.data_type() => (run_end_size_helper, u8), +/// _ => unreachable!(), +/// }, +/// _ => u8::MAX, +/// } +/// } +/// +/// assert_eq!(run_end_index_size(&DataType::RunEndEncoded(Arc::new(Field::new("a", DataType::Int32, false)), Arc::new(Field::new("b", DataType::Utf8, true)))), 4); +/// assert_eq!(run_end_index_size(&DataType::RunEndEncoded(Arc::new(Field::new("a", DataType::Int64, false)), Arc::new(Field::new("b", DataType::Utf8, true)))), 8); +/// assert_eq!(run_end_index_size(&DataType::RunEndEncoded(Arc::new(Field::new("a", DataType::Int16, false)), Arc::new(Field::new("b", DataType::Utf8, true)))), 2); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_run_end_index { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + match ($($data_type),+) { + $crate::repeat_pat!(arrow_schema::DataType::Int16, $($data_type),+) => { + $m!($crate::types::Int16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int32, $($data_type),+) => { + $m!($crate::types::Int32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Int64, $($data_type),+) => { + $m!($crate::types::Int64Type $(, $args)*) + } + $($p => $fallback,)* + } + }; +} + +/// Given one or more expressions evaluating to primitive [`DataType`] invokes the provided macro +/// `m` with the corresponding [`ArrowPrimitiveType`], followed by any additional arguments +/// +/// ``` +/// # use arrow_array::{downcast_temporal, ArrowPrimitiveType}; +/// # use arrow_schema::DataType; +/// +/// macro_rules! temporal_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn temporal_size(t: &DataType) -> u8 { +/// downcast_temporal! { +/// t => (temporal_size_helper, u8), +/// _ => u8::MAX +/// } +/// } +/// +/// assert_eq!(temporal_size(&DataType::Date32), 4); +/// assert_eq!(temporal_size(&DataType::Date64), 8); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_temporal { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + match ($($data_type),+) { + $crate::repeat_pat!(arrow_schema::DataType::Time32(arrow_schema::TimeUnit::Second), $($data_type),+) => { + $m!($crate::types::Time32SecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Time32(arrow_schema::TimeUnit::Millisecond), $($data_type),+) => { + $m!($crate::types::Time32MillisecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond), $($data_type),+) => { + $m!($crate::types::Time64MicrosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Nanosecond), $($data_type),+) => { + $m!($crate::types::Time64NanosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Date32, $($data_type),+) => { + $m!($crate::types::Date32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Date64, $($data_type),+) => { + $m!($crate::types::Date64Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Second, _), $($data_type),+) => { + $m!($crate::types::TimestampSecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _), $($data_type),+) => { + $m!($crate::types::TimestampMillisecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _), $($data_type),+) => { + $m!($crate::types::TimestampMicrosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _), $($data_type),+) => { + $m!($crate::types::TimestampNanosecondType $(, $args)*) + } + $($p => $fallback,)* + } + }; +} + +/// Downcast an [`Array`] to a temporal [`PrimitiveArray`] based on its [`DataType`] +/// accepts a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, downcast_temporal_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_temporal(array: &dyn Array) { +/// downcast_temporal_array!( +/// array => { +/// for v in array { +/// println!("{:?}", v); +/// } +/// } +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_temporal_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal_array!($values => {$e} $($p => $fallback)*) + }; + (($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal_array!($($values),+ => {$e} $($p => $fallback)*) + }; + ($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal_array!(($($values),+) => $e $($p => $fallback)*) + }; + (($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_temporal!{ + $($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e), + $($p => $fallback,)* + } + }; +} + +/// Given one or more expressions evaluating to primitive [`DataType`] invokes the provided macro +/// `m` with the corresponding [`ArrowPrimitiveType`], followed by any additional arguments +/// +/// ``` +/// # use arrow_array::{downcast_primitive, ArrowPrimitiveType}; +/// # use arrow_schema::DataType; +/// +/// macro_rules! primitive_size_helper { +/// ($t:ty, $o:ty) => { +/// std::mem::size_of::<<$t as ArrowPrimitiveType>::Native>() as $o +/// }; +/// } +/// +/// fn primitive_size(t: &DataType) -> u8 { +/// downcast_primitive! { +/// t => (primitive_size_helper, u8), +/// _ => u8::MAX +/// } +/// } +/// +/// assert_eq!(primitive_size(&DataType::Int32), 4); +/// assert_eq!(primitive_size(&DataType::Int64), 8); +/// assert_eq!(primitive_size(&DataType::Float16), 2); +/// assert_eq!(primitive_size(&DataType::Decimal128(38, 10)), 16); +/// assert_eq!(primitive_size(&DataType::Decimal256(76, 20)), 32); +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_primitive { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_integer! { + $($data_type),+ => ($m $(, $args)*), + $crate::repeat_pat!(arrow_schema::DataType::Float16, $($data_type),+) => { + $m!($crate::types::Float16Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Float32, $($data_type),+) => { + $m!($crate::types::Float32Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Float64, $($data_type),+) => { + $m!($crate::types::Float64Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Decimal128(_, _), $($data_type),+) => { + $m!($crate::types::Decimal128Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Decimal256(_, _), $($data_type),+) => { + $m!($crate::types::Decimal256Type $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::YearMonth), $($data_type),+) => { + $m!($crate::types::IntervalYearMonthType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::DayTime), $($data_type),+) => { + $m!($crate::types::IntervalDayTimeType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano), $($data_type),+) => { + $m!($crate::types::IntervalMonthDayNanoType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Second), $($data_type),+) => { + $m!($crate::types::DurationSecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Millisecond), $($data_type),+) => { + $m!($crate::types::DurationMillisecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Microsecond), $($data_type),+) => { + $m!($crate::types::DurationMicrosecondType $(, $args)*) + } + $crate::repeat_pat!(arrow_schema::DataType::Duration(arrow_schema::TimeUnit::Nanosecond), $($data_type),+) => { + $m!($crate::types::DurationNanosecondType $(, $args)*) + } + _ => { + $crate::downcast_temporal! { + $($data_type),+ => ($m $(, $args)*), + $($p => $fallback,)* + } + } + } + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! downcast_primitive_array_helper { + ($t:ty, $($values:ident),+, $e:block) => {{ + $(let $values = $crate::cast::as_primitive_array::<$t>($values);)+ + $e + }}; +} + +/// Downcast an [`Array`] to a [`PrimitiveArray`] based on its [`DataType`] +/// accepts a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, downcast_primitive_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_primitive(array: &dyn Array) { +/// downcast_primitive_array!( +/// array => { +/// for v in array { +/// println!("{:?}", v); +/// } +/// } +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_primitive_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array!($values => {$e} $($p => $fallback)*) + }; + (($($values:ident),+) => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array!($($values),+ => {$e} $($p => $fallback)*) + }; + ($($values:ident),+ => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive_array!(($($values),+) => $e $($p => $fallback)*) + }; + (($($values:ident),+) => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + $crate::downcast_primitive!{ + $($values.data_type()),+ => ($crate::downcast_primitive_array_helper, $($values),+, $e), + $($p => $fallback,)* + } + }; +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`], to +/// [`PrimitiveArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, Int32Array}; +/// # use arrow_array::cast::as_primitive_array; +/// # use arrow_array::types::Int32Type; +/// +/// let arr: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); +/// +/// // Downcast an `ArrayRef` to Int32Array / PrimitiveArray: +/// let primitive_array: &Int32Array = as_primitive_array(&arr); +/// +/// // Equivalently: +/// let primitive_array = as_primitive_array::(&arr); +/// +/// // This is the equivalent of: +/// let primitive_array = arr +/// .as_any() +/// .downcast_ref::() +/// .unwrap(); +/// ``` + +pub fn as_primitive_array(arr: &dyn Array) -> &PrimitiveArray +where + T: ArrowPrimitiveType, +{ + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to primitive array") +} + +#[macro_export] +#[doc(hidden)] +macro_rules! downcast_dictionary_array_helper { + ($t:ty, $($values:ident),+, $e:block) => {{ + $(let $values = $crate::cast::as_dictionary_array::<$t>($values);)+ + $e + }}; +} + +/// Downcast an [`Array`] to a [`DictionaryArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, StringArray, downcast_dictionary_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_strings(array: &dyn Array) { +/// downcast_dictionary_array!( +/// array => match array.values().data_type() { +/// DataType::Utf8 => { +/// for v in array.downcast_dict::().unwrap() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported dictionary value type {}", t), +/// }, +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_dictionary_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_dictionary_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + arrow_schema::DataType::Dictionary(k, _) => { + $crate::downcast_integer! { + k.as_ref() => ($crate::downcast_dictionary_array_helper, $values, $e), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + $($p => $fallback,)* + } + } +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`DictionaryArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow_array::{ArrayRef, DictionaryArray}; +/// # use arrow_array::cast::as_dictionary_array; +/// # use arrow_array::types::Int32Type; +/// +/// let arr: DictionaryArray = vec![Some("foo")].into_iter().collect(); +/// let arr: ArrayRef = std::sync::Arc::new(arr); +/// let dict_array: &DictionaryArray = as_dictionary_array::(&arr); +/// ``` +pub fn as_dictionary_array(arr: &dyn Array) -> &DictionaryArray +where + T: ArrowDictionaryKeyType, +{ + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to dictionary array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`RunArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow_array::{ArrayRef, RunArray}; +/// # use arrow_array::cast::as_run_array; +/// # use arrow_array::types::Int32Type; +/// +/// let arr: RunArray = vec![Some("foo")].into_iter().collect(); +/// let arr: ArrayRef = std::sync::Arc::new(arr); +/// let run_array: &RunArray = as_run_array::(&arr); +/// ``` +pub fn as_run_array(arr: &dyn Array) -> &RunArray +where + T: RunEndIndexType, +{ + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to run array") +} + +#[macro_export] +#[doc(hidden)] +macro_rules! downcast_run_array_helper { + ($t:ty, $($values:ident),+, $e:block) => {{ + $(let $values = $crate::cast::as_run_array::<$t>($values);)+ + $e + }}; +} + +/// Downcast an [`Array`] to a [`RunArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow_array::{Array, StringArray, downcast_run_array, cast::as_string_array}; +/// # use arrow_schema::DataType; +/// +/// fn print_strings(array: &dyn Array) { +/// downcast_run_array!( +/// array => match array.values().data_type() { +/// DataType::Utf8 => { +/// for v in array.downcast::().unwrap() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported run array value type {}", t), +/// }, +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +/// [`DataType`]: arrow_schema::DataType +#[macro_export] +macro_rules! downcast_run_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_run_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + arrow_schema::DataType::RunEndEncoded(k, _) => { + $crate::downcast_run_end_index! { + k.data_type() => ($crate::downcast_run_array_helper, $values, $e), + k => unreachable!("unsupported run end index type: {}", k) + } + } + $($p => $fallback,)* + } + } +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`GenericListArray`], panicking on failure. +pub fn as_generic_list_array(arr: &dyn Array) -> &GenericListArray { + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to list array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`ListArray`], panicking on failure. +#[inline] +pub fn as_list_array(arr: &dyn Array) -> &ListArray { + as_generic_list_array::(arr) +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`FixedSizeListArray`], panicking on failure. +#[inline] +pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to fixed size list array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`LargeListArray`], panicking on failure. +#[inline] +pub fn as_large_list_array(arr: &dyn Array) -> &LargeListArray { + as_generic_list_array::(arr) +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`GenericBinaryArray`], panicking on failure. +#[inline] +pub fn as_generic_binary_array(arr: &dyn Array) -> &GenericBinaryArray { + arr.as_any() + .downcast_ref::>() + .expect("Unable to downcast to binary array") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`StringArray`], panicking on failure. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::cast::as_string_array; +/// # use arrow_array::{ArrayRef, StringArray}; +/// +/// let arr: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("foo")])); +/// let string_array = as_string_array(&arr); +/// ``` +pub fn as_string_array(arr: &dyn Array) -> &StringArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to StringArray") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`BooleanArray`], panicking on failure. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, BooleanArray}; +/// # use arrow_array::cast::as_boolean_array; +/// +/// let arr: ArrayRef = Arc::new(BooleanArray::from_iter(vec![Some(true)])); +/// let boolean_array = as_boolean_array(&arr); +/// ``` +pub fn as_boolean_array(arr: &dyn Array) -> &BooleanArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to BooleanArray") +} + +macro_rules! array_downcast_fn { + ($name: ident, $arrty: ty, $arrty_str:expr) => { + #[doc = "Force downcast of an [`Array`], such as an [`ArrayRef`] to "] + #[doc = $arrty_str] + pub fn $name(arr: &dyn Array) -> &$arrty { + arr.as_any().downcast_ref::<$arrty>().expect(concat!( + "Unable to downcast to typed array through ", + stringify!($name) + )) + } + }; + + // use recursive macro to generate dynamic doc string for a given array type + ($name: ident, $arrty: ty) => { + array_downcast_fn!( + $name, + $arrty, + concat!("[`", stringify!($arrty), "`], panicking on failure.") + ); + }; +} + +array_downcast_fn!(as_largestring_array, LargeStringArray); +array_downcast_fn!(as_null_array, NullArray); +array_downcast_fn!(as_struct_array, StructArray); +array_downcast_fn!(as_union_array, UnionArray); +array_downcast_fn!(as_map_array, MapArray); + +/// Force downcast of an Array, such as an ArrayRef to Decimal128Array, panic’ing on failure. +#[deprecated(note = "please use `as_primitive_array::` instead")] +pub fn as_decimal_array(arr: &dyn Array) -> &PrimitiveArray { + as_primitive_array::(arr) +} + +/// Downcasts a `dyn Array` to a concrete type +/// +/// ``` +/// # use arrow_array::{BooleanArray, Int32Array, RecordBatch, StringArray}; +/// # use arrow_array::cast::downcast_array; +/// struct ConcreteBatch { +/// col1: Int32Array, +/// col2: BooleanArray, +/// col3: StringArray, +/// } +/// +/// impl ConcreteBatch { +/// fn new(batch: &RecordBatch) -> Self { +/// Self { +/// col1: downcast_array(batch.column(0).as_ref()), +/// col2: downcast_array(batch.column(1).as_ref()), +/// col3: downcast_array(batch.column(2).as_ref()), +/// } +/// } +/// } +/// ``` +/// +/// # Panics +/// +/// Panics if array is not of the correct data type +pub fn downcast_array(array: &dyn Array) -> T +where + T: From, +{ + T::from(array.to_data()) +} + +mod private { + pub trait Sealed {} +} + +/// An extension trait for `dyn Array` that provides ergonomic downcasting +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, Int32Array}; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::types::Int32Type; +/// let col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; +/// assert_eq!(col.as_primitive::().values(), &[1, 2, 3]); +/// ``` +pub trait AsArray: private::Sealed { + /// Downcast this to a [`BooleanArray`] returning `None` if not possible + fn as_boolean_opt(&self) -> Option<&BooleanArray>; + + /// Downcast this to a [`BooleanArray`] panicking if not possible + fn as_boolean(&self) -> &BooleanArray { + self.as_boolean_opt().expect("boolean array") + } + + /// Downcast this to a [`PrimitiveArray`] returning `None` if not possible + fn as_primitive_opt(&self) -> Option<&PrimitiveArray>; + + /// Downcast this to a [`PrimitiveArray`] panicking if not possible + fn as_primitive(&self) -> &PrimitiveArray { + self.as_primitive_opt().expect("primitive array") + } + + /// Downcast this to a [`GenericByteArray`] returning `None` if not possible + fn as_bytes_opt(&self) -> Option<&GenericByteArray>; + + /// Downcast this to a [`GenericByteArray`] panicking if not possible + fn as_bytes(&self) -> &GenericByteArray { + self.as_bytes_opt().expect("byte array") + } + + /// Downcast this to a [`GenericStringArray`] returning `None` if not possible + fn as_string_opt(&self) -> Option<&GenericStringArray> { + self.as_bytes_opt() + } + + /// Downcast this to a [`GenericStringArray`] panicking if not possible + fn as_string(&self) -> &GenericStringArray { + self.as_bytes_opt().expect("string array") + } + + /// Downcast this to a [`GenericBinaryArray`] returning `None` if not possible + fn as_binary_opt(&self) -> Option<&GenericBinaryArray> { + self.as_bytes_opt() + } + + /// Downcast this to a [`GenericBinaryArray`] panicking if not possible + fn as_binary(&self) -> &GenericBinaryArray { + self.as_bytes_opt().expect("binary array") + } + + /// Downcast this to a [`StructArray`] returning `None` if not possible + fn as_struct_opt(&self) -> Option<&StructArray>; + + /// Downcast this to a [`StructArray`] panicking if not possible + fn as_struct(&self) -> &StructArray { + self.as_struct_opt().expect("struct array") + } + + /// Downcast this to a [`GenericListArray`] returning `None` if not possible + fn as_list_opt(&self) -> Option<&GenericListArray>; + + /// Downcast this to a [`GenericListArray`] panicking if not possible + fn as_list(&self) -> &GenericListArray { + self.as_list_opt().expect("list array") + } + + /// Downcast this to a [`FixedSizeBinaryArray`] returning `None` if not possible + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray>; + + /// Downcast this to a [`FixedSizeBinaryArray`] panicking if not possible + fn as_fixed_size_binary(&self) -> &FixedSizeBinaryArray { + self.as_fixed_size_binary_opt() + .expect("fixed size binary array") + } + + /// Downcast this to a [`FixedSizeListArray`] returning `None` if not possible + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray>; + + /// Downcast this to a [`FixedSizeListArray`] panicking if not possible + fn as_fixed_size_list(&self) -> &FixedSizeListArray { + self.as_fixed_size_list_opt() + .expect("fixed size list array") + } + + /// Downcast this to a [`MapArray`] returning `None` if not possible + fn as_map_opt(&self) -> Option<&MapArray>; + + /// Downcast this to a [`MapArray`] panicking if not possible + fn as_map(&self) -> &MapArray { + self.as_map_opt().expect("map array") + } + + /// Downcast this to a [`DictionaryArray`] returning `None` if not possible + fn as_dictionary_opt(&self) -> Option<&DictionaryArray>; + + /// Downcast this to a [`DictionaryArray`] panicking if not possible + fn as_dictionary(&self) -> &DictionaryArray { + self.as_dictionary_opt().expect("dictionary array") + } + + /// Downcasts this to a [`AnyDictionaryArray`] returning `None` if not possible + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray>; + + /// Downcasts this to a [`AnyDictionaryArray`] panicking if not possible + fn as_any_dictionary(&self) -> &dyn AnyDictionaryArray { + self.as_any_dictionary_opt().expect("any dictionary array") + } +} + +impl private::Sealed for dyn Array + '_ {} +impl AsArray for dyn Array + '_ { + fn as_boolean_opt(&self) -> Option<&BooleanArray> { + self.as_any().downcast_ref() + } + + fn as_primitive_opt(&self) -> Option<&PrimitiveArray> { + self.as_any().downcast_ref() + } + + fn as_bytes_opt(&self) -> Option<&GenericByteArray> { + self.as_any().downcast_ref() + } + + fn as_struct_opt(&self) -> Option<&StructArray> { + self.as_any().downcast_ref() + } + + fn as_list_opt(&self) -> Option<&GenericListArray> { + self.as_any().downcast_ref() + } + + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_any().downcast_ref() + } + + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { + self.as_any().downcast_ref() + } + + fn as_map_opt(&self) -> Option<&MapArray> { + self.as_any().downcast_ref() + } + + fn as_dictionary_opt(&self) -> Option<&DictionaryArray> { + self.as_any().downcast_ref() + } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + let array = self; + downcast_dictionary_array! { + array => Some(array), + _ => None + } + } +} + +impl private::Sealed for ArrayRef {} +impl AsArray for ArrayRef { + fn as_boolean_opt(&self) -> Option<&BooleanArray> { + self.as_ref().as_boolean_opt() + } + + fn as_primitive_opt(&self) -> Option<&PrimitiveArray> { + self.as_ref().as_primitive_opt() + } + + fn as_bytes_opt(&self) -> Option<&GenericByteArray> { + self.as_ref().as_bytes_opt() + } + + fn as_struct_opt(&self) -> Option<&StructArray> { + self.as_ref().as_struct_opt() + } + + fn as_list_opt(&self) -> Option<&GenericListArray> { + self.as_ref().as_list_opt() + } + + fn as_fixed_size_binary_opt(&self) -> Option<&FixedSizeBinaryArray> { + self.as_ref().as_fixed_size_binary_opt() + } + + fn as_fixed_size_list_opt(&self) -> Option<&FixedSizeListArray> { + self.as_ref().as_fixed_size_list_opt() + } + + fn as_map_opt(&self) -> Option<&MapArray> { + self.as_any().downcast_ref() + } + + fn as_dictionary_opt(&self) -> Option<&DictionaryArray> { + self.as_ref().as_dictionary_opt() + } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + self.as_ref().as_any_dictionary_opt() + } +} + +#[cfg(test)] +mod tests { + use arrow_buffer::i256; + use std::sync::Arc; + + use super::*; + + #[test] + fn test_as_primitive_array_ref() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert!(!as_primitive_array::(&array).is_empty()); + + // should also work when wrapped in an Arc + let array: ArrayRef = Arc::new(array); + assert!(!as_primitive_array::(&array).is_empty()); + } + + #[test] + fn test_as_string_array_ref() { + let array: StringArray = vec!["foo", "bar"].into_iter().map(Some).collect(); + assert!(!as_string_array(&array).is_empty()); + + // should also work when wrapped in an Arc + let array: ArrayRef = Arc::new(array); + assert!(!as_string_array(&array).is_empty()) + } + + #[test] + fn test_decimal128array() { + let a = Decimal128Array::from_iter_values([1, 2, 4, 5]); + assert!(!as_primitive_array::(&a).is_empty()); + } + + #[test] + fn test_decimal256array() { + let a = Decimal256Array::from_iter_values([1, 2, 4, 5].into_iter().map(i256::from_i128)); + assert!(!as_primitive_array::(&a).is_empty()); + } +} diff --git a/arrow-array/src/delta.rs b/arrow-array/src/delta.rs new file mode 100644 index 000000000000..d9aa4aa6de5d --- /dev/null +++ b/arrow-array/src/delta.rs @@ -0,0 +1,285 @@ +// MIT License +// +// Copyright (c) 2020-2022 Oliver Margetts +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Copied from chronoutil crate + +//! Contains utility functions for shifting Date objects. +use chrono::{DateTime, Datelike, Days, Months, TimeZone}; +use std::cmp::Ordering; + +/// Shift a date by the given number of months. +pub(crate) fn shift_months(date: D, months: i32) -> D +where + D: Datelike + std::ops::Add + std::ops::Sub, +{ + match months.cmp(&0) { + Ordering::Equal => date, + Ordering::Greater => date + Months::new(months as u32), + Ordering::Less => date - Months::new(months.unsigned_abs()), + } +} + +/// Add the given number of months to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn add_months_datetime( + dt: DateTime, + months: i32, +) -> Option> { + match months.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_add_months(Months::new(months as u32)), + Ordering::Less => dt.checked_sub_months(Months::new(months.unsigned_abs())), + } +} + +/// Add the given number of days to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn add_days_datetime(dt: DateTime, days: i32) -> Option> { + match days.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_add_days(Days::new(days as u64)), + Ordering::Less => dt.checked_sub_days(Days::new(days.unsigned_abs() as u64)), + } +} + +/// Substract the given number of months to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn sub_months_datetime( + dt: DateTime, + months: i32, +) -> Option> { + match months.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_sub_months(Months::new(months as u32)), + Ordering::Less => dt.checked_add_months(Months::new(months.unsigned_abs())), + } +} + +/// Substract the given number of days to the given datetime. +/// +/// Returns `None` when it will result in overflow. +pub(crate) fn sub_days_datetime(dt: DateTime, days: i32) -> Option> { + match days.cmp(&0) { + Ordering::Equal => Some(dt), + Ordering::Greater => dt.checked_sub_days(Days::new(days as u64)), + Ordering::Less => dt.checked_add_days(Days::new(days.unsigned_abs() as u64)), + } +} + +#[cfg(test)] +mod tests { + + use chrono::naive::{NaiveDate, NaiveDateTime, NaiveTime}; + + use super::*; + + #[test] + fn test_shift_months() { + let base = NaiveDate::from_ymd_opt(2020, 1, 31).unwrap(); + + assert_eq!( + shift_months(base, 0), + NaiveDate::from_ymd_opt(2020, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 1), + NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() + ); + assert_eq!( + shift_months(base, 2), + NaiveDate::from_ymd_opt(2020, 3, 31).unwrap() + ); + assert_eq!( + shift_months(base, 3), + NaiveDate::from_ymd_opt(2020, 4, 30).unwrap() + ); + assert_eq!( + shift_months(base, 4), + NaiveDate::from_ymd_opt(2020, 5, 31).unwrap() + ); + assert_eq!( + shift_months(base, 5), + NaiveDate::from_ymd_opt(2020, 6, 30).unwrap() + ); + assert_eq!( + shift_months(base, 6), + NaiveDate::from_ymd_opt(2020, 7, 31).unwrap() + ); + assert_eq!( + shift_months(base, 7), + NaiveDate::from_ymd_opt(2020, 8, 31).unwrap() + ); + assert_eq!( + shift_months(base, 8), + NaiveDate::from_ymd_opt(2020, 9, 30).unwrap() + ); + assert_eq!( + shift_months(base, 9), + NaiveDate::from_ymd_opt(2020, 10, 31).unwrap() + ); + assert_eq!( + shift_months(base, 10), + NaiveDate::from_ymd_opt(2020, 11, 30).unwrap() + ); + assert_eq!( + shift_months(base, 11), + NaiveDate::from_ymd_opt(2020, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, 12), + NaiveDate::from_ymd_opt(2021, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 13), + NaiveDate::from_ymd_opt(2021, 2, 28).unwrap() + ); + + assert_eq!( + shift_months(base, -1), + NaiveDate::from_ymd_opt(2019, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, -2), + NaiveDate::from_ymd_opt(2019, 11, 30).unwrap() + ); + assert_eq!( + shift_months(base, -3), + NaiveDate::from_ymd_opt(2019, 10, 31).unwrap() + ); + assert_eq!( + shift_months(base, -4), + NaiveDate::from_ymd_opt(2019, 9, 30).unwrap() + ); + assert_eq!( + shift_months(base, -5), + NaiveDate::from_ymd_opt(2019, 8, 31).unwrap() + ); + assert_eq!( + shift_months(base, -6), + NaiveDate::from_ymd_opt(2019, 7, 31).unwrap() + ); + assert_eq!( + shift_months(base, -7), + NaiveDate::from_ymd_opt(2019, 6, 30).unwrap() + ); + assert_eq!( + shift_months(base, -8), + NaiveDate::from_ymd_opt(2019, 5, 31).unwrap() + ); + assert_eq!( + shift_months(base, -9), + NaiveDate::from_ymd_opt(2019, 4, 30).unwrap() + ); + assert_eq!( + shift_months(base, -10), + NaiveDate::from_ymd_opt(2019, 3, 31).unwrap() + ); + assert_eq!( + shift_months(base, -11), + NaiveDate::from_ymd_opt(2019, 2, 28).unwrap() + ); + assert_eq!( + shift_months(base, -12), + NaiveDate::from_ymd_opt(2019, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, -13), + NaiveDate::from_ymd_opt(2018, 12, 31).unwrap() + ); + + assert_eq!( + shift_months(base, 1265), + NaiveDate::from_ymd_opt(2125, 6, 30).unwrap() + ); + } + + #[test] + fn test_shift_months_with_overflow() { + let base = NaiveDate::from_ymd_opt(2020, 12, 31).unwrap(); + + assert_eq!(shift_months(base, 0), base); + assert_eq!( + shift_months(base, 1), + NaiveDate::from_ymd_opt(2021, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 2), + NaiveDate::from_ymd_opt(2021, 2, 28).unwrap() + ); + assert_eq!( + shift_months(base, 12), + NaiveDate::from_ymd_opt(2021, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, 18), + NaiveDate::from_ymd_opt(2022, 6, 30).unwrap() + ); + + assert_eq!( + shift_months(base, -1), + NaiveDate::from_ymd_opt(2020, 11, 30).unwrap() + ); + assert_eq!( + shift_months(base, -2), + NaiveDate::from_ymd_opt(2020, 10, 31).unwrap() + ); + assert_eq!( + shift_months(base, -10), + NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() + ); + assert_eq!( + shift_months(base, -12), + NaiveDate::from_ymd_opt(2019, 12, 31).unwrap() + ); + assert_eq!( + shift_months(base, -18), + NaiveDate::from_ymd_opt(2019, 6, 30).unwrap() + ); + } + + #[test] + fn test_shift_months_datetime() { + let date = NaiveDate::from_ymd_opt(2020, 1, 31).unwrap(); + let o_clock = NaiveTime::from_hms_opt(1, 2, 3).unwrap(); + + let base = NaiveDateTime::new(date, o_clock); + + assert_eq!( + shift_months(base, 0).date(), + NaiveDate::from_ymd_opt(2020, 1, 31).unwrap() + ); + assert_eq!( + shift_months(base, 1).date(), + NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() + ); + assert_eq!( + shift_months(base, 2).date(), + NaiveDate::from_ymd_opt(2020, 3, 31).unwrap() + ); + assert_eq!(shift_months(base, 0).time(), o_clock); + assert_eq!(shift_months(base, 1).time(), o_clock); + assert_eq!(shift_months(base, 2).time(), o_clock); + } +} diff --git a/arrow/src/array/iterator.rs b/arrow-array/src/iterator.rs similarity index 73% rename from arrow/src/array/iterator.rs rename to arrow-array/src/iterator.rs index 4269e99625b7..3f9cc0d525c1 100644 --- a/arrow/src/array/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -15,20 +15,39 @@ // specific language governing permissions and limitations // under the License. -use crate::array::array::ArrayAccessor; -use crate::array::{DecimalArray, FixedSizeBinaryArray}; -use crate::datatypes::{Decimal128Type, Decimal256Type}; +//! Idiomatic iterators for [`Array`](crate::Array) -use super::{ - BooleanArray, GenericBinaryArray, GenericListArray, GenericStringArray, - PrimitiveArray, +use crate::array::{ + ArrayAccessor, BooleanArray, FixedSizeBinaryArray, GenericBinaryArray, GenericListArray, + GenericStringArray, PrimitiveArray, }; - -/// an iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] -// Note: This implementation is based on std's [Vec]s' [IntoIter]. +use crate::{FixedSizeListArray, MapArray}; +use arrow_buffer::NullBuffer; + +/// An iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] +/// +/// # Performance +/// +/// [`ArrayIter`] provides an idiomatic way to iterate over an array, however, this +/// comes at the cost of performance. In particular the interleaved handling of +/// the null mask is often sub-optimal. +/// +/// If performing an infallible operation, it is typically faster to perform the operation +/// on every index of the array, and handle the null mask separately. For [`PrimitiveArray`] +/// this functionality is provided by [`compute::unary`] +/// +/// If performing a fallible operation, it isn't possible to perform the operation independently +/// of the null mask, as this might result in a spurious failure on a null index. However, +/// there are more efficient ways to iterate over just the non-null indices, this functionality +/// is provided by [`compute::try_unary`] +/// +/// [`PrimitiveArray`]: crate::PrimitiveArray +/// [`compute::unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.unary.html +/// [`compute::try_unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.try_unary.html #[derive(Debug)] pub struct ArrayIter { array: T, + logical_nulls: Option, current: usize, current_end: usize, } @@ -37,12 +56,22 @@ impl ArrayIter { /// create a new iterator pub fn new(array: T) -> Self { let len = array.len(); + let logical_nulls = array.logical_nulls(); ArrayIter { array, + logical_nulls, current: 0, current_end: len, } } + + #[inline] + fn is_null(&self, idx: usize) -> bool { + self.logical_nulls + .as_ref() + .map(|x| x.is_null(idx)) + .unwrap_or_default() + } } impl Iterator for ArrayIter { @@ -52,7 +81,7 @@ impl Iterator for ArrayIter { fn next(&mut self) -> Option { if self.current == self.current_end { None - } else if self.array.is_null(self.current) { + } else if self.is_null(self.current) { self.current += 1; Some(None) } else { @@ -81,7 +110,7 @@ impl DoubleEndedIterator for ArrayIter { None } else { self.current_end -= 1; - Some(if self.array.is_null(self.current_end) { + Some(if self.is_null(self.current_end) { None } else { // Safety: @@ -100,20 +129,20 @@ impl ExactSizeIterator for ArrayIter {} /// an iterator that returns Some(T) or None, that can be used on any PrimitiveArray pub type PrimitiveIter<'a, T> = ArrayIter<&'a PrimitiveArray>; +/// an iterator that returns Some(T) or None, that can be used on any BooleanArray pub type BooleanIter<'a> = ArrayIter<&'a BooleanArray>; +/// an iterator that returns Some(T) or None, that can be used on any Utf8Array pub type GenericStringIter<'a, T> = ArrayIter<&'a GenericStringArray>; +/// an iterator that returns Some(T) or None, that can be used on any BinaryArray pub type GenericBinaryIter<'a, T> = ArrayIter<&'a GenericBinaryArray>; +/// an iterator that returns Some(T) or None, that can be used on any FixedSizeBinaryArray pub type FixedSizeBinaryIter<'a> = ArrayIter<&'a FixedSizeBinaryArray>; +/// an iterator that returns Some(T) or None, that can be used on any FixedSizeListArray +pub type FixedSizeListIter<'a> = ArrayIter<&'a FixedSizeListArray>; +/// an iterator that returns Some(T) or None, that can be used on any ListArray pub type GenericListArrayIter<'a, O> = ArrayIter<&'a GenericListArray>; - -pub type DecimalIter<'a, T> = ArrayIter<&'a DecimalArray>; -/// an iterator that returns `Some(Decimal128)` or `None`, that can be used on a -/// [`super::Decimal128Array`] -pub type Decimal128Iter<'a> = DecimalIter<'a, Decimal128Type>; - -/// an iterator that returns `Some(Decimal256)` or `None`, that can be used on a -/// [`super::Decimal256Array`] -pub type Decimal256Iter<'a> = DecimalIter<'a, Decimal256Type>; +/// an iterator that returns Some(T) or None, that can be used on any MapArray +pub type MapArrayIter<'a> = ArrayIter<&'a MapArray>; #[cfg(test)] mod tests { @@ -158,8 +187,7 @@ mod tests { #[test] fn test_string_array_iter_round_trip() { - let array = - StringArray::from(vec![Some("a"), None, Some("aaa"), None, Some("aaaaa")]); + let array = StringArray::from(vec![Some("a"), None, Some("aaa"), None, Some("aaaaa")]); let array = Arc::new(array) as ArrayRef; let array = array.as_any().downcast_ref::().unwrap(); @@ -182,8 +210,7 @@ mod tests { // check if DoubleEndedIterator is implemented let result: StringArray = array.iter().rev().collect(); - let rev_array = - StringArray::from(vec![Some("aaaaa"), None, Some("aaa"), None, Some("a")]); + let rev_array = StringArray::from(vec![Some("aaaaa"), None, Some("aaa"), None, Some("a")]); assert_eq!(result, rev_array); // check if ExactSizeIterator is implemented let _ = array.iter().rposition(|opt_b| opt_b == Some("a")); diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs new file mode 100644 index 000000000000..ef98c5efefb0 --- /dev/null +++ b/arrow-array/src/lib.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The central type in Apache Arrow are arrays, which are a known-length sequence of values +//! all having the same type. This crate provides concrete implementations of each type, as +//! well as an [`Array`] trait that can be used for type-erasure. +//! +//! # Building an Array +//! +//! Most [`Array`] implementations can be constructed directly from iterators or [`Vec`] +//! +//! ``` +//! # use arrow_array::{Int32Array, ListArray, StringArray}; +//! # use arrow_array::types::Int32Type; +//! # +//! Int32Array::from(vec![1, 2]); +//! Int32Array::from(vec![Some(1), None]); +//! Int32Array::from_iter([1, 2, 3, 4]); +//! Int32Array::from_iter([Some(1), Some(2), None, Some(4)]); +//! +//! StringArray::from(vec!["foo", "bar"]); +//! StringArray::from(vec![Some("foo"), None]); +//! StringArray::from_iter([Some("foo"), None]); +//! StringArray::from_iter_values(["foo", "bar"]); +//! +//! ListArray::from_iter_primitive::([ +//! Some(vec![Some(1), None, Some(3)]), +//! None, +//! Some(vec![]) +//! ]); +//! ``` +//! +//! Additionally [`ArrayBuilder`](builder::ArrayBuilder) implementations can be +//! used to construct arrays with a push-based interface +//! +//! ``` +//! # use arrow_array::Int16Array; +//! # +//! // Create a new builder with a capacity of 100 +//! let mut builder = Int16Array::builder(100); +//! +//! // Append a single primitive value +//! builder.append_value(1); +//! // Append a null value +//! builder.append_null(); +//! // Append a slice of primitive values +//! builder.append_slice(&[2, 3, 4]); +//! +//! // Build the array +//! let array = builder.finish(); +//! +//! assert_eq!(5, array.len()); +//! assert_eq!(2, array.value(2)); +//! assert_eq!(&array.values()[3..5], &[3, 4]) +//! ``` +//! +//! # Low-level API +//! +//! Internally, arrays consist of one or more shared memory regions backed by a [`Buffer`], +//! the number and meaning of which depend on the array’s data type, as documented in +//! the [Arrow specification]. +//! +//! For example, the type [`Int16Array`] represents an array of 16-bit integers and consists of: +//! +//! * An optional [`NullBuffer`] identifying any null values +//! * A contiguous [`ScalarBuffer`] of values +//! +//! Similarly, the type [`StringArray`] represents an array of UTF-8 strings and consists of: +//! +//! * An optional [`NullBuffer`] identifying any null values +//! * An offsets [`OffsetBuffer`] identifying valid UTF-8 sequences within the values buffer +//! * A values [`Buffer`] of UTF-8 encoded string data +//! +//! Array constructors such as [`PrimitiveArray::try_new`] provide the ability to cheaply +//! construct an array from these parts, with functions such as [`PrimitiveArray::into_parts`] +//! providing the reverse operation. +//! +//! ``` +//! # use arrow_array::{Array, Int32Array, StringArray}; +//! # use arrow_buffer::OffsetBuffer; +//! # +//! // Create a Int32Array from Vec without copying +//! let array = Int32Array::new(vec![1, 2, 3].into(), None); +//! assert_eq!(array.values(), &[1, 2, 3]); +//! assert_eq!(array.null_count(), 0); +//! +//! // Create a StringArray from parts +//! let offsets = OffsetBuffer::new(vec![0, 5, 10].into()); +//! let array = StringArray::new(offsets, b"helloworld".into(), None); +//! let values: Vec<_> = array.iter().map(|x| x.unwrap()).collect(); +//! assert_eq!(values, &["hello", "world"]); +//! ``` +//! +//! As [`Buffer`], and its derivatives, can be created from [`Vec`] without copying, this provides +//! an efficient way to not only interoperate with other Rust code, but also implement kernels +//! optimised for the arrow data layout - e.g. by handling buffers instead of values. +//! +//! # Zero-Copy Slicing +//! +//! Given an [`Array`] of arbitrary length, it is possible to create an owned slice of this +//! data. Internally this just increments some ref-counts, and so is incredibly cheap +//! +//! ```rust +//! # use arrow_array::Int32Array; +//! let array = Int32Array::from_iter([1, 2, 3]); +//! +//! // Slice with offset 1 and length 2 +//! let sliced = array.slice(1, 2); +//! assert_eq!(sliced.values(), &[2, 3]); +//! ``` +//! +//! # Downcasting an Array +//! +//! Arrays are often passed around as a dynamically typed [`&dyn Array`] or [`ArrayRef`]. +//! For example, [`RecordBatch`](`crate::RecordBatch`) stores columns as [`ArrayRef`]. +//! +//! Whilst these arrays can be passed directly to the [`compute`], [`csv`], [`json`], etc... APIs, +//! it is often the case that you wish to interact with the concrete arrays directly. +//! +//! This requires downcasting to the concrete type of the array: +//! +//! ``` +//! # use arrow_array::{Array, Float32Array, Int32Array}; +//! +//! // Safely downcast an `Array` to an `Int32Array` and compute the sum +//! // using native i32 values +//! fn sum_int32(array: &dyn Array) -> i32 { +//! let integers: &Int32Array = array.as_any().downcast_ref().unwrap(); +//! integers.iter().map(|val| val.unwrap_or_default()).sum() +//! } +//! +//! // Safely downcasts the array to a `Float32Array` and returns a &[f32] view of the data +//! // Note: the values for positions corresponding to nulls will be arbitrary (but still valid f32) +//! fn as_f32_slice(array: &dyn Array) -> &[f32] { +//! array.as_any().downcast_ref::().unwrap().values() +//! } +//! ``` +//! +//! The [`cast::AsArray`] extension trait can make this more ergonomic +//! +//! ``` +//! # use arrow_array::Array; +//! # use arrow_array::cast::{AsArray, as_primitive_array}; +//! # use arrow_array::types::Float32Type; +//! +//! fn as_f32_slice(array: &dyn Array) -> &[f32] { +//! array.as_primitive::().values() +//! } +//! ``` +//! +//! [`ScalarBuffer`]: arrow_buffer::ScalarBuffer +//! [`ScalarBuffer`]: arrow_buffer::ScalarBuffer +//! [`OffsetBuffer`]: arrow_buffer::OffsetBuffer +//! [`NullBuffer`]: arrow_buffer::NullBuffer +//! [Arrow specification]: https://arrow.apache.org/docs/format/Columnar.html +//! [`&dyn Array`]: Array +//! [`NullBuffer`]: arrow_buffer::NullBuffer +//! [`Buffer`]: arrow_buffer::Buffer +//! [`compute`]: https://docs.rs/arrow/latest/arrow/compute/index.html +//! [`json`]: https://docs.rs/arrow/latest/arrow/json/index.html +//! [`csv`]: https://docs.rs/arrow/latest/arrow/csv/index.html + +#![deny(rustdoc::broken_intra_doc_links)] +#![warn(missing_docs)] + +pub mod array; +pub use array::*; + +mod record_batch; +pub use record_batch::{ + RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, RecordBatchWriter, +}; + +mod arithmetic; +pub use arithmetic::ArrowNativeTypeOp; + +mod numeric; +pub use numeric::*; + +mod scalar; +pub use scalar::*; + +pub mod builder; +pub mod cast; +mod delta; +pub mod iterator; +pub mod run_iterator; +pub mod temporal_conversions; +pub mod timezone; +mod trusted_len; +pub mod types; + +#[cfg(test)] +mod tests { + use crate::builder::*; + + #[test] + fn test_buffer_builder_availability() { + let _builder = Int8BufferBuilder::new(10); + let _builder = Int16BufferBuilder::new(10); + let _builder = Int32BufferBuilder::new(10); + let _builder = Int64BufferBuilder::new(10); + let _builder = UInt16BufferBuilder::new(10); + let _builder = UInt32BufferBuilder::new(10); + let _builder = Float32BufferBuilder::new(10); + let _builder = Float64BufferBuilder::new(10); + let _builder = TimestampSecondBufferBuilder::new(10); + let _builder = TimestampMillisecondBufferBuilder::new(10); + let _builder = TimestampMicrosecondBufferBuilder::new(10); + let _builder = TimestampNanosecondBufferBuilder::new(10); + let _builder = Date32BufferBuilder::new(10); + let _builder = Date64BufferBuilder::new(10); + let _builder = Time32SecondBufferBuilder::new(10); + let _builder = Time32MillisecondBufferBuilder::new(10); + let _builder = Time64MicrosecondBufferBuilder::new(10); + let _builder = Time64NanosecondBufferBuilder::new(10); + let _builder = IntervalYearMonthBufferBuilder::new(10); + let _builder = IntervalDayTimeBufferBuilder::new(10); + let _builder = IntervalMonthDayNanoBufferBuilder::new(10); + let _builder = DurationSecondBufferBuilder::new(10); + let _builder = DurationMillisecondBufferBuilder::new(10); + let _builder = DurationMicrosecondBufferBuilder::new(10); + let _builder = DurationNanosecondBufferBuilder::new(10); + } +} diff --git a/arrow/src/csv/mod.rs b/arrow-array/src/numeric.rs similarity index 73% rename from arrow/src/csv/mod.rs rename to arrow-array/src/numeric.rs index ffe82f335801..a3cd7bde5d36 100644 --- a/arrow/src/csv/mod.rs +++ b/arrow-array/src/numeric.rs @@ -15,13 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! Transfer data between the Arrow memory format and CSV (comma-separated values). +use crate::ArrowPrimitiveType; -pub mod reader; -pub mod writer; +/// A subtype of primitive type that represents numeric values. +pub trait ArrowNumericType: ArrowPrimitiveType {} -pub use self::reader::infer_schema_from_files; -pub use self::reader::Reader; -pub use self::reader::ReaderBuilder; -pub use self::writer::Writer; -pub use self::writer::WriterBuilder; +impl ArrowNumericType for T {} diff --git a/arrow/src/record_batch.rs b/arrow-array/src/record_batch.rs similarity index 63% rename from arrow/src/record_batch.rs rename to arrow-array/src/record_batch.rs index 47257b496c1b..4e859fdfe7ea 100644 --- a/arrow/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -16,17 +16,50 @@ // under the License. //! A two-dimensional batch of column-oriented data with a defined -//! [schema](crate::datatypes::Schema). +//! [schema](arrow_schema::Schema). +use crate::{new_empty_array, Array, ArrayRef, StructArray}; +use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef}; +use std::ops::Index; use std::sync::Arc; -use crate::array::*; -use crate::compute::kernels::concat::concat; -use crate::datatypes::*; -use crate::error::{ArrowError, Result}; +/// Trait for types that can read `RecordBatch`'s. +/// +/// To create from an iterator, see [RecordBatchIterator]. +pub trait RecordBatchReader: Iterator> { + /// Returns the schema of this `RecordBatchReader`. + /// + /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this + /// reader should have the same schema as returned from this method. + fn schema(&self) -> SchemaRef; + + /// Reads the next `RecordBatch`. + #[deprecated( + since = "2.0.0", + note = "This method is deprecated in favour of `next` from the trait Iterator." + )] + fn next_batch(&mut self) -> Result, ArrowError> { + self.next().transpose() + } +} + +impl RecordBatchReader for Box { + fn schema(&self) -> SchemaRef { + self.as_ref().schema() + } +} + +/// Trait for types that can write `RecordBatch`'s. +pub trait RecordBatchWriter { + /// Write a single batch to the writer. + fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>; + + /// Write footer or termination data, then mark the writer as done. + fn close(self) -> Result<(), ArrowError>; +} /// A two-dimensional batch of column-oriented data with a defined -/// [schema](crate::datatypes::Schema). +/// [schema](arrow_schema::Schema). /// /// A `RecordBatch` is a two-dimensional dataset of a number of /// contiguous arrays, each the same length. @@ -35,8 +68,6 @@ use crate::error::{ArrowError, Result}; /// /// Record batches are a convenient unit of work for various /// serialization and computation functions, possibly incremental. -/// See also [CSV reader](crate::csv::Reader) and -/// [JSON reader](crate::json::Reader). #[derive(Clone, Debug, PartialEq)] pub struct RecordBatch { schema: SchemaRef, @@ -62,12 +93,10 @@ impl RecordBatch { /// # Example /// /// ``` - /// use std::sync::Arc; - /// use arrow::array::Int32Array; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; /// - /// # fn main() -> arrow::error::Result<()> { /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false) @@ -76,12 +105,10 @@ impl RecordBatch { /// let batch = RecordBatch::try_new( /// Arc::new(schema), /// vec![Arc::new(id_array)] - /// )?; - /// # Ok(()) - /// # } + /// ).unwrap(); /// ``` - pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { - let options = RecordBatchOptions::default(); + pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { + let options = RecordBatchOptions::new(); Self::try_new_impl(schema, columns, &options) } @@ -93,7 +120,7 @@ impl RecordBatch { schema: SchemaRef, columns: Vec, options: &RecordBatchOptions, - ) -> Result { + ) -> Result { Self::try_new_impl(schema, columns, options) } @@ -118,7 +145,7 @@ impl RecordBatch { schema: SchemaRef, columns: Vec, options: &RecordBatchOptions, - ) -> Result { + ) -> Result { // check that number of fields in schema match column length if schema.fields().len() != columns.len() { return Err(ArrowError::InvalidArgumentError(format!( @@ -128,7 +155,6 @@ impl RecordBatch { ))); } - // check that all columns have the same row count let row_count = options .row_count .or_else(|| columns.first().map(|col| col.len())) @@ -147,11 +173,10 @@ impl RecordBatch { } } + // check that all columns have the same row count if columns.iter().any(|c| c.len() != row_count) { let err = match options.row_count { - Some(_) => { - "all columns in a record batch must have the specified row count" - } + Some(_) => "all columns in a record batch must have the specified row count", None => "all columns in a record batch must have the same length", }; return Err(ArrowError::InvalidArgumentError(err.to_string())); @@ -160,9 +185,7 @@ impl RecordBatch { // function for comparing column type and field type // return true if 2 types are not matched let type_not_match = if options.match_field_names { - |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { - col_type != field_type - } + |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type } else { |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { !col_type.equals_datatype(field_type) @@ -179,10 +202,7 @@ impl RecordBatch { if let Some((i, (col_type, field_type))) = not_match { return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {:?} but found {:?} at column index {}", - field_type, - col_type, - i))); + "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}"))); } Ok(RecordBatch { @@ -192,13 +212,32 @@ impl RecordBatch { }) } - /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. + /// Override the schema of this [`RecordBatch`] + /// + /// Returns an error if `schema` is not a superset of the current schema + /// as determined by [`Schema::contains`] + pub fn with_schema(self, schema: SchemaRef) -> Result { + if !schema.contains(self.schema.as_ref()) { + return Err(ArrowError::SchemaError(format!( + "{schema} is not a superset of {}", + self.schema + ))); + } + + Ok(Self { + schema, + columns: self.columns, + row_count: self.row_count, + }) + } + + /// Returns the [`Schema`] of the record batch. pub fn schema(&self) -> SchemaRef { self.schema.clone() } /// Projects the schema onto the specified columns - pub fn project(&self, indices: &[usize]) -> Result { + pub fn project(&self, indices: &[usize]) -> Result { let projected_schema = self.schema.project(indices)?; let batch_fields = indices .iter() @@ -211,9 +250,16 @@ impl RecordBatch { )) }) }) - .collect::>>()?; - - RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields) + .collect::, _>>()?; + + RecordBatch::try_new_with_options( + SchemaRef::new(projected_schema), + batch_fields, + &RecordBatchOptions { + match_field_names: true, + row_count: Some(self.row_count), + }, + ) } /// Returns the number of columns in the record batch. @@ -221,22 +267,18 @@ impl RecordBatch { /// # Example /// /// ``` - /// use std::sync::Arc; - /// use arrow::array::Int32Array; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; /// - /// # fn main() -> arrow::error::Result<()> { /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false) /// ]); /// - /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)])?; + /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); /// /// assert_eq!(batch.num_columns(), 1); - /// # Ok(()) - /// # } /// ``` pub fn num_columns(&self) -> usize { self.columns.len() @@ -247,22 +289,18 @@ impl RecordBatch { /// # Example /// /// ``` - /// use std::sync::Arc; - /// use arrow::array::Int32Array; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{Int32Array, RecordBatch}; + /// # use arrow_schema::{DataType, Field, Schema}; /// - /// # fn main() -> arrow::error::Result<()> { /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); /// let schema = Schema::new(vec![ /// Field::new("id", DataType::Int32, false) /// ]); /// - /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)])?; + /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); /// /// assert_eq!(batch.num_rows(), 5); - /// # Ok(()) - /// # } /// ``` pub fn num_rows(&self) -> usize { self.row_count @@ -277,11 +315,52 @@ impl RecordBatch { &self.columns[index] } + /// Get a reference to a column's array by name. + pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> { + self.schema() + .column_with_name(name) + .map(|(index, _)| &self.columns[index]) + } + /// Get a reference to all columns in the record batch. pub fn columns(&self) -> &[ArrayRef] { &self.columns[..] } + /// Remove column by index and return it. + /// + /// Return the `ArrayRef` if the column is removed. + /// + /// # Panics + /// + /// Panics if `index`` out of bounds. + /// + /// # Example + /// + /// ``` + /// use std::sync::Arc; + /// use arrow_array::{BooleanArray, Int32Array, RecordBatch}; + /// use arrow_schema::{DataType, Field, Schema}; + /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]); + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("bool", DataType::Boolean, false), + /// ]); + /// + /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap(); + /// + /// let removed_column = batch.remove_column(0); + /// assert_eq!(removed_column.as_any().downcast_ref::().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5])); + /// assert_eq!(batch.num_columns(), 1); + /// ``` + pub fn remove_column(&mut self, index: usize) -> ArrayRef { + let mut builder = SchemaBuilder::from(self.schema.fields()); + builder.remove(index); + self.schema = Arc::new(builder.finish()); + self.columns.remove(index) + } + /// Return a new RecordBatch where each column is sliced /// according to `offset` and `length` /// @@ -316,10 +395,8 @@ impl RecordBatch { /// /// Example: /// ``` - /// use std::sync::Arc; - /// use arrow::array::{ArrayRef, Int32Array, StringArray}; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; /// /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); @@ -329,7 +406,7 @@ impl RecordBatch { /// ("b", b), /// ]); /// ``` - pub fn try_from_iter(value: I) -> Result + pub fn try_from_iter(value: I) -> Result where I: IntoIterator, F: AsRef, @@ -353,10 +430,8 @@ impl RecordBatch { /// /// Example: /// ``` - /// use std::sync::Arc; - /// use arrow::array::{ArrayRef, Int32Array, StringArray}; - /// use arrow::datatypes::{Schema, Field, DataType}; - /// use arrow::record_batch::RecordBatch; + /// # use std::sync::Arc; + /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; /// /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")])); @@ -368,54 +443,32 @@ impl RecordBatch { /// ("b", b, true), /// ]); /// ``` - pub fn try_from_iter_with_nullable(value: I) -> Result + pub fn try_from_iter_with_nullable(value: I) -> Result where I: IntoIterator, F: AsRef, { - // TODO: implement `TryFrom` trait, once - // https://github.com/rust-lang/rust/issues/50133 is no longer an - // issue - let (fields, columns) = value - .into_iter() - .map(|(field_name, array, nullable)| { - let field_name = field_name.as_ref(); - let field = Field::new(field_name, array.data_type().clone(), nullable); - (field, array) - }) - .unzip(); + let iter = value.into_iter(); + let capacity = iter.size_hint().0; + let mut schema = SchemaBuilder::with_capacity(capacity); + let mut columns = Vec::with_capacity(capacity); + + for (field_name, array, nullable) in iter { + let field_name = field_name.as_ref(); + schema.push(Field::new(field_name, array.data_type().clone(), nullable)); + columns.push(array); + } - let schema = Arc::new(Schema::new(fields)); + let schema = Arc::new(schema.finish()); RecordBatch::try_new(schema, columns) } - /// Concatenates `batches` together into a single record batch. - pub fn concat(schema: &SchemaRef, batches: &[Self]) -> Result { - if batches.is_empty() { - return Ok(RecordBatch::new_empty(schema.clone())); - } - if let Some((i, _)) = batches + /// Returns the total number of bytes of memory occupied physically by this batch. + pub fn get_array_memory_size(&self) -> usize { + self.columns() .iter() - .enumerate() - .find(|&(_, batch)| batch.schema() != *schema) - { - return Err(ArrowError::InvalidArgumentError(format!( - "batches[{}] schema is different with argument schema.", - i - ))); - } - let field_num = schema.fields().len(); - let mut arrays = Vec::with_capacity(field_num); - for i in 0..field_num { - let array = concat( - &batches - .iter() - .map(|batch| batch.column(i).as_ref()) - .collect::>(), - )?; - arrays.push(array); - } - Self::try_new(schema.clone(), arrays) + .map(|array| array.get_array_memory_size()) + .sum() } } @@ -430,71 +483,146 @@ pub struct RecordBatchOptions { pub row_count: Option, } -impl Default for RecordBatchOptions { - fn default() -> Self { +impl RecordBatchOptions { + /// Creates a new `RecordBatchOptions` + pub fn new() -> Self { Self { match_field_names: true, row_count: None, } } + /// Sets the row_count of RecordBatchOptions and returns self + pub fn with_row_count(mut self, row_count: Option) -> Self { + self.row_count = row_count; + self + } + /// Sets the match_field_names of RecordBatchOptions and returns self + pub fn with_match_field_names(mut self, match_field_names: bool) -> Self { + self.match_field_names = match_field_names; + self + } +} +impl Default for RecordBatchOptions { + fn default() -> Self { + Self::new() + } +} +impl From for RecordBatch { + fn from(value: StructArray) -> Self { + let row_count = value.len(); + let (fields, columns, nulls) = value.into_parts(); + assert_eq!( + nulls.map(|n| n.null_count()).unwrap_or_default(), + 0, + "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" + ); + + RecordBatch { + schema: Arc::new(Schema::new(fields)), + row_count, + columns, + } + } } impl From<&StructArray> for RecordBatch { - /// Create a record batch from struct array, where each field of - /// the `StructArray` becomes a `Field` in the schema. - /// - /// This currently does not flatten and nested struct types fn from(struct_array: &StructArray) -> Self { - if let DataType::Struct(fields) = struct_array.data_type() { - let schema = Schema::new(fields.clone()); - let columns = struct_array.boxed_fields.clone(); - RecordBatch { - schema: Arc::new(schema), - row_count: struct_array.len(), - columns, - } - } else { - unreachable!("unable to get datatype as struct") - } + struct_array.clone().into() } } -impl From for StructArray { - fn from(batch: RecordBatch) -> Self { - batch - .schema - .fields - .iter() - .zip(batch.columns.iter()) - .map(|t| (t.0.clone(), t.1.clone())) - .collect::>() - .into() +impl Index<&str> for RecordBatch { + type Output = ArrayRef; + + /// Get a reference to a column's array by name. + /// + /// # Panics + /// + /// Panics if the name is not in the schema. + fn index(&self, name: &str) -> &Self::Output { + self.column_by_name(name).unwrap() } } -/// Trait for types that can read `RecordBatch`'s. -pub trait RecordBatchReader: Iterator> { - /// Returns the schema of this `RecordBatchReader`. +/// Generic implementation of [RecordBatchReader] that wraps an iterator. +/// +/// # Example +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader}; +/// # +/// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); +/// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); +/// +/// let record_batch = RecordBatch::try_from_iter(vec![ +/// ("a", a), +/// ("b", b), +/// ]).unwrap(); +/// +/// let batches: Vec = vec![record_batch.clone(), record_batch.clone()]; +/// +/// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema()); +/// +/// assert_eq!(reader.schema(), record_batch.schema()); +/// assert_eq!(reader.next().unwrap().unwrap(), record_batch); +/// # assert_eq!(reader.next().unwrap().unwrap(), record_batch); +/// # assert!(reader.next().is_none()); +/// ``` +pub struct RecordBatchIterator +where + I: IntoIterator>, +{ + inner: I::IntoIter, + inner_schema: SchemaRef, +} + +impl RecordBatchIterator +where + I: IntoIterator>, +{ + /// Create a new [RecordBatchIterator]. /// - /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this - /// reader should have the same schema as returned from this method. - fn schema(&self) -> SchemaRef; + /// If `iter` is an infallible iterator, use `.map(Ok)`. + pub fn new(iter: I, schema: SchemaRef) -> Self { + Self { + inner: iter.into_iter(), + inner_schema: schema, + } + } +} - /// Reads the next `RecordBatch`. - #[deprecated( - since = "2.0.0", - note = "This method is deprecated in favour of `next` from the trait Iterator." - )] - fn next_batch(&mut self) -> Result> { - self.next().transpose() +impl Iterator for RecordBatchIterator +where + I: IntoIterator>, +{ + type Item = I::Item; + + fn next(&mut self) -> Option { + self.inner.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl RecordBatchReader for RecordBatchIterator +where + I: IntoIterator>, +{ + fn schema(&self) -> SchemaRef { + self.inner_schema.clone() } } #[cfg(test)] mod tests { use super::*; - - use crate::buffer::Buffer; + use crate::{BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray}; + use arrow_buffer::{Buffer, ToByteSlice}; + use arrow_data::{ArrayData, ArrayDataBuilder}; + use arrow_schema::Fields; #[test] fn create_record_batch() { @@ -507,18 +635,32 @@ mod tests { let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); check_batch(record_batch, 5) } + #[test] + fn byte_size_should_not_regress() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); + + let record_batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); + assert_eq!(record_batch.get_array_memory_size(), 364); + } + fn check_batch(record_batch: RecordBatch, num_rows: usize) { assert_eq!(num_rows, record_batch.num_rows()); assert_eq!(2, record_batch.num_columns()); assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type()); - assert_eq!(num_rows, record_batch.column(0).data().len()); - assert_eq!(num_rows, record_batch.column(1).data().len()); + assert_eq!(num_rows, record_batch.column(0).len()); + assert_eq!(num_rows, record_batch.column(1).len()); } #[test] @@ -534,8 +676,7 @@ mod tests { let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]); let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); let offset = 2; let length = 5; @@ -559,7 +700,7 @@ mod tests { #[test] #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")] fn create_record_batch_slice_empty_batch() { - let schema = Schema::new(vec![]); + let schema = Schema::empty(); let record_batch = RecordBatch::new_empty(Arc::new(schema)); @@ -584,8 +725,8 @@ mod tests { ])); let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); - let record_batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion"); let expected_schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), @@ -601,11 +742,9 @@ mod tests { let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); // Note there are no nulls in a or b, but we specify that b is nullable - let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ - ("a", a, false), - ("b", b, true), - ]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)]) + .expect("valid conversion"); let expected_schema = Schema::new(vec![ Field::new("a", DataType::Int32, false), @@ -627,34 +766,29 @@ mod tests { #[test] fn create_record_batch_field_name_mismatch() { - let struct_fields = vec![ + let fields = vec![ Field::new("a1", DataType::Int32, false), - Field::new( - "a2", - DataType::List(Box::new(Field::new("item", DataType::Int8, false))), - false, - ), + Field::new_list("a2", Field::new("item", DataType::Int8, false), false), ]; - let struct_type = DataType::Struct(struct_fields); - let schema = Arc::new(Schema::new(vec![Field::new("a", struct_type, true)])); + let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)])); let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); let a2_child = Int8Array::from(vec![1, 2, 3, 4]); - let a2 = ArrayDataBuilder::new(DataType::List(Box::new(Field::new( + let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( "array", DataType::Int8, false, )))) .add_child_data(a2_child.into_data()) .len(2) - .add_buffer(Buffer::from(vec![0i32, 3, 4].to_byte_slice())) + .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice())) .build() .unwrap(); let a2: ArrayRef = Arc::new(ListArray::from(a2)); - let a = ArrayDataBuilder::new(DataType::Struct(vec![ + let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![ Field::new("aa1", DataType::Int32, false), Field::new("a2", a2.data_type().clone(), false), - ])) + ]))) .add_child_data(a1.into_data()) .add_child_data(a2.into_data()) .len(2) @@ -682,8 +816,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let b = Int32Array::from(vec![1, 2, 3, 4, 5]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); assert!(batch.is_err()); } @@ -693,11 +826,11 @@ mod tests { let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); let struct_array = StructArray::from(vec![ ( - Field::new("b", DataType::Boolean, false), + Arc::new(Field::new("b", DataType::Boolean, false)), boolean.clone() as ArrayRef, ), ( - Field::new("c", DataType::Int32, false), + Arc::new(Field::new("c", DataType::Int32, false)), int.clone() as ArrayRef, ), ]); @@ -707,84 +840,12 @@ mod tests { assert_eq!(4, batch.num_rows()); assert_eq!( struct_array.data_type(), - &DataType::Struct(batch.schema().fields().to_vec()) + &DataType::Struct(batch.schema().fields().clone()) ); assert_eq!(batch.column(0).as_ref(), boolean.as_ref()); assert_eq!(batch.column(1).as_ref(), int.as_ref()); } - #[test] - fn concat_record_batches() { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2])), - Arc::new(StringArray::from(vec!["a", "b"])), - ], - ) - .unwrap(); - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![3, 4])), - Arc::new(StringArray::from(vec!["c", "d"])), - ], - ) - .unwrap(); - let new_batch = RecordBatch::concat(&schema, &[batch1, batch2]).unwrap(); - assert_eq!(new_batch.schema().as_ref(), schema.as_ref()); - assert_eq!(2, new_batch.num_columns()); - assert_eq!(4, new_batch.num_rows()); - } - - #[test] - fn concat_empty_record_batch() { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let batch = RecordBatch::concat(&schema, &[]).unwrap(); - assert_eq!(batch.schema().as_ref(), schema.as_ref()); - assert_eq!(0, batch.num_rows()); - } - - #[test] - fn concat_record_batches_of_different_schemas() { - let schema1 = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let schema2 = Arc::new(Schema::new(vec![ - Field::new("c", DataType::Int32, false), - Field::new("d", DataType::Utf8, false), - ])); - let batch1 = RecordBatch::try_new( - schema1.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2])), - Arc::new(StringArray::from(vec!["a", "b"])), - ], - ) - .unwrap(); - let batch2 = RecordBatch::try_new( - schema2, - vec![ - Arc::new(Int32Array::from(vec![3, 4])), - Arc::new(StringArray::from(vec!["c", "d"])), - ], - ) - .unwrap(); - let error = RecordBatch::concat(&schema1, &[batch1, batch2]).unwrap_err(); - assert_eq!( - error.to_string(), - "Invalid argument error: batches[1] schema is different with argument schema.", - ); - } - #[test] fn record_batch_equality() { let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); @@ -816,6 +877,22 @@ mod tests { assert_eq!(batch1, batch2); } + /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]` + #[test] + fn record_batch_index_access() { + let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); + let schema1 = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Int32, false), + ]); + let record_batch = + RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap(); + + assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref()); + assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref()); + } + #[test] fn record_batch_vals_ne() { let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); @@ -948,35 +1025,48 @@ mod tests { let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); - let record_batch = RecordBatch::try_from_iter(vec![ - ("a", a.clone()), - ("b", b.clone()), - ("c", c.clone()), - ]) - .expect("valid conversion"); + let record_batch = + RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())]) + .expect("valid conversion"); - let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)]) - .expect("valid conversion"); + let expected = + RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion"); assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); } + #[test] + fn project_empty() { + let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); + + let record_batch = + RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion"); + + let expected = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions { + match_field_names: true, + row_count: Some(3), + }, + ) + .expect("valid conversion"); + + assert_eq!(expected, record_batch.project(&[]).unwrap()); + } + #[test] fn test_no_column_record_batch() { - let schema = Arc::new(Schema::new(vec![])); + let schema = Arc::new(Schema::empty()); let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err(); assert!(err .to_string() .contains("must either specify a row count or at least one column")); - let options = RecordBatchOptions { - row_count: Some(10), - ..Default::default() - }; + let options = RecordBatchOptions::new().with_row_count(Some(10)); - let ok = - RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); assert_eq!(ok.num_rows(), 10); let a = ok.slice(2, 5); @@ -998,4 +1088,71 @@ mod tests { ); assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap())); } + #[test] + fn test_record_batch_options() { + let options = RecordBatchOptions::new() + .with_match_field_names(false) + .with_row_count(Some(20)); + assert!(!options.match_field_names); + assert_eq!(options.row_count.unwrap(), 20) + } + + #[test] + #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")] + fn test_from_struct() { + let s = StructArray::from(ArrayData::new_null( + // Note child is not nullable + &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()), + 2, + )); + let _ = RecordBatch::from(s); + } + + #[test] + fn test_with_schema() { + let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let required_schema = Arc::new(required_schema); + let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let nullable_schema = Arc::new(nullable_schema); + + let batch = RecordBatch::try_new( + required_schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _], + ) + .unwrap(); + + // Can add nullability + let batch = batch.with_schema(nullable_schema.clone()).unwrap(); + + // Cannot remove nullability + batch.clone().with_schema(required_schema).unwrap_err(); + + // Can add metadata + let metadata = vec![("foo".to_string(), "bar".to_string())] + .into_iter() + .collect(); + let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata); + let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap(); + + // Cannot remove metadata + batch.with_schema(nullable_schema).unwrap_err(); + } + + #[test] + fn test_boxed_reader() { + // Make sure we can pass a boxed reader to a function generic over + // RecordBatchReader. + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Arc::new(schema); + + let reader = RecordBatchIterator::new(std::iter::empty(), schema); + let reader: Box = Box::new(reader); + + fn get_size(reader: impl RecordBatchReader) -> usize { + reader.size_hint().0 + } + + let size = get_size(reader); + assert_eq!(size, 0); + } } diff --git a/arrow-array/src/run_iterator.rs b/arrow-array/src/run_iterator.rs new file mode 100644 index 000000000000..7a98fccb73b5 --- /dev/null +++ b/arrow-array/src/run_iterator.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Idiomatic iterator for [`RunArray`](crate::Array) + +use crate::{array::ArrayAccessor, types::RunEndIndexType, Array, TypedRunArray}; +use arrow_buffer::ArrowNativeType; + +/// The [`RunArrayIter`] provides an idiomatic way to iterate over the run array. +/// It returns Some(T) if there is a value or None if the value is null. +/// +/// The iterator comes with a cost as it has to iterate over three arrays to determine +/// the value to be returned. The run_ends array is used to determine the index of the value. +/// The nulls array is used to determine if the value is null and the values array is used to +/// get the value. +/// +/// Unlike other iterators in this crate, [`RunArrayIter`] does not use [`ArrayAccessor`] +/// because the run array accessor does binary search to access each value which is too slow. +/// The run array iterator can determine the next value in constant time. +/// +#[derive(Debug)] +pub struct RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + array: TypedRunArray<'a, R, V>, + current_front_logical: usize, + current_front_physical: usize, + current_back_logical: usize, + current_back_physical: usize, +} + +impl<'a, R, V> RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + /// create a new iterator + pub fn new(array: TypedRunArray<'a, R, V>) -> Self { + let current_front_physical = array.run_array().get_start_physical_index(); + let current_back_physical = array.run_array().get_end_physical_index() + 1; + RunArrayIter { + array, + current_front_logical: array.offset(), + current_front_physical, + current_back_logical: array.offset() + array.len(), + current_back_physical, + } + } +} + +impl<'a, R, V> Iterator for RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = Option<<&'a V as ArrayAccessor>::Item>; + + #[inline] + fn next(&mut self) -> Option { + if self.current_front_logical == self.current_back_logical { + return None; + } + + // If current logical index is greater than current run end index then increment + // the physical index. + let run_ends = self.array.run_ends().values(); + if self.current_front_logical >= run_ends[self.current_front_physical].as_usize() { + // As the run_ends is expected to be strictly increasing, there + // should be at least one logical entry in one physical entry. Because of this + // reason the next value can be accessed by incrementing physical index once. + self.current_front_physical += 1; + } + if self.array.values().is_null(self.current_front_physical) { + self.current_front_logical += 1; + Some(None) + } else { + self.current_front_logical += 1; + // Safety: + // The self.current_physical is kept within bounds of self.current_logical. + // The self.current_logical will not go out of bounds because of the check + // `self.current_logical = self.current_end_logical` above. + unsafe { + Some(Some( + self.array + .values() + .value_unchecked(self.current_front_physical), + )) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.current_back_logical - self.current_front_logical, + Some(self.current_back_logical - self.current_front_logical), + ) + } +} + +impl<'a, R, V> DoubleEndedIterator for RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + fn next_back(&mut self) -> Option { + if self.current_back_logical == self.current_front_logical { + return None; + } + + self.current_back_logical -= 1; + + let run_ends = self.array.run_ends().values(); + if self.current_back_physical > 0 + && self.current_back_logical < run_ends[self.current_back_physical - 1].as_usize() + { + // As the run_ends is expected to be strictly increasing, there + // should be at least one logical entry in one physical entry. Because of this + // reason the next value can be accessed by decrementing physical index once. + self.current_back_physical -= 1; + } + Some(if self.array.values().is_null(self.current_back_physical) { + None + } else { + // Safety: + // The check `self.current_end_physical > 0` ensures the value will not underflow. + // Also self.current_end_physical starts with array.len() and + // decrements based on the bounds of self.current_end_logical. + unsafe { + Some( + self.array + .values() + .value_unchecked(self.current_back_physical), + ) + } + }) + } +} + +/// all arrays have known size. +impl<'a, R, V> ExactSizeIterator for RunArrayIter<'a, R, V> +where + R: RunEndIndexType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ +} + +#[cfg(test)] +mod tests { + use rand::{seq::SliceRandom, thread_rng, Rng}; + + use crate::{ + array::{Int32Array, StringArray}, + builder::PrimitiveRunBuilder, + types::{Int16Type, Int32Type}, + Array, Int64RunArray, PrimitiveArray, RunArray, + }; + + fn build_input_array(size: usize) -> Vec> { + // The input array is created by shuffling and repeating + // the seed values random number of times. + let mut seed: Vec> = vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ]; + let mut result: Vec> = Vec::with_capacity(size); + let mut ix = 0; + let mut rng = thread_rng(); + // run length can go up to 8. Cap the max run length for smaller arrays to size / 2. + let max_run_length = 8_usize.min(1_usize.max(size / 2)); + while result.len() < size { + // shuffle the seed array if all the values are iterated. + if ix == 0 { + seed.shuffle(&mut rng); + } + // repeat the items between 1 and 8 times. Cap the length for smaller sized arrays + let num = max_run_length.min(rand::thread_rng().gen_range(1..=max_run_length)); + for _ in 0..num { + result.push(seed[ix]); + } + ix += 1; + if ix == seed.len() { + ix = 0 + } + } + result.resize(size, None); + result + } + + #[test] + fn test_primitive_array_iter_round_trip() { + let mut input_vec = vec![ + Some(32), + Some(32), + None, + Some(64), + Some(64), + Some(64), + Some(72), + ]; + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend(input_vec.iter().copied()); + let ree_array = builder.finish(); + let ree_array = ree_array.downcast::().unwrap(); + + let output_vec: Vec> = ree_array.into_iter().collect(); + assert_eq!(input_vec, output_vec); + + let rev_output_vec: Vec> = ree_array.into_iter().rev().collect(); + input_vec.reverse(); + assert_eq!(input_vec, rev_output_vec); + } + + #[test] + fn test_double_ended() { + let input_vec = vec![ + Some(32), + Some(32), + None, + Some(64), + Some(64), + Some(64), + Some(72), + ]; + let mut builder = PrimitiveRunBuilder::::new(); + builder.extend(input_vec); + let ree_array = builder.finish(); + let ree_array = ree_array.downcast::().unwrap(); + + let mut iter = ree_array.into_iter(); + assert_eq!(Some(Some(32)), iter.next()); + assert_eq!(Some(Some(72)), iter.next_back()); + assert_eq!(Some(Some(32)), iter.next()); + assert_eq!(Some(Some(64)), iter.next_back()); + assert_eq!(Some(None), iter.next()); + assert_eq!(Some(Some(64)), iter.next_back()); + assert_eq!(Some(Some(64)), iter.next()); + assert_eq!(None, iter.next_back()); + assert_eq!(None, iter.next()); + } + + #[test] + fn test_run_iterator_comprehensive() { + // Test forward and backward iterator for different array lengths. + let logical_lengths = vec![1_usize, 2, 3, 4, 15, 16, 17, 63, 64, 65]; + + for logical_len in logical_lengths { + let input_array = build_input_array(logical_len); + + let mut run_array_builder = PrimitiveRunBuilder::::new(); + run_array_builder.extend(input_array.iter().copied()); + let run_array = run_array_builder.finish(); + let typed_array = run_array.downcast::().unwrap(); + + // test forward iterator + let mut input_iter = input_array.iter().copied(); + let mut run_array_iter = typed_array.into_iter(); + for _ in 0..logical_len { + assert_eq!(input_iter.next(), run_array_iter.next()); + } + assert_eq!(None, run_array_iter.next()); + + // test reverse iterator + let mut input_iter = input_array.iter().rev().copied(); + let mut run_array_iter = typed_array.into_iter().rev(); + for _ in 0..logical_len { + assert_eq!(input_iter.next(), run_array_iter.next()); + } + assert_eq!(None, run_array_iter.next()); + } + } + + #[test] + fn test_string_array_iter_round_trip() { + let input_vec = vec!["ab", "ab", "ba", "cc", "cc"]; + let input_ree_array: Int64RunArray = input_vec.into_iter().collect(); + let string_ree_array = input_ree_array.downcast::().unwrap(); + + // to and from iter, with a +1 + let result: Vec> = string_ree_array + .into_iter() + .map(|e| { + e.map(|e| { + let mut a = e.to_string(); + a.push('b'); + a + }) + }) + .collect(); + + let result_asref: Vec> = result.iter().map(|f| f.as_deref()).collect(); + + let expected_vec = vec![ + Some("abb"), + Some("abb"), + Some("bab"), + Some("ccb"), + Some("ccb"), + ]; + + assert_eq!(expected_vec, result_asref); + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_sliced_run_array_iterator() { + let total_len = 80; + let input_array = build_input_array(total_len); + + // Encode the input_array to run array + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array.iter().copied()); + let run_array = builder.finish(); + + // test for all slice lengths. + for slice_len in 1..=total_len { + // test for offset = 0, slice length = slice_len + let sliced_run_array: RunArray = + run_array.slice(0, slice_len).into_data().into(); + let sliced_typed_run_array = sliced_run_array + .downcast::>() + .unwrap(); + + // Iterate on sliced typed run array + let actual: Vec> = sliced_typed_run_array.into_iter().collect(); + let expected: Vec> = input_array.iter().take(slice_len).copied().collect(); + assert_eq!(expected, actual); + + // test for offset = total_len - slice_len, length = slice_len + let sliced_run_array: RunArray = run_array + .slice(total_len - slice_len, slice_len) + .into_data() + .into(); + let sliced_typed_run_array = sliced_run_array + .downcast::>() + .unwrap(); + + // Iterate on sliced typed run array + let actual: Vec> = sliced_typed_run_array.into_iter().collect(); + let expected: Vec> = input_array + .iter() + .skip(total_len - slice_len) + .copied() + .collect(); + assert_eq!(expected, actual); + } + } +} diff --git a/arrow-array/src/scalar.rs b/arrow-array/src/scalar.rs new file mode 100644 index 000000000000..f2a696a8f329 --- /dev/null +++ b/arrow-array/src/scalar.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Array; + +/// A possibly [`Scalar`] [`Array`] +/// +/// This allows optimised binary kernels where one or more arguments are constant +/// +/// ``` +/// # use arrow_array::*; +/// # use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; +/// # use arrow_schema::ArrowError; +/// # +/// fn eq_impl( +/// a: &PrimitiveArray, +/// a_scalar: bool, +/// b: &PrimitiveArray, +/// b_scalar: bool, +/// ) -> BooleanArray { +/// let (array, scalar) = match (a_scalar, b_scalar) { +/// (true, true) | (false, false) => { +/// let len = a.len().min(b.len()); +/// let nulls = NullBuffer::union(a.nulls(), b.nulls()); +/// let buffer = BooleanBuffer::collect_bool(len, |idx| a.value(idx) == b.value(idx)); +/// return BooleanArray::new(buffer, nulls); +/// } +/// (true, false) => (b, (a.null_count() == 0).then(|| a.value(0))), +/// (false, true) => (a, (b.null_count() == 0).then(|| b.value(0))), +/// }; +/// match scalar { +/// Some(v) => { +/// let len = array.len(); +/// let nulls = array.nulls().cloned(); +/// let buffer = BooleanBuffer::collect_bool(len, |idx| array.value(idx) == v); +/// BooleanArray::new(buffer, nulls) +/// } +/// None => BooleanArray::new_null(array.len()), +/// } +/// } +/// +/// pub fn eq(l: &dyn Datum, r: &dyn Datum) -> Result { +/// let (l_array, l_scalar) = l.get(); +/// let (r_array, r_scalar) = r.get(); +/// downcast_primitive_array!( +/// (l_array, r_array) => Ok(eq_impl(l_array, l_scalar, r_array, r_scalar)), +/// (a, b) => Err(ArrowError::NotYetImplemented(format!("{a} == {b}"))), +/// ) +/// } +/// +/// // Comparison of two arrays +/// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); +/// let b = Int32Array::from(vec![1, 2, 4, 7, 3]); +/// let r = eq(&a, &b).unwrap(); +/// let values: Vec<_> = r.values().iter().collect(); +/// assert_eq!(values, &[true, true, false, false, false]); +/// +/// // Comparison of an array and a scalar +/// let a = Int32Array::from(vec![1, 2, 3, 4, 5]); +/// let b = Int32Array::new_scalar(1); +/// let r = eq(&a, &b).unwrap(); +/// let values: Vec<_> = r.values().iter().collect(); +/// assert_eq!(values, &[true, false, false, false, false]); +pub trait Datum { + /// Returns the value for this [`Datum`] and a boolean indicating if the value is scalar + fn get(&self) -> (&dyn Array, bool); +} + +impl Datum for T { + fn get(&self) -> (&dyn Array, bool) { + (self, false) + } +} + +impl Datum for dyn Array { + fn get(&self) -> (&dyn Array, bool) { + (self, false) + } +} + +impl Datum for &dyn Array { + fn get(&self) -> (&dyn Array, bool) { + (*self, false) + } +} + +/// A wrapper around a single value [`Array`] that implements +/// [`Datum`] and indicates [compute] kernels should treat this array +/// as a scalar value (a single value). +/// +/// Using a [`Scalar`] is often much more efficient than creating an +/// [`Array`] with the same (repeated) value. +/// +/// See [`Datum`] for more information. +/// +/// # Example +/// +/// ```rust +/// # use arrow_array::{Scalar, Int32Array, ArrayRef}; +/// # fn get_array() -> ArrayRef { std::sync::Arc::new(Int32Array::from(vec![42])) } +/// // Create a (typed) scalar for Int32Array for the value 42 +/// let scalar = Scalar::new(Int32Array::from(vec![42])); +/// +/// // Create a scalar using PrimtiveArray::scalar +/// let scalar = Int32Array::new_scalar(42); +/// +/// // create a scalar from an ArrayRef (for dynamic typed Arrays) +/// let array: ArrayRef = get_array(); +/// let scalar = Scalar::new(array); +/// ``` +/// +/// [compute]: https://docs.rs/arrow/latest/arrow/compute/index.html +#[derive(Debug, Copy, Clone)] +pub struct Scalar(T); + +impl Scalar { + /// Create a new [`Scalar`] from an [`Array`] + /// + /// # Panics + /// + /// Panics if `array.len() != 1` + pub fn new(array: T) -> Self { + assert_eq!(array.len(), 1); + Self(array) + } +} + +impl Datum for Scalar { + fn get(&self) -> (&dyn Array, bool) { + (&self.0, true) + } +} diff --git a/arrow-array/src/temporal_conversions.rs b/arrow-array/src/temporal_conversions.rs new file mode 100644 index 000000000000..e0edcc9bc182 --- /dev/null +++ b/arrow-array/src/temporal_conversions.rs @@ -0,0 +1,347 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Conversion methods for dates and times. + +use crate::timezone::Tz; +use crate::ArrowPrimitiveType; +use arrow_schema::{DataType, TimeUnit}; +use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc}; + +/// Number of seconds in a day +pub const SECONDS_IN_DAY: i64 = 86_400; +/// Number of milliseconds in a second +pub const MILLISECONDS: i64 = 1_000; +/// Number of microseconds in a second +pub const MICROSECONDS: i64 = 1_000_000; +/// Number of nanoseconds in a second +pub const NANOSECONDS: i64 = 1_000_000_000; + +/// Number of milliseconds in a day +pub const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; +/// Number of microseconds in a day +pub const MICROSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MICROSECONDS; +/// Number of nanoseconds in a day +pub const NANOSECONDS_IN_DAY: i64 = SECONDS_IN_DAY * NANOSECONDS; +/// Number of days between 0001-01-01 and 1970-01-01 +pub const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime(v: i32) -> Option { + NaiveDateTime::from_timestamp_opt(v as i64 * SECONDS_IN_DAY, 0) +} + +/// converts a `i64` representing a `date64` to [`NaiveDateTime`] +#[inline] +pub fn date64_to_datetime(v: i64) -> Option { + let (sec, milli_sec) = split_second(v, MILLISECONDS); + + NaiveDateTime::from_timestamp_opt( + // extract seconds from milliseconds + sec, + // discard extracted seconds and convert milliseconds to nanoseconds + milli_sec * MICROSECONDS as u32, + ) +} + +/// converts a `i32` representing a `time32(s)` to [`NaiveDateTime`] +#[inline] +pub fn time32s_to_time(v: i32) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0) +} + +/// converts a `i32` representing a `time32(ms)` to [`NaiveDateTime`] +#[inline] +pub fn time32ms_to_time(v: i32) -> Option { + let v = v as i64; + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from milliseconds + (v / MILLISECONDS) as u32, + // discard extracted seconds and convert milliseconds to + // nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveDateTime`] +#[inline] +pub fn time64us_to_time(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from microseconds + (v / MICROSECONDS) as u32, + // discard extracted seconds and convert microseconds to + // nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveDateTime`] +#[inline] +pub fn time64ns_to_time(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from nanoseconds + (v / NANOSECONDS) as u32, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) +} + +/// converts [`NaiveTime`] to a `i32` representing a `time32(s)` +#[inline] +pub fn time_to_time32s(v: NaiveTime) -> i32 { + v.num_seconds_from_midnight() as i32 +} + +/// converts [`NaiveTime`] to a `i32` representing a `time32(ms)` +#[inline] +pub fn time_to_time32ms(v: NaiveTime) -> i32 { + (v.num_seconds_from_midnight() as i64 * MILLISECONDS + + v.nanosecond() as i64 * MILLISECONDS / NANOSECONDS) as i32 +} + +/// converts [`NaiveTime`] to a `i64` representing a `time64(us)` +#[inline] +pub fn time_to_time64us(v: NaiveTime) -> i64 { + v.num_seconds_from_midnight() as i64 * MICROSECONDS + + v.nanosecond() as i64 * MICROSECONDS / NANOSECONDS +} + +/// converts [`NaiveTime`] to a `i64` representing a `time64(ns)` +#[inline] +pub fn time_to_time64ns(v: NaiveTime) -> i64 { + v.num_seconds_from_midnight() as i64 * NANOSECONDS + v.nanosecond() as i64 +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime(v: i64) -> Option { + NaiveDateTime::from_timestamp_opt(v, 0) +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime(v: i64) -> Option { + let (sec, milli_sec) = split_second(v, MILLISECONDS); + + NaiveDateTime::from_timestamp_opt( + // extract seconds from milliseconds + sec, + // discard extracted seconds and convert milliseconds to nanoseconds + milli_sec * MICROSECONDS as u32, + ) +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime(v: i64) -> Option { + let (sec, micro_sec) = split_second(v, MICROSECONDS); + + NaiveDateTime::from_timestamp_opt( + // extract seconds from microseconds + sec, + // discard extracted seconds and convert microseconds to nanoseconds + micro_sec * MILLISECONDS as u32, + ) +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime(v: i64) -> Option { + let (sec, nano_sec) = split_second(v, NANOSECONDS); + + NaiveDateTime::from_timestamp_opt( + // extract seconds from nanoseconds + sec, // discard extracted seconds + nano_sec, + ) +} + +#[inline] +pub(crate) fn split_second(v: i64, base: i64) -> (i64, u32) { + (v.div_euclid(base), v.rem_euclid(base) as u32) +} + +/// converts a `i64` representing a `duration(s)` to [`Duration`] +#[inline] +pub fn duration_s_to_duration(v: i64) -> Duration { + Duration::seconds(v) +} + +/// converts a `i64` representing a `duration(ms)` to [`Duration`] +#[inline] +pub fn duration_ms_to_duration(v: i64) -> Duration { + Duration::milliseconds(v) +} + +/// converts a `i64` representing a `duration(us)` to [`Duration`] +#[inline] +pub fn duration_us_to_duration(v: i64) -> Duration { + Duration::microseconds(v) +} + +/// converts a `i64` representing a `duration(ns)` to [`Duration`] +#[inline] +pub fn duration_ns_to_duration(v: i64) -> Duration { + Duration::nanoseconds(v) +} + +/// Converts an [`ArrowPrimitiveType`] to [`NaiveDateTime`] +pub fn as_datetime(v: i64) -> Option { + match T::DATA_TYPE { + DataType::Date32 => date32_to_datetime(v as i32), + DataType::Date64 => date64_to_datetime(v), + DataType::Time32(_) | DataType::Time64(_) => None, + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => timestamp_s_to_datetime(v), + TimeUnit::Millisecond => timestamp_ms_to_datetime(v), + TimeUnit::Microsecond => timestamp_us_to_datetime(v), + TimeUnit::Nanosecond => timestamp_ns_to_datetime(v), + }, + // interval is not yet fully documented [ARROW-3097] + DataType::Interval(_) => None, + _ => None, + } +} + +/// Converts an [`ArrowPrimitiveType`] to [`DateTime`] +pub fn as_datetime_with_timezone(v: i64, tz: Tz) -> Option> { + let naive = as_datetime::(v)?; + Some(Utc.from_utc_datetime(&naive).with_timezone(&tz)) +} + +/// Converts an [`ArrowPrimitiveType`] to [`NaiveDate`] +pub fn as_date(v: i64) -> Option { + as_datetime::(v).map(|datetime| datetime.date()) +} + +/// Converts an [`ArrowPrimitiveType`] to [`NaiveTime`] +pub fn as_time(v: i64) -> Option { + match T::DATA_TYPE { + DataType::Time32(unit) => { + // safe to immediately cast to u32 as `self.value(i)` is positive i32 + let v = v as u32; + match unit { + TimeUnit::Second => time32s_to_time(v as i32), + TimeUnit::Millisecond => time32ms_to_time(v as i32), + _ => None, + } + } + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => time64us_to_time(v), + TimeUnit::Nanosecond => time64ns_to_time(v), + _ => None, + }, + DataType::Timestamp(_, _) => as_datetime::(v).map(|datetime| datetime.time()), + DataType::Date32 | DataType::Date64 => NaiveTime::from_hms_opt(0, 0, 0), + DataType::Interval(_) => None, + _ => None, + } +} + +/// Converts an [`ArrowPrimitiveType`] to [`Duration`] +pub fn as_duration(v: i64) -> Option { + match T::DATA_TYPE { + DataType::Duration(unit) => match unit { + TimeUnit::Second => Some(duration_s_to_duration(v)), + TimeUnit::Millisecond => Some(duration_ms_to_duration(v)), + TimeUnit::Microsecond => Some(duration_us_to_duration(v)), + TimeUnit::Nanosecond => Some(duration_ns_to_duration(v)), + }, + _ => None, + } +} + +#[cfg(test)] +mod tests { + use crate::temporal_conversions::{ + date64_to_datetime, split_second, timestamp_ms_to_datetime, timestamp_ns_to_datetime, + timestamp_us_to_datetime, NANOSECONDS, + }; + use chrono::NaiveDateTime; + + #[test] + fn negative_input_timestamp_ns_to_datetime() { + assert_eq!( + timestamp_ns_to_datetime(-1), + NaiveDateTime::from_timestamp_opt(-1, 999_999_999) + ); + + assert_eq!( + timestamp_ns_to_datetime(-1_000_000_001), + NaiveDateTime::from_timestamp_opt(-2, 999_999_999) + ); + } + + #[test] + fn negative_input_timestamp_us_to_datetime() { + assert_eq!( + timestamp_us_to_datetime(-1), + NaiveDateTime::from_timestamp_opt(-1, 999_999_000) + ); + + assert_eq!( + timestamp_us_to_datetime(-1_000_001), + NaiveDateTime::from_timestamp_opt(-2, 999_999_000) + ); + } + + #[test] + fn negative_input_timestamp_ms_to_datetime() { + assert_eq!( + timestamp_ms_to_datetime(-1), + NaiveDateTime::from_timestamp_opt(-1, 999_000_000) + ); + + assert_eq!( + timestamp_ms_to_datetime(-1_001), + NaiveDateTime::from_timestamp_opt(-2, 999_000_000) + ); + } + + #[test] + fn negative_input_date64_to_datetime() { + assert_eq!( + date64_to_datetime(-1), + NaiveDateTime::from_timestamp_opt(-1, 999_000_000) + ); + + assert_eq!( + date64_to_datetime(-1_001), + NaiveDateTime::from_timestamp_opt(-2, 999_000_000) + ); + } + + #[test] + fn test_split_seconds() { + let (sec, nano_sec) = split_second(100, NANOSECONDS); + assert_eq!(sec, 0); + assert_eq!(nano_sec, 100); + + let (sec, nano_sec) = split_second(123_000_000_456, NANOSECONDS); + assert_eq!(sec, 123); + assert_eq!(nano_sec, 456); + + let (sec, nano_sec) = split_second(-1, NANOSECONDS); + assert_eq!(sec, -1); + assert_eq!(nano_sec, 999_999_999); + + let (sec, nano_sec) = split_second(-123_000_000_001, NANOSECONDS); + assert_eq!(sec, -124); + assert_eq!(nano_sec, 999_999_999); + } +} diff --git a/arrow-array/src/timezone.rs b/arrow-array/src/timezone.rs new file mode 100644 index 000000000000..dc91886f34c5 --- /dev/null +++ b/arrow-array/src/timezone.rs @@ -0,0 +1,339 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Timezone for timestamp arrays + +use arrow_schema::ArrowError; +use chrono::FixedOffset; +pub use private::{Tz, TzOffset}; + +/// Parses a fixed offset of the form "+09:00", "-09" or "+0930" +fn parse_fixed_offset(tz: &str) -> Option { + let bytes = tz.as_bytes(); + + let mut values = match bytes.len() { + // [+-]XX:XX + 6 if bytes[3] == b':' => [bytes[1], bytes[2], bytes[4], bytes[5]], + // [+-]XXXX + 5 => [bytes[1], bytes[2], bytes[3], bytes[4]], + // [+-]XX + 3 => [bytes[1], bytes[2], b'0', b'0'], + _ => return None, + }; + values.iter_mut().for_each(|x| *x = x.wrapping_sub(b'0')); + if values.iter().any(|x| *x > 9) { + return None; + } + let secs = + (values[0] * 10 + values[1]) as i32 * 60 * 60 + (values[2] * 10 + values[3]) as i32 * 60; + + match bytes[0] { + b'+' => FixedOffset::east_opt(secs), + b'-' => FixedOffset::west_opt(secs), + _ => None, + } +} + +#[cfg(feature = "chrono-tz")] +mod private { + use super::*; + use chrono::offset::TimeZone; + use chrono::{LocalResult, NaiveDate, NaiveDateTime, Offset}; + use std::str::FromStr; + + /// An [`Offset`] for [`Tz`] + #[derive(Debug, Copy, Clone)] + pub struct TzOffset { + tz: Tz, + offset: FixedOffset, + } + + impl std::fmt::Display for TzOffset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.offset.fmt(f) + } + } + + impl Offset for TzOffset { + fn fix(&self) -> FixedOffset { + self.offset + } + } + + /// An Arrow [`TimeZone`] + #[derive(Debug, Copy, Clone)] + pub struct Tz(TzInner); + + #[derive(Debug, Copy, Clone)] + enum TzInner { + Timezone(chrono_tz::Tz), + Offset(FixedOffset), + } + + impl FromStr for Tz { + type Err = ArrowError; + + fn from_str(tz: &str) -> Result { + match parse_fixed_offset(tz) { + Some(offset) => Ok(Self(TzInner::Offset(offset))), + None => Ok(Self(TzInner::Timezone(tz.parse().map_err(|e| { + ArrowError::ParseError(format!("Invalid timezone \"{tz}\": {e}")) + })?))), + } + } + } + + macro_rules! tz { + ($s:ident, $tz:ident, $b:block) => { + match $s.0 { + TzInner::Timezone($tz) => $b, + TzInner::Offset($tz) => $b, + } + }; + } + + impl TimeZone for Tz { + type Offset = TzOffset; + + fn from_offset(offset: &Self::Offset) -> Self { + offset.tz + } + + fn offset_from_local_date(&self, local: &NaiveDate) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_date(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_local_datetime(&self, local: &NaiveDateTime) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_datetime(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_utc_date(&self, utc: &NaiveDate) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_date(utc).fix(), + } + }) + } + + fn offset_from_utc_datetime(&self, utc: &NaiveDateTime) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_datetime(utc).fix(), + } + }) + } + } + + #[cfg(test)] + mod tests { + use super::*; + use chrono::{Timelike, Utc}; + + #[test] + fn test_with_timezone() { + let vals = [ + Utc.timestamp_millis_opt(37800000).unwrap(), + Utc.timestamp_millis_opt(86339000).unwrap(), + ]; + + assert_eq!(10, vals[0].hour()); + assert_eq!(23, vals[1].hour()); + + let tz: Tz = "America/Los_Angeles".parse().unwrap(); + + assert_eq!(2, vals[0].with_timezone(&tz).hour()); + assert_eq!(15, vals[1].with_timezone(&tz).hour()); + } + + #[test] + fn test_using_chrono_tz_and_utc_naive_date_time() { + let sydney_tz = "Australia/Sydney".to_string(); + let tz: Tz = sydney_tz.parse().unwrap(); + let sydney_offset_without_dst = FixedOffset::east_opt(10 * 60 * 60).unwrap(); + let sydney_offset_with_dst = FixedOffset::east_opt(11 * 60 * 60).unwrap(); + // Daylight savings ends + // When local daylight time was about to reach + // Sunday, 4 April 2021, 3:00:00 am clocks were turned backward 1 hour to + // Sunday, 4 April 2021, 2:00:00 am local standard time instead. + + // Daylight savings starts + // When local standard time was about to reach + // Sunday, 3 October 2021, 2:00:00 am clocks were turned forward 1 hour to + // Sunday, 3 October 2021, 3:00:00 am local daylight time instead. + + // Sydney 2021-04-04T02:30:00+11:00 is 2021-04-03T15:30:00Z + let utc_just_before_sydney_dst_ends = NaiveDate::from_ymd_opt(2021, 4, 3) + .unwrap() + .and_hms_nano_opt(15, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_before_sydney_dst_ends) + .fix(), + sydney_offset_with_dst + ); + // Sydney 2021-04-04T02:30:00+10:00 is 2021-04-03T16:30:00Z + let utc_just_after_sydney_dst_ends = NaiveDate::from_ymd_opt(2021, 4, 3) + .unwrap() + .and_hms_nano_opt(16, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_after_sydney_dst_ends) + .fix(), + sydney_offset_without_dst + ); + // Sydney 2021-10-03T01:30:00+10:00 is 2021-10-02T15:30:00Z + let utc_just_before_sydney_dst_starts = NaiveDate::from_ymd_opt(2021, 10, 2) + .unwrap() + .and_hms_nano_opt(15, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_before_sydney_dst_starts) + .fix(), + sydney_offset_without_dst + ); + // Sydney 2021-04-04T03:30:00+11:00 is 2021-10-02T16:30:00Z + let utc_just_after_sydney_dst_starts = NaiveDate::from_ymd_opt(2022, 10, 2) + .unwrap() + .and_hms_nano_opt(16, 30, 0, 0) + .unwrap(); + assert_eq!( + tz.offset_from_utc_datetime(&utc_just_after_sydney_dst_starts) + .fix(), + sydney_offset_with_dst + ); + } + } +} + +#[cfg(not(feature = "chrono-tz"))] +mod private { + use super::*; + use chrono::offset::TimeZone; + use chrono::{FixedOffset, LocalResult, NaiveDate, NaiveDateTime, Offset}; + use std::str::FromStr; + + /// An [`Offset`] for [`Tz`] + #[derive(Debug, Copy, Clone)] + pub struct TzOffset(FixedOffset); + + impl std::fmt::Display for TzOffset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl Offset for TzOffset { + fn fix(&self) -> FixedOffset { + self.0 + } + } + + /// An Arrow [`TimeZone`] + #[derive(Debug, Copy, Clone)] + pub struct Tz(FixedOffset); + + impl FromStr for Tz { + type Err = ArrowError; + + fn from_str(tz: &str) -> Result { + let offset = parse_fixed_offset(tz).ok_or_else(|| { + ArrowError::ParseError(format!( + "Invalid timezone \"{tz}\": only offset based timezones supported without chrono-tz feature" + )) + })?; + Ok(Self(offset)) + } + } + + impl TimeZone for Tz { + type Offset = TzOffset; + + fn from_offset(offset: &Self::Offset) -> Self { + Self(offset.0) + } + + fn offset_from_local_date(&self, local: &NaiveDate) -> LocalResult { + self.0.offset_from_local_date(local).map(TzOffset) + } + + fn offset_from_local_datetime(&self, local: &NaiveDateTime) -> LocalResult { + self.0.offset_from_local_datetime(local).map(TzOffset) + } + + fn offset_from_utc_date(&self, utc: &NaiveDate) -> Self::Offset { + TzOffset(self.0.offset_from_utc_date(utc).fix()) + } + + fn offset_from_utc_datetime(&self, utc: &NaiveDateTime) -> Self::Offset { + TzOffset(self.0.offset_from_utc_datetime(utc).fix()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{NaiveDate, Offset, TimeZone}; + + #[test] + fn test_with_offset() { + let t = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + + let tz: Tz = "-00:00".parse().unwrap(); + assert_eq!(tz.offset_from_utc_date(&t).fix().local_minus_utc(), 0); + let tz: Tz = "+00:00".parse().unwrap(); + assert_eq!(tz.offset_from_utc_date(&t).fix().local_minus_utc(), 0); + + let tz: Tz = "-10:00".parse().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + -10 * 60 * 60 + ); + let tz: Tz = "+09:00".parse().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + 9 * 60 * 60 + ); + + let tz = "+09".parse::().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + 9 * 60 * 60 + ); + + let tz = "+0900".parse::().unwrap(); + assert_eq!( + tz.offset_from_utc_date(&t).fix().local_minus_utc(), + 9 * 60 * 60 + ); + + let err = "+9:00".parse::().unwrap_err().to_string(); + assert!(err.contains("Invalid timezone"), "{}", err); + } +} diff --git a/arrow/src/util/trusted_len.rs b/arrow-array/src/trusted_len.rs similarity index 94% rename from arrow/src/util/trusted_len.rs rename to arrow-array/src/trusted_len.rs index 84a66238b634..781cad38f7e9 100644 --- a/arrow/src/util/trusted_len.rs +++ b/arrow-array/src/trusted_len.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::bit_util; -use crate::{ - buffer::{Buffer, MutableBuffer}, - datatypes::ArrowNativeType, -}; +use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; /// Creates two [`Buffer`]s from an iterator of `Option`. /// The first buffer corresponds to a bitmap buffer, the second one @@ -67,7 +63,7 @@ mod tests { #[test] fn trusted_len_unzip_good() { - let vec = vec![Some(1u32), None]; + let vec = [Some(1u32), None]; let (null, buffer) = unsafe { trusted_len_unzip(vec.iter()) }; assert_eq!(null.as_slice(), &[0b00000001]); assert_eq!(buffer.as_slice(), &[1u8, 0, 0, 0, 0, 0, 0, 0]); diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs new file mode 100644 index 000000000000..6e177838c4f5 --- /dev/null +++ b/arrow-array/src/types.rs @@ -0,0 +1,1634 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Zero-sized types used to parameterize generic array implementations + +use crate::delta::{ + add_days_datetime, add_months_datetime, shift_months, sub_days_datetime, sub_months_datetime, +}; +use crate::temporal_conversions::as_datetime_with_timezone; +use crate::timezone::Tz; +use crate::{ArrowNativeTypeOp, OffsetSizeTrait}; +use arrow_buffer::{i256, Buffer, OffsetBuffer}; +use arrow_data::decimal::{validate_decimal256_precision, validate_decimal_precision}; +use arrow_schema::{ + ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, +}; +use chrono::{Duration, NaiveDate, NaiveDateTime}; +use half::f16; +use std::marker::PhantomData; +use std::ops::{Add, Sub}; + +// BooleanType is special: its bit-width is not the size of the primitive type, and its `index` +// operation assumes bit-packing. +/// A boolean datatype +#[derive(Debug)] +pub struct BooleanType {} + +impl BooleanType { + /// The corresponding Arrow data type + pub const DATA_TYPE: DataType = DataType::Boolean; +} + +/// Trait bridging the dynamic-typed nature of Arrow (via [`DataType`]) with the +/// static-typed nature of rust types ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`]. +/// +/// [`ArrowNativeType`]: arrow_buffer::ArrowNativeType +pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { + /// Corresponding Rust native type for the primitive type. + type Native: ArrowNativeTypeOp; + + /// the corresponding Arrow data type of this primitive type. + const DATA_TYPE: DataType; + + /// Returns the byte width of this primitive type. + fn get_byte_width() -> usize { + std::mem::size_of::() + } + + /// Returns a default value of this primitive type. + /// + /// This is useful for aggregate array ops like `sum()`, `mean()`. + fn default_value() -> Self::Native { + Default::default() + } +} + +mod primitive { + pub trait PrimitiveTypeSealed {} +} + +macro_rules! make_type { + ($name:ident, $native_ty:ty, $data_ty:expr, $doc_string: literal) => { + #[derive(Debug)] + #[doc = $doc_string] + pub struct $name {} + + impl ArrowPrimitiveType for $name { + type Native = $native_ty; + const DATA_TYPE: DataType = $data_ty; + } + + impl primitive::PrimitiveTypeSealed for $name {} + }; +} + +make_type!(Int8Type, i8, DataType::Int8, "A signed 8-bit integer type."); +make_type!( + Int16Type, + i16, + DataType::Int16, + "A signed 16-bit integer type." +); +make_type!( + Int32Type, + i32, + DataType::Int32, + "A signed 32-bit integer type." +); +make_type!( + Int64Type, + i64, + DataType::Int64, + "A signed 64-bit integer type." +); +make_type!( + UInt8Type, + u8, + DataType::UInt8, + "An unsigned 8-bit integer type." +); +make_type!( + UInt16Type, + u16, + DataType::UInt16, + "An unsigned 16-bit integer type." +); +make_type!( + UInt32Type, + u32, + DataType::UInt32, + "An unsigned 32-bit integer type." +); +make_type!( + UInt64Type, + u64, + DataType::UInt64, + "An unsigned 64-bit integer type." +); +make_type!( + Float16Type, + f16, + DataType::Float16, + "A 16-bit floating point number type." +); +make_type!( + Float32Type, + f32, + DataType::Float32, + "A 32-bit floating point number type." +); +make_type!( + Float64Type, + f64, + DataType::Float64, + "A 64-bit floating point number type." +); +make_type!( + TimestampSecondType, + i64, + DataType::Timestamp(TimeUnit::Second, None), + "A timestamp second type with an optional timezone." +); +make_type!( + TimestampMillisecondType, + i64, + DataType::Timestamp(TimeUnit::Millisecond, None), + "A timestamp millisecond type with an optional timezone." +); +make_type!( + TimestampMicrosecondType, + i64, + DataType::Timestamp(TimeUnit::Microsecond, None), + "A timestamp microsecond type with an optional timezone." +); +make_type!( + TimestampNanosecondType, + i64, + DataType::Timestamp(TimeUnit::Nanosecond, None), + "A timestamp nanosecond type with an optional timezone." +); +make_type!( + Date32Type, + i32, + DataType::Date32, + "A 32-bit date type representing the elapsed time since UNIX epoch in days(32 bits)." +); +make_type!( + Date64Type, + i64, + DataType::Date64, + "A 64-bit date type representing the elapsed time since UNIX epoch in milliseconds(64 bits)." +); +make_type!( + Time32SecondType, + i32, + DataType::Time32(TimeUnit::Second), + "A 32-bit time type representing the elapsed time since midnight in seconds." +); +make_type!( + Time32MillisecondType, + i32, + DataType::Time32(TimeUnit::Millisecond), + "A 32-bit time type representing the elapsed time since midnight in milliseconds." +); +make_type!( + Time64MicrosecondType, + i64, + DataType::Time64(TimeUnit::Microsecond), + "A 64-bit time type representing the elapsed time since midnight in microseconds." +); +make_type!( + Time64NanosecondType, + i64, + DataType::Time64(TimeUnit::Nanosecond), + "A 64-bit time type representing the elapsed time since midnight in nanoseconds." +); +make_type!( + IntervalYearMonthType, + i32, + DataType::Interval(IntervalUnit::YearMonth), + "A “calendar” interval stored as the number of whole months." +); +make_type!( + IntervalDayTimeType, + i64, + DataType::Interval(IntervalUnit::DayTime), + r#"A “calendar” interval type in days and milliseconds. + +## Representation +This type is stored as a single 64 bit integer, interpreted as two i32 fields: +1. the number of elapsed days +2. The number of milliseconds (no leap seconds), + +```text + ┌──────────────┬──────────────┐ + │ Days │ Milliseconds │ + │ (32 bits) │ (32 bits) │ + └──────────────┴──────────────┘ + 0 31 63 bit offset +``` +Please see the [Arrow Spec](https://github.com/apache/arrow/blob/081b4022fe6f659d8765efc82b3f4787c5039e3c/format/Schema.fbs#L406-L408) for more details + +## Note on Comparing and Ordering for Calendar Types + +Values of `IntervalDayTimeType` are compared using their binary representation, +which can lead to surprising results. Please see the description of ordering on +[`IntervalMonthDayNanoType`] for more details +"# +); +make_type!( + IntervalMonthDayNanoType, + i128, + DataType::Interval(IntervalUnit::MonthDayNano), + r#"A “calendar” interval type in months, days, and nanoseconds. + +## Representation +This type is stored as a single 128 bit integer, +interpreted as three different signed integral fields: + +1. The number of months (32 bits) +2. The number days (32 bits) +2. The number of nanoseconds (64 bits). + +Nanoseconds does not allow for leap seconds. +Each field is independent (e.g. there is no constraint that the quantity of +nanoseconds represents less than a day's worth of time). + +```text +┌──────────────────────────────┬─────────────┬──────────────┐ +│ Nanos │ Days │ Months │ +│ (64 bits) │ (32 bits) │ (32 bits) │ +└──────────────────────────────┴─────────────┴──────────────┘ + 0 63 95 127 bit offset +``` +Please see the [Arrow Spec](https://github.com/apache/arrow/blob/081b4022fe6f659d8765efc82b3f4787c5039e3c/format/Schema.fbs#L409-L415) for more details + +## Note on Comparing and Ordering for Calendar Types +Values of `IntervalMonthDayNanoType` are compared using their binary representation, +which can lead to surprising results. + +Spans of time measured in calendar units are not fixed in absolute size (e.g. +number of seconds) which makes defining comparisons and ordering non trivial. +For example `1 month` is 28 days for February but `1 month` is 31 days +in December. + +This makes the seemingly simple operation of comparing two intervals +complicated in practice. For example is `1 month` more or less than `30 days`? The +answer depends on what month you are talking about. + +This crate defines comparisons for calendar types using their binary +representation which is fast and efficient, but leads +to potentially surprising results. + +For example a +`IntervalMonthDayNano` of `1 month` will compare as **greater** than a +`IntervalMonthDayNano` of `100 days` because the binary representation of `1 month` +is larger than the binary representation of 100 days. +"# +); +make_type!( + DurationSecondType, + i64, + DataType::Duration(TimeUnit::Second), + "An elapsed time type in seconds." +); +make_type!( + DurationMillisecondType, + i64, + DataType::Duration(TimeUnit::Millisecond), + "An elapsed time type in milliseconds." +); +make_type!( + DurationMicrosecondType, + i64, + DataType::Duration(TimeUnit::Microsecond), + "An elapsed time type in microseconds." +); +make_type!( + DurationNanosecondType, + i64, + DataType::Duration(TimeUnit::Nanosecond), + "An elapsed time type in nanoseconds." +); + +/// A subtype of primitive type that represents legal dictionary keys. +/// See +pub trait ArrowDictionaryKeyType: ArrowPrimitiveType {} + +impl ArrowDictionaryKeyType for Int8Type {} + +impl ArrowDictionaryKeyType for Int16Type {} + +impl ArrowDictionaryKeyType for Int32Type {} + +impl ArrowDictionaryKeyType for Int64Type {} + +impl ArrowDictionaryKeyType for UInt8Type {} + +impl ArrowDictionaryKeyType for UInt16Type {} + +impl ArrowDictionaryKeyType for UInt32Type {} + +impl ArrowDictionaryKeyType for UInt64Type {} + +/// A subtype of primitive type that is used as run-ends index +/// in `RunArray`. +/// See +pub trait RunEndIndexType: ArrowPrimitiveType {} + +impl RunEndIndexType for Int16Type {} + +impl RunEndIndexType for Int32Type {} + +impl RunEndIndexType for Int64Type {} + +/// A subtype of primitive type that represents temporal values. +pub trait ArrowTemporalType: ArrowPrimitiveType {} + +impl ArrowTemporalType for TimestampSecondType {} +impl ArrowTemporalType for TimestampMillisecondType {} +impl ArrowTemporalType for TimestampMicrosecondType {} +impl ArrowTemporalType for TimestampNanosecondType {} +impl ArrowTemporalType for Date32Type {} +impl ArrowTemporalType for Date64Type {} +impl ArrowTemporalType for Time32SecondType {} +impl ArrowTemporalType for Time32MillisecondType {} +impl ArrowTemporalType for Time64MicrosecondType {} +impl ArrowTemporalType for Time64NanosecondType {} +// impl ArrowTemporalType for IntervalYearMonthType {} +// impl ArrowTemporalType for IntervalDayTimeType {} +// impl ArrowTemporalType for IntervalMonthDayNanoType {} +impl ArrowTemporalType for DurationSecondType {} +impl ArrowTemporalType for DurationMillisecondType {} +impl ArrowTemporalType for DurationMicrosecondType {} +impl ArrowTemporalType for DurationNanosecondType {} + +/// A timestamp type allows us to create array builders that take a timestamp. +pub trait ArrowTimestampType: ArrowTemporalType { + /// The [`TimeUnit`] of this timestamp. + const UNIT: TimeUnit; + + /// Returns the `TimeUnit` of this timestamp. + #[deprecated(note = "Use Self::UNIT")] + fn get_time_unit() -> TimeUnit { + Self::UNIT + } + + /// Creates a ArrowTimestampType::Native from the provided [`NaiveDateTime`] + /// + /// See [`DataType::Timestamp`] for more information on timezone handling + fn make_value(naive: NaiveDateTime) -> Option; +} + +impl ArrowTimestampType for TimestampSecondType { + const UNIT: TimeUnit = TimeUnit::Second; + + fn make_value(naive: NaiveDateTime) -> Option { + Some(naive.timestamp()) + } +} +impl ArrowTimestampType for TimestampMillisecondType { + const UNIT: TimeUnit = TimeUnit::Millisecond; + + fn make_value(naive: NaiveDateTime) -> Option { + let millis = naive.timestamp().checked_mul(1_000)?; + millis.checked_add(naive.timestamp_subsec_millis() as i64) + } +} +impl ArrowTimestampType for TimestampMicrosecondType { + const UNIT: TimeUnit = TimeUnit::Microsecond; + + fn make_value(naive: NaiveDateTime) -> Option { + let micros = naive.timestamp().checked_mul(1_000_000)?; + micros.checked_add(naive.timestamp_subsec_micros() as i64) + } +} +impl ArrowTimestampType for TimestampNanosecondType { + const UNIT: TimeUnit = TimeUnit::Nanosecond; + + fn make_value(naive: NaiveDateTime) -> Option { + let nanos = naive.timestamp().checked_mul(1_000_000_000)?; + nanos.checked_add(naive.timestamp_subsec_nanos() as i64) + } +} + +fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let months = IntervalYearMonthType::to_months(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_months_datetime(res, months)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_days_datetime(res, days)?; + let res = res.checked_add_signed(Duration::milliseconds(ms as i64))?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = add_months_datetime(res, months)?; + let res = add_days_datetime(res, days)?; + let res = res.checked_add_signed(Duration::nanoseconds(nanos))?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let months = IntervalYearMonthType::to_months(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_months_datetime(res, months)?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_days_datetime(res, days)?; + let res = res.checked_sub_signed(Duration::milliseconds(ms as i64))?; + let res = res.naive_utc(); + T::make_value(res) +} + +fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, +) -> Option<::Native> { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = as_datetime_with_timezone::(timestamp, tz)?; + let res = sub_months_datetime(res, months)?; + let res = sub_days_datetime(res, days)?; + let res = res.checked_sub_signed(Duration::nanoseconds(nanos))?; + let res = res.naive_utc(); + T::make_value(res) +} + +impl TimestampSecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampSecondType. + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampSecondType. + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampSecondType + /// + /// Returns `None` when it will result in overflow. + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl TimestampMicrosecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampMicrosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl TimestampMillisecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampMillisecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl TimestampNanosecondType { + /// Adds the given IntervalYearMonthType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_year_months::(timestamp, delta, tz) + } + + /// Adds the given IntervalDayTimeType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_day_time::(timestamp, delta, tz) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn add_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + add_month_day_nano::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalYearMonthType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_year_months( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_year_months::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalDayTimeType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_day_time( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_day_time::(timestamp, delta, tz) + } + + /// Subtracts the given IntervalMonthDayNanoType to an arrow TimestampNanosecondType + /// + /// # Arguments + /// + /// * `timestamp` - The date on which to perform the operation + /// * `delta` - The interval to add + /// * `tz` - The timezone in which to interpret `timestamp` + pub fn subtract_month_day_nano( + timestamp: ::Native, + delta: ::Native, + tz: Tz, + ) -> Option<::Native> { + subtract_month_day_nano::(timestamp, delta, tz) + } +} + +impl IntervalYearMonthType { + /// Creates a IntervalYearMonthType::Native + /// + /// # Arguments + /// + /// * `years` - The number of years (+/-) represented in this interval + /// * `months` - The number of months (+/-) represented in this interval + #[inline] + pub fn make_value( + years: i32, + months: i32, + ) -> ::Native { + years * 12 + months + } + + /// Turns a IntervalYearMonthType type into an i32 of months. + /// + /// This operation is technically a no-op, it is included for comprehensiveness. + /// + /// # Arguments + /// + /// * `i` - The IntervalYearMonthType::Native to convert + #[inline] + pub fn to_months(i: ::Native) -> i32 { + i + } +} + +impl IntervalDayTimeType { + /// Creates a IntervalDayTimeType::Native + /// + /// # Arguments + /// + /// * `days` - The number of days (+/-) represented in this interval + /// * `millis` - The number of milliseconds (+/-) represented in this interval + #[inline] + pub fn make_value( + days: i32, + millis: i32, + ) -> ::Native { + /* + https://github.com/apache/arrow/blob/02c8598d264c839a5b5cf3109bfd406f3b8a6ba5/cpp/src/arrow/type.h#L1433 + struct DayMilliseconds { + int32_t days = 0; + int32_t milliseconds = 0; + ... + } + 64 56 48 40 32 24 16 8 0 + +-------+-------+-------+-------+-------+-------+-------+-------+ + | days | milliseconds | + +-------+-------+-------+-------+-------+-------+-------+-------+ + */ + let m = millis as u64 & u32::MAX as u64; + let d = (days as u64 & u32::MAX as u64) << 32; + (m | d) as ::Native + } + + /// Turns a IntervalDayTimeType into a tuple of (days, milliseconds) + /// + /// # Arguments + /// + /// * `i` - The IntervalDayTimeType to convert + #[inline] + pub fn to_parts(i: ::Native) -> (i32, i32) { + let days = (i >> 32) as i32; + let ms = i as i32; + (days, ms) + } +} + +impl IntervalMonthDayNanoType { + /// Creates a IntervalMonthDayNanoType::Native + /// + /// # Arguments + /// + /// * `months` - The number of months (+/-) represented in this interval + /// * `days` - The number of days (+/-) represented in this interval + /// * `nanos` - The number of nanoseconds (+/-) represented in this interval + #[inline] + pub fn make_value( + months: i32, + days: i32, + nanos: i64, + ) -> ::Native { + /* + https://github.com/apache/arrow/blob/02c8598d264c839a5b5cf3109bfd406f3b8a6ba5/cpp/src/arrow/type.h#L1475 + struct MonthDayNanos { + int32_t months; + int32_t days; + int64_t nanoseconds; + } + 128 112 96 80 64 48 32 16 0 + +-------+-------+-------+-------+-------+-------+-------+-------+ + | months | days | nanos | + +-------+-------+-------+-------+-------+-------+-------+-------+ + */ + let m = (months as u128 & u32::MAX as u128) << 96; + let d = (days as u128 & u32::MAX as u128) << 64; + let n = nanos as u128 & u64::MAX as u128; + (m | d | n) as ::Native + } + + /// Turns a IntervalMonthDayNanoType into a tuple of (months, days, nanos) + /// + /// # Arguments + /// + /// * `i` - The IntervalMonthDayNanoType to convert + #[inline] + pub fn to_parts( + i: ::Native, + ) -> (i32, i32, i64) { + let months = (i >> 96) as i32; + let days = (i >> 64) as i32; + let nanos = i as i64; + (months, days, nanos) + } +} + +impl Date32Type { + /// Converts an arrow Date32Type into a chrono::NaiveDate + /// + /// # Arguments + /// + /// * `i` - The Date32Type to convert + pub fn to_naive_date(i: ::Native) -> NaiveDate { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + epoch.add(Duration::days(i as i64)) + } + + /// Converts a chrono::NaiveDate into an arrow Date32Type + /// + /// # Arguments + /// + /// * `d` - The NaiveDate to convert + pub fn from_naive_date(d: NaiveDate) -> ::Native { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + d.sub(epoch).num_days() as ::Native + } + + /// Adds the given IntervalYearMonthType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date32Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(delta); + let posterior = shift_months(prior, months); + Date32Type::from_naive_date(posterior) + } + + /// Adds the given IntervalDayTimeType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = res.add(Duration::days(days as i64)); + let res = res.add(Duration::milliseconds(ms as i64)); + Date32Type::from_naive_date(res) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = shift_months(res, months); + let res = res.add(Duration::days(days as i64)); + let res = res.add(Duration::nanoseconds(nanos)); + Date32Type::from_naive_date(res) + } + + /// Subtract the given IntervalYearMonthType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date32Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(-delta); + let posterior = shift_months(prior, months); + Date32Type::from_naive_date(posterior) + } + + /// Subtract the given IntervalDayTimeType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::milliseconds(ms as i64)); + Date32Type::from_naive_date(res) + } + + /// Subtract the given IntervalMonthDayNanoType to an arrow Date32Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date32Type::to_naive_date(date); + let res = shift_months(res, -months); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::nanoseconds(nanos)); + Date32Type::from_naive_date(res) + } +} + +impl Date64Type { + /// Converts an arrow Date64Type into a chrono::NaiveDate + /// + /// # Arguments + /// + /// * `i` - The Date64Type to convert + pub fn to_naive_date(i: ::Native) -> NaiveDate { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + epoch.add(Duration::milliseconds(i)) + } + + /// Converts a chrono::NaiveDate into an arrow Date64Type + /// + /// # Arguments + /// + /// * `d` - The NaiveDate to convert + pub fn from_naive_date(d: NaiveDate) -> ::Native { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + d.sub(epoch).num_milliseconds() as ::Native + } + + /// Adds the given IntervalYearMonthType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date64Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(delta); + let posterior = shift_months(prior, months); + Date64Type::from_naive_date(posterior) + } + + /// Adds the given IntervalDayTimeType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = res.add(Duration::days(days as i64)); + let res = res.add(Duration::milliseconds(ms as i64)); + Date64Type::from_naive_date(res) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + pub fn add_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = shift_months(res, months); + let res = res.add(Duration::days(days as i64)); + let res = res.add(Duration::nanoseconds(nanos)); + Date64Type::from_naive_date(res) + } + + /// Subtract the given IntervalYearMonthType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_year_months( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let prior = Date64Type::to_naive_date(date); + let months = IntervalYearMonthType::to_months(-delta); + let posterior = shift_months(prior, months); + Date64Type::from_naive_date(posterior) + } + + /// Subtract the given IntervalDayTimeType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_day_time( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (days, ms) = IntervalDayTimeType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::milliseconds(ms as i64)); + Date64Type::from_naive_date(res) + } + + /// Subtract the given IntervalMonthDayNanoType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + pub fn subtract_month_day_nano( + date: ::Native, + delta: ::Native, + ) -> ::Native { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); + let res = Date64Type::to_naive_date(date); + let res = shift_months(res, -months); + let res = res.sub(Duration::days(days as i64)); + let res = res.sub(Duration::nanoseconds(nanos)); + Date64Type::from_naive_date(res) + } +} + +/// Crate private types for Decimal Arrays +/// +/// Not intended to be used outside this crate +mod decimal { + use super::*; + + pub trait DecimalTypeSealed {} + impl DecimalTypeSealed for Decimal128Type {} + impl DecimalTypeSealed for Decimal256Type {} +} + +/// A trait over the decimal types, used by [`PrimitiveArray`] to provide a generic +/// implementation across the various decimal types +/// +/// Implemented by [`Decimal128Type`] and [`Decimal256Type`] for [`Decimal128Array`] +/// and [`Decimal256Array`] respectively +/// +/// [`PrimitiveArray`]: crate::array::PrimitiveArray +/// [`Decimal128Array`]: crate::array::Decimal128Array +/// [`Decimal256Array`]: crate::array::Decimal256Array +pub trait DecimalType: + 'static + Send + Sync + ArrowPrimitiveType + decimal::DecimalTypeSealed +{ + /// Width of the type + const BYTE_LENGTH: usize; + /// Maximum number of significant digits + const MAX_PRECISION: u8; + /// Maximum no of digits after the decimal point (note the scale can be negative) + const MAX_SCALE: i8; + /// fn to create its [`DataType`] + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType; + /// Default values for [`DataType`] + const DEFAULT_TYPE: DataType; + + /// "Decimal128" or "Decimal256", for use in error messages + const PREFIX: &'static str; + + /// Formats the decimal value with the provided precision and scale + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String; + + /// Validates that `value` contains no more than `precision` decimal digits + fn validate_decimal_precision(value: Self::Native, precision: u8) -> Result<(), ArrowError>; +} + +/// Validate that `precision` and `scale` are valid for `T` +/// +/// Returns an Error if: +/// - `precision` is zero +/// - `precision` is larger than `T:MAX_PRECISION` +/// - `scale` is larger than `T::MAX_SCALE` +/// - `scale` is > `precision` +pub fn validate_decimal_precision_and_scale( + precision: u8, + scale: i8, +) -> Result<(), ArrowError> { + if precision == 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "precision cannot be 0, has to be between [1, {}]", + T::MAX_PRECISION + ))); + } + if precision > T::MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, + T::MAX_PRECISION + ))); + } + if scale > T::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, + T::MAX_SCALE + ))); + } + if scale > 0 && scale as u8 > precision { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {scale} is greater than precision {precision}" + ))); + } + + Ok(()) +} + +/// The decimal type for a Decimal128Array +#[derive(Debug)] +pub struct Decimal128Type {} + +impl DecimalType for Decimal128Type { + const BYTE_LENGTH: usize = 16; + const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION; + const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128; + const DEFAULT_TYPE: DataType = + DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal128"; + + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) + } + + fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> { + validate_decimal_precision(num, precision) + } +} + +impl ArrowPrimitiveType for Decimal128Type { + type Native = i128; + + const DATA_TYPE: DataType = ::DEFAULT_TYPE; +} + +impl primitive::PrimitiveTypeSealed for Decimal128Type {} + +/// The decimal type for a Decimal256Array +#[derive(Debug)] +pub struct Decimal256Type {} + +impl DecimalType for Decimal256Type { + const BYTE_LENGTH: usize = 32; + const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION; + const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256; + const DEFAULT_TYPE: DataType = + DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal256"; + + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) + } + + fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> { + validate_decimal256_precision(num, precision) + } +} + +impl ArrowPrimitiveType for Decimal256Type { + type Native = i256; + + const DATA_TYPE: DataType = ::DEFAULT_TYPE; +} + +impl primitive::PrimitiveTypeSealed for Decimal256Type {} + +fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + let bound = precision.min(rest.len()) + sign.len(); + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} + +/// Crate private types for Byte Arrays +/// +/// Not intended to be used outside this crate +pub(crate) mod bytes { + use super::*; + + pub trait ByteArrayTypeSealed {} + impl ByteArrayTypeSealed for GenericStringType {} + impl ByteArrayTypeSealed for GenericBinaryType {} + + pub trait ByteArrayNativeType: std::fmt::Debug + Send + Sync { + /// # Safety + /// + /// `b` must be a valid byte sequence for `Self` + unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self; + } + + impl ByteArrayNativeType for [u8] { + #[inline] + unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self { + b + } + } + + impl ByteArrayNativeType for str { + #[inline] + unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self { + std::str::from_utf8_unchecked(b) + } + } +} + +/// A trait over the variable-size byte array types +/// +/// See [Variable Size Binary Layout](https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout) +pub trait ByteArrayType: 'static + Send + Sync + bytes::ByteArrayTypeSealed { + /// Type of offset i.e i32/i64 + type Offset: OffsetSizeTrait; + /// Type for representing its equivalent rust type i.e + /// Utf8Array will have native type has &str + /// BinaryArray will have type as [u8] + type Native: bytes::ByteArrayNativeType + AsRef + AsRef<[u8]> + ?Sized; + + /// "Binary" or "String", for use in error messages + const PREFIX: &'static str; + + /// Datatype of array elements + const DATA_TYPE: DataType; + + /// Verifies that every consecutive pair of `offsets` denotes a valid slice of `values` + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError>; +} + +/// [`ByteArrayType`] for string arrays +pub struct GenericStringType { + phantom: PhantomData, +} + +impl ByteArrayType for GenericStringType { + type Offset = O; + type Native = str; + const PREFIX: &'static str = "String"; + + const DATA_TYPE: DataType = if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError> { + // Verify that the slice as a whole is valid UTF-8 + let validated = std::str::from_utf8(values).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Encountered non UTF-8 data: {e}")) + })?; + + // Verify each offset is at a valid character boundary in this UTF-8 array + for offset in offsets.iter() { + let o = offset.as_usize(); + if !validated.is_char_boundary(o) { + if o < validated.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Split UTF-8 codepoint at offset {o}" + ))); + } + return Err(ArrowError::InvalidArgumentError(format!( + "Offset of {o} exceeds length of values {}", + validated.len() + ))); + } + } + Ok(()) + } +} + +/// An arrow utf8 array with i32 offsets +pub type Utf8Type = GenericStringType; +/// An arrow utf8 array with i64 offsets +pub type LargeUtf8Type = GenericStringType; + +/// [`ByteArrayType`] for binary arrays +pub struct GenericBinaryType { + phantom: PhantomData, +} + +impl ByteArrayType for GenericBinaryType { + type Offset = O; + type Native = [u8]; + const PREFIX: &'static str = "Binary"; + + const DATA_TYPE: DataType = if O::IS_LARGE { + DataType::LargeBinary + } else { + DataType::Binary + }; + + fn validate(offsets: &OffsetBuffer, values: &Buffer) -> Result<(), ArrowError> { + // offsets are guaranteed to be monotonically increasing and non-empty + let max_offset = offsets.last().unwrap().as_usize(); + if values.len() < max_offset { + return Err(ArrowError::InvalidArgumentError(format!( + "Maximum offset of {max_offset} is larger than values of length {}", + values.len() + ))); + } + Ok(()) + } +} + +/// An arrow binary array with i32 offsets +pub type BinaryType = GenericBinaryType; +/// An arrow binary array with i64 offsets +pub type LargeBinaryType = GenericBinaryType; + +#[cfg(test)] +mod tests { + use super::*; + use arrow_data::{layout, BufferSpec}; + + #[test] + fn month_day_nano_should_roundtrip() { + let value = IntervalMonthDayNanoType::make_value(1, 2, 3); + assert_eq!(IntervalMonthDayNanoType::to_parts(value), (1, 2, 3)); + } + + #[test] + fn month_day_nano_should_roundtrip_neg() { + let value = IntervalMonthDayNanoType::make_value(-1, -2, -3); + assert_eq!(IntervalMonthDayNanoType::to_parts(value), (-1, -2, -3)); + } + + #[test] + fn day_time_should_roundtrip() { + let value = IntervalDayTimeType::make_value(1, 2); + assert_eq!(IntervalDayTimeType::to_parts(value), (1, 2)); + } + + #[test] + fn day_time_should_roundtrip_neg() { + let value = IntervalDayTimeType::make_value(-1, -2); + assert_eq!(IntervalDayTimeType::to_parts(value), (-1, -2)); + } + + #[test] + fn year_month_should_roundtrip() { + let value = IntervalYearMonthType::make_value(1, 2); + assert_eq!(IntervalYearMonthType::to_months(value), 14); + } + + #[test] + fn year_month_should_roundtrip_neg() { + let value = IntervalYearMonthType::make_value(-1, -2); + assert_eq!(IntervalYearMonthType::to_months(value), -14); + } + + fn test_layout() { + let layout = layout(&T::DATA_TYPE); + + assert_eq!(layout.buffers.len(), 1); + + let spec = &layout.buffers[0]; + assert_eq!( + spec, + &BufferSpec::FixedWidth { + byte_width: std::mem::size_of::(), + alignment: std::mem::align_of::(), + } + ); + } + + #[test] + fn test_layouts() { + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + test_layout::(); + } +} diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml new file mode 100644 index 000000000000..9575874c41d2 --- /dev/null +++ b/arrow-avro/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-avro" +version = { workspace = true } +description = "Support for parsing Avro format into the Arrow format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_avro" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +serde_json = { version = "1.0", default-features = false, features = ["std"] } +serde = { version = "1.0.188", features = ["derive"] } + +[dev-dependencies] + diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs new file mode 100644 index 000000000000..a1a44fc22b68 --- /dev/null +++ b/arrow-avro/src/compression.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::{Deserialize, Serialize}; + +/// The metadata key used for storing the JSON encoded [`CompressionCodec`] +pub const CODEC_METADATA_KEY: &str = "avro.codec"; + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CompressionCodec { + Null, + Deflate, + BZip2, + Snappy, + XZ, + ZStandard, +} diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs new file mode 100644 index 000000000000..c76ecb399a45 --- /dev/null +++ b/arrow-avro/src/lib.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Convert data to / from the [Apache Arrow] memory format and [Apache Avro] +//! +//! [Apache Arrow]: https://arrow.apache.org +//! [Apache Avro]: https://avro.apache.org/ + +#![allow(unused)] // Temporary + +pub mod reader; +mod schema; + +mod compression; + +#[cfg(test)] +mod test_util { + pub fn arrow_test_data(path: &str) -> String { + match std::env::var("ARROW_TEST_DATA") { + Ok(dir) => format!("{dir}/{path}"), + Err(_) => format!("../testing/data/{path}"), + } + } +} diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs new file mode 100644 index 000000000000..479f0ef90909 --- /dev/null +++ b/arrow-avro/src/reader/block.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for [`Block`] + +use crate::reader::vlq::VLQDecoder; +use arrow_schema::ArrowError; + +/// A file data block +/// +/// +#[derive(Debug, Default)] +pub struct Block { + /// The number of objects in this block + pub count: usize, + /// The serialized objects within this block + pub data: Vec, + /// The sync marker + pub sync: [u8; 16], +} + +/// A decoder for [`Block`] +#[derive(Debug)] +pub struct BlockDecoder { + state: BlockDecoderState, + in_progress: Block, + vlq_decoder: VLQDecoder, + bytes_remaining: usize, +} + +#[derive(Debug)] +enum BlockDecoderState { + Count, + Size, + Data, + Sync, + Finished, +} + +impl Default for BlockDecoder { + fn default() -> Self { + Self { + state: BlockDecoderState::Count, + in_progress: Default::default(), + vlq_decoder: Default::default(), + bytes_remaining: 0, + } + } +} + +impl BlockDecoder { + /// Parse [`Block`] from `buf`, returning the number of bytes read + /// + /// This method can be called multiple times with consecutive chunks of data, allowing + /// integration with chunked IO systems like [`BufRead::fill_buf`] + /// + /// All errors should be considered fatal, and decoding aborted + /// + /// Once an entire [`Block`] has been decoded this method will not read any further + /// input bytes, until [`Self::flush`] is called. Afterwards [`Self::decode`] + /// can then be used again to read the next block, if any + /// + /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf + pub fn decode(&mut self, mut buf: &[u8]) -> Result { + let max_read = buf.len(); + while !buf.is_empty() { + match self.state { + BlockDecoderState::Count => { + if let Some(c) = self.vlq_decoder.long(&mut buf) { + self.in_progress.count = c.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Block count cannot be negative, got {c}" + )) + })?; + + self.state = BlockDecoderState::Size; + } + } + BlockDecoderState::Size => { + if let Some(c) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = c.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Block size cannot be negative, got {c}" + )) + })?; + + self.in_progress.data.reserve(self.bytes_remaining); + self.state = BlockDecoderState::Data; + } + } + BlockDecoderState::Data => { + let to_read = self.bytes_remaining.min(buf.len()); + self.in_progress.data.extend_from_slice(&buf[..to_read]); + buf = &buf[to_read..]; + self.bytes_remaining -= to_read; + if self.bytes_remaining == 0 { + self.bytes_remaining = 16; + self.state = BlockDecoderState::Sync; + } + } + BlockDecoderState::Sync => { + let to_decode = buf.len().min(self.bytes_remaining); + let write = &mut self.in_progress.sync[16 - to_decode..]; + write[..to_decode].copy_from_slice(&buf[..to_decode]); + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = BlockDecoderState::Finished; + } + } + BlockDecoderState::Finished => return Ok(max_read - buf.len()), + } + } + Ok(max_read) + } + + /// Flush this decoder returning the parsed [`Block`] if any + pub fn flush(&mut self) -> Option { + match self.state { + BlockDecoderState::Finished => { + self.state = BlockDecoderState::Count; + Some(std::mem::take(&mut self.in_progress)) + } + _ => None, + } + } +} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs new file mode 100644 index 000000000000..00e85b39be73 --- /dev/null +++ b/arrow-avro/src/reader/header.rs @@ -0,0 +1,288 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for [`Header`] + +use crate::reader::vlq::VLQDecoder; +use crate::schema::Schema; +use arrow_schema::ArrowError; + +#[derive(Debug)] +enum HeaderDecoderState { + /// Decoding the [`MAGIC`] prefix + Magic, + /// Decoding a block count + BlockCount, + /// Decoding a block byte length + BlockLen, + /// Decoding a key length + KeyLen, + /// Decoding a key string + Key, + /// Decoding a value length + ValueLen, + /// Decoding a value payload + Value, + /// Decoding sync marker + Sync, + /// Finished decoding + Finished, +} + +/// A decoded header for an [Object Container File](https://avro.apache.org/docs/1.11.1/specification/#object-container-files) +#[derive(Debug, Clone)] +pub struct Header { + meta_offsets: Vec, + meta_buf: Vec, + sync: [u8; 16], +} + +impl Header { + /// Returns an iterator over the meta keys in this header + pub fn metadata(&self) -> impl Iterator { + let mut last = 0; + self.meta_offsets.windows(2).map(move |w| { + let start = last; + last = w[1]; + (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]]) + }) + } + + /// Returns the value for a given metadata key if present + pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> { + self.metadata() + .find_map(|(k, v)| (k == key.as_ref()).then_some(v)) + } + + /// Returns the sync token for this file + pub fn sync(&self) -> [u8; 16] { + self.sync + } +} + +/// A decoder for [`Header`] +/// +/// The avro file format does not encode the length of the header, and so it +/// is necessary to provide a push-based decoder that can be used with streams +#[derive(Debug)] +pub struct HeaderDecoder { + state: HeaderDecoderState, + vlq_decoder: VLQDecoder, + + /// The end offsets of strings in `meta_buf` + meta_offsets: Vec, + /// The raw binary data of the metadata map + meta_buf: Vec, + + /// The decoded sync marker + sync_marker: [u8; 16], + + /// The number of remaining tuples in the current block + tuples_remaining: usize, + /// The number of bytes remaining in the current string/bytes payload + bytes_remaining: usize, +} + +impl Default for HeaderDecoder { + fn default() -> Self { + Self { + state: HeaderDecoderState::Magic, + meta_offsets: vec![], + meta_buf: vec![], + sync_marker: [0; 16], + vlq_decoder: Default::default(), + tuples_remaining: 0, + bytes_remaining: MAGIC.len(), + } + } +} + +const MAGIC: &[u8; 4] = b"Obj\x01"; + +impl HeaderDecoder { + /// Parse [`Header`] from `buf`, returning the number of bytes read + /// + /// This method can be called multiple times with consecutive chunks of data, allowing + /// integration with chunked IO systems like [`BufRead::fill_buf`] + /// + /// All errors should be considered fatal, and decoding aborted + /// + /// Once the entire [`Header`] has been decoded this method will not read any further + /// input bytes, and the header can be obtained with [`Self::flush`] + /// + /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf + pub fn decode(&mut self, mut buf: &[u8]) -> Result { + let max_read = buf.len(); + while !buf.is_empty() { + match self.state { + HeaderDecoderState::Magic => { + let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..]; + let to_decode = buf.len().min(remaining.len()); + if !buf.starts_with(&remaining[..to_decode]) { + return Err(ArrowError::ParseError("Incorrect avro magic".to_string())); + } + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = HeaderDecoderState::BlockCount; + } + } + HeaderDecoderState::BlockCount => { + if let Some(block_count) = self.vlq_decoder.long(&mut buf) { + match block_count.try_into() { + Ok(0) => { + self.state = HeaderDecoderState::Sync; + self.bytes_remaining = 16; + } + Ok(remaining) => { + self.tuples_remaining = remaining; + self.state = HeaderDecoderState::KeyLen; + } + Err(_) => { + self.tuples_remaining = block_count.unsigned_abs() as _; + self.state = HeaderDecoderState::BlockLen; + } + } + } + } + HeaderDecoderState::BlockLen => { + if self.vlq_decoder.long(&mut buf).is_some() { + self.state = HeaderDecoderState::KeyLen + } + } + HeaderDecoderState::Key => { + let to_read = self.bytes_remaining.min(buf.len()); + self.meta_buf.extend_from_slice(&buf[..to_read]); + self.bytes_remaining -= to_read; + buf = &buf[to_read..]; + if self.bytes_remaining == 0 { + self.meta_offsets.push(self.meta_buf.len()); + self.state = HeaderDecoderState::ValueLen; + } + } + HeaderDecoderState::Value => { + let to_read = self.bytes_remaining.min(buf.len()); + self.meta_buf.extend_from_slice(&buf[..to_read]); + self.bytes_remaining -= to_read; + buf = &buf[to_read..]; + if self.bytes_remaining == 0 { + self.meta_offsets.push(self.meta_buf.len()); + + self.tuples_remaining -= 1; + match self.tuples_remaining { + 0 => self.state = HeaderDecoderState::BlockCount, + _ => self.state = HeaderDecoderState::KeyLen, + } + } + } + HeaderDecoderState::KeyLen => { + if let Some(len) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = len as _; + self.state = HeaderDecoderState::Key; + } + } + HeaderDecoderState::ValueLen => { + if let Some(len) = self.vlq_decoder.long(&mut buf) { + self.bytes_remaining = len as _; + self.state = HeaderDecoderState::Value; + } + } + HeaderDecoderState::Sync => { + let to_decode = buf.len().min(self.bytes_remaining); + let write = &mut self.sync_marker[16 - to_decode..]; + write[..to_decode].copy_from_slice(&buf[..to_decode]); + self.bytes_remaining -= to_decode; + buf = &buf[to_decode..]; + if self.bytes_remaining == 0 { + self.state = HeaderDecoderState::Finished; + } + } + HeaderDecoderState::Finished => return Ok(max_read - buf.len()), + } + } + Ok(max_read) + } + + /// Flush this decoder returning the parsed [`Header`] if any + pub fn flush(&mut self) -> Option

{ + match self.state { + HeaderDecoderState::Finished => { + self.state = HeaderDecoderState::Magic; + Some(Header { + meta_offsets: std::mem::take(&mut self.meta_offsets), + meta_buf: std::mem::take(&mut self.meta_buf), + sync: self.sync_marker, + }) + } + _ => None, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::reader::read_header; + use crate::schema::SCHEMA_METADATA_KEY; + use crate::test_util::arrow_test_data; + use std::fs::File; + use std::io::{BufRead, BufReader}; + + #[test] + fn test_header_decode() { + let mut decoder = HeaderDecoder::default(); + for m in MAGIC { + decoder.decode(std::slice::from_ref(m)).unwrap(); + } + + let mut decoder = HeaderDecoder::default(); + assert_eq!(decoder.decode(MAGIC).unwrap(), 4); + + let mut decoder = HeaderDecoder::default(); + decoder.decode(b"Ob").unwrap(); + let err = decoder.decode(b"s").unwrap_err().to_string(); + assert_eq!(err, "Parser error: Incorrect avro magic"); + } + + fn decode_file(file: &str) -> Header { + let file = File::open(file).unwrap(); + read_header(BufReader::with_capacity(100, file)).unwrap() + } + + #[test] + fn test_header() { + let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro")); + let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); + let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#; + assert_eq!(schema_json, expected); + let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + assert_eq!( + u128::from_le_bytes(header.sync()), + 226966037233754408753420635932530907102 + ); + + let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro")); + let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); + let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#; + assert_eq!(schema_json, expected); + let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + assert_eq!( + u128::from_le_bytes(header.sync()), + 325166208089902833952788552656412487328 + ); + } +} diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs new file mode 100644 index 000000000000..7769bbbc4998 --- /dev/null +++ b/arrow-avro/src/reader/mod.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Read Avro data to Arrow + +use crate::reader::block::{Block, BlockDecoder}; +use crate::reader::header::{Header, HeaderDecoder}; +use arrow_schema::ArrowError; +use std::io::BufRead; + +mod header; + +mod block; + +mod vlq; + +/// Read a [`Header`] from the provided [`BufRead`] +fn read_header(mut reader: R) -> Result { + let mut decoder = HeaderDecoder::default(); + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + + decoder + .flush() + .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) +} + +/// Return an iterator of [`Block`] from the provided [`BufRead`] +fn read_blocks(mut reader: R) -> impl Iterator> { + let mut decoder = BlockDecoder::default(); + + let mut try_next = move || { + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + Ok(decoder.flush()) + }; + std::iter::from_fn(move || try_next().transpose()) +} + +#[cfg(test)] +mod test { + use crate::reader::{read_blocks, read_header}; + use crate::test_util::arrow_test_data; + use std::fs::File; + use std::io::BufReader; + + #[test] + fn test_mux() { + let file = File::open(arrow_test_data("avro/alltypes_plain.avro")).unwrap(); + let mut reader = BufReader::new(file); + let header = read_header(&mut reader).unwrap(); + for result in read_blocks(reader) { + let block = result.unwrap(); + assert_eq!(block.sync, header.sync()); + } + } +} diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs new file mode 100644 index 000000000000..80f1c60eec7d --- /dev/null +++ b/arrow-avro/src/reader/vlq.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Decoder for zig-zag encoded variable length (VLW) integers +/// +/// See also: +/// +/// +#[derive(Debug, Default)] +pub struct VLQDecoder { + /// Scratch space for decoding VLQ integers + in_progress: u64, + shift: u32, +} + +impl VLQDecoder { + /// Decode a signed long from `buf` + pub fn long(&mut self, buf: &mut &[u8]) -> Option { + while let Some(byte) = buf.first().copied() { + *buf = &buf[1..]; + self.in_progress |= ((byte & 0x7F) as u64) << self.shift; + self.shift += 7; + if byte & 0x80 == 0 { + let val = self.in_progress; + self.in_progress = 0; + self.shift = 0; + return Some((val >> 1) as i64 ^ -((val & 1) as i64)); + } + } + None + } +} diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs new file mode 100644 index 000000000000..17b82cf861b7 --- /dev/null +++ b/arrow-avro/src/schema.rs @@ -0,0 +1,482 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// The metadata key used for storing the JSON encoded [`Schema`] +pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; + +/// Either a [`PrimitiveType`] or a reference to a previously defined named type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum TypeName<'a> { + Primitive(PrimitiveType), + Ref(&'a str), +} + +/// A primitive type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum PrimitiveType { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, +} + +/// Additional attributes within a [`Schema`] +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Attributes<'a> { + /// A logical type name + /// + /// + #[serde(default)] + pub logical_type: Option<&'a str>, + + /// Additional JSON attributes + #[serde(flatten)] + pub additional: HashMap<&'a str, serde_json::Value>, +} + +/// A type definition that is not a variant of [`ComplexType`] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Type<'a> { + #[serde(borrow)] + pub r#type: TypeName<'a>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// An Avro schema +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Schema<'a> { + #[serde(borrow)] + TypeName(TypeName<'a>), + #[serde(borrow)] + Union(Vec>), + #[serde(borrow)] + Complex(ComplexType<'a>), + #[serde(borrow)] + Type(Type<'a>), +} + +/// A complex type +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ComplexType<'a> { + #[serde(borrow)] + Union(Vec>), + #[serde(borrow)] + Record(Record<'a>), + #[serde(borrow)] + Enum(Enum<'a>), + #[serde(borrow)] + Array(Array<'a>), + #[serde(borrow)] + Map(Map<'a>), + #[serde(borrow)] + Fixed(Fixed<'a>), +} + +/// A record +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Record<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + #[serde(borrow)] + pub fields: Vec>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A field within a [`Record`] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Field<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow)] + pub r#type: Schema<'a>, + #[serde(borrow, default)] + pub default: Option<&'a str>, +} + +/// An enumeration +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Enum<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + #[serde(borrow)] + pub symbols: Vec<&'a str>, + #[serde(borrow, default)] + pub default: Option<&'a str>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// An array +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Array<'a> { + #[serde(borrow)] + pub items: Box>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A map +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Map<'a> { + #[serde(borrow)] + pub values: Box>, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +/// A fixed length binary array +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Fixed<'a> { + #[serde(borrow)] + pub name: &'a str, + #[serde(borrow, default)] + pub namespace: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, + pub size: usize, + #[serde(flatten)] + pub attributes: Attributes<'a>, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + #[test] + fn test_deserialize() { + let t: Schema = serde_json::from_str("\"string\"").unwrap(); + assert_eq!( + t, + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)) + ); + + let t: Schema = serde_json::from_str("[\"int\", \"null\"]").unwrap(); + assert_eq!( + t, + Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]) + ); + + let t: Type = serde_json::from_str( + r#"{ + "type":"long", + "logicalType":"timestamp-micros" + }"#, + ) + .unwrap(); + + let timestamp = Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type: Some("timestamp-micros"), + additional: Default::default(), + }, + }; + + assert_eq!(t, timestamp); + + let t: ComplexType = serde_json::from_str( + r#"{ + "type":"fixed", + "name":"fixed", + "namespace":"topLevelRecord.value", + "size":11, + "logicalType":"decimal", + "precision":25, + "scale":2 + }"#, + ) + .unwrap(); + + let decimal = ComplexType::Fixed(Fixed { + name: "fixed", + namespace: Some("topLevelRecord.value"), + aliases: vec![], + size: 11, + attributes: Attributes { + logical_type: Some("decimal"), + additional: vec![("precision", json!(25)), ("scale", json!(2))] + .into_iter() + .collect(), + }, + }); + + assert_eq!(t, decimal); + + let schema: Schema = serde_json::from_str( + r#"{ + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"value", + "type":[ + { + "type":"fixed", + "name":"fixed", + "namespace":"topLevelRecord.value", + "size":11, + "logicalType":"decimal", + "precision":25, + "scale":2 + }, + "null" + ] + } + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "topLevelRecord", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "value", + doc: None, + r#type: Schema::Union(vec![ + Schema::Complex(decimal), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + },], + attributes: Default::default(), + })) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "LongList", + "aliases": ["LinkedLongs"], + "fields" : [ + {"name": "value", "type": "long"}, + {"name": "next", "type": ["null", "LongList"]} + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "LongList", + namespace: None, + doc: None, + aliases: vec!["LinkedLongs"], + fields: vec![ + Field { + name: "value", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + }, + Field { + name: "next", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Ref("LongList")), + ]), + default: None, + } + ], + attributes: Attributes::default(), + })) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"id", + "type":[ + "int", + "null" + ] + }, + { + "name":"timestamp_col", + "type":[ + { + "type":"long", + "logicalType":"timestamp-micros" + }, + "null" + ] + } + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "topLevelRecord", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + Field { + name: "id", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + }, + Field { + name: "timestamp_col", + doc: None, + r#type: Schema::Union(vec![ + Schema::Type(timestamp), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]), + default: None, + } + ], + attributes: Default::default(), + })) + ); + + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc", + "fields": [ + {"name": "clientHash", + "type": {"type": "fixed", "name": "MD5", "size": 16}}, + {"name": "clientProtocol", "type": ["null", "string"]}, + {"name": "serverHash", "type": "MD5"}, + {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]} + ] + }"#, + ) + .unwrap(); + + assert_eq!( + schema, + Schema::Complex(ComplexType::Record(Record { + name: "HandshakeRequest", + namespace: Some("org.apache.avro.ipc"), + doc: None, + aliases: vec![], + fields: vec![ + Field { + name: "clientHash", + doc: None, + r#type: Schema::Complex(ComplexType::Fixed(Fixed { + name: "MD5", + namespace: None, + aliases: vec![], + size: 16, + attributes: Default::default(), + })), + default: None, + }, + Field { + name: "clientProtocol", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]), + default: None, + }, + Field { + name: "serverHash", + doc: None, + r#type: Schema::TypeName(TypeName::Ref("MD5")), + default: None, + }, + Field { + name: "meta", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::Complex(ComplexType::Map(Map { + values: Box::new(Schema::TypeName(TypeName::Primitive( + PrimitiveType::Bytes + ))), + attributes: Default::default(), + })), + ]), + default: None, + } + ], + attributes: Default::default(), + })) + ); + } +} diff --git a/arrow-buffer/Cargo.toml b/arrow-buffer/Cargo.toml new file mode 100644 index 000000000000..746045cc8dde --- /dev/null +++ b/arrow-buffer/Cargo.toml @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-buffer" +version = { workspace = true } +description = "Buffer abstractions for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_buffer" +path = "src/lib.rs" +bench = false + +[dependencies] +bytes = { version = "1.4" } +num = { version = "0.4", default-features = false, features = ["std"] } +half = { version = "2.1", default-features = false } + +[dev-dependencies] +criterion = { version = "0.5", default-features = false } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } + +[build-dependencies] + +[[bench]] +name = "i256" +harness = false \ No newline at end of file diff --git a/arrow-buffer/benches/i256.rs b/arrow-buffer/benches/i256.rs new file mode 100644 index 000000000000..ebb45e793bd0 --- /dev/null +++ b/arrow-buffer/benches/i256.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::i256; +use criterion::*; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::str::FromStr; + +const SIZE: usize = 1024; + +fn criterion_benchmark(c: &mut Criterion) { + let numbers = vec![ + i256::ZERO, + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(1233456789), + i256::from_i128(-1233456789), + i256::from_i128(i128::MAX), + i256::from_i128(i128::MIN), + i256::MIN, + i256::MAX, + ]; + + for number in numbers { + let t = black_box(number.to_string()); + c.bench_function(&format!("i256_parse({t})"), |b| { + b.iter(|| i256::from_str(&t).unwrap()); + }); + } + + let mut rng = StdRng::seed_from_u64(42); + + let numerators: Vec<_> = (0..SIZE) + .map(|_| { + let high = rng.gen_range(1000..i128::MAX); + let low = rng.gen(); + i256::from_parts(low, high) + }) + .collect(); + + let divisors: Vec<_> = numerators + .iter() + .map(|n| { + let quotient = rng.gen_range(1..100_i32); + n.wrapping_div(i256::from(quotient)) + }) + .collect(); + + c.bench_function("i256_div_rem small quotient", |b| { + b.iter(|| { + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); + } + }); + }); + + let divisors: Vec<_> = (0..SIZE) + .map(|_| i256::from(rng.gen_range(1..100_i32))) + .collect(); + + c.bench_function("i256_div_rem small divisor", |b| { + b.iter(|| { + for (n, d) in numerators.iter().zip(&divisors) { + black_box(n.wrapping_div(*d)); + } + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow/src/alloc/alignment.rs b/arrow-buffer/src/alloc/alignment.rs similarity index 97% rename from arrow/src/alloc/alignment.rs rename to arrow-buffer/src/alloc/alignment.rs index 1bd15c54b990..b3979e1d6a06 100644 --- a/arrow/src/alloc/alignment.rs +++ b/arrow-buffer/src/alloc/alignment.rs @@ -18,7 +18,7 @@ // NOTE: Below code is written for spatial/temporal prefetcher optimizations. Memory allocation // should align well with usage pattern of cache access and block sizes on layers of storage levels from // registers to non-volatile memory. These alignments are all cache aware alignments incorporated -// from [cuneiform](https://crates.io/crates/cuneiform) crate. This approach mimicks Intel TBB's +// from [cuneiform](https://crates.io/crates/cuneiform) crate. This approach mimics Intel TBB's // cache_aligned_allocator which exploits cache locality and minimizes prefetch signals // resulting in less round trip time between the layers of storage. // For further info: https://software.intel.com/en-us/node/506094 @@ -117,3 +117,7 @@ pub const ALIGNMENT: usize = 1 << 7; /// Cache and allocation multiple alignment size #[cfg(target_arch = "aarch64")] pub const ALIGNMENT: usize = 1 << 6; + +/// Cache and allocation multiple alignment size +#[cfg(target_arch = "loongarch64")] +pub const ALIGNMENT: usize = 1 << 6; diff --git a/arrow-buffer/src/alloc/mod.rs b/arrow-buffer/src/alloc/mod.rs new file mode 100644 index 000000000000..a3cb6253f324 --- /dev/null +++ b/arrow-buffer/src/alloc/mod.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the low-level [`Allocation`] API for shared memory regions + +use std::alloc::Layout; +use std::fmt::{Debug, Formatter}; +use std::panic::RefUnwindSafe; +use std::sync::Arc; + +mod alignment; + +pub use alignment::ALIGNMENT; + +/// The owner of an allocation. +/// The trait implementation is responsible for dropping the allocations once no more references exist. +pub trait Allocation: RefUnwindSafe + Send + Sync {} + +impl Allocation for T {} + +/// Mode of deallocating memory regions +pub(crate) enum Deallocation { + /// An allocation using [`std::alloc`] + Standard(Layout), + /// An allocation from an external source like the FFI interface + /// Deallocation will happen on `Allocation::drop` + Custom(Arc), +} + +impl Debug for Deallocation { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Deallocation::Standard(layout) => { + write!(f, "Deallocation::Standard {layout:?}") + } + Deallocation::Custom(_) => { + write!(f, "Deallocation::Custom {{ capacity: unknown }}") + } + } + } +} diff --git a/arrow-buffer/src/bigint/div.rs b/arrow-buffer/src/bigint/div.rs new file mode 100644 index 000000000000..e1b2ed4f8aa5 --- /dev/null +++ b/arrow-buffer/src/bigint/div.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! N-digit division +//! +//! Implementation heavily inspired by [uint] +//! +//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844 + +/// Unsigned, little-endian, n-digit division with remainder +/// +/// # Panics +/// +/// Panics if divisor is zero +pub fn div_rem(numerator: &[u64; N], divisor: &[u64; N]) -> ([u64; N], [u64; N]) { + let numerator_bits = bits(numerator); + let divisor_bits = bits(divisor); + assert_ne!(divisor_bits, 0, "division by zero"); + + if numerator_bits < divisor_bits { + return ([0; N], *numerator); + } + + if divisor_bits <= 64 { + return div_rem_small(numerator, divisor[0]); + } + + let numerator_words = (numerator_bits + 63) / 64; + let divisor_words = (divisor_bits + 63) / 64; + let n = divisor_words; + let m = numerator_words - divisor_words; + + div_rem_knuth(numerator, divisor, n, m) +} + +/// Return the least number of bits needed to represent the number +fn bits(arr: &[u64]) -> usize { + for (idx, v) in arr.iter().enumerate().rev() { + if *v > 0 { + return 64 - v.leading_zeros() as usize + 64 * idx; + } + } + 0 +} + +/// Division of numerator by a u64 divisor +fn div_rem_small(numerator: &[u64; N], divisor: u64) -> ([u64; N], [u64; N]) { + let mut rem = 0u64; + let mut numerator = *numerator; + numerator.iter_mut().rev().for_each(|d| { + let (q, r) = div_rem_word(rem, *d, divisor); + *d = q; + rem = r; + }); + + let mut rem_padded = [0; N]; + rem_padded[0] = rem; + (numerator, rem_padded) +} + +/// Use Knuth Algorithm D to compute `numerator / divisor` returning the +/// quotient and remainder +/// +/// `n` is the number of non-zero 64-bit words in `divisor` +/// `m` is the number of non-zero 64-bit words present in `numerator` beyond `divisor`, and +/// therefore the number of words in the quotient +/// +/// A good explanation of the algorithm can be found [here](https://ridiculousfish.com/blog/posts/labor-of-division-episode-iv.html) +fn div_rem_knuth( + numerator: &[u64; N], + divisor: &[u64; N], + n: usize, + m: usize, +) -> ([u64; N], [u64; N]) { + assert!(n + m <= N); + + // The algorithm works by incrementally generating guesses `q_hat`, for the next digit + // of the quotient, starting from the most significant digit. + // + // This relies on the property that for any `q_hat` where + // + // (q_hat << (j * 64)) * divisor <= numerator` + // + // We can set + // + // q += q_hat << (j * 64) + // numerator -= (q_hat << (j * 64)) * divisor + // + // And then iterate until `numerator < divisor` + + // We normalize the divisor so that the highest bit in the highest digit of the + // divisor is set, this ensures our initial guess of `q_hat` is at most 2 off from + // the correct value for q[j] + let shift = divisor[n - 1].leading_zeros(); + // As the shift is computed based on leading zeros, don't need to perform full_shl + let divisor = shl_word(divisor, shift); + // numerator may have fewer leading zeros than divisor, so must add another digit + let mut numerator = full_shl(numerator, shift); + + // The two most significant digits of the divisor + let b0 = divisor[n - 1]; + let b1 = divisor[n - 2]; + + let mut q = [0; N]; + + for j in (0..=m).rev() { + let a0 = numerator[j + n]; + let a1 = numerator[j + n - 1]; + + let mut q_hat = if a0 < b0 { + // The first estimate is [a1, a0] / b0, it may be too large by at most 2 + let (mut q_hat, mut r_hat) = div_rem_word(a0, a1, b0); + + // r_hat = [a1, a0] - q_hat * b0 + // + // Now we want to compute a more precise estimate [a2,a1,a0] / [b1,b0] + // which can only be less or equal to the current q_hat + // + // q_hat is too large if: + // [a2,a1,a0] < q_hat * [b1,b0] + // [a2,r_hat] < q_hat * b1 + let a2 = numerator[j + n - 2]; + loop { + let r = u128::from(q_hat) * u128::from(b1); + let (lo, hi) = (r as u64, (r >> 64) as u64); + if (hi, lo) <= (r_hat, a2) { + break; + } + + q_hat -= 1; + let (new_r_hat, overflow) = r_hat.overflowing_add(b0); + r_hat = new_r_hat; + + if overflow { + break; + } + } + q_hat + } else { + u64::MAX + }; + + // q_hat is now either the correct quotient digit, or in rare cases 1 too large + + // Compute numerator -= (q_hat * divisor) << (j * 64) + let q_hat_v = full_mul_u64(&divisor, q_hat); + let c = sub_assign(&mut numerator[j..], &q_hat_v[..n + 1]); + + // If underflow, q_hat was too large by 1 + if c { + // Reduce q_hat by 1 + q_hat -= 1; + + // Add back one multiple of divisor + let c = add_assign(&mut numerator[j..], &divisor[..n]); + numerator[j + n] = numerator[j + n].wrapping_add(u64::from(c)); + } + + // q_hat is the correct value for q[j] + q[j] = q_hat; + } + + // The remainder is what is left in numerator, with the initial normalization shl reversed + let remainder = full_shr(&numerator, shift); + (q, remainder) +} + +/// Perform narrowing division of a u128 by a u64 divisor, returning the quotient and remainder +/// +/// This method may trap or panic if hi >= divisor, i.e. the quotient would not fit +/// into a 64-bit integer +fn div_rem_word(hi: u64, lo: u64, divisor: u64) -> (u64, u64) { + debug_assert!(hi < divisor); + debug_assert_ne!(divisor, 0); + + // LLVM fails to use the div instruction as it is not able to prove + // that hi < divisor, and therefore the result will fit into 64-bits + #[cfg(target_arch = "x86_64")] + unsafe { + let mut quot = lo; + let mut rem = hi; + std::arch::asm!( + "div {divisor}", + divisor = in(reg) divisor, + inout("rax") quot, + inout("rdx") rem, + options(pure, nomem, nostack) + ); + (quot, rem) + } + #[cfg(not(target_arch = "x86_64"))] + { + let x = (u128::from(hi) << 64) + u128::from(lo); + let y = u128::from(divisor); + ((x / y) as u64, (x % y) as u64) + } +} + +/// Perform `a += b` +fn add_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_add) +} + +/// Perform `a -= b` +fn sub_assign(a: &mut [u64], b: &[u64]) -> bool { + binop_slice(a, b, u64::overflowing_sub) +} + +/// Converts an overflowing binary operation on scalars to one on slices +fn binop_slice(a: &mut [u64], b: &[u64], binop: impl Fn(u64, u64) -> (u64, bool) + Copy) -> bool { + let mut c = false; + a.iter_mut().zip(b.iter()).for_each(|(x, y)| { + let (res1, overflow1) = y.overflowing_add(u64::from(c)); + let (res2, overflow2) = binop(*x, res1); + *x = res2; + c = overflow1 || overflow2; + }); + c +} + +/// Widening multiplication of an N-digit array with a u64 +fn full_mul_u64(a: &[u64; N], b: u64) -> ArrayPlusOne { + let mut carry = 0; + let mut out = [0; N]; + out.iter_mut().zip(a).for_each(|(o, v)| { + let r = *v as u128 * b as u128 + carry as u128; + *o = r as u64; + carry = (r >> 64) as u64; + }); + ArrayPlusOne(out, carry) +} + +/// Left shift of an N-digit array by at most 63 bits +fn shl_word(v: &[u64; N], shift: u32) -> [u64; N] { + full_shl(v, shift).0 +} + +/// Widening left shift of an N-digit array by at most 63 bits +fn full_shl(v: &[u64; N], shift: u32) -> ArrayPlusOne { + debug_assert!(shift < 64); + if shift == 0 { + return ArrayPlusOne(*v, 0); + } + let mut out = [0u64; N]; + out[0] = v[0] << shift; + for i in 1..N { + out[i] = v[i - 1] >> (64 - shift) | v[i] << shift + } + let carry = v[N - 1] >> (64 - shift); + ArrayPlusOne(out, carry) +} + +/// Narrowing right shift of an (N+1)-digit array by at most 63 bits +fn full_shr(a: &ArrayPlusOne, shift: u32) -> [u64; N] { + debug_assert!(shift < 64); + if shift == 0 { + return a.0; + } + let mut out = [0; N]; + for i in 0..N - 1 { + out[i] = a[i] >> shift | a[i + 1] << (64 - shift) + } + out[N - 1] = a[N - 1] >> shift; + out +} + +/// An array of N + 1 elements +/// +/// This is a hack around lack of support for const arithmetic +#[repr(C)] +struct ArrayPlusOne([T; N], T); + +impl std::ops::Deref for ArrayPlusOne { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + let x = self as *const Self; + unsafe { std::slice::from_raw_parts(x as *const T, N + 1) } + } +} + +impl std::ops::DerefMut for ArrayPlusOne { + fn deref_mut(&mut self) -> &mut Self::Target { + let x = self as *mut Self; + unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) } + } +} diff --git a/arrow-buffer/src/bigint/mod.rs b/arrow-buffer/src/bigint/mod.rs new file mode 100644 index 000000000000..afbb3a31df12 --- /dev/null +++ b/arrow-buffer/src/bigint/mod.rs @@ -0,0 +1,1267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bigint::div::div_rem; +use num::cast::AsPrimitive; +use num::{BigInt, FromPrimitive, ToPrimitive}; +use std::cmp::Ordering; +use std::num::ParseIntError; +use std::ops::{BitAnd, BitOr, BitXor, Neg, Shl, Shr}; +use std::str::FromStr; + +mod div; + +/// An opaque error similar to [`std::num::ParseIntError`] +#[derive(Debug)] +pub struct ParseI256Error {} + +impl From for ParseI256Error { + fn from(_: ParseIntError) -> Self { + Self {} + } +} + +impl std::fmt::Display for ParseI256Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to parse as i256") + } +} +impl std::error::Error for ParseI256Error {} + +/// Error returned by i256::DivRem +enum DivRemError { + /// Division by zero + DivideByZero, + /// Division overflow + DivideOverflow, +} + +/// A signed 256-bit integer +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, Default, Eq, PartialEq, Hash)] +pub struct i256 { + low: u128, + high: i128, +} + +impl std::fmt::Debug for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + +impl std::fmt::Display for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", BigInt::from_signed_bytes_le(&self.to_le_bytes())) + } +} + +impl FromStr for i256 { + type Err = ParseI256Error; + + fn from_str(s: &str) -> Result { + // i128 can store up to 38 decimal digits + if s.len() <= 38 { + return Ok(Self::from_i128(i128::from_str(s)?)); + } + + let (negative, s) = match s.as_bytes()[0] { + b'-' => (true, &s[1..]), + b'+' => (false, &s[1..]), + _ => (false, s), + }; + + // Trim leading 0s + let s = s.trim_start_matches('0'); + if s.is_empty() { + return Ok(i256::ZERO); + } + + if !s.as_bytes()[0].is_ascii_digit() { + // Ensures no duplicate sign + return Err(ParseI256Error {}); + } + + parse_impl(s, negative) + } +} + +impl From for i256 { + fn from(value: i8) -> Self { + Self::from_i128(value.into()) + } +} + +impl From for i256 { + fn from(value: i16) -> Self { + Self::from_i128(value.into()) + } +} + +impl From for i256 { + fn from(value: i32) -> Self { + Self::from_i128(value.into()) + } +} + +impl From for i256 { + fn from(value: i64) -> Self { + Self::from_i128(value.into()) + } +} + +/// Parse `s` with any sign and leading 0s removed +fn parse_impl(s: &str, negative: bool) -> Result { + if s.len() <= 38 { + let low = i128::from_str(s)?; + return Ok(match negative { + true => i256::from_parts(low.neg() as _, -1), + false => i256::from_parts(low as _, 0), + }); + } + + let split = s.len() - 38; + if !s.as_bytes()[split].is_ascii_digit() { + // Ensures not splitting codepoint and no sign + return Err(ParseI256Error {}); + } + let (hs, ls) = s.split_at(split); + + let mut low = i128::from_str(ls)?; + let high = parse_impl(hs, negative)?; + + if negative { + low = -low; + } + + let low = i256::from_i128(low); + + high.checked_mul(i256::from_i128(10_i128.pow(38))) + .and_then(|high| high.checked_add(low)) + .ok_or(ParseI256Error {}) +} + +impl PartialOrd for i256 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for i256 { + fn cmp(&self, other: &Self) -> Ordering { + // This is 25x faster than using a variable length encoding such + // as BigInt as it avoids allocation and branching + self.high.cmp(&other.high).then(self.low.cmp(&other.low)) + } +} + +impl i256 { + /// The additive identity for this integer type, i.e. `0`. + pub const ZERO: Self = i256 { low: 0, high: 0 }; + + /// The multiplicative identity for this integer type, i.e. `1`. + pub const ONE: Self = i256 { low: 1, high: 0 }; + + /// The multiplicative inverse for this integer type, i.e. `-1`. + pub const MINUS_ONE: Self = i256 { + low: u128::MAX, + high: -1, + }; + + /// The maximum value that can be represented by this integer type + pub const MAX: Self = i256 { + low: u128::MAX, + high: i128::MAX, + }; + + /// The minimum value that can be represented by this integer type + pub const MIN: Self = i256 { + low: u128::MIN, + high: i128::MIN, + }; + + /// Create an integer value from its representation as a byte array in little-endian. + #[inline] + pub const fn from_le_bytes(b: [u8; 32]) -> Self { + let (low, high) = split_array(b); + Self { + high: i128::from_le_bytes(high), + low: u128::from_le_bytes(low), + } + } + + /// Create an integer value from its representation as a byte array in big-endian. + #[inline] + pub const fn from_be_bytes(b: [u8; 32]) -> Self { + let (high, low) = split_array(b); + Self { + high: i128::from_be_bytes(high), + low: u128::from_be_bytes(low), + } + } + + pub const fn from_i128(v: i128) -> Self { + Self::from_parts(v as u128, v >> 127) + } + + /// Create an integer value from its representation as string. + #[inline] + pub fn from_string(value_str: &str) -> Option { + value_str.parse().ok() + } + + /// Create an optional i256 from the provided `f64`. Returning `None` + /// if overflow occurred + pub fn from_f64(v: f64) -> Option { + BigInt::from_f64(v).and_then(|i| { + let (integer, overflow) = i256::from_bigint_with_overflow(i); + if overflow { + None + } else { + Some(integer) + } + }) + } + + /// Create an i256 from the provided low u128 and high i128 + #[inline] + pub const fn from_parts(low: u128, high: i128) -> Self { + Self { low, high } + } + + /// Returns this `i256` as a low u128 and high i128 + pub const fn to_parts(self) -> (u128, i128) { + (self.low, self.high) + } + + /// Converts this `i256` into an `i128` returning `None` if this would result + /// in truncation/overflow + pub fn to_i128(self) -> Option { + let as_i128 = self.low as i128; + + let high_negative = self.high < 0; + let low_negative = as_i128 < 0; + let high_valid = self.high == -1 || self.high == 0; + + (high_negative == low_negative && high_valid).then_some(self.low as i128) + } + + /// Wraps this `i256` into an `i128` + pub fn as_i128(self) -> i128 { + self.low as i128 + } + + /// Return the memory representation of this integer as a byte array in little-endian byte order. + #[inline] + pub const fn to_le_bytes(self) -> [u8; 32] { + let low = self.low.to_le_bytes(); + let high = self.high.to_le_bytes(); + let mut t = [0; 32]; + let mut i = 0; + while i != 16 { + t[i] = low[i]; + t[i + 16] = high[i]; + i += 1; + } + t + } + + /// Return the memory representation of this integer as a byte array in big-endian byte order. + #[inline] + pub const fn to_be_bytes(self) -> [u8; 32] { + let low = self.low.to_be_bytes(); + let high = self.high.to_be_bytes(); + let mut t = [0; 32]; + let mut i = 0; + while i != 16 { + t[i] = high[i]; + t[i + 16] = low[i]; + i += 1; + } + t + } + + /// Create an i256 from the provided [`BigInt`] returning a bool indicating + /// if overflow occurred + fn from_bigint_with_overflow(v: BigInt) -> (Self, bool) { + let v_bytes = v.to_signed_bytes_le(); + match v_bytes.len().cmp(&32) { + Ordering::Less => { + let mut bytes = if num::Signed::is_negative(&v) { + [255_u8; 32] + } else { + [0; 32] + }; + bytes[0..v_bytes.len()].copy_from_slice(&v_bytes[..v_bytes.len()]); + (Self::from_le_bytes(bytes), false) + } + Ordering::Equal => (Self::from_le_bytes(v_bytes.try_into().unwrap()), false), + Ordering::Greater => (Self::from_le_bytes(v_bytes[..32].try_into().unwrap()), true), + } + } + + /// Computes the absolute value of this i256 + #[inline] + pub fn wrapping_abs(self) -> Self { + // -1 if negative, otherwise 0 + let sa = self.high >> 127; + let sa = Self::from_parts(sa as u128, sa); + + // Inverted if negative + Self::from_parts(self.low ^ sa.low, self.high ^ sa.high).wrapping_sub(sa) + } + + /// Computes the absolute value of this i256 returning `None` if `Self == Self::MIN` + #[inline] + pub fn checked_abs(self) -> Option { + (self != Self::MIN).then(|| self.wrapping_abs()) + } + + /// Negates this i256 + #[inline] + pub fn wrapping_neg(self) -> Self { + Self::from_parts(!self.low, !self.high).wrapping_add(i256::ONE) + } + + /// Negates this i256 returning `None` if `Self == Self::MIN` + #[inline] + pub fn checked_neg(self) -> Option { + (self != Self::MIN).then(|| self.wrapping_neg()) + } + + /// Performs wrapping addition + #[inline] + pub fn wrapping_add(self, other: Self) -> Self { + let (low, carry) = self.low.overflowing_add(other.low); + let high = self.high.wrapping_add(other.high).wrapping_add(carry as _); + Self { low, high } + } + + /// Performs checked addition + #[inline] + pub fn checked_add(self, other: Self) -> Option { + let r = self.wrapping_add(other); + ((other.is_negative() && r < self) || (!other.is_negative() && r >= self)).then_some(r) + } + + /// Performs wrapping subtraction + #[inline] + pub fn wrapping_sub(self, other: Self) -> Self { + let (low, carry) = self.low.overflowing_sub(other.low); + let high = self.high.wrapping_sub(other.high).wrapping_sub(carry as _); + Self { low, high } + } + + /// Performs checked subtraction + #[inline] + pub fn checked_sub(self, other: Self) -> Option { + let r = self.wrapping_sub(other); + ((other.is_negative() && r > self) || (!other.is_negative() && r <= self)).then_some(r) + } + + /// Performs wrapping multiplication + #[inline] + pub fn wrapping_mul(self, other: Self) -> Self { + let (low, high) = mulx(self.low, other.low); + + // Compute the high multiples, only impacting the high 128-bits + let hl = self.high.wrapping_mul(other.low as i128); + let lh = (self.low as i128).wrapping_mul(other.high); + + Self { + low, + high: (high as i128).wrapping_add(hl).wrapping_add(lh), + } + } + + /// Performs checked multiplication + #[inline] + pub fn checked_mul(self, other: Self) -> Option { + if self == i256::ZERO || other == i256::ZERO { + return Some(i256::ZERO); + } + + // Shift sign bit down to construct mask of all set bits if negative + let l_sa = self.high >> 127; + let r_sa = other.high >> 127; + let out_sa = (l_sa ^ r_sa) as u128; + + // Compute absolute values + let l_abs = self.wrapping_abs(); + let r_abs = other.wrapping_abs(); + + // Overflow if both high parts are non-zero + if l_abs.high != 0 && r_abs.high != 0 { + return None; + } + + // Perform checked multiplication on absolute values + let (low, high) = mulx(l_abs.low, r_abs.low); + + // Compute the high multiples, only impacting the high 128-bits + let hl = (l_abs.high as u128).checked_mul(r_abs.low)?; + let lh = l_abs.low.checked_mul(r_abs.high as u128)?; + + let high = high.checked_add(hl)?.checked_add(lh)?; + + // Reverse absolute value, if necessary + let (low, c) = (low ^ out_sa).overflowing_sub(out_sa); + let high = (high ^ out_sa).wrapping_sub(out_sa).wrapping_sub(c as u128) as i128; + + // Check for overflow in final conversion + (high.is_negative() == (self.is_negative() ^ other.is_negative())) + .then_some(Self { low, high }) + } + + /// Division operation, returns (quotient, remainder). + /// This basically implements [Long division]: `` + #[inline] + fn div_rem(self, other: Self) -> Result<(Self, Self), DivRemError> { + if other == Self::ZERO { + return Err(DivRemError::DivideByZero); + } + if other == Self::MINUS_ONE && self == Self::MIN { + return Err(DivRemError::DivideOverflow); + } + + let a = self.wrapping_abs(); + let b = other.wrapping_abs(); + + let (div, rem) = div_rem(&a.as_digits(), &b.as_digits()); + let div = Self::from_digits(div); + let rem = Self::from_digits(rem); + + Ok(( + if self.is_negative() == other.is_negative() { + div + } else { + div.wrapping_neg() + }, + if self.is_negative() { + rem.wrapping_neg() + } else { + rem + }, + )) + } + + /// Interpret this [`i256`] as 4 `u64` digits, least significant first + fn as_digits(self) -> [u64; 4] { + [ + self.low as u64, + (self.low >> 64) as u64, + self.high as u64, + (self.high as u128 >> 64) as u64, + ] + } + + /// Interpret 4 `u64` digits, least significant first, as a [`i256`] + fn from_digits(digits: [u64; 4]) -> Self { + Self::from_parts( + digits[0] as u128 | (digits[1] as u128) << 64, + digits[2] as i128 | (digits[3] as i128) << 64, + ) + } + + /// Performs wrapping division + #[inline] + pub fn wrapping_div(self, other: Self) -> Self { + match self.div_rem(other) { + Ok((v, _)) => v, + Err(DivRemError::DivideByZero) => panic!("attempt to divide by zero"), + Err(_) => Self::MIN, + } + } + + /// Performs checked division + #[inline] + pub fn checked_div(self, other: Self) -> Option { + self.div_rem(other).map(|(v, _)| v).ok() + } + + /// Performs wrapping remainder + #[inline] + pub fn wrapping_rem(self, other: Self) -> Self { + match self.div_rem(other) { + Ok((_, v)) => v, + Err(DivRemError::DivideByZero) => panic!("attempt to divide by zero"), + Err(_) => Self::ZERO, + } + } + + /// Performs checked remainder + #[inline] + pub fn checked_rem(self, other: Self) -> Option { + self.div_rem(other).map(|(_, v)| v).ok() + } + + /// Performs checked exponentiation + #[inline] + pub fn checked_pow(self, mut exp: u32) -> Option { + if exp == 0 { + return Some(i256::from_i128(1)); + } + + let mut base = self; + let mut acc: Self = i256::from_i128(1); + + while exp > 1 { + if (exp & 1) == 1 { + acc = acc.checked_mul(base)?; + } + exp /= 2; + base = base.checked_mul(base)?; + } + // since exp!=0, finally the exp must be 1. + // Deal with the final bit of the exponent separately, since + // squaring the base afterwards is not necessary and may cause a + // needless overflow. + acc.checked_mul(base) + } + + /// Performs wrapping exponentiation + #[inline] + pub fn wrapping_pow(self, mut exp: u32) -> Self { + if exp == 0 { + return i256::from_i128(1); + } + + let mut base = self; + let mut acc: Self = i256::from_i128(1); + + while exp > 1 { + if (exp & 1) == 1 { + acc = acc.wrapping_mul(base); + } + exp /= 2; + base = base.wrapping_mul(base); + } + + // since exp!=0, finally the exp must be 1. + // Deal with the final bit of the exponent separately, since + // squaring the base afterwards is not necessary and may cause a + // needless overflow. + acc.wrapping_mul(base) + } + + /// Returns a number [`i256`] representing sign of this [`i256`]. + /// + /// 0 if the number is zero + /// 1 if the number is positive + /// -1 if the number is negative + pub const fn signum(self) -> Self { + if self.is_positive() { + i256::ONE + } else if self.is_negative() { + i256::MINUS_ONE + } else { + i256::ZERO + } + } + + /// Returns `true` if this [`i256`] is negative + #[inline] + pub const fn is_negative(self) -> bool { + self.high.is_negative() + } + + /// Returns `true` if this [`i256`] is positive + pub const fn is_positive(self) -> bool { + self.high.is_positive() || self.high == 0 && self.low != 0 + } +} + +/// Temporary workaround due to lack of stable const array slicing +/// See +const fn split_array(vals: [u8; N]) -> ([u8; M], [u8; M]) { + let mut a = [0; M]; + let mut b = [0; M]; + let mut i = 0; + while i != M { + a[i] = vals[i]; + b[i] = vals[i + M]; + i += 1; + } + (a, b) +} + +/// Performs an unsigned multiplication of `a * b` returning a tuple of +/// `(low, high)` where `low` contains the lower 128-bits of the result +/// and `high` the higher 128-bits +/// +/// This mirrors the x86 mulx instruction but for 128-bit types +#[inline] +fn mulx(a: u128, b: u128) -> (u128, u128) { + let split = |a: u128| (a & (u64::MAX as u128), a >> 64); + + const MASK: u128 = u64::MAX as _; + + let (a_low, a_high) = split(a); + let (b_low, b_high) = split(b); + + // Carry stores the upper 64-bits of low and lower 64-bits of high + let (mut low, mut carry) = split(a_low * b_low); + carry += a_high * b_low; + + // Update low and high with corresponding parts of carry + low += carry << 64; + let mut high = carry >> 64; + + // Update carry with overflow from low + carry = low >> 64; + low &= MASK; + + // Perform multiply including overflow from low + carry += b_high * a_low; + + // Update low and high with values from carry + low += carry << 64; + high += carry >> 64; + + // Perform 4th multiplication + high += a_high * b_high; + + (low, high) +} + +macro_rules! derive_op { + ($t:ident, $op:ident, $wrapping:ident, $checked:ident) => { + impl std::ops::$t for i256 { + type Output = i256; + + #[cfg(debug_assertions)] + fn $op(self, rhs: Self) -> Self::Output { + self.$checked(rhs).expect("i256 overflow") + } + + #[cfg(not(debug_assertions))] + fn $op(self, rhs: Self) -> Self::Output { + self.$wrapping(rhs) + } + } + + impl<'a> std::ops::$t for &'a i256 { + type Output = i256; + + fn $op(self, rhs: i256) -> Self::Output { + (*self).$op(rhs) + } + } + + impl<'a> std::ops::$t<&'a i256> for i256 { + type Output = i256; + + fn $op(self, rhs: &'a i256) -> Self::Output { + self.$op(*rhs) + } + } + + impl<'a, 'b> std::ops::$t<&'b i256> for &'a i256 { + type Output = i256; + + fn $op(self, rhs: &'b i256) -> Self::Output { + (*self).$op(*rhs) + } + } + }; +} + +derive_op!(Add, add, wrapping_add, checked_add); +derive_op!(Sub, sub, wrapping_sub, checked_sub); +derive_op!(Mul, mul, wrapping_mul, checked_mul); +derive_op!(Div, div, wrapping_div, checked_div); +derive_op!(Rem, rem, wrapping_rem, checked_rem); + +impl std::ops::Neg for i256 { + type Output = i256; + + #[cfg(debug_assertions)] + fn neg(self) -> Self::Output { + self.checked_neg().expect("i256 overflow") + } + + #[cfg(not(debug_assertions))] + fn neg(self) -> Self::Output { + self.wrapping_neg() + } +} + +impl BitAnd for i256 { + type Output = i256; + + #[inline] + fn bitand(self, rhs: Self) -> Self::Output { + Self { + low: self.low & rhs.low, + high: self.high & rhs.high, + } + } +} + +impl BitOr for i256 { + type Output = i256; + + #[inline] + fn bitor(self, rhs: Self) -> Self::Output { + Self { + low: self.low | rhs.low, + high: self.high | rhs.high, + } + } +} + +impl BitXor for i256 { + type Output = i256; + + #[inline] + fn bitxor(self, rhs: Self) -> Self::Output { + Self { + low: self.low ^ rhs.low, + high: self.high ^ rhs.high, + } + } +} + +impl Shl for i256 { + type Output = i256; + + #[inline] + fn shl(self, rhs: u8) -> Self::Output { + if rhs == 0 { + self + } else if rhs < 128 { + Self { + high: self.high << rhs | (self.low >> (128 - rhs)) as i128, + low: self.low << rhs, + } + } else { + Self { + high: (self.low << (rhs - 128)) as i128, + low: 0, + } + } + } +} + +impl Shr for i256 { + type Output = i256; + + #[inline] + fn shr(self, rhs: u8) -> Self::Output { + if rhs == 0 { + self + } else if rhs < 128 { + Self { + high: self.high >> rhs, + low: self.low >> rhs | ((self.high as u128) << (128 - rhs)), + } + } else { + Self { + high: self.high >> 127, + low: (self.high >> (rhs - 128)) as u128, + } + } + } +} + +macro_rules! define_as_primitive { + ($native_ty:ty) => { + impl AsPrimitive for $native_ty { + fn as_(self) -> i256 { + i256::from_i128(self as i128) + } + } + }; +} + +define_as_primitive!(i8); +define_as_primitive!(i16); +define_as_primitive!(i32); +define_as_primitive!(i64); +define_as_primitive!(u8); +define_as_primitive!(u16); +define_as_primitive!(u32); +define_as_primitive!(u64); + +impl ToPrimitive for i256 { + fn to_i64(&self) -> Option { + let as_i128 = self.low as i128; + + let high_negative = self.high < 0; + let low_negative = as_i128 < 0; + let high_valid = self.high == -1 || self.high == 0; + + if high_negative == low_negative && high_valid { + let (low_bytes, high_bytes) = split_array(u128::to_le_bytes(self.low)); + let high = i64::from_le_bytes(high_bytes); + let low = i64::from_le_bytes(low_bytes); + + let high_negative = high < 0; + let low_negative = low < 0; + let high_valid = self.high == -1 || self.high == 0; + + (high_negative == low_negative && high_valid).then_some(low) + } else { + None + } + } + + fn to_u64(&self) -> Option { + let as_i128 = self.low as i128; + + let high_negative = self.high < 0; + let low_negative = as_i128 < 0; + let high_valid = self.high == -1 || self.high == 0; + + if high_negative == low_negative && high_valid { + self.low.to_u64() + } else { + None + } + } +} + +#[cfg(all(test, not(miri)))] // llvm.x86.subborrow.64 not supported by MIRI +mod tests { + use super::*; + use num::{BigInt, FromPrimitive, Signed, ToPrimitive}; + use rand::{thread_rng, Rng}; + use std::ops::Neg; + + #[test] + fn test_signed_cmp() { + let a = i256::from_parts(i128::MAX as u128, 12); + let b = i256::from_parts(i128::MIN as u128, 12); + assert!(a < b); + + let a = i256::from_parts(i128::MAX as u128, 12); + let b = i256::from_parts(i128::MIN as u128, -12); + assert!(a > b); + } + + #[test] + fn test_to_i128() { + let vals = [ + BigInt::from_i128(-1).unwrap(), + BigInt::from_i128(i128::MAX).unwrap(), + BigInt::from_i128(i128::MIN).unwrap(), + BigInt::from_u128(u128::MIN).unwrap(), + BigInt::from_u128(u128::MAX).unwrap(), + ]; + + for v in vals { + let (t, overflow) = i256::from_bigint_with_overflow(v.clone()); + assert!(!overflow); + assert_eq!(t.to_i128(), v.to_i128(), "{v} vs {t}"); + } + } + + /// Tests operations against the two provided [`i256`] + fn test_ops(il: i256, ir: i256) { + let bl = BigInt::from_signed_bytes_le(&il.to_le_bytes()); + let br = BigInt::from_signed_bytes_le(&ir.to_le_bytes()); + + // Comparison + assert_eq!(il.cmp(&ir), bl.cmp(&br), "{bl} cmp {br}"); + + // Conversions + assert_eq!(i256::from_le_bytes(il.to_le_bytes()), il); + assert_eq!(i256::from_be_bytes(il.to_be_bytes()), il); + assert_eq!(i256::from_le_bytes(ir.to_le_bytes()), ir); + assert_eq!(i256::from_be_bytes(ir.to_be_bytes()), ir); + + // To i128 + assert_eq!(il.to_i128(), bl.to_i128(), "{bl}"); + assert_eq!(ir.to_i128(), br.to_i128(), "{br}"); + + // Absolute value + let (abs, overflow) = i256::from_bigint_with_overflow(bl.abs()); + assert_eq!(il.wrapping_abs(), abs); + assert_eq!(il.checked_abs().is_none(), overflow); + + let (abs, overflow) = i256::from_bigint_with_overflow(br.abs()); + assert_eq!(ir.wrapping_abs(), abs); + assert_eq!(ir.checked_abs().is_none(), overflow); + + // Negation + let (neg, overflow) = i256::from_bigint_with_overflow(bl.clone().neg()); + assert_eq!(il.wrapping_neg(), neg); + assert_eq!(il.checked_neg().is_none(), overflow); + + // Negation + let (neg, overflow) = i256::from_bigint_with_overflow(br.clone().neg()); + assert_eq!(ir.wrapping_neg(), neg); + assert_eq!(ir.checked_neg().is_none(), overflow); + + // Addition + let actual = il.wrapping_add(ir); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() + br.clone()); + assert_eq!(actual, expected); + + let checked = il.checked_add(ir); + match overflow { + true => assert!(checked.is_none()), + false => assert_eq!(checked, Some(actual)), + } + + // Subtraction + let actual = il.wrapping_sub(ir); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() - br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let checked = il.checked_sub(ir); + match overflow { + true => assert!(checked.is_none()), + false => assert_eq!(checked, Some(actual), "{bl} - {br} = {expected}"), + } + + // Multiplication + let actual = il.wrapping_mul(ir); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone() * br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let checked = il.checked_mul(ir); + match overflow { + true => assert!( + checked.is_none(), + "{il} * {ir} = {actual} vs {bl} * {br} = {expected}" + ), + false => assert_eq!( + checked, + Some(actual), + "{il} * {ir} = {actual} vs {bl} * {br} = {expected}" + ), + } + + // Division + if ir != i256::ZERO { + let actual = il.wrapping_div(ir); + let expected = bl.clone() / br.clone(); + let checked = il.checked_div(ir); + + if ir == i256::MINUS_ONE && il == i256::MIN { + // BigInt produces an integer over i256::MAX + assert_eq!(actual, i256::MIN); + assert!(checked.is_none()); + } else { + assert_eq!(actual.to_string(), expected.to_string()); + assert_eq!(checked.unwrap().to_string(), expected.to_string()); + } + } else { + // `wrapping_div` panics on division by zero + assert!(il.checked_div(ir).is_none()); + } + + // Remainder + if ir != i256::ZERO { + let actual = il.wrapping_rem(ir); + let expected = bl.clone() % br.clone(); + let checked = il.checked_rem(ir); + + assert_eq!(actual.to_string(), expected.to_string(), "{il} % {ir}"); + + if ir == i256::MINUS_ONE && il == i256::MIN { + assert!(checked.is_none()); + } else { + assert_eq!(checked.unwrap().to_string(), expected.to_string()); + } + } else { + // `wrapping_rem` panics on division by zero + assert!(il.checked_rem(ir).is_none()); + } + + // Exponentiation + for exp in vec![0, 1, 2, 3, 8, 100].into_iter() { + let actual = il.wrapping_pow(exp); + let (expected, overflow) = i256::from_bigint_with_overflow(bl.clone().pow(exp)); + assert_eq!(actual.to_string(), expected.to_string()); + + let checked = il.checked_pow(exp); + match overflow { + true => assert!( + checked.is_none(), + "{il} ^ {exp} = {actual} vs {bl} * {exp} = {expected}" + ), + false => assert_eq!( + checked, + Some(actual), + "{il} ^ {exp} = {actual} vs {bl} ^ {exp} = {expected}" + ), + } + } + + // Bit operations + let actual = il & ir; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() & br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let actual = il | ir; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() | br.clone()); + assert_eq!(actual.to_string(), expected.to_string()); + + let actual = il ^ ir; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() ^ br); + assert_eq!(actual.to_string(), expected.to_string()); + + for shift in [0_u8, 1, 4, 126, 128, 129, 254, 255] { + let actual = il << shift; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() << shift); + assert_eq!(actual.to_string(), expected.to_string()); + + let actual = il >> shift; + let (expected, _) = i256::from_bigint_with_overflow(bl.clone() >> shift); + assert_eq!(actual.to_string(), expected.to_string()); + } + } + + #[test] + fn test_i256() { + let candidates = [ + i256::ZERO, + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(2), + i256::from_i128(-2), + i256::from_parts(u128::MAX, 1), + i256::from_parts(u128::MAX, -1), + i256::from_parts(0, 1), + i256::from_parts(0, -1), + i256::from_parts(1, -1), + i256::from_parts(1, 1), + i256::from_parts(0, i128::MAX), + i256::from_parts(0, i128::MIN), + i256::from_parts(1, i128::MAX), + i256::from_parts(1, i128::MIN), + i256::from_parts(u128::MAX, i128::MIN), + i256::from_parts(100, 32), + i256::MIN, + i256::MAX, + i256::MIN >> 1, + i256::MAX >> 1, + i256::ONE << 127, + i256::ONE << 128, + i256::ONE << 129, + i256::MINUS_ONE << 127, + i256::MINUS_ONE << 128, + i256::MINUS_ONE << 129, + ]; + + for il in candidates { + for ir in candidates { + test_ops(il, ir) + } + } + } + + #[test] + fn test_signed_ops() { + // signum + assert_eq!(i256::from_i128(1).signum(), i256::ONE); + assert_eq!(i256::from_i128(0).signum(), i256::ZERO); + assert_eq!(i256::from_i128(-0).signum(), i256::ZERO); + assert_eq!(i256::from_i128(-1).signum(), i256::MINUS_ONE); + + // is_positive + assert!(i256::from_i128(1).is_positive()); + assert!(!i256::from_i128(0).is_positive()); + assert!(!i256::from_i128(-0).is_positive()); + assert!(!i256::from_i128(-1).is_positive()); + + // is_negative + assert!(!i256::from_i128(1).is_negative()); + assert!(!i256::from_i128(0).is_negative()); + assert!(!i256::from_i128(-0).is_negative()); + assert!(i256::from_i128(-1).is_negative()); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_i256_fuzz() { + let mut rng = thread_rng(); + + for _ in 0..1000 { + let mut l = [0_u8; 32]; + let len = rng.gen_range(0..32); + l.iter_mut().take(len).for_each(|x| *x = rng.gen()); + + let mut r = [0_u8; 32]; + let len = rng.gen_range(0..32); + r.iter_mut().take(len).for_each(|x| *x = rng.gen()); + + test_ops(i256::from_le_bytes(l), i256::from_le_bytes(r)) + } + } + + #[test] + fn test_i256_to_primitive() { + let a = i256::MAX; + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i128::MAX); + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i64::MAX as i128); + assert_eq!(a.to_i64().unwrap(), i64::MAX); + assert_eq!(a.to_u64().unwrap(), i64::MAX as u64); + + let a = i256::from_i128(i64::MAX as i128 + 1); + assert!(a.to_i64().is_none()); + assert_eq!(a.to_u64().unwrap(), i64::MAX as u64 + 1); + + let a = i256::MIN; + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i128::MIN); + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i64::MIN as i128); + assert_eq!(a.to_i64().unwrap(), i64::MIN); + assert!(a.to_u64().is_none()); + + let a = i256::from_i128(i64::MIN as i128 - 1); + assert!(a.to_i64().is_none()); + assert!(a.to_u64().is_none()); + } + + #[test] + fn test_i256_as_i128() { + let a = i256::from_i128(i128::MAX).wrapping_add(i256::from_i128(1)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MIN); + + let a = i256::from_i128(i128::MAX).wrapping_add(i256::from_i128(2)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MIN + 1); + + let a = i256::from_i128(i128::MIN).wrapping_sub(i256::from_i128(1)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MAX); + + let a = i256::from_i128(i128::MIN).wrapping_sub(i256::from_i128(2)); + let i128 = a.as_i128(); + assert_eq!(i128, i128::MAX - 1); + } + + #[test] + fn test_string_roundtrip() { + let roundtrip_cases = [ + i256::ZERO, + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(123456789), + i256::from_i128(-123456789), + i256::from_i128(i128::MIN), + i256::from_i128(i128::MAX), + i256::MIN, + i256::MAX, + ]; + for case in roundtrip_cases { + let formatted = case.to_string(); + let back: i256 = formatted.parse().unwrap(); + assert_eq!(case, back); + } + } + + #[test] + fn test_from_string() { + let cases = [ + ( + "000000000000000000000000000000000000000011", + Some(i256::from_i128(11)), + ), + ( + "-000000000000000000000000000000000000000011", + Some(i256::from_i128(-11)), + ), + ( + "-0000000000000000000000000000000000000000123456789", + Some(i256::from_i128(-123456789)), + ), + ("-", None), + ("+", None), + ("--1", None), + ("-+1", None), + ("000000000000000000000000000000000000000", Some(i256::ZERO)), + ("0000000000000000000000000000000000000000-11", None), + ("11-1111111111111111111111111111111111111", None), + ( + "115792089237316195423570985008687907853269984665640564039457584007913129639936", + None, + ), + ]; + for (case, expected) in cases { + assert_eq!(i256::from_string(case), expected) + } + } + + #[allow(clippy::op_ref)] + fn test_reference_op(il: i256, ir: i256) { + let r1 = il + ir; + let r2 = &il + ir; + let r3 = il + &ir; + let r4 = &il + &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il - ir; + let r2 = &il - ir; + let r3 = il - &ir; + let r4 = &il - &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il * ir; + let r2 = &il * ir; + let r3 = il * &ir; + let r4 = &il * &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + + let r1 = il / ir; + let r2 = &il / ir; + let r3 = il / &ir; + let r4 = &il / &ir; + assert_eq!(r1, r2); + assert_eq!(r1, r3); + assert_eq!(r1, r4); + } + + #[test] + fn test_i256_reference_op() { + let candidates = [ + i256::ONE, + i256::MINUS_ONE, + i256::from_i128(2), + i256::from_i128(-2), + i256::from_i128(3), + i256::from_i128(-3), + ]; + + for il in candidates { + for ir in candidates { + test_reference_op(il, ir) + } + } + } +} diff --git a/arrow-buffer/src/buffer/boolean.rs b/arrow-buffer/src/buffer/boolean.rs new file mode 100644 index 000000000000..1589cc5b102b --- /dev/null +++ b/arrow-buffer/src/buffer/boolean.rs @@ -0,0 +1,417 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_chunk_iterator::BitChunks; +use crate::bit_iterator::{BitIndexIterator, BitIterator, BitSliceIterator}; +use crate::{ + bit_util, buffer_bin_and, buffer_bin_or, buffer_bin_xor, buffer_unary_not, + BooleanBufferBuilder, Buffer, MutableBuffer, +}; +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +/// A slice-able [`Buffer`] containing bit-packed booleans +#[derive(Debug, Clone, Eq)] +pub struct BooleanBuffer { + buffer: Buffer, + offset: usize, + len: usize, +} + +impl PartialEq for BooleanBuffer { + fn eq(&self, other: &Self) -> bool { + if self.len != other.len { + return false; + } + + let lhs = self.bit_chunks().iter_padded(); + let rhs = other.bit_chunks().iter_padded(); + lhs.zip(rhs).all(|(a, b)| a == b) + } +} + +impl BooleanBuffer { + /// Create a new [`BooleanBuffer`] from a [`Buffer`], an `offset` and `length` in bits + /// + /// # Panics + /// + /// This method will panic if `buffer` is not large enough + pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { + let total_len = offset.saturating_add(len); + let bit_len = buffer.len().saturating_mul(8); + assert!(total_len <= bit_len); + Self { + buffer, + offset, + len, + } + } + + /// Create a new [`BooleanBuffer`] of `length` where all values are `true` + pub fn new_set(length: usize) -> Self { + let mut builder = BooleanBufferBuilder::new(length); + builder.append_n(length, true); + builder.finish() + } + + /// Create a new [`BooleanBuffer`] of `length` where all values are `false` + pub fn new_unset(length: usize) -> Self { + let buffer = MutableBuffer::new_null(length).into_buffer(); + Self { + buffer, + offset: 0, + len: length, + } + } + + /// Invokes `f` with indexes `0..len` collecting the boolean results into a new `BooleanBuffer` + pub fn collect_bool bool>(len: usize, f: F) -> Self { + let buffer = MutableBuffer::collect_bool(len, f); + Self::new(buffer.into(), 0, len) + } + + /// Returns the number of set bits in this buffer + pub fn count_set_bits(&self) -> usize { + self.buffer.count_set_bits_offset(self.offset, self.len) + } + + /// Returns a `BitChunks` instance which can be used to iterate over + /// this buffer's bits in `u64` chunks + #[inline] + pub fn bit_chunks(&self) -> BitChunks { + BitChunks::new(self.values(), self.offset, self.len) + } + + /// Returns `true` if the bit at index `i` is set + /// + /// # Panics + /// + /// Panics if `i >= self.len()` + #[inline] + #[deprecated(note = "use BooleanBuffer::value")] + pub fn is_set(&self, i: usize) -> bool { + self.value(i) + } + + /// Returns the offset of this [`BooleanBuffer`] in bits + #[inline] + pub fn offset(&self) -> usize { + self.offset + } + + /// Returns the length of this [`BooleanBuffer`] in bits + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Returns true if this [`BooleanBuffer`] is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the boolean value at index `i`. + /// + /// # Panics + /// + /// Panics if `i >= self.len()` + #[inline] + pub fn value(&self, idx: usize) -> bool { + assert!(idx < self.len); + unsafe { self.value_unchecked(idx) } + } + + /// Returns the boolean value at index `i`. + /// + /// # Safety + /// This doesn't check bounds, the caller must ensure that index < self.len() + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> bool { + unsafe { bit_util::get_bit_raw(self.buffer.as_ptr(), i + self.offset) } + } + + /// Returns the packed values of this [`BooleanBuffer`] not including any offset + #[inline] + pub fn values(&self) -> &[u8] { + &self.buffer + } + + /// Slices this [`BooleanBuffer`] by the provided `offset` and `length` + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced BooleanBuffer cannot exceed the existing length" + ); + Self { + buffer: self.buffer.clone(), + offset: self.offset + offset, + len, + } + } + + /// Returns a [`Buffer`] containing the sliced contents of this [`BooleanBuffer`] + /// + /// Equivalent to `self.buffer.bit_slice(self.offset, self.len)` + pub fn sliced(&self) -> Buffer { + self.buffer.bit_slice(self.offset, self.len) + } + + /// Returns true if this [`BooleanBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + pub fn ptr_eq(&self, other: &Self) -> bool { + self.buffer.as_ptr() == other.buffer.as_ptr() + && self.offset == other.offset + && self.len == other.len + } + + /// Returns the inner [`Buffer`] + #[inline] + pub fn inner(&self) -> &Buffer { + &self.buffer + } + + /// Returns the inner [`Buffer`], consuming self + pub fn into_inner(self) -> Buffer { + self.buffer + } + + /// Returns an iterator over the bits in this [`BooleanBuffer`] + pub fn iter(&self) -> BitIterator<'_> { + self.into_iter() + } + + /// Returns an iterator over the set bit positions in this [`BooleanBuffer`] + pub fn set_indices(&self) -> BitIndexIterator<'_> { + BitIndexIterator::new(self.values(), self.offset, self.len) + } + + /// Returns a [`BitSliceIterator`] yielding contiguous ranges of set bits + pub fn set_slices(&self) -> BitSliceIterator<'_> { + BitSliceIterator::new(self.values(), self.offset, self.len) + } +} + +impl Not for &BooleanBuffer { + type Output = BooleanBuffer; + + fn not(self) -> Self::Output { + BooleanBuffer { + buffer: buffer_unary_not(&self.buffer, self.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl BitAnd<&BooleanBuffer> for &BooleanBuffer { + type Output = BooleanBuffer; + + fn bitand(self, rhs: &BooleanBuffer) -> Self::Output { + assert_eq!(self.len, rhs.len); + BooleanBuffer { + buffer: buffer_bin_and(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl BitOr<&BooleanBuffer> for &BooleanBuffer { + type Output = BooleanBuffer; + + fn bitor(self, rhs: &BooleanBuffer) -> Self::Output { + assert_eq!(self.len, rhs.len); + BooleanBuffer { + buffer: buffer_bin_or(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl BitXor<&BooleanBuffer> for &BooleanBuffer { + type Output = BooleanBuffer; + + fn bitxor(self, rhs: &BooleanBuffer) -> Self::Output { + assert_eq!(self.len, rhs.len); + BooleanBuffer { + buffer: buffer_bin_xor(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), + offset: 0, + len: self.len, + } + } +} + +impl<'a> IntoIterator for &'a BooleanBuffer { + type Item = bool; + type IntoIter = BitIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitIterator::new(self.values(), self.offset, self.len) + } +} + +impl From<&[bool]> for BooleanBuffer { + fn from(value: &[bool]) -> Self { + let mut builder = BooleanBufferBuilder::new(value.len()); + builder.append_slice(value); + builder.finish() + } +} + +impl From> for BooleanBuffer { + fn from(value: Vec) -> Self { + value.as_slice().into() + } +} + +impl FromIterator for BooleanBuffer { + fn from_iter>(iter: T) -> Self { + let iter = iter.into_iter(); + let (hint, _) = iter.size_hint(); + let mut builder = BooleanBufferBuilder::new(hint); + iter.for_each(|b| builder.append(b)); + builder.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_boolean_new() { + let bytes = &[0, 1, 2, 3, 4]; + let buf = Buffer::from(bytes); + let offset = 0; + let len = 24; + + let boolean_buf = BooleanBuffer::new(buf.clone(), offset, len); + assert_eq!(bytes, boolean_buf.values()); + assert_eq!(offset, boolean_buf.offset()); + assert_eq!(len, boolean_buf.len()); + + assert_eq!(2, boolean_buf.count_set_bits()); + assert_eq!(&buf, boolean_buf.inner()); + assert_eq!(buf, boolean_buf.clone().into_inner()); + + assert!(!boolean_buf.is_empty()) + } + + #[test] + fn test_boolean_data_equality() { + let boolean_buf1 = BooleanBuffer::new(Buffer::from(&[0, 1, 4, 3, 5]), 0, 32); + let boolean_buf2 = BooleanBuffer::new(Buffer::from(&[0, 1, 4, 3, 5]), 0, 32); + assert_eq!(boolean_buf1, boolean_buf2); + + // slice with same offset and same length should still preserve equality + let boolean_buf3 = boolean_buf1.slice(8, 16); + assert_ne!(boolean_buf1, boolean_buf3); + let boolean_buf4 = boolean_buf1.slice(0, 32); + assert_eq!(boolean_buf1, boolean_buf4); + + // unequal because of different elements + let boolean_buf2 = BooleanBuffer::new(Buffer::from(&[0, 0, 2, 3, 4]), 0, 32); + assert_ne!(boolean_buf1, boolean_buf2); + + // unequal because of different length + let boolean_buf2 = BooleanBuffer::new(Buffer::from(&[0, 1, 4, 3, 5]), 0, 24); + assert_ne!(boolean_buf1, boolean_buf2); + + // ptr_eq + assert!(boolean_buf1.ptr_eq(&boolean_buf1)); + assert!(boolean_buf2.ptr_eq(&boolean_buf2)); + assert!(!boolean_buf1.ptr_eq(&boolean_buf2)); + } + + #[test] + fn test_boolean_slice() { + let bytes = &[0, 3, 2, 6, 2]; + let boolean_buf1 = BooleanBuffer::new(Buffer::from(bytes), 0, 32); + let boolean_buf2 = BooleanBuffer::new(Buffer::from(bytes), 0, 32); + + let boolean_slice1 = boolean_buf1.slice(16, 16); + let boolean_slice2 = boolean_buf2.slice(0, 16); + assert_eq!(boolean_slice1.values(), boolean_slice2.values()); + + assert_eq!(bytes, boolean_slice1.values()); + assert_eq!(16, boolean_slice1.offset); + assert_eq!(16, boolean_slice1.len); + + assert_eq!(bytes, boolean_slice2.values()); + assert_eq!(0, boolean_slice2.offset); + assert_eq!(16, boolean_slice2.len); + } + + #[test] + fn test_boolean_bitand() { + let offset = 0; + let len = 40; + + let buf1 = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf1 = &BooleanBuffer::new(buf1, offset, len); + + let buf2 = Buffer::from(&[0, 1, 1, 1, 0]); + let boolean_buf2 = &BooleanBuffer::new(buf2, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[0, 1, 1, 0, 0]), offset, len); + assert_eq!(boolean_buf1 & boolean_buf2, expected); + } + + #[test] + fn test_boolean_bitor() { + let offset = 0; + let len = 40; + + let buf1 = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf1 = &BooleanBuffer::new(buf1, offset, len); + + let buf2 = Buffer::from(&[0, 1, 1, 1, 0]); + let boolean_buf2 = &BooleanBuffer::new(buf2, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[0, 1, 1, 1, 0]), offset, len); + assert_eq!(boolean_buf1 | boolean_buf2, expected); + } + + #[test] + fn test_boolean_bitxor() { + let offset = 0; + let len = 40; + + let buf1 = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf1 = &BooleanBuffer::new(buf1, offset, len); + + let buf2 = Buffer::from(&[0, 1, 1, 1, 0]); + let boolean_buf2 = &BooleanBuffer::new(buf2, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[0, 0, 0, 1, 0]), offset, len); + assert_eq!(boolean_buf1 ^ boolean_buf2, expected); + } + + #[test] + fn test_boolean_not() { + let offset = 0; + let len = 40; + + let buf = Buffer::from(&[0, 1, 1, 0, 0]); + let boolean_buf = &BooleanBuffer::new(buf, offset, len); + + let expected = BooleanBuffer::new(Buffer::from(&[255, 254, 254, 255, 255]), offset, len); + assert_eq!(!boolean_buf, expected); + } +} diff --git a/arrow/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs similarity index 64% rename from arrow/src/buffer/immutable.rs rename to arrow-buffer/src/buffer/immutable.rs index 28042a3817be..9db8732f3611 100644 --- a/arrow/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -15,55 +15,79 @@ // specific language governing permissions and limitations // under the License. +use std::alloc::Layout; use std::fmt::Debug; use std::iter::FromIterator; use std::ptr::NonNull; use std::sync::Arc; -use std::{convert::AsRef, usize}; -use crate::alloc::{Allocation, Deallocation}; +use crate::alloc::{Allocation, Deallocation, ALIGNMENT}; use crate::util::bit_chunk_iterator::{BitChunks, UnalignedBitChunk}; -use crate::{bytes::Bytes, datatypes::ArrowNativeType}; +use crate::BufferBuilder; +use crate::{bytes::Bytes, native::ArrowNativeType}; use super::ops::bitwise_unary_op_helper; use super::MutableBuffer; /// Buffer represents a contiguous memory region that can be shared with other buffers and across /// thread boundaries. -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, Debug)] pub struct Buffer { /// the internal byte buffer. data: Arc, - /// The offset into the buffer. - offset: usize, + /// Pointer into `data` valid + /// + /// We store a pointer instead of an offset to avoid pointer arithmetic + /// which causes LLVM to fail to vectorise code correctly + ptr: *const u8, /// Byte length of the buffer. + /// + /// Must be less than or equal to `data.len()` length: usize, } +impl PartialEq for Buffer { + fn eq(&self, other: &Self) -> bool { + self.as_slice().eq(other.as_slice()) + } +} + +impl Eq for Buffer {} + +unsafe impl Send for Buffer where Bytes: Send {} +unsafe impl Sync for Buffer where Bytes: Sync {} + impl Buffer { /// Auxiliary method to create a new Buffer #[inline] pub fn from_bytes(bytes: Bytes) -> Self { let length = bytes.len(); + let ptr = bytes.as_ptr(); Buffer { data: Arc::new(bytes), - offset: 0, + ptr, length, } } + /// Create a [`Buffer`] from the provided [`Vec`] without copying + #[inline] + pub fn from_vec(vec: Vec) -> Self { + MutableBuffer::from(vec).into() + } + /// Initializes a [Buffer] from a slice of items. - pub fn from_slice_ref>(items: &T) -> Self { + pub fn from_slice_ref>(items: T) -> Self { let slice = items.as_ref(); - let capacity = slice.len() * std::mem::size_of::(); + let capacity = std::mem::size_of_val(slice); let mut buffer = MutableBuffer::with_capacity(capacity); buffer.extend_from_slice(slice); buffer.into() } - /// Creates a buffer from an existing memory region (must already be byte-aligned), this + /// Creates a buffer from an existing aligned memory region (must already be byte-aligned), this /// `Buffer` will free this piece of memory when dropped. /// /// # Arguments @@ -76,9 +100,11 @@ impl Buffer { /// /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. + #[deprecated(note = "Use Buffer::from_vec")] pub unsafe fn from_raw_parts(ptr: NonNull, len: usize, capacity: usize) -> Self { assert!(len <= capacity); - Buffer::build_with_arguments(ptr, len, Deallocation::Arrow(capacity)) + let layout = Layout::from_size_align(capacity, ALIGNMENT).unwrap(); + Buffer::build_with_arguments(ptr, len, Deallocation::Standard(layout)) } /// Creates a buffer from an existing memory region. Ownership of the memory is tracked via reference counting @@ -108,9 +134,10 @@ impl Buffer { deallocation: Deallocation, ) -> Self { let bytes = Bytes::new(ptr, len, deallocation); + let ptr = bytes.as_ptr(); Buffer { + ptr, data: Arc::new(bytes), - offset: 0, length: len, } } @@ -136,7 +163,11 @@ impl Buffer { /// Returns the byte slice stored in this buffer pub fn as_slice(&self) -> &[u8] { - &self.data[self.offset..(self.offset + self.length)] + unsafe { std::slice::from_raw_parts(self.ptr, self.length) } + } + + pub(crate) fn deallocation(&self) -> &Deallocation { + self.data.deallocation() } /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`. @@ -145,13 +176,18 @@ impl Buffer { /// Panics iff `offset` is larger than `len`. pub fn slice(&self, offset: usize) -> Self { assert!( - offset <= self.len(), + offset <= self.length, "the offset of the new Buffer cannot exceed the existing length" ); + // Safety: + // This cannot overflow as + // `self.offset + self.length < self.data.len()` + // `offset < self.length` + let ptr = unsafe { self.ptr.add(offset) }; Self { data: self.data.clone(), - offset: self.offset + offset, length: self.length - offset, + ptr, } } @@ -162,12 +198,15 @@ impl Buffer { /// Panics iff `(offset + length)` is larger than the existing length. pub fn slice_with_length(&self, offset: usize, length: usize) -> Self { assert!( - offset + length <= self.len(), + offset.saturating_add(length) <= self.length, "the offset of the new Buffer cannot exceed the existing length" ); + // Safety: + // offset + length <= self.length + let ptr = unsafe { self.ptr.add(offset) }; Self { data: self.data.clone(), - offset: self.offset + offset, + ptr, length, } } @@ -178,10 +217,10 @@ impl Buffer { /// stored anywhere, to avoid dangling pointers. #[inline] pub fn as_ptr(&self) -> *const u8 { - unsafe { self.data.ptr().as_ptr().add(self.offset) } + self.ptr } - /// View buffer as typed slice. + /// View buffer as a slice of a specific type. /// /// # Panics /// @@ -215,6 +254,7 @@ impl Buffer { } /// Returns the number of 1-bits in this buffer. + #[deprecated(note = "use count_set_bits_offset instead")] pub fn count_set_bits(&self) -> usize { let len_in_bits = self.len() * 8; // self.offset is already taken into consideration by the bit_chunks implementation @@ -226,6 +266,72 @@ impl Buffer { pub fn count_set_bits_offset(&self, offset: usize, len: usize) -> usize { UnalignedBitChunk::new(self.as_slice(), offset, len).count_ones() } + + /// Returns `MutableBuffer` for mutating the buffer if this buffer is not shared. + /// Returns `Err` if this is shared or its allocation is from an external source or + /// it is not allocated with alignment [`ALIGNMENT`] + pub fn into_mutable(self) -> Result { + let ptr = self.ptr; + let length = self.length; + Arc::try_unwrap(self.data) + .and_then(|bytes| { + // The pointer of underlying buffer should not be offset. + assert_eq!(ptr, bytes.ptr().as_ptr()); + MutableBuffer::from_bytes(bytes).map_err(Arc::new) + }) + .map_err(|bytes| Buffer { + data: bytes, + ptr, + length, + }) + } + + /// Returns `Vec` for mutating the buffer + /// + /// Returns `Err(self)` if this buffer does not have the same [`Layout`] as + /// the destination Vec or contains a non-zero offset + pub fn into_vec(self) -> Result, Self> { + let layout = match self.data.deallocation() { + Deallocation::Standard(l) => l, + _ => return Err(self), // Custom allocation + }; + + if self.ptr != self.data.as_ptr() { + return Err(self); // Data is offset + } + + let v_capacity = layout.size() / std::mem::size_of::(); + match Layout::array::(v_capacity) { + Ok(expected) if layout == &expected => {} + _ => return Err(self), // Incorrect layout + } + + let length = self.length; + let ptr = self.ptr; + let v_len = self.length / std::mem::size_of::(); + + Arc::try_unwrap(self.data) + .map(|bytes| unsafe { + let ptr = bytes.ptr().as_ptr() as _; + std::mem::forget(bytes); + // Safety + // Verified that bytes layout matches that of Vec + Vec::from_raw_parts(ptr, v_len, v_capacity) + }) + .map_err(|bytes| Buffer { + data: bytes, + ptr, + length, + }) + } + + /// Returns true if this [`Buffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.ptr == other.ptr && self.length == other.length + } } /// Creating a `Buffer` instance by copying the memory from a `AsRef<[u8]>` into a newly @@ -242,7 +348,7 @@ impl> From for Buffer { } /// Creating a `Buffer` instance by storing the boolean values into the buffer -impl std::iter::FromIterator for Buffer { +impl FromIterator for Buffer { fn from_iter(iter: I) -> Self where I: IntoIterator, @@ -266,12 +372,18 @@ impl From for Buffer { } } +impl From> for Buffer { + fn from(mut value: BufferBuilder) -> Self { + value.finish() + } +} + impl Buffer { /// Creates a [`Buffer`] from an [`Iterator`] with a trusted (upper) length. /// Prefer this to `collect` whenever possible, as it is ~60% faster. /// # Example /// ``` - /// # use arrow::buffer::Buffer; + /// # use arrow_buffer::buffer::Buffer; /// let v = vec![1u32]; /// let iter = v.iter().map(|x| x * 2); /// let buffer = unsafe { Buffer::from_trusted_len_iter(iter) }; @@ -301,10 +413,10 @@ impl Buffer { pub unsafe fn try_from_trusted_len_iter< E, T: ArrowNativeType, - I: Iterator>, + I: Iterator>, >( iterator: I, - ) -> std::result::Result { + ) -> Result { Ok(MutableBuffer::try_from_trusted_len_iter(iterator)?.into()) } } @@ -335,6 +447,7 @@ impl FromIterator for Buffer { #[cfg(test)] mod tests { + use crate::i256; use std::panic::{RefUnwindSafe, UnwindSafe}; use std::thread; @@ -417,9 +530,7 @@ mod tests { } #[test] - #[should_panic( - expected = "the offset of the new Buffer cannot exceed the existing length" - )] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_slice_offset_out_of_bound() { let buf = Buffer::from(&[2, 4, 6, 8, 10]); buf.slice(6); @@ -466,11 +577,17 @@ mod tests { #[test] fn test_count_bits() { - assert_eq!(0, Buffer::from(&[0b00000000]).count_set_bits()); - assert_eq!(8, Buffer::from(&[0b11111111]).count_set_bits()); - assert_eq!(3, Buffer::from(&[0b00001101]).count_set_bits()); - assert_eq!(6, Buffer::from(&[0b01001001, 0b01010010]).count_set_bits()); - assert_eq!(16, Buffer::from(&[0b11111111, 0b11111111]).count_set_bits()); + assert_eq!(0, Buffer::from(&[0b00000000]).count_set_bits_offset(0, 8)); + assert_eq!(8, Buffer::from(&[0b11111111]).count_set_bits_offset(0, 8)); + assert_eq!(3, Buffer::from(&[0b00001101]).count_set_bits_offset(0, 8)); + assert_eq!( + 6, + Buffer::from(&[0b01001001, 0b01010010]).count_set_bits_offset(0, 16) + ); + assert_eq!( + 16, + Buffer::from(&[0b11111111, 0b11111111]).count_set_bits_offset(0, 16) + ); } #[test] @@ -479,31 +596,31 @@ mod tests { 0, Buffer::from(&[0b11111111, 0b00000000]) .slice(1) - .count_set_bits() + .count_set_bits_offset(0, 8) ); assert_eq!( 8, Buffer::from(&[0b11111111, 0b11111111]) .slice_with_length(1, 1) - .count_set_bits() + .count_set_bits_offset(0, 8) ); assert_eq!( 3, Buffer::from(&[0b11111111, 0b11111111, 0b00001101]) .slice(2) - .count_set_bits() + .count_set_bits_offset(0, 8) ); assert_eq!( 6, Buffer::from(&[0b11111111, 0b01001001, 0b01010010]) .slice_with_length(1, 2) - .count_set_bits() + .count_set_bits_offset(0, 16) ); assert_eq!( 16, Buffer::from(&[0b11111111, 0b11111111, 0b11111111, 0b11111111]) .slice(2) - .count_set_bits() + .count_set_bits_offset(0, 16) ); } @@ -574,4 +691,133 @@ mod tests { let slice = buffer.typed_data::(); assert_eq!(slice, &[2, 3, 4, 5]); } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn slice_overflow() { + let buffer = Buffer::from(MutableBuffer::from_len_zeroed(12)); + buffer.slice_with_length(2, usize::MAX); + } + + #[test] + fn test_vec_interop() { + // Test empty vec + let a: Vec = Vec::new(); + let b = Buffer::from_vec(a); + b.into_vec::().unwrap(); + + // Test vec with capacity + let a: Vec = Vec::with_capacity(20); + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 0); + assert_eq!(back.capacity(), 20); + + // Test vec with values + let mut a: Vec = Vec::with_capacity(3); + a.extend_from_slice(&[1, 2, 3]); + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 3); + assert_eq!(back.capacity(), 3); + + // Test vec with values and spare capacity + let mut a: Vec = Vec::with_capacity(20); + a.extend_from_slice(&[1, 4, 7, 8, 9, 3, 6]); + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 7); + assert_eq!(back.capacity(), 20); + + // Test incorrect alignment + let a: Vec = Vec::new(); + let b = Buffer::from_vec(a); + let b = b.into_vec::().unwrap_err(); + b.into_vec::().unwrap_err(); + + // Test convert between types with same alignment + // This is an implementation quirk, but isn't harmful + // as ArrowNativeType are trivially transmutable + let a: Vec = vec![1, 2, 3, 4]; + let b = Buffer::from_vec(a); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 4); + assert_eq!(back.capacity(), 4); + + // i256 has the same layout as i128 so this is valid + let mut b: Vec = Vec::with_capacity(4); + b.extend_from_slice(&[1, 2, 3, 4]); + let b = Buffer::from_vec(b); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 2); + assert_eq!(back.capacity(), 2); + + // Invalid layout + let b: Vec = vec![1, 2, 3]; + let b = Buffer::from_vec(b); + b.into_vec::().unwrap_err(); + + // Invalid layout + let mut b: Vec = Vec::with_capacity(5); + b.extend_from_slice(&[1, 2, 3, 4]); + let b = Buffer::from_vec(b); + b.into_vec::().unwrap_err(); + + // Truncates length + // This is an implementation quirk, but isn't harmful + let mut b: Vec = Vec::with_capacity(4); + b.extend_from_slice(&[1, 2, 3]); + let b = Buffer::from_vec(b); + let back = b.into_vec::().unwrap(); + assert_eq!(back.len(), 1); + assert_eq!(back.capacity(), 2); + + // Cannot use aligned allocation + let b = Buffer::from(MutableBuffer::new(10)); + let b = b.into_vec::().unwrap_err(); + b.into_vec::().unwrap_err(); + + // Test slicing + let mut a: Vec = Vec::with_capacity(20); + a.extend_from_slice(&[1, 4, 7, 8, 9, 3, 6]); + let b = Buffer::from_vec(a); + let slice = b.slice_with_length(0, 64); + + // Shared reference fails + let slice = slice.into_vec::().unwrap_err(); + drop(b); + + // Succeeds as no outstanding shared reference + let back = slice.into_vec::().unwrap(); + assert_eq!(&back, &[1, 4, 7, 8]); + assert_eq!(back.capacity(), 20); + + // Slicing by non-multiple length truncates + let mut a: Vec = Vec::with_capacity(8); + a.extend_from_slice(&[1, 4, 7, 3]); + + let b = Buffer::from_vec(a); + let slice = b.slice_with_length(0, 34); + drop(b); + + let back = slice.into_vec::().unwrap(); + assert_eq!(&back, &[1, 4]); + assert_eq!(back.capacity(), 8); + + // Offset prevents conversion + let a: Vec = vec![1, 3, 4, 6]; + let b = Buffer::from_vec(a).slice(2); + b.into_vec::().unwrap_err(); + + let b = MutableBuffer::new(16).into_buffer(); + let b = b.into_vec::().unwrap_err(); // Invalid layout + let b = b.into_vec::().unwrap_err(); // Invalid layout + b.into_mutable().unwrap(); + + let b = Buffer::from_vec(vec![1_u32, 3, 5]); + let b = b.into_mutable().unwrap(); + let b = Buffer::from(b); + let b = b.into_vec::().unwrap(); + assert_eq!(b, &[1, 3, 5]); + } } diff --git a/parquet/build.rs b/arrow-buffer/src/buffer/mod.rs similarity index 72% rename from parquet/build.rs rename to arrow-buffer/src/buffer/mod.rs index 8aada1835ce1..d33e68795e4e 100644 --- a/parquet/build.rs +++ b/arrow-buffer/src/buffer/mod.rs @@ -15,10 +15,21 @@ // specific language governing permissions and limitations // under the License. -fn main() { - // Set Parquet version and "created by" string. - let version = env!("CARGO_PKG_VERSION"); - let created_by = format!("parquet-rs version {}", version); - println!("cargo:rustc-env=PARQUET_VERSION={}", version); - println!("cargo:rustc-env=PARQUET_CREATED_BY={}", created_by); -} +//! Types of shared memory region + +mod offset; +pub use offset::*; +mod immutable; +pub use immutable::*; +mod mutable; +pub use mutable::*; +mod ops; +pub use ops::*; +mod scalar; +pub use scalar::*; +mod boolean; +pub use boolean::*; +mod null; +pub use null::*; +mod run; +pub use run::*; diff --git a/arrow/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs similarity index 72% rename from arrow/src/buffer/mutable.rs rename to arrow-buffer/src/buffer/mutable.rs index 1c662ec23eef..69c986cc1056 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -15,28 +15,35 @@ // specific language governing permissions and limitations // under the License. -use super::Buffer; -use crate::alloc::Deallocation; +use std::alloc::{handle_alloc_error, Layout}; +use std::mem; +use std::ptr::NonNull; + +use crate::alloc::{Deallocation, ALIGNMENT}; use crate::{ - alloc, bytes::Bytes, - datatypes::{ArrowNativeType, ToByteSlice}, + native::{ArrowNativeType, ToByteSlice}, util::bit_util, }; -use std::ptr::NonNull; + +use super::Buffer; /// A [`MutableBuffer`] is Arrow's interface to build a [`Buffer`] out of items or slices of items. +/// /// [`Buffer`]s created from [`MutableBuffer`] (via `into`) are guaranteed to have its pointer aligned /// along cache lines and in multiple of 64 bytes. +/// /// Use [MutableBuffer::push] to insert an item, [MutableBuffer::extend_from_slice] /// to insert many items, and `into` to convert it to [`Buffer`]. /// -/// For a safe, strongly typed API consider using [`crate::array::BufferBuilder`] +/// For a safe, strongly typed API consider using [`Vec`] and [`ScalarBuffer`](crate::ScalarBuffer) +/// +/// Note: this may be deprecated in a future release ([#1176](https://github.com/apache/arrow-rs/issues/1176)) /// /// # Example /// /// ``` -/// # use arrow::buffer::{Buffer, MutableBuffer}; +/// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); /// buffer.push(256u32); /// buffer.extend_from_slice(&[1u32]); @@ -49,7 +56,7 @@ pub struct MutableBuffer { data: NonNull, // invariant: len <= capacity len: usize, - capacity: usize, + layout: Layout, } impl MutableBuffer { @@ -63,11 +70,19 @@ impl MutableBuffer { #[inline] pub fn with_capacity(capacity: usize) -> Self { let capacity = bit_util::round_upto_multiple_of_64(capacity); - let ptr = alloc::allocate_aligned(capacity); + let layout = Layout::from_size_align(capacity, ALIGNMENT).unwrap(); + let data = match layout.size() { + 0 => dangling_ptr(), + _ => { + // Safety: Verified size != 0 + let raw_ptr = unsafe { std::alloc::alloc(layout) }; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + }; Self { - data: ptr, + data, len: 0, - capacity, + layout, } } @@ -75,7 +90,7 @@ impl MutableBuffer { /// all bytes are guaranteed to be `0u8`. /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::from_len_zeroed(127); /// assert_eq!(buffer.len(), 127); /// assert!(buffer.capacity() >= 127); @@ -83,13 +98,37 @@ impl MutableBuffer { /// assert_eq!(data[126], 0u8); /// ``` pub fn from_len_zeroed(len: usize) -> Self { - let new_capacity = bit_util::round_upto_multiple_of_64(len); - let ptr = alloc::allocate_aligned_zeroed(new_capacity); - Self { - data: ptr, - len, - capacity: new_capacity, - } + let layout = Layout::from_size_align(len, ALIGNMENT).unwrap(); + let data = match layout.size() { + 0 => dangling_ptr(), + _ => { + // Safety: Verified size != 0 + let raw_ptr = unsafe { std::alloc::alloc_zeroed(layout) }; + NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) + } + }; + Self { data, len, layout } + } + + /// Create a [`MutableBuffer`] from the provided [`Vec`] without copying + #[inline] + #[deprecated(note = "Use From>")] + pub fn from_vec(vec: Vec) -> Self { + Self::from(vec) + } + + /// Allocates a new [MutableBuffer] from given `Bytes`. + pub(crate) fn from_bytes(bytes: Bytes) -> Result { + let layout = match bytes.deallocation() { + Deallocation::Standard(layout) => *layout, + _ => return Err(bytes), + }; + + let len = bytes.len(); + let data = bytes.ptr(); + mem::forget(bytes); + + Ok(Self { data, len, layout }) } /// creates a new [MutableBuffer] with capacity and length capable of holding `len` bits. @@ -106,7 +145,7 @@ impl MutableBuffer { /// the buffer directly (e.g., modifying the buffer by holding a mutable reference /// from `data_mut()`). pub fn with_bitset(mut self, end: usize, val: bool) -> Self { - assert!(end <= self.capacity); + assert!(end <= self.layout.size()); let v = if val { 255 } else { 0 }; unsafe { std::ptr::write_bytes(self.data.as_ptr(), v, end); @@ -121,7 +160,14 @@ impl MutableBuffer { /// `len` of the buffer and so can be used to initialize the memory region from /// `len` to `capacity`. pub fn set_null_bits(&mut self, start: usize, count: usize) { - assert!(start + count <= self.capacity); + assert!( + start.saturating_add(count) <= self.layout.size(), + "range start index {start} and count {count} out of bounds for \ + buffer of length {}", + self.layout.size(), + ); + + // Safety: `self.data[start..][..count]` is in-bounds and well-aligned for `u8` unsafe { std::ptr::write_bytes(self.data.as_ptr().add(start), 0, count); } @@ -131,7 +177,7 @@ impl MutableBuffer { /// `self.len + additional > capacity`. /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); /// buffer.reserve(253); // allocates for the first time /// (0..253u8).for_each(|i| buffer.push(i)); // no reallocation @@ -143,19 +189,35 @@ impl MutableBuffer { #[inline(always)] pub fn reserve(&mut self, additional: usize) { let required_cap = self.len + additional; - if required_cap > self.capacity { - // JUSTIFICATION - // Benefit - // necessity - // Soundness - // `self.data` is valid for `self.capacity`. - let (ptr, new_capacity) = - unsafe { reallocate(self.data, self.capacity, required_cap) }; - self.data = ptr; - self.capacity = new_capacity; + if required_cap > self.layout.size() { + let new_capacity = bit_util::round_upto_multiple_of_64(required_cap); + let new_capacity = std::cmp::max(new_capacity, self.layout.size() * 2); + self.reallocate(new_capacity) } } + #[cold] + fn reallocate(&mut self, capacity: usize) { + let new_layout = Layout::from_size_align(capacity, self.layout.align()).unwrap(); + if new_layout.size() == 0 { + if self.layout.size() != 0 { + // Safety: data was allocated with layout + unsafe { std::alloc::dealloc(self.as_mut_ptr(), self.layout) }; + self.layout = new_layout + } + return; + } + + let data = match self.layout.size() { + // Safety: new_layout is not empty + 0 => unsafe { std::alloc::alloc(new_layout) }, + // Safety: verified new layout is valid and not empty + _ => unsafe { std::alloc::realloc(self.as_mut_ptr(), self.layout, capacity) }, + }; + self.data = NonNull::new(data).unwrap_or_else(|| handle_alloc_error(new_layout)); + self.layout = new_layout; + } + /// Truncates this buffer to `len` bytes /// /// If `len` is greater than the buffer's current length, this has no effect @@ -171,7 +233,7 @@ impl MutableBuffer { /// growing it (potentially reallocating it) and writing `value` in the newly available bytes. /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); /// buffer.resize(253, 2); // allocates for the first time /// assert_eq!(buffer.as_slice()[252], 2u8); @@ -195,7 +257,7 @@ impl MutableBuffer { /// /// # Example /// ``` - /// # use arrow::buffer::{Buffer, MutableBuffer}; + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// // 2 cache lines /// let mut buffer = MutableBuffer::new(128); /// assert_eq!(buffer.capacity(), 128); @@ -207,17 +269,8 @@ impl MutableBuffer { /// ``` pub fn shrink_to_fit(&mut self) { let new_capacity = bit_util::round_upto_multiple_of_64(self.len); - if new_capacity < self.capacity { - // JUSTIFICATION - // Benefit - // necessity - // Soundness - // `self.data` is valid for `self.capacity`. - let ptr = - unsafe { alloc::reallocate(self.data, self.capacity, new_capacity) }; - - self.data = ptr; - self.capacity = new_capacity; + if new_capacity < self.layout.size() { + self.reallocate(new_capacity) } } @@ -238,7 +291,7 @@ impl MutableBuffer { /// The invariant `buffer.len() <= buffer.capacity()` is always upheld. #[inline] pub const fn capacity(&self) -> usize { - self.capacity + self.layout.size() } /// Clear all existing data from this buffer. @@ -281,14 +334,12 @@ impl MutableBuffer { #[inline] pub(super) fn into_buffer(self) -> Buffer { - let bytes = unsafe { - Bytes::new(self.data, self.len, Deallocation::Arrow(self.capacity)) - }; + let bytes = unsafe { Bytes::new(self.data, self.len, Deallocation::Standard(self.layout)) }; std::mem::forget(self); Buffer::from_bytes(bytes) } - /// View this buffer as a slice of a specific type. + /// View this buffer as a mutable slice of a specific type. /// /// # Panics /// @@ -298,8 +349,22 @@ impl MutableBuffer { // SAFETY // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect // implementation outside this crate, and this method checks alignment - let (prefix, offsets, suffix) = - unsafe { self.as_slice_mut().align_to_mut::() }; + let (prefix, offsets, suffix) = unsafe { self.as_slice_mut().align_to_mut::() }; + assert!(prefix.is_empty() && suffix.is_empty()); + offsets + } + + /// View buffer as a immutable slice of a specific type. + /// + /// # Panics + /// + /// This function panics if the underlying buffer is not aligned + /// correctly for type `T`. + pub fn typed_data(&self) -> &[T] { + // SAFETY + // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect + // implementation outside this crate, and this method checks alignment + let (prefix, offsets, suffix) = unsafe { self.as_slice().align_to::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -307,15 +372,14 @@ impl MutableBuffer { /// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed. /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let mut buffer = MutableBuffer::new(0); /// buffer.extend_from_slice(&[2u32, 0]); /// assert_eq!(buffer.len(), 8) // u32 has 4 bytes /// ``` #[inline] pub fn extend_from_slice(&mut self, items: &[T]) { - let len = items.len(); - let additional = len * std::mem::size_of::(); + let additional = mem::size_of_val(items); self.reserve(additional); unsafe { // this assumes that `[ToByteSlice]` can be copied directly @@ -331,7 +395,7 @@ impl MutableBuffer { /// Extends the buffer with a new item, increasing its capacity if needed. /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let mut buffer = MutableBuffer::new(0); /// buffer.push(256u32); /// assert_eq!(buffer.len(), 4) // u32 has 4 bytes @@ -350,7 +414,7 @@ impl MutableBuffer { /// Extends the buffer with a new item, without checking for sufficient capacity /// # Safety - /// Caller must ensure that the capacity()-len()>=size_of() + /// Caller must ensure that the capacity()-len()>=`size_of`() #[inline] pub unsafe fn push_unchecked(&mut self, item: T) { let additional = std::mem::size_of::(); @@ -369,24 +433,54 @@ impl MutableBuffer { /// # Safety /// The caller must ensure that the buffer was properly initialized up to `len`. #[inline] - pub(crate) unsafe fn set_len(&mut self, len: usize) { + pub unsafe fn set_len(&mut self, len: usize) { assert!(len <= self.capacity()); self.len = len; } + + /// Invokes `f` with values `0..len` collecting the boolean results into a new `MutableBuffer` + /// + /// This is similar to `from_trusted_len_iter_bool`, however, can be significantly faster + /// as it eliminates the conditional `Iterator::next` + #[inline] + pub fn collect_bool bool>(len: usize, mut f: F) -> Self { + let mut buffer = Self::new(bit_util::ceil(len, 64) * 8); + + let chunks = len / 64; + let remainder = len % 64; + for chunk in 0..chunks { + let mut packed = 0; + for bit_idx in 0..64 { + let i = bit_idx + chunk * 64; + packed |= (f(i) as u64) << bit_idx; + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + + if remainder != 0 { + let mut packed = 0; + for bit_idx in 0..remainder { + let i = bit_idx + chunks * 64; + packed |= (f(i) as u64) << bit_idx; + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + + buffer.truncate(bit_util::ceil(len, 8)); + buffer + } } -/// # Safety -/// `ptr` must be allocated for `old_capacity`. -#[cold] -unsafe fn reallocate( - ptr: NonNull, - old_capacity: usize, - new_capacity: usize, -) -> (NonNull, usize) { - let new_capacity = bit_util::round_upto_multiple_of_64(new_capacity); - let new_capacity = std::cmp::max(new_capacity, old_capacity * 2); - let ptr = alloc::reallocate(ptr, old_capacity, new_capacity); - (ptr, new_capacity) +#[inline] +fn dangling_ptr() -> NonNull { + // SAFETY: ALIGNMENT is a non-zero usize which is then casted + // to a *mut T. Therefore, `ptr` is not null and the conditions for + // calling new_unchecked() are respected. + unsafe { NonNull::new_unchecked(ALIGNMENT as *mut u8) } } impl Extend for MutableBuffer { @@ -397,6 +491,21 @@ impl Extend for MutableBuffer { } } +impl From> for MutableBuffer { + fn from(value: Vec) -> Self { + // Safety + // Vec::as_ptr guaranteed to not be null and ArrowNativeType are trivially transmutable + let data = unsafe { NonNull::new_unchecked(value.as_ptr() as _) }; + let len = value.len() * mem::size_of::(); + // Safety + // Vec guaranteed to have a valid layout matching that of `Layout::array` + // This is based on `RawVec::current_memory` + let layout = unsafe { Layout::array::(value.capacity()).unwrap_unchecked() }; + mem::forget(value); + Self { data, len, layout } + } +} + impl MutableBuffer { #[inline] pub(super) fn extend_from_iter>( @@ -411,7 +520,7 @@ impl MutableBuffer { // this is necessary because of https://github.com/rust-lang/rust/issues/32155 let mut len = SetLenOnDrop::new(&mut self.len); let mut dst = unsafe { self.data.as_ptr().add(len.local_len) }; - let capacity = self.capacity; + let capacity = self.layout.size(); while len.local_len + item_size <= capacity { if let Some(item) = iterator.next() { @@ -434,7 +543,7 @@ impl MutableBuffer { /// Prefer this to `collect` whenever possible, as it is faster ~60% faster. /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let v = vec![1u32]; /// let iter = v.iter().map(|x| x * 2); /// let buffer = unsafe { MutableBuffer::from_trusted_len_iter(iter) }; @@ -475,10 +584,10 @@ impl MutableBuffer { } /// Creates a [`MutableBuffer`] from a boolean [`Iterator`] with a trusted (upper) length. - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// # Example /// ``` - /// # use arrow::buffer::MutableBuffer; + /// # use arrow_buffer::buffer::MutableBuffer; /// let v = vec![false, true, false]; /// let iter = v.iter().map(|x| *x || true); /// let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(iter) }; @@ -492,42 +601,11 @@ impl MutableBuffer { // we can't specialize `extend` for `TrustedLen` like `Vec` does. // 2. `from_trusted_len_iter_bool` is faster. #[inline] - pub unsafe fn from_trusted_len_iter_bool>( - mut iterator: I, - ) -> Self { + pub unsafe fn from_trusted_len_iter_bool>(mut iterator: I) -> Self { let (_, upper) = iterator.size_hint(); - let upper = upper.expect("from_trusted_len_iter requires an upper limit"); - - let mut result = { - let byte_capacity: usize = upper.saturating_add(7) / 8; - MutableBuffer::new(byte_capacity) - }; - - 'a: loop { - let mut byte_accum: u8 = 0; - let mut mask: u8 = 1; - - //collect (up to) 8 bits into a byte - while mask != 0 { - if let Some(value) = iterator.next() { - byte_accum |= match value { - true => mask, - false => 0, - }; - mask <<= 1; - } else { - if mask != 1 { - // Add last byte - result.push_unchecked(byte_accum); - } - break 'a; - } - } + let len = upper.expect("from_trusted_len_iter requires an upper limit"); - // Soundness: from_trusted_len - result.push_unchecked(byte_accum); - } - result + Self::collect_bool(len, |_| iterator.next().unwrap()) } /// Creates a [`MutableBuffer`] from an [`Iterator`] with a trusted (upper) length or errors @@ -540,10 +618,10 @@ impl MutableBuffer { pub unsafe fn try_from_trusted_len_iter< E, T: ArrowNativeType, - I: Iterator>, + I: Iterator>, >( iterator: I, - ) -> std::result::Result { + ) -> Result { let item_size = std::mem::size_of::(); let (_, upper) = iterator.size_hint(); let upper = upper.expect("try_from_trusted_len_iter requires an upper limit"); @@ -574,6 +652,12 @@ impl MutableBuffer { } } +impl Default for MutableBuffer { + fn default() -> Self { + Self::with_capacity(0) + } +} + impl std::ops::Deref for MutableBuffer { type Target = [u8]; @@ -590,7 +674,10 @@ impl std::ops::DerefMut for MutableBuffer { impl Drop for MutableBuffer { fn drop(&mut self) { - unsafe { alloc::free_aligned(self.data, self.capacity) }; + if self.layout.size() != 0 { + // Safety: data was allocated with standard allocator with given layout + unsafe { std::alloc::dealloc(self.data.as_ptr() as _, self.layout) }; + } } } @@ -599,7 +686,7 @@ impl PartialEq for MutableBuffer { if self.len != other.len { return false; } - if self.capacity != other.capacity { + if self.layout != other.layout { return false; } self.as_slice() == other.as_slice() @@ -686,6 +773,14 @@ impl std::iter::FromIterator for MutableBuffer { } } +impl std::iter::FromIterator for MutableBuffer { + fn from_iter>(iter: I) -> Self { + let mut buffer = Self::default(); + buffer.extend_from_iter(iter.into_iter()); + buffer + } +} + #[cfg(test)] mod tests { use super::*; @@ -698,6 +793,19 @@ mod tests { assert!(buf.is_empty()); } + #[test] + fn test_mutable_default() { + let buf = MutableBuffer::default(); + assert_eq!(0, buf.capacity()); + assert_eq!(0, buf.len()); + assert!(buf.is_empty()); + + let mut buf = MutableBuffer::default(); + buf.extend_from_slice(b"hello"); + assert_eq!(5, buf.len()); + assert_eq!(b"hello", buf.as_slice()); + } + #[test] fn test_mutable_extend_from_slice() { let mut buf = MutableBuffer::new(100); @@ -860,4 +968,38 @@ mod tests { buffer.shrink_to_fit(); assert!(buffer.capacity() >= 64 && buffer.capacity() < 128); } + + #[test] + fn test_mutable_set_null_bits() { + let mut buffer = MutableBuffer::new(8).with_bitset(8, true); + + for i in 0..=buffer.capacity() { + buffer.set_null_bits(i, 0); + assert_eq!(buffer[..8], [255; 8][..]); + } + + buffer.set_null_bits(1, 4); + assert_eq!(buffer[..8], [255, 0, 0, 0, 0, 255, 255, 255][..]); + } + + #[test] + #[should_panic = "out of bounds for buffer of length"] + fn test_mutable_set_null_bits_oob() { + let mut buffer = MutableBuffer::new(64); + buffer.set_null_bits(1, buffer.capacity()); + } + + #[test] + #[should_panic = "out of bounds for buffer of length"] + fn test_mutable_set_null_bits_oob_by_overflow() { + let mut buffer = MutableBuffer::new(0); + buffer.set_null_bits(1, usize::MAX); + } + + #[test] + fn from_iter() { + let buffer = [1u16, 2, 3, 4].into_iter().collect::(); + assert_eq!(buffer.len(), 4 * mem::size_of::()); + assert_eq!(buffer.as_slice(), &[1, 0, 2, 0, 3, 0, 4, 0]); + } } diff --git a/arrow-buffer/src/buffer/null.rs b/arrow-buffer/src/buffer/null.rs new file mode 100644 index 000000000000..c79aef398059 --- /dev/null +++ b/arrow-buffer/src/buffer/null.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_iterator::{BitIndexIterator, BitIterator, BitSliceIterator}; +use crate::buffer::BooleanBuffer; +use crate::{Buffer, MutableBuffer}; + +/// A [`BooleanBuffer`] used to encode validity for arrow arrays +/// +/// As per the [Arrow specification], array validity is encoded in a packed bitmask with a +/// `true` value indicating the corresponding slot is not null, and `false` indicating +/// that it is null. +/// +/// [Arrow specification]: https://arrow.apache.org/docs/format/Columnar.html#validity-bitmaps +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct NullBuffer { + buffer: BooleanBuffer, + null_count: usize, +} + +impl NullBuffer { + /// Create a new [`NullBuffer`] computing the null count + pub fn new(buffer: BooleanBuffer) -> Self { + let null_count = buffer.len() - buffer.count_set_bits(); + Self { buffer, null_count } + } + + /// Create a new [`NullBuffer`] of length `len` where all values are null + pub fn new_null(len: usize) -> Self { + Self { + buffer: BooleanBuffer::new_unset(len), + null_count: len, + } + } + + /// Create a new [`NullBuffer`] of length `len` where all values are valid + /// + /// Note: it is more efficient to not set the null buffer if it is known to be all valid + pub fn new_valid(len: usize) -> Self { + Self { + buffer: BooleanBuffer::new_set(len), + null_count: 0, + } + } + + /// Create a new [`NullBuffer`] with the provided `buffer` and `null_count` + /// + /// # Safety + /// + /// `buffer` must contain `null_count` `0` bits + pub unsafe fn new_unchecked(buffer: BooleanBuffer, null_count: usize) -> Self { + Self { buffer, null_count } + } + + /// Computes the union of the nulls in two optional [`NullBuffer`] + /// + /// This is commonly used by binary operations where the result is NULL if either + /// of the input values is NULL. Handling the null mask separately in this way + /// can yield significant performance improvements over an iterator approach + pub fn union(lhs: Option<&NullBuffer>, rhs: Option<&NullBuffer>) -> Option { + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(Self::new(lhs.inner() & rhs.inner())), + (Some(n), None) | (None, Some(n)) => Some(n.clone()), + (None, None) => None, + } + } + + /// Returns true if all nulls in `other` also exist in self + pub fn contains(&self, other: &NullBuffer) -> bool { + if other.null_count == 0 { + return true; + } + let lhs = self.inner().bit_chunks().iter_padded(); + let rhs = other.inner().bit_chunks().iter_padded(); + lhs.zip(rhs).all(|(l, r)| (l & !r) == 0) + } + + /// Returns a new [`NullBuffer`] where each bit in the current null buffer + /// is repeated `count` times. This is useful for masking the nulls of + /// the child of a FixedSizeListArray based on its parent + pub fn expand(&self, count: usize) -> Self { + let capacity = self.buffer.len().checked_mul(count).unwrap(); + let mut buffer = MutableBuffer::new_null(capacity); + + // Expand each bit within `null_mask` into `element_len` + // bits, constructing the implicit mask of the child elements + for i in 0..self.buffer.len() { + if self.is_null(i) { + continue; + } + for j in 0..count { + crate::bit_util::set_bit(buffer.as_mut(), i * count + j) + } + } + Self { + buffer: BooleanBuffer::new(buffer.into(), 0, capacity), + null_count: self.null_count * count, + } + } + + /// Returns the length of this [`NullBuffer`] + #[inline] + pub fn len(&self) -> usize { + self.buffer.len() + } + + /// Returns the offset of this [`NullBuffer`] in bits + #[inline] + pub fn offset(&self) -> usize { + self.buffer.offset() + } + + /// Returns true if this [`NullBuffer`] is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + /// Returns the null count for this [`NullBuffer`] + #[inline] + pub fn null_count(&self) -> usize { + self.null_count + } + + /// Returns `true` if the value at `idx` is not null + #[inline] + pub fn is_valid(&self, idx: usize) -> bool { + self.buffer.value(idx) + } + + /// Returns `true` if the value at `idx` is null + #[inline] + pub fn is_null(&self, idx: usize) -> bool { + !self.is_valid(idx) + } + + /// Returns the packed validity of this [`NullBuffer`] not including any offset + #[inline] + pub fn validity(&self) -> &[u8] { + self.buffer.values() + } + + /// Slices this [`NullBuffer`] by the provided `offset` and `length` + pub fn slice(&self, offset: usize, len: usize) -> Self { + Self::new(self.buffer.slice(offset, len)) + } + + /// Returns an iterator over the bits in this [`NullBuffer`] + /// + /// * `true` indicates that the corresponding value is not NULL + /// * `false` indicates that the corresponding value is NULL + /// + /// Note: [`Self::valid_indices`] will be significantly faster for most use-cases + pub fn iter(&self) -> BitIterator<'_> { + self.buffer.iter() + } + + /// Returns a [`BitIndexIterator`] over the valid indices in this [`NullBuffer`] + /// + /// Valid indices indicate the corresponding value is not NULL + pub fn valid_indices(&self) -> BitIndexIterator<'_> { + self.buffer.set_indices() + } + + /// Returns a [`BitSliceIterator`] yielding contiguous ranges of valid indices + /// + /// Valid indices indicate the corresponding value is not NULL + pub fn valid_slices(&self) -> BitSliceIterator<'_> { + self.buffer.set_slices() + } + + /// Calls the provided closure for each index in this null mask that is set + #[inline] + pub fn try_for_each_valid_idx Result<(), E>>( + &self, + f: F, + ) -> Result<(), E> { + if self.null_count == self.len() { + return Ok(()); + } + self.valid_indices().try_for_each(f) + } + + /// Returns the inner [`BooleanBuffer`] + #[inline] + pub fn inner(&self) -> &BooleanBuffer { + &self.buffer + } + + /// Returns the inner [`BooleanBuffer`] + #[inline] + pub fn into_inner(self) -> BooleanBuffer { + self.buffer + } + + /// Returns the underlying [`Buffer`] + #[inline] + pub fn buffer(&self) -> &Buffer { + self.buffer.inner() + } +} + +impl<'a> IntoIterator for &'a NullBuffer { + type Item = bool; + type IntoIter = BitIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.buffer.iter() + } +} + +impl From for NullBuffer { + fn from(value: BooleanBuffer) -> Self { + Self::new(value) + } +} + +impl From<&[bool]> for NullBuffer { + fn from(value: &[bool]) -> Self { + BooleanBuffer::from(value).into() + } +} + +impl From> for NullBuffer { + fn from(value: Vec) -> Self { + BooleanBuffer::from(value).into() + } +} + +impl FromIterator for NullBuffer { + fn from_iter>(iter: T) -> Self { + BooleanBuffer::from_iter(iter).into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_size() { + // This tests that the niche optimisation eliminates the overhead of an option + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::>() + ); + } +} diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs new file mode 100644 index 000000000000..652d30c3b0ab --- /dev/null +++ b/arrow-buffer/src/buffer/offset.rs @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::ScalarBuffer; +use crate::{ArrowNativeType, MutableBuffer}; +use std::ops::Deref; + +/// A non-empty buffer of monotonically increasing, positive integers. +/// +/// [`OffsetBuffer`] are used to represent ranges of offsets. An +/// `OffsetBuffer` of `N+1` items contains `N` such ranges. The start +/// offset for element `i` is `offsets[i]` and the end offset is +/// `offsets[i+1]`. Equal offsets represent an empty range. +/// +/// # Example +/// +/// This example shows how 5 distinct ranges, are represented using a +/// 6 entry `OffsetBuffer`. The first entry `(0, 3)` represents the +/// three offsets `0, 1, 2`. The entry `(3,3)` represent no offsets +/// (e.g. an empty list). +/// +/// ```text +/// ┌───────┐ ┌───┐ +/// │ (0,3) │ │ 0 │ +/// ├───────┤ ├───┤ +/// │ (3,3) │ │ 3 │ +/// ├───────┤ ├───┤ +/// │ (3,4) │ │ 3 │ +/// ├───────┤ ├───┤ +/// │ (4,5) │ │ 4 │ +/// ├───────┤ ├───┤ +/// │ (5,7) │ │ 5 │ +/// └───────┘ ├───┤ +/// │ 7 │ +/// └───┘ +/// +/// Offsets Buffer +/// Logical +/// Offsets +/// +/// (offsets[i], +/// offsets[i+1]) +/// ``` + +#[derive(Debug, Clone)] +pub struct OffsetBuffer(ScalarBuffer); + +impl OffsetBuffer { + /// Create a new [`OffsetBuffer`] from the provided [`ScalarBuffer`] + /// + /// # Panics + /// + /// Panics if `buffer` is not a non-empty buffer containing + /// monotonically increasing values greater than or equal to zero + pub fn new(buffer: ScalarBuffer) -> Self { + assert!(!buffer.is_empty(), "offsets cannot be empty"); + assert!( + buffer[0] >= O::usize_as(0), + "offsets must be greater than 0" + ); + assert!( + buffer.windows(2).all(|w| w[0] <= w[1]), + "offsets must be monotonically increasing" + ); + Self(buffer) + } + + /// Create a new [`OffsetBuffer`] from the provided [`ScalarBuffer`] + /// + /// # Safety + /// + /// `buffer` must be a non-empty buffer containing monotonically increasing + /// values greater than or equal to zero + pub unsafe fn new_unchecked(buffer: ScalarBuffer) -> Self { + Self(buffer) + } + + /// Create a new [`OffsetBuffer`] containing a single 0 value + pub fn new_empty() -> Self { + let buffer = MutableBuffer::from_len_zeroed(std::mem::size_of::()); + Self(buffer.into_buffer().into()) + } + + /// Create a new [`OffsetBuffer`] containing `len + 1` `0` values + pub fn new_zeroed(len: usize) -> Self { + let len_bytes = len + .checked_add(1) + .and_then(|o| o.checked_mul(std::mem::size_of::())) + .expect("overflow"); + let buffer = MutableBuffer::from_len_zeroed(len_bytes); + Self(buffer.into_buffer().into()) + } + + /// Create a new [`OffsetBuffer`] from the iterator of slice lengths + /// + /// ``` + /// # use arrow_buffer::OffsetBuffer; + /// let offsets = OffsetBuffer::::from_lengths([1, 3, 5]); + /// assert_eq!(offsets.as_ref(), &[0, 1, 4, 9]); + /// ``` + /// + /// # Panics + /// + /// Panics on overflow + pub fn from_lengths(lengths: I) -> Self + where + I: IntoIterator, + { + let iter = lengths.into_iter(); + let mut out = Vec::with_capacity(iter.size_hint().0 + 1); + out.push(O::usize_as(0)); + + let mut acc = 0_usize; + for length in iter { + acc = acc.checked_add(length).expect("usize overflow"); + out.push(O::usize_as(acc)) + } + // Check for overflow + O::from_usize(acc).expect("offset overflow"); + Self(out.into()) + } + + /// Returns the inner [`ScalarBuffer`] + pub fn inner(&self) -> &ScalarBuffer { + &self.0 + } + + /// Returns the inner [`ScalarBuffer`], consuming self + pub fn into_inner(self) -> ScalarBuffer { + self.0 + } + + /// Returns a zero-copy slice of this buffer with length `len` and starting at `offset` + pub fn slice(&self, offset: usize, len: usize) -> Self { + Self(self.0.slice(offset, len.saturating_add(1))) + } + + /// Returns true if this [`OffsetBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.0.ptr_eq(&other.0) + } +} + +impl Deref for OffsetBuffer { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef<[T]> for OffsetBuffer { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic(expected = "offsets cannot be empty")] + fn empty_offsets() { + OffsetBuffer::new(Vec::::new().into()); + } + + #[test] + #[should_panic(expected = "offsets must be greater than 0")] + fn negative_offsets() { + OffsetBuffer::new(vec![-1, 0, 1].into()); + } + + #[test] + fn offsets() { + OffsetBuffer::new(vec![0, 1, 2, 3].into()); + + let offsets = OffsetBuffer::::new_zeroed(3); + assert_eq!(offsets.as_ref(), &[0; 4]); + + let offsets = OffsetBuffer::::new_zeroed(0); + assert_eq!(offsets.as_ref(), &[0; 1]); + } + + #[test] + #[should_panic(expected = "overflow")] + fn offsets_new_zeroed_overflow() { + OffsetBuffer::::new_zeroed(usize::MAX); + } + + #[test] + #[should_panic(expected = "offsets must be monotonically increasing")] + fn non_monotonic_offsets() { + OffsetBuffer::new(vec![1, 2, 0].into()); + } + + #[test] + fn from_lengths() { + let buffer = OffsetBuffer::::from_lengths([2, 6, 3, 7, 2]); + assert_eq!(buffer.as_ref(), &[0, 2, 8, 11, 18, 20]); + + let half_max = i32::MAX / 2; + let buffer = OffsetBuffer::::from_lengths([half_max as usize, half_max as usize]); + assert_eq!(buffer.as_ref(), &[0, half_max, half_max * 2]); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn from_lengths_offset_overflow() { + OffsetBuffer::::from_lengths([i32::MAX as usize, 1]); + } + + #[test] + #[should_panic(expected = "usize overflow")] + fn from_lengths_usize_overflow() { + OffsetBuffer::::from_lengths([usize::MAX, 1]); + } +} diff --git a/arrow/src/buffer/ops.rs b/arrow-buffer/src/buffer/ops.rs similarity index 76% rename from arrow/src/buffer/ops.rs rename to arrow-buffer/src/buffer/ops.rs index 7000f39767cb..ca00e41bea21 100644 --- a/arrow/src/buffer/ops.rs +++ b/arrow-buffer/src/buffer/ops.rs @@ -20,26 +20,19 @@ use crate::util::bit_util::ceil; /// Apply a bitwise operation `op` to four inputs and return the result as a Buffer. /// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. -#[allow(clippy::too_many_arguments)] -pub(crate) fn bitwise_quaternary_op_helper( - first: &Buffer, - first_offset_in_bits: usize, - second: &Buffer, - second_offset_in_bits: usize, - third: &Buffer, - third_offset_in_bits: usize, - fourth: &Buffer, - fourth_offset_in_bits: usize, +pub fn bitwise_quaternary_op_helper( + buffers: [&Buffer; 4], + offsets: [usize; 4], len_in_bits: usize, op: F, ) -> Buffer where F: Fn(u64, u64, u64, u64) -> u64, { - let first_chunks = first.bit_chunks(first_offset_in_bits, len_in_bits); - let second_chunks = second.bit_chunks(second_offset_in_bits, len_in_bits); - let third_chunks = third.bit_chunks(third_offset_in_bits, len_in_bits); - let fourth_chunks = fourth.bit_chunks(fourth_offset_in_bits, len_in_bits); + let first_chunks = buffers[0].bit_chunks(offsets[0], len_in_bits); + let second_chunks = buffers[1].bit_chunks(offsets[1], len_in_bits); + let third_chunks = buffers[2].bit_chunks(offsets[2], len_in_bits); + let fourth_chunks = buffers[3].bit_chunks(offsets[3], len_in_bits); let chunks = first_chunks .iter() @@ -73,10 +66,10 @@ pub fn bitwise_bin_op_helper( right: &Buffer, right_offset_in_bits: usize, len_in_bits: usize, - op: F, + mut op: F, ) -> Buffer where - F: Fn(u64, u64) -> u64, + F: FnMut(u64, u64) -> u64, { let left_chunks = left.bit_chunks(left_offset_in_bits, len_in_bits); let right_chunks = right.bit_chunks(right_offset_in_bits, len_in_bits); @@ -104,10 +97,10 @@ pub fn bitwise_unary_op_helper( left: &Buffer, offset_in_bits: usize, len_in_bits: usize, - op: F, + mut op: F, ) -> Buffer where - F: Fn(u64) -> u64, + F: FnMut(u64) -> u64, { // reserve capacity and set length so we can get a typed view of u64 chunks let mut result = @@ -132,6 +125,8 @@ where result.into() } +/// Apply a bitwise and to two inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. pub fn buffer_bin_and( left: &Buffer, left_offset_in_bits: usize, @@ -149,6 +144,8 @@ pub fn buffer_bin_and( ) } +/// Apply a bitwise or to two inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. pub fn buffer_bin_or( left: &Buffer, left_offset_in_bits: usize, @@ -166,10 +163,27 @@ pub fn buffer_bin_or( ) } -pub fn buffer_unary_not( +/// Apply a bitwise xor to two inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. +pub fn buffer_bin_xor( left: &Buffer, - offset_in_bits: usize, + left_offset_in_bits: usize, + right: &Buffer, + right_offset_in_bits: usize, len_in_bits: usize, ) -> Buffer { + bitwise_bin_op_helper( + left, + left_offset_in_bits, + right, + right_offset_in_bits, + len_in_bits, + |a, b| a ^ b, + ) +} + +/// Apply a bitwise not to one input and return the result as a Buffer. +/// The input is treated as a bitmap, meaning that offset and length are specified in number of bits. +pub fn buffer_unary_not(left: &Buffer, offset_in_bits: usize, len_in_bits: usize) -> Buffer { bitwise_unary_op_helper(left, offset_in_bits, len_in_bits, |a| !a) } diff --git a/arrow-buffer/src/buffer/run.rs b/arrow-buffer/src/buffer/run.rs new file mode 100644 index 000000000000..3dbbe344a025 --- /dev/null +++ b/arrow-buffer/src/buffer/run.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::ScalarBuffer; +use crate::ArrowNativeType; + +/// A slice-able buffer of monotonically increasing, positive integers used to store run-ends +/// +/// # Logical vs Physical +/// +/// A [`RunEndBuffer`] is used to encode runs of the same value, the index of each run is +/// called the physical index. The logical index is then the corresponding index in the logical +/// run-encoded array, i.e. a single run of length `3`, would have the logical indices `0..3`. +/// +/// Each value in [`RunEndBuffer::values`] is the cumulative length of all runs in the +/// logical array, up to that physical index. +/// +/// Consider a [`RunEndBuffer`] containing `[3, 4, 6]`. The maximum physical index is `2`, +/// as there are `3` values, and the maximum logical index is `5`, as the maximum run end +/// is `6`. The physical indices are therefore `[0, 0, 0, 1, 2, 2]` +/// +/// ```text +/// ┌─────────┐ ┌─────────┐ ┌─────────┐ +/// │ 3 │ │ 0 │ ─┬──────▶ │ 0 │ +/// ├─────────┤ ├─────────┤ │ ├─────────┤ +/// │ 4 │ │ 1 │ ─┤ ┌────▶ │ 1 │ +/// ├─────────┤ ├─────────┤ │ │ ├─────────┤ +/// │ 6 │ │ 2 │ ─┘ │ ┌──▶ │ 2 │ +/// └─────────┘ ├─────────┤ │ │ └─────────┘ +/// run ends │ 3 │ ───┘ │ physical indices +/// ├─────────┤ │ +/// │ 4 │ ─────┤ +/// ├─────────┤ │ +/// │ 5 │ ─────┘ +/// └─────────┘ +/// logical indices +/// ``` +/// +/// # Slicing +/// +/// In order to provide zero-copy slicing, this container stores a separate offset and length +/// +/// For example, a [`RunEndBuffer`] containing values `[3, 6, 8]` with offset and length `4` would +/// describe the physical indices `1, 1, 2, 2` +/// +/// For example, a [`RunEndBuffer`] containing values `[6, 8, 9]` with offset `2` and length `5` +/// would describe the physical indices `0, 0, 0, 0, 1` +/// +/// [Run-End encoded layout]: https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout +#[derive(Debug, Clone)] +pub struct RunEndBuffer { + run_ends: ScalarBuffer, + len: usize, + offset: usize, +} + +impl RunEndBuffer +where + E: ArrowNativeType, +{ + /// Create a new [`RunEndBuffer`] from a [`ScalarBuffer`], an `offset` and `len` + /// + /// # Panics + /// + /// - `buffer` does not contain strictly increasing values greater than zero + /// - the last value of `buffer` is less than `offset + len` + pub fn new(run_ends: ScalarBuffer, offset: usize, len: usize) -> Self { + assert!( + run_ends.windows(2).all(|w| w[0] < w[1]), + "run-ends not strictly increasing" + ); + + if len != 0 { + assert!(!run_ends.is_empty(), "non-empty slice but empty run-ends"); + let end = E::from_usize(offset.saturating_add(len)).unwrap(); + assert!( + *run_ends.first().unwrap() > E::usize_as(0), + "run-ends not greater than 0" + ); + assert!( + *run_ends.last().unwrap() >= end, + "slice beyond bounds of run-ends" + ); + } + + Self { + run_ends, + offset, + len, + } + } + + /// Create a new [`RunEndBuffer`] from an [`ScalarBuffer`], an `offset` and `len` + /// + /// # Safety + /// + /// - `buffer` must contain strictly increasing values greater than zero + /// - The last value of `buffer` must be greater than or equal to `offset + len` + pub unsafe fn new_unchecked(run_ends: ScalarBuffer, offset: usize, len: usize) -> Self { + Self { + run_ends, + offset, + len, + } + } + + /// Returns the logical offset into the run-ends stored by this buffer + #[inline] + pub fn offset(&self) -> usize { + self.offset + } + + /// Returns the logical length of the run-ends stored by this buffer + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Returns true if this buffer is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the values of this [`RunEndBuffer`] not including any offset + #[inline] + pub fn values(&self) -> &[E] { + &self.run_ends + } + + /// Returns the maximum run-end encoded in the underlying buffer + #[inline] + pub fn max_value(&self) -> usize { + self.values().last().copied().unwrap_or_default().as_usize() + } + + /// Performs a binary search to find the physical index for the given logical index + /// + /// The result is arbitrary if `logical_index >= self.len()` + pub fn get_physical_index(&self, logical_index: usize) -> usize { + let logical_index = E::usize_as(self.offset + logical_index); + let cmp = |p: &E| p.partial_cmp(&logical_index).unwrap(); + + match self.run_ends.binary_search_by(cmp) { + Ok(idx) => idx + 1, + Err(idx) => idx, + } + } + + /// Returns the physical index at which the logical array starts + pub fn get_start_physical_index(&self) -> usize { + if self.offset == 0 || self.len == 0 { + return 0; + } + // Fallback to binary search + self.get_physical_index(0) + } + + /// Returns the physical index at which the logical array ends + pub fn get_end_physical_index(&self) -> usize { + if self.len == 0 { + return 0; + } + if self.max_value() == self.offset + self.len { + return self.values().len() - 1; + } + // Fallback to binary search + self.get_physical_index(self.len - 1) + } + + /// Slices this [`RunEndBuffer`] by the provided `offset` and `length` + pub fn slice(&self, offset: usize, len: usize) -> Self { + assert!( + offset.saturating_add(len) <= self.len, + "the length + offset of the sliced RunEndBuffer cannot exceed the existing length" + ); + Self { + run_ends: self.run_ends.clone(), + offset: self.offset + offset, + len, + } + } + + /// Returns the inner [`ScalarBuffer`] + pub fn inner(&self) -> &ScalarBuffer { + &self.run_ends + } + + /// Returns the inner [`ScalarBuffer`], consuming self + pub fn into_inner(self) -> ScalarBuffer { + self.run_ends + } +} + +#[cfg(test)] +mod tests { + use crate::buffer::RunEndBuffer; + + #[test] + fn test_zero_length_slice() { + let buffer = RunEndBuffer::new(vec![1_i32, 4_i32].into(), 0, 4); + assert_eq!(buffer.get_start_physical_index(), 0); + assert_eq!(buffer.get_end_physical_index(), 1); + assert_eq!(buffer.get_physical_index(3), 1); + + for offset in 0..4 { + let sliced = buffer.slice(offset, 0); + assert_eq!(sliced.get_start_physical_index(), 0); + assert_eq!(sliced.get_end_physical_index(), 0); + } + + let buffer = RunEndBuffer::new(Vec::::new().into(), 0, 0); + assert_eq!(buffer.get_start_physical_index(), 0); + assert_eq!(buffer.get_end_physical_index(), 0); + } +} diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs new file mode 100644 index 000000000000..f1c2ae785720 --- /dev/null +++ b/arrow-buffer/src/buffer/scalar.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::alloc::Deallocation; +use crate::buffer::Buffer; +use crate::native::ArrowNativeType; +use crate::{BufferBuilder, MutableBuffer, OffsetBuffer}; +use std::fmt::Formatter; +use std::marker::PhantomData; +use std::ops::Deref; + +/// A strongly-typed [`Buffer`] supporting zero-copy cloning and slicing +/// +/// The easiest way to think about `ScalarBuffer` is being equivalent to a `Arc>`, +/// with the following differences: +/// +/// - slicing and cloning is O(1). +/// - it supports external allocated memory +/// +/// ``` +/// # use arrow_buffer::ScalarBuffer; +/// // Zero-copy conversion from Vec +/// let buffer = ScalarBuffer::from(vec![1, 2, 3]); +/// assert_eq!(&buffer, &[1, 2, 3]); +/// +/// // Zero-copy slicing +/// let sliced = buffer.slice(1, 2); +/// assert_eq!(&sliced, &[2, 3]); +/// ``` +#[derive(Clone)] +pub struct ScalarBuffer { + /// Underlying data buffer + buffer: Buffer, + phantom: PhantomData, +} + +impl std::fmt::Debug for ScalarBuffer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("ScalarBuffer").field(&self.as_ref()).finish() + } +} + +impl ScalarBuffer { + /// Create a new [`ScalarBuffer`] from a [`Buffer`], and an `offset` + /// and `length` in units of `T` + /// + /// # Panics + /// + /// This method will panic if + /// + /// * `offset` or `len` would result in overflow + /// * `buffer` is not aligned to a multiple of `std::mem::size_of::` + /// * `bytes` is not large enough for the requested slice + pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { + let size = std::mem::size_of::(); + let byte_offset = offset.checked_mul(size).expect("offset overflow"); + let byte_len = len.checked_mul(size).expect("length overflow"); + buffer.slice_with_length(byte_offset, byte_len).into() + } + + /// Returns a zero-copy slice of this buffer with length `len` and starting at `offset` + pub fn slice(&self, offset: usize, len: usize) -> Self { + Self::new(self.buffer.clone(), offset, len) + } + + /// Returns the inner [`Buffer`] + pub fn inner(&self) -> &Buffer { + &self.buffer + } + + /// Returns the inner [`Buffer`], consuming self + pub fn into_inner(self) -> Buffer { + self.buffer + } + + /// Returns true if this [`ScalarBuffer`] is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + #[inline] + pub fn ptr_eq(&self, other: &Self) -> bool { + self.buffer.ptr_eq(&other.buffer) + } +} + +impl Deref for ScalarBuffer { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + // SAFETY: Verified alignment in From + unsafe { + std::slice::from_raw_parts( + self.buffer.as_ptr() as *const T, + self.buffer.len() / std::mem::size_of::(), + ) + } + } +} + +impl AsRef<[T]> for ScalarBuffer { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +impl From for ScalarBuffer { + fn from(value: MutableBuffer) -> Self { + Buffer::from(value).into() + } +} + +impl From for ScalarBuffer { + fn from(buffer: Buffer) -> Self { + let align = std::mem::align_of::(); + let is_aligned = buffer.as_ptr().align_offset(align) == 0; + + match buffer.deallocation() { + Deallocation::Standard(_) => assert!( + is_aligned, + "Memory pointer is not aligned with the specified scalar type" + ), + Deallocation::Custom(_) => + assert!(is_aligned, "Memory pointer from external source (e.g, FFI) is not aligned with the specified scalar type. Before importing buffer through FFI, please make sure the allocation is aligned."), + } + + Self { + buffer, + phantom: Default::default(), + } + } +} + +impl From> for ScalarBuffer { + fn from(value: OffsetBuffer) -> Self { + value.into_inner() + } +} + +impl From> for ScalarBuffer { + fn from(value: Vec) -> Self { + Self { + buffer: Buffer::from_vec(value), + phantom: Default::default(), + } + } +} + +impl From> for ScalarBuffer { + fn from(mut value: BufferBuilder) -> Self { + let len = value.len(); + Self::new(value.finish(), 0, len) + } +} + +impl FromIterator for ScalarBuffer { + fn from_iter>(iter: I) -> Self { + iter.into_iter().collect::>().into() + } +} + +impl<'a, T: ArrowNativeType> IntoIterator for &'a ScalarBuffer { + type Item = &'a T; + type IntoIter = std::slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.as_ref().iter() + } +} + +impl + ?Sized> PartialEq for ScalarBuffer { + fn eq(&self, other: &S) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl PartialEq> for [T; N] { + fn eq(&self, other: &ScalarBuffer) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl PartialEq> for [T] { + fn eq(&self, other: &ScalarBuffer) -> bool { + self.as_ref().eq(other.as_ref()) + } +} + +impl PartialEq> for Vec { + fn eq(&self, other: &ScalarBuffer) -> bool { + self.as_slice().eq(other.as_ref()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic() { + let expected = [0_i32, 1, 2]; + let buffer = Buffer::from_iter(expected.iter().cloned()); + let typed = ScalarBuffer::::new(buffer.clone(), 0, 3); + assert_eq!(*typed, expected); + + let typed = ScalarBuffer::::new(buffer.clone(), 1, 2); + assert_eq!(*typed, expected[1..]); + + let typed = ScalarBuffer::::new(buffer.clone(), 1, 0); + assert!(typed.is_empty()); + + let typed = ScalarBuffer::::new(buffer, 3, 0); + assert!(typed.is_empty()); + } + + #[test] + fn test_debug() { + let buffer = ScalarBuffer::from(vec![1, 2, 3]); + assert_eq!(format!("{buffer:?}"), "ScalarBuffer([1, 2, 3])"); + } + + #[test] + #[should_panic(expected = "Memory pointer is not aligned with the specified scalar type")] + fn test_unaligned() { + let expected = [0_i32, 1, 2]; + let buffer = Buffer::from_iter(expected.iter().cloned()); + let buffer = buffer.slice(1); + ScalarBuffer::::new(buffer, 0, 2); + } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn test_length_out_of_bounds() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 1, 3); + } + + #[test] + #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] + fn test_offset_out_of_bounds() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 4, 0); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn test_length_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, usize::MAX, 1); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn test_start_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, usize::MAX / 4 + 1, 0); + } + + #[test] + #[should_panic(expected = "length overflow")] + fn test_end_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 0, usize::MAX / 4 + 1); + } + + #[test] + fn convert_from_buffer_builder() { + let input = vec![1, 2, 3, 4]; + let buffer_builder = BufferBuilder::from(input.clone()); + let scalar_buffer = ScalarBuffer::from(buffer_builder); + assert_eq!(scalar_buffer.as_ref(), input); + } +} diff --git a/arrow/src/array/builder/boolean_buffer_builder.rs b/arrow-buffer/src/builder/boolean.rs similarity index 65% rename from arrow/src/array/builder/boolean_buffer_builder.rs rename to arrow-buffer/src/builder/boolean.rs index 5b6d1ce48478..ca178ae5ce4e 100644 --- a/arrow/src/array/builder/boolean_buffer_builder.rs +++ b/arrow-buffer/src/builder/boolean.rs @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::{Buffer, MutableBuffer}; - -use super::Range; - -use crate::util::bit_util; +use crate::{bit_mask, bit_util, BooleanBuffer, Buffer, MutableBuffer}; +use std::ops::Range; +/// Builder for [`BooleanBuffer`] #[derive(Debug)] pub struct BooleanBufferBuilder { buffer: MutableBuffer, @@ -28,6 +26,7 @@ pub struct BooleanBufferBuilder { } impl BooleanBufferBuilder { + /// Creates a new `BooleanBufferBuilder` #[inline] pub fn new(capacity: usize) -> Self { let byte_capacity = bit_util::ceil(capacity, 8); @@ -35,11 +34,24 @@ impl BooleanBufferBuilder { Self { buffer, len: 0 } } + /// Creates a new `BooleanBufferBuilder` from [`MutableBuffer`] of `len` + pub fn new_from_buffer(buffer: MutableBuffer, len: usize) -> Self { + assert!(len <= buffer.len() * 8); + let mut s = Self { + len: buffer.len() * 8, + buffer, + }; + s.truncate(len); + s + } + + /// Returns the length of the buffer #[inline] pub fn len(&self) -> usize { self.len } + /// Sets a bit in the buffer at `index` #[inline] pub fn set_bit(&mut self, index: usize, v: bool) { if v { @@ -49,21 +61,25 @@ impl BooleanBufferBuilder { } } + /// Gets a bit in the buffer at `index` #[inline] pub fn get_bit(&self, index: usize) -> bool { bit_util::get_bit(self.buffer.as_slice(), index) } + /// Returns true if empty #[inline] pub fn is_empty(&self) -> bool { self.len == 0 } + /// Returns the capacity of the buffer #[inline] pub fn capacity(&self) -> usize { self.buffer.capacity() * 8 } + /// Advances the buffer by `additional` bits #[inline] pub fn advance(&mut self, additional: usize) { let new_len = self.len + additional; @@ -74,6 +90,26 @@ impl BooleanBufferBuilder { self.len = new_len; } + /// Truncates the builder to the given length + /// + /// If `len` is greater than the buffer's current length, this has no effect + #[inline] + pub fn truncate(&mut self, len: usize) { + if len > self.len { + return; + } + + let new_len_bytes = bit_util::ceil(len, 8); + self.buffer.truncate(new_len_bytes); + self.len = len; + + let remainder = self.len % 8; + if remainder != 0 { + let mask = (1_u8 << remainder).wrapping_sub(1); + *self.buffer.as_mut().last_mut().unwrap() &= mask; + } + } + /// Reserve space to at least `additional` new bits. /// Capacity will be `>= self.len() + additional`. /// New bytes are uninitialized and reading them is undefined behavior. @@ -91,11 +127,13 @@ impl BooleanBufferBuilder { /// growing it (potentially reallocating it) and writing `false` in the newly available bits. #[inline] pub fn resize(&mut self, len: usize) { - let len_bytes = bit_util::ceil(len, 8); - self.buffer.resize(len_bytes, 0); - self.len = len; + match len.checked_sub(self.len) { + Some(delta) => self.advance(delta), + None => self.truncate(len), + } } + /// Appends a boolean `v` into the buffer #[inline] pub fn append(&mut self, v: bool) { self.advance(1); @@ -104,17 +142,32 @@ impl BooleanBufferBuilder { } } + /// Appends n `additional` bits of value `v` into the buffer #[inline] pub fn append_n(&mut self, additional: usize, v: bool) { - self.advance(additional); - if additional > 0 && v { - let offset = self.len() - additional; - (0..additional).for_each(|i| unsafe { - bit_util::set_bit_raw(self.buffer.as_mut_ptr(), offset + i) - }) + match v { + true => { + let new_len = self.len + additional; + let new_len_bytes = bit_util::ceil(new_len, 8); + let cur_remainder = self.len % 8; + let new_remainder = new_len % 8; + + if cur_remainder != 0 { + // Pad last byte with 1s + *self.buffer.as_slice_mut().last_mut().unwrap() |= !((1 << cur_remainder) - 1) + } + self.buffer.resize(new_len_bytes, 0xFF); + if new_remainder != 0 { + // Clear remaining bits + *self.buffer.as_slice_mut().last_mut().unwrap() &= (1 << new_remainder) - 1 + } + self.len = new_len; + } + false => self.advance(additional), } } + /// Appends a slice of booleans into the buffer #[inline] pub fn append_slice(&mut self, slice: &[bool]) { let additional = slice.len(); @@ -139,7 +192,7 @@ impl BooleanBufferBuilder { let offset_write = self.len; let len = range.end - range.start; self.advance(len); - crate::util::bit_mask::set_bits( + bit_mask::set_bits( self.buffer.as_slice_mut(), to_set, offset_write, @@ -148,16 +201,33 @@ impl BooleanBufferBuilder { ); } + /// Append [`BooleanBuffer`] to this [`BooleanBufferBuilder`] + pub fn append_buffer(&mut self, buffer: &BooleanBuffer) { + let range = buffer.offset()..buffer.offset() + buffer.len(); + self.append_packed_range(range, buffer.values()) + } + /// Returns the packed bits pub fn as_slice(&self) -> &[u8] { self.buffer.as_slice() } + /// Returns the packed bits + pub fn as_slice_mut(&mut self) -> &mut [u8] { + self.buffer.as_slice_mut() + } + + /// Creates a [`BooleanBuffer`] #[inline] - pub fn finish(&mut self) -> Buffer { + pub fn finish(&mut self) -> BooleanBuffer { let buf = std::mem::replace(&mut self.buffer, MutableBuffer::new(0)); - self.len = 0; - buf.into() + let len = std::mem::replace(&mut self.len, 0); + BooleanBuffer::new(buf.into(), 0, len) + } + + /// Builds the [BooleanBuffer] without resetting the builder. + pub fn finish_cloned(&self) -> BooleanBuffer { + BooleanBuffer::new(Buffer::from_slice_ref(self.as_slice()), 0, self.len) } } @@ -168,6 +238,13 @@ impl From for Buffer { } } +impl From for BooleanBuffer { + #[inline] + fn from(builder: BooleanBufferBuilder) -> Self { + BooleanBuffer::new(builder.buffer.into(), 0, builder.len) + } +} + #[cfg(test)] mod tests { use super::*; @@ -182,7 +259,7 @@ mod tests { assert_eq!(4, b.len()); assert_eq!(512, b.capacity()); let buffer = b.finish(); - assert_eq!(1, buffer.len()); + assert_eq!(4, buffer.len()); // Overallocate capacity let mut b = BooleanBufferBuilder::new(8); @@ -190,7 +267,7 @@ mod tests { assert_eq!(4, b.len()); assert_eq!(512, b.capacity()); let buffer = b.finish(); - assert_eq!(1, buffer.len()); + assert_eq!(4, buffer.len()); } #[test] @@ -202,7 +279,7 @@ mod tests { buffer.append(true); buffer.set_bit(0, false); assert_eq!(buffer.len(), 4); - assert_eq!(buffer.finish().as_slice(), &[0b1010_u8]); + assert_eq!(buffer.finish().values(), &[0b1010_u8]); } #[test] @@ -214,7 +291,7 @@ mod tests { buffer.append(true); buffer.set_bit(3, false); assert_eq!(buffer.len(), 4); - assert_eq!(buffer.finish().as_slice(), &[0b0011_u8]); + assert_eq!(buffer.finish().values(), &[0b0011_u8]); } #[test] @@ -226,7 +303,7 @@ mod tests { buffer.append(true); buffer.set_bit(1, false); assert_eq!(buffer.len(), 4); - assert_eq!(buffer.finish().as_slice(), &[0b1001_u8]); + assert_eq!(buffer.finish().values(), &[0b1001_u8]); } #[test] @@ -240,7 +317,7 @@ mod tests { buffer.set_bit(1, false); buffer.set_bit(2, false); assert_eq!(buffer.len(), 5); - assert_eq!(buffer.finish().as_slice(), &[0b10001_u8]); + assert_eq!(buffer.finish().values(), &[0b10001_u8]); } #[test] @@ -251,7 +328,7 @@ mod tests { buffer.set_bit(3, false); buffer.set_bit(9, false); assert_eq!(buffer.len(), 10); - assert_eq!(buffer.finish().as_slice(), &[0b11110110_u8, 0b01_u8]); + assert_eq!(buffer.finish().values(), &[0b11110110_u8, 0b01_u8]); } #[test] @@ -267,7 +344,7 @@ mod tests { buffer.set_bit(14, true); buffer.set_bit(13, false); assert_eq!(buffer.len(), 15); - assert_eq!(buffer.finish().as_slice(), &[0b01010110_u8, 0b1011100_u8]); + assert_eq!(buffer.finish().values(), &[0b01010110_u8, 0b1011100_u8]); } #[test] @@ -332,7 +409,7 @@ mod tests { let start = a.min(b); let end = a.max(b); - buffer.append_packed_range(start..end, compacted_src.as_slice()); + buffer.append_packed_range(start..end, compacted_src.values()); all_bools.extend_from_slice(&src[start..end]); } @@ -362,6 +439,45 @@ mod tests { assert_eq!(builder.as_slice(), &[0b11101111, 0b00000001]); } + #[test] + fn test_truncate() { + let b = MutableBuffer::from_iter([true, true, true, true]); + let mut builder = BooleanBufferBuilder::new_from_buffer(b, 2); + builder.advance(2); + let finished = builder.finish(); + assert_eq!(finished.values(), &[0b00000011]); + + let mut builder = BooleanBufferBuilder::new(10); + builder.append_n(5, true); + builder.resize(3); + builder.advance(2); + let finished = builder.finish(); + assert_eq!(finished.values(), &[0b00000111]); + + let mut builder = BooleanBufferBuilder::new(10); + builder.append_n(16, true); + assert_eq!(builder.as_slice(), &[0xFF, 0xFF]); + builder.truncate(20); + assert_eq!(builder.as_slice(), &[0xFF, 0xFF]); + builder.truncate(14); + assert_eq!(builder.as_slice(), &[0xFF, 0b00111111]); + builder.append(false); + builder.append(true); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111]); + builder.append_packed_range(0..3, &[0xFF]); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111, 0b00000111]); + builder.truncate(17); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111, 0b00000001]); + builder.append_packed_range(0..2, &[2]); + assert_eq!(builder.as_slice(), &[0xFF, 0b10111111, 0b0000101]); + builder.truncate(8); + assert_eq!(builder.as_slice(), &[0xFF]); + builder.resize(14); + assert_eq!(builder.as_slice(), &[0xFF, 0x00]); + builder.truncate(0); + assert_eq!(builder.as_slice(), &[]); + } + #[test] fn test_boolean_builder_increases_buffer_len() { // 00000010 01001000 @@ -377,7 +493,7 @@ mod tests { } let buf2 = builder.finish(); - assert_eq!(buf.len(), buf2.len()); - assert_eq!(buf.as_slice(), buf2.as_slice()); + assert_eq!(buf.len(), buf2.inner().len()); + assert_eq!(buf.as_slice(), buf2.values()); } } diff --git a/arrow/src/array/builder/buffer_builder.rs b/arrow-buffer/src/builder/mod.rs similarity index 58% rename from arrow/src/array/builder/buffer_builder.rs rename to arrow-buffer/src/builder/mod.rs index a6a81dfd6c0e..d5d5a7d3f18d 100644 --- a/arrow/src/array/builder/buffer_builder.rs +++ b/arrow-buffer/src/builder/mod.rs @@ -15,35 +15,34 @@ // specific language governing permissions and limitations // under the License. -use std::mem; +//! Buffer builders -use crate::buffer::{Buffer, MutableBuffer}; -use crate::datatypes::ArrowNativeType; +mod boolean; +pub use boolean::*; +mod null; +pub use null::*; -use super::PhantomData; +use crate::{ArrowNativeType, Buffer, MutableBuffer}; +use std::{iter, marker::PhantomData}; -/// Builder for creating a [`Buffer`](crate::buffer::Buffer) object. +/// Builder for creating a [Buffer] object. /// -/// A [`Buffer`](crate::buffer::Buffer) is the underlying data -/// structure of Arrow's [`Arrays`](crate::array::Array). +/// A [Buffer] is the underlying data structure of Arrow's Arrays. /// /// For all supported types, there are type definitions for the -/// generic version of `BufferBuilder`, e.g. `UInt8BufferBuilder`. +/// generic version of `BufferBuilder`, e.g. `BufferBuilder`. /// /// # Example: /// /// ``` -/// use arrow::array::UInt8BufferBuilder; +/// # use arrow_buffer::builder::BufferBuilder; /// -/// # fn main() -> arrow::error::Result<()> { -/// let mut builder = UInt8BufferBuilder::new(100); +/// let mut builder = BufferBuilder::::new(100); /// builder.append_slice(&[42, 43, 44]); /// builder.append(45); /// let buffer = builder.finish(); /// /// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 43, 44, 45]); -/// # Ok(()) -/// # } /// ``` #[derive(Debug)] pub struct BufferBuilder { @@ -67,15 +66,15 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// /// assert!(builder.capacity() >= 10); /// ``` #[inline] pub fn new(capacity: usize) -> Self { - let buffer = MutableBuffer::new(capacity * mem::size_of::()); + let buffer = MutableBuffer::new(capacity * std::mem::size_of::()); Self { buffer, @@ -84,14 +83,24 @@ impl BufferBuilder { } } + /// Creates a new builder from a [`MutableBuffer`] + pub fn new_from_buffer(buffer: MutableBuffer) -> Self { + let buffer_len = buffer.len(); + Self { + buffer, + len: buffer_len / std::mem::size_of::(), + _marker: PhantomData, + } + } + /// Returns the current number of array elements in the internal buffer. /// /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(42); /// /// assert_eq!(builder.len(), 1); @@ -105,9 +114,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(42); /// /// assert_eq!(builder.is_empty(), false); @@ -136,16 +145,16 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.advance(2); /// /// assert_eq!(builder.len(), 2); /// ``` #[inline] pub fn advance(&mut self, i: usize) { - self.buffer.extend_zeros(i * mem::size_of::()); + self.buffer.extend_zeros(i * std::mem::size_of::()); self.len += i; } @@ -154,16 +163,16 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.reserve(10); /// /// assert!(builder.capacity() >= 20); /// ``` #[inline] pub fn reserve(&mut self, n: usize) { - self.buffer.reserve(n * mem::size_of::()); + self.buffer.reserve(n * std::mem::size_of::()); } /// Appends a value of type `T` into the builder, @@ -172,9 +181,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(42); /// /// assert_eq!(builder.len(), 1); @@ -192,9 +201,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_n(10, 42); /// /// assert_eq!(builder.len(), 10); @@ -202,10 +211,7 @@ impl BufferBuilder { #[inline] pub fn append_n(&mut self, n: usize, v: T) { self.reserve(n); - for _ in 0..n { - self.buffer.push(v); - } - self.len += n; + self.extend(iter::repeat(v).take(n)) } /// Appends `n`, zero-initialized values @@ -213,16 +219,16 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt32BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt32BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_n_zeroed(3); /// /// assert_eq!(builder.len(), 3); /// assert_eq!(builder.as_slice(), &[0, 0, 0]) #[inline] pub fn append_n_zeroed(&mut self, n: usize) { - self.buffer.extend_zeros(n * mem::size_of::()); + self.buffer.extend_zeros(n * std::mem::size_of::()); self.len += n; } @@ -231,9 +237,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_slice(&[42, 44, 46]); /// /// assert_eq!(builder.len(), 3); @@ -247,9 +253,9 @@ impl BufferBuilder { /// View the contents of this buffer as a slice /// /// ``` - /// use arrow::array::Float64BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = Float64BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append(1.3); /// builder.append_n(2, 2.3); /// @@ -270,9 +276,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::Float32BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = Float32BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// /// builder.append_slice(&[1., 2., 3.4]); /// assert_eq!(builder.as_slice(), &[1., 2., 3.4]); @@ -297,9 +303,9 @@ impl BufferBuilder { /// # Example: /// /// ``` - /// use arrow::array::UInt16BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt16BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// /// builder.append_slice(&[42, 44, 46]); /// assert_eq!(builder.as_slice(), &[42, 44, 46]); @@ -312,7 +318,7 @@ impl BufferBuilder { /// ``` #[inline] pub fn truncate(&mut self, len: usize) { - self.buffer.truncate(len * mem::size_of::()); + self.buffer.truncate(len * std::mem::size_of::()); self.len = len; } @@ -327,20 +333,17 @@ impl BufferBuilder { .1 .expect("append_trusted_len_iter expects upper bound"); self.reserve(len); - for v in iter { - self.buffer.push(v) - } - self.len += len; + self.extend(iter); } - /// Resets this builder and returns an immutable [`Buffer`](crate::buffer::Buffer). + /// Resets this builder and returns an immutable [Buffer]. /// /// # Example: /// /// ``` - /// use arrow::array::UInt8BufferBuilder; + /// # use arrow_buffer::builder::BufferBuilder; /// - /// let mut builder = UInt8BufferBuilder::new(10); + /// let mut builder = BufferBuilder::::new(10); /// builder.append_slice(&[42, 44, 46]); /// /// let buffer = builder.finish(); @@ -349,133 +352,67 @@ impl BufferBuilder { /// ``` #[inline] pub fn finish(&mut self) -> Buffer { - let buf = std::mem::replace(&mut self.buffer, MutableBuffer::new(0)); + let buf = std::mem::take(&mut self.buffer); self.len = 0; buf.into() } } -#[cfg(test)] -mod tests { - use crate::array::array::Array; - use crate::array::builder::ArrayBuilder; - use crate::array::Int32BufferBuilder; - use crate::array::Int8Builder; - use crate::array::UInt8BufferBuilder; - - #[test] - fn test_builder_i32_empty() { - let mut b = Int32BufferBuilder::new(5); - assert_eq!(0, b.len()); - assert_eq!(16, b.capacity()); - let a = b.finish(); - assert_eq!(0, a.len()); +impl Default for BufferBuilder { + fn default() -> Self { + Self::new(0) } +} - #[test] - fn test_builder_i32_alloc_zero_bytes() { - let mut b = Int32BufferBuilder::new(0); - b.append(123); - let a = b.finish(); - assert_eq!(4, a.len()); +impl Extend for BufferBuilder { + fn extend>(&mut self, iter: I) { + self.buffer.extend(iter.into_iter().inspect(|_| { + self.len += 1; + })) } +} - #[test] - fn test_builder_i32() { - let mut b = Int32BufferBuilder::new(5); - for i in 0..5 { - b.append(i); - } - assert_eq!(16, b.capacity()); - let a = b.finish(); - assert_eq!(20, a.len()); +impl From> for BufferBuilder { + fn from(value: Vec) -> Self { + Self::new_from_buffer(MutableBuffer::from(value)) } +} - #[test] - fn test_builder_i32_grow_buffer() { - let mut b = Int32BufferBuilder::new(2); - assert_eq!(16, b.capacity()); - for i in 0..20 { - b.append(i); - } - assert_eq!(32, b.capacity()); - let a = b.finish(); - assert_eq!(80, a.len()); +impl FromIterator for BufferBuilder { + fn from_iter>(iter: I) -> Self { + let mut builder = Self::default(); + builder.extend(iter); + builder } +} - #[test] - fn test_builder_finish() { - let mut b = Int32BufferBuilder::new(5); - assert_eq!(16, b.capacity()); - for i in 0..10 { - b.append(i); - } - let mut a = b.finish(); - assert_eq!(40, a.len()); - assert_eq!(0, b.len()); - assert_eq!(0, b.capacity()); - - // Try build another buffer after cleaning up. - for i in 0..20 { - b.append(i) - } - assert_eq!(32, b.capacity()); - a = b.finish(); - assert_eq!(80, a.len()); - } +#[cfg(test)] +mod tests { + use super::*; + use std::mem; #[test] - fn test_reserve() { - let mut b = UInt8BufferBuilder::new(2); - assert_eq!(64, b.capacity()); - b.reserve(64); - assert_eq!(64, b.capacity()); - b.reserve(65); - assert_eq!(128, b.capacity()); - - let mut b = Int32BufferBuilder::new(2); - assert_eq!(16, b.capacity()); - b.reserve(16); - assert_eq!(16, b.capacity()); - b.reserve(17); - assert_eq!(32, b.capacity()); + fn default() { + let builder = BufferBuilder::::default(); + assert!(builder.is_empty()); + assert!(builder.buffer.is_empty()); + assert_eq!(builder.buffer.capacity(), 0); } #[test] - fn test_append_slice() { - let mut b = UInt8BufferBuilder::new(0); - b.append_slice(b"Hello, "); - b.append_slice(b"World!"); - let buffer = b.finish(); - assert_eq!(13, buffer.len()); - - let mut b = Int32BufferBuilder::new(0); - b.append_slice(&[32, 54]); - let buffer = b.finish(); - assert_eq!(8, buffer.len()); + fn from_iter() { + let input = [1u16, 2, 3, 4]; + let builder = input.into_iter().collect::>(); + assert_eq!(builder.len(), 4); + assert_eq!(builder.buffer.len(), 4 * mem::size_of::()); } #[test] - fn test_append_values() { - let mut a = Int8Builder::new(); - a.append_value(1); - a.append_null(); - a.append_value(-2); - assert_eq!(a.len(), 3); - - // append values - let values = &[1, 2, 3, 4]; - let is_valid = &[true, true, false, true]; - a.append_values(values, is_valid); - - assert_eq!(a.len(), 7); - let array = a.finish(); - assert_eq!(array.value(0), 1); - assert!(array.is_null(1)); - assert_eq!(array.value(2), -2); - assert_eq!(array.value(3), 1); - assert_eq!(array.value(4), 2); - assert!(array.is_null(5)); - assert_eq!(array.value(6), 4); + fn extend() { + let input = [1, 2]; + let mut builder = input.into_iter().collect::>(); + assert_eq!(builder.len(), 2); + builder.extend([3, 4]); + assert_eq!(builder.len(), 4); } } diff --git a/arrow/src/array/builder/null_buffer_builder.rs b/arrow-buffer/src/builder/null.rs similarity index 79% rename from arrow/src/array/builder/null_buffer_builder.rs rename to arrow-buffer/src/builder/null.rs index ef2e4c50ab9c..d805b79f09e6 100644 --- a/arrow/src/array/builder/null_buffer_builder.rs +++ b/arrow-buffer/src/builder/null.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::Buffer; - -use super::BooleanBufferBuilder; +use crate::{BooleanBufferBuilder, MutableBuffer, NullBuffer}; /// Builder for creating the null bit buffer. /// This builder only materializes the buffer when we append `false`. @@ -25,7 +23,7 @@ use super::BooleanBufferBuilder; /// `None` when calling [`finish`](#method.finish). /// This optimization is **very** important for the performance. #[derive(Debug)] -pub(super) struct NullBufferBuilder { +pub struct NullBufferBuilder { bitmap_builder: Option, /// Store the length of the buffer before materializing. len: usize, @@ -43,6 +41,29 @@ impl NullBufferBuilder { } } + /// Creates a new builder with given length. + pub fn new_with_len(len: usize) -> Self { + Self { + bitmap_builder: None, + len, + capacity: len, + } + } + + /// Creates a new builder from a `MutableBuffer`. + pub fn new_from_buffer(buffer: MutableBuffer, len: usize) -> Self { + let capacity = buffer.len() * 8; + + assert!(len < capacity); + + let bitmap_builder = Some(BooleanBufferBuilder::new_from_buffer(buffer, len)); + Self { + bitmap_builder, + len, + capacity, + } + } + /// Appends `n` `true`s into the builder /// to indicate that these `n` items are not nulls. #[inline] @@ -106,14 +127,22 @@ impl NullBufferBuilder { /// Builds the null buffer and resets the builder. /// Returns `None` if the builder only contains `true`s. - pub fn finish(&mut self) -> Option { - let buf = self.bitmap_builder.as_mut().map(|b| b.finish()); - self.bitmap_builder = None; + pub fn finish(&mut self) -> Option { self.len = 0; - buf + Some(NullBuffer::new(self.bitmap_builder.take()?.finish())) + } + + /// Builds the [NullBuffer] without resetting the builder. + pub fn finish_cloned(&self) -> Option { + let buffer = self.bitmap_builder.as_ref()?.finish_cloned(); + Some(NullBuffer::new(buffer)) + } + + /// Returns the inner bitmap builder as slice + pub fn as_slice(&self) -> Option<&[u8]> { + Some(self.bitmap_builder.as_ref()?.as_slice()) } - #[inline] fn materialize_if_needed(&mut self) { if self.bitmap_builder.is_none() { self.materialize() @@ -128,6 +157,10 @@ impl NullBufferBuilder { self.bitmap_builder = Some(b); } } + + pub fn as_slice_mut(&mut self) -> Option<&mut [u8]> { + self.bitmap_builder.as_mut().map(|b| b.as_slice_mut()) + } } impl NullBufferBuilder { @@ -158,7 +191,7 @@ mod tests { assert_eq!(6, builder.len()); let buf = builder.finish().unwrap(); - assert_eq!(Buffer::from(&[0b110010_u8]), buf); + assert_eq!(&[0b110010_u8], buf.validity()); } #[test] @@ -170,7 +203,7 @@ mod tests { assert_eq!(6, builder.len()); let buf = builder.finish().unwrap(); - assert_eq!(Buffer::from(&[0b0_u8]), buf); + assert_eq!(&[0b0_u8], buf.validity()); } #[test] @@ -199,6 +232,6 @@ mod tests { builder.append_slice(&[true, true, false, true]); let buf = builder.finish().unwrap(); - assert_eq!(Buffer::from(&[0b1011_u8]), buf); + assert_eq!(&[0b1011_u8], buf.validity()); } } diff --git a/arrow/src/bytes.rs b/arrow-buffer/src/bytes.rs similarity index 76% rename from arrow/src/bytes.rs rename to arrow-buffer/src/bytes.rs index 75137a55295b..81860b604868 100644 --- a/arrow/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -23,15 +23,15 @@ use core::slice; use std::ptr::NonNull; use std::{fmt::Debug, fmt::Formatter}; -use crate::alloc; use crate::alloc::Deallocation; /// A continuous, fixed-size, immutable memory region that knows how to de-allocate itself. +/// /// This structs' API is inspired by the `bytes::Bytes`, but it is not limited to using rust's /// global allocator nor u8 alignment. /// -/// In the most common case, this buffer is allocated using [`allocate_aligned`](crate::alloc::allocate_aligned) -/// and deallocated accordingly [`free_aligned`](crate::alloc::free_aligned). +/// In the most common case, this buffer is allocated using [`alloc`](std::alloc::alloc) +/// with an alignment of [`ALIGNMENT`](crate::alloc::ALIGNMENT) /// /// When the region is allocated by a different allocator, [Deallocation::Custom], this calls the /// custom deallocator to deallocate the region when it is no longer needed. @@ -53,18 +53,14 @@ impl Bytes { /// /// * `ptr` - Pointer to raw parts /// * `len` - Length of raw parts in **bytes** - /// * `capacity` - Total allocated memory for the pointer `ptr`, in **bytes** + /// * `deallocation` - Type of allocation /// /// # Safety /// /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. #[inline] - pub(crate) unsafe fn new( - ptr: std::ptr::NonNull, - len: usize, - deallocation: Deallocation, - ) -> Bytes { + pub(crate) unsafe fn new(ptr: NonNull, len: usize, deallocation: Deallocation) -> Bytes { Bytes { ptr, len, @@ -93,12 +89,17 @@ impl Bytes { pub fn capacity(&self) -> usize { match self.deallocation { - Deallocation::Arrow(capacity) => capacity, + Deallocation::Standard(layout) => layout.size(), // we cannot determine this in general, // and thus we state that this is externally-owned memory Deallocation::Custom(_) => 0, } } + + #[inline] + pub(crate) fn deallocation(&self) -> &Deallocation { + &self.deallocation + } } // Deallocation is Send + Sync, repeating the bound here makes that refactoring safe @@ -110,9 +111,10 @@ impl Drop for Bytes { #[inline] fn drop(&mut self) { match &self.deallocation { - Deallocation::Arrow(capacity) => { - unsafe { alloc::free_aligned::(self.ptr, *capacity) }; - } + Deallocation::Standard(layout) => match layout.size() { + 0 => {} // Nothing to do + _ => unsafe { std::alloc::dealloc(self.ptr.as_ptr(), *layout) }, + }, // The automatic drop implementation will free the memory once the reference count reaches zero Deallocation::Custom(_allocation) => (), } @@ -142,3 +144,31 @@ impl Debug for Bytes { write!(f, " }}") } } + +impl From for Bytes { + fn from(value: bytes::Bytes) -> Self { + Self { + len: value.len(), + ptr: NonNull::new(value.as_ptr() as _).unwrap(), + deallocation: Deallocation::Custom(std::sync::Arc::new(value)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_bytes() { + let bytes = bytes::Bytes::from(vec![1, 2, 3, 4]); + let arrow_bytes: Bytes = bytes.clone().into(); + + assert_eq!(bytes.as_ptr(), arrow_bytes.as_ptr()); + + drop(bytes); + drop(arrow_bytes); + + let _ = Bytes::from(bytes::Bytes::new()); + } +} diff --git a/arrow-buffer/src/lib.rs b/arrow-buffer/src/lib.rs new file mode 100644 index 000000000000..cbcdb979e693 --- /dev/null +++ b/arrow-buffer/src/lib.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Low-level buffer abstractions for [Apache Arrow Rust](https://docs.rs/arrow) + +pub mod alloc; +pub mod buffer; +pub use buffer::*; + +pub mod builder; +pub use builder::*; + +mod bigint; +mod bytes; +mod native; +pub use bigint::i256; + +pub use native::*; +mod util; +pub use util::*; diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs new file mode 100644 index 000000000000..38074a8dc26c --- /dev/null +++ b/arrow-buffer/src/native.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::i256; +use half::f16; + +mod private { + pub trait Sealed {} +} + +/// Trait expressing a Rust type that has the same in-memory representation +/// as Arrow. This includes `i16`, `f32`, but excludes `bool` (which in arrow is represented in bits). +/// +/// In little endian machines, types that implement [`ArrowNativeType`] can be memcopied to arrow buffers +/// as is. +/// +/// # Transmute Safety +/// +/// A type T implementing this trait means that any arbitrary slice of bytes of length and +/// alignment `size_of::()` can be safely interpreted as a value of that type without +/// being unsound, i.e. potentially resulting in undefined behaviour. +/// +/// Note: in the case of floating point numbers this transmutation can result in a signalling +/// NaN, which, whilst sound, can be unwieldy. In general, whilst it is perfectly sound to +/// reinterpret bytes as different types using this trait, it is likely unwise. For more information +/// see [f32::from_bits] and [f64::from_bits]. +/// +/// Note: `bool` is restricted to `0` or `1`, and so `bool: !ArrowNativeType` +/// +/// # Sealed +/// +/// Due to the above restrictions, this trait is sealed to prevent accidental misuse +pub trait ArrowNativeType: + std::fmt::Debug + Send + Sync + Copy + PartialOrd + Default + private::Sealed + 'static +{ + /// Convert native integer type from usize + /// + /// Returns `None` if [`Self`] is not an integer or conversion would result + /// in truncation/overflow + fn from_usize(_: usize) -> Option; + + /// Convert to usize according to the [`as`] operator + /// + /// [`as`]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast + fn as_usize(self) -> usize; + + /// Convert from usize according to the [`as`] operator + /// + /// [`as`]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast + fn usize_as(i: usize) -> Self; + + /// Convert native type to usize. + /// + /// Returns `None` if [`Self`] is not an integer or conversion would result + /// in truncation/overflow + fn to_usize(self) -> Option; + + /// Convert native type to isize. + /// + /// Returns `None` if [`Self`] is not an integer or conversion would result + /// in truncation/overflow + fn to_isize(self) -> Option; + + /// Convert native type from i32. + /// + /// Returns `None` if [`Self`] is not `i32` + #[deprecated(note = "please use `Option::Some` instead")] + fn from_i32(_: i32) -> Option { + None + } + + /// Convert native type from i64. + /// + /// Returns `None` if [`Self`] is not `i64` + #[deprecated(note = "please use `Option::Some` instead")] + fn from_i64(_: i64) -> Option { + None + } + + /// Convert native type from i128. + /// + /// Returns `None` if [`Self`] is not `i128` + #[deprecated(note = "please use `Option::Some` instead")] + fn from_i128(_: i128) -> Option { + None + } +} + +macro_rules! native_integer { + ($t: ty $(, $from:ident)*) => { + impl private::Sealed for $t {} + impl ArrowNativeType for $t { + #[inline] + fn from_usize(v: usize) -> Option { + v.try_into().ok() + } + + #[inline] + fn to_usize(self) -> Option { + self.try_into().ok() + } + + #[inline] + fn to_isize(self) -> Option { + self.try_into().ok() + } + + #[inline] + fn as_usize(self) -> usize { + self as _ + } + + #[inline] + fn usize_as(i: usize) -> Self { + i as _ + } + + + $( + #[inline] + fn $from(v: $t) -> Option { + Some(v) + } + )* + } + }; +} + +native_integer!(i8); +native_integer!(i16); +native_integer!(i32, from_i32); +native_integer!(i64, from_i64); +native_integer!(i128, from_i128); +native_integer!(u8); +native_integer!(u16); +native_integer!(u32); +native_integer!(u64); + +macro_rules! native_float { + ($t:ty, $s:ident, $as_usize: expr, $i:ident, $usize_as: expr) => { + impl private::Sealed for $t {} + impl ArrowNativeType for $t { + #[inline] + fn from_usize(_: usize) -> Option { + None + } + + #[inline] + fn to_usize(self) -> Option { + None + } + + #[inline] + fn to_isize(self) -> Option { + None + } + + #[inline] + fn as_usize($s) -> usize { + $as_usize + } + + #[inline] + fn usize_as($i: usize) -> Self { + $usize_as + } + } + }; +} + +native_float!(f16, self, self.to_f32() as _, i, f16::from_f32(i as _)); +native_float!(f32, self, self as _, i, i as _); +native_float!(f64, self, self as _, i, i as _); + +impl private::Sealed for i256 {} +impl ArrowNativeType for i256 { + fn from_usize(u: usize) -> Option { + Some(Self::from_parts(u as u128, 0)) + } + + fn as_usize(self) -> usize { + self.to_parts().0 as usize + } + + fn usize_as(i: usize) -> Self { + Self::from_parts(i as u128, 0) + } + + fn to_usize(self) -> Option { + let (low, high) = self.to_parts(); + if high != 0 { + return None; + } + low.try_into().ok() + } + + fn to_isize(self) -> Option { + self.to_i128()?.try_into().ok() + } +} + +/// Allows conversion from supported Arrow types to a byte slice. +pub trait ToByteSlice { + /// Converts this instance into a byte slice + fn to_byte_slice(&self) -> &[u8]; +} + +impl ToByteSlice for [T] { + #[inline] + fn to_byte_slice(&self) -> &[u8] { + let raw_ptr = self.as_ptr() as *const u8; + unsafe { std::slice::from_raw_parts(raw_ptr, std::mem::size_of_val(self)) } + } +} + +impl ToByteSlice for T { + #[inline] + fn to_byte_slice(&self) -> &[u8] { + let raw_ptr = self as *const T as *const u8; + unsafe { std::slice::from_raw_parts(raw_ptr, std::mem::size_of::()) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_i256() { + let a = i256::from_parts(0, 0); + assert_eq!(a.as_usize(), 0); + assert_eq!(a.to_usize().unwrap(), 0); + assert_eq!(a.to_isize().unwrap(), 0); + + let a = i256::from_parts(0, -1); + assert_eq!(a.as_usize(), 0); + assert!(a.to_usize().is_none()); + assert!(a.to_usize().is_none()); + + let a = i256::from_parts(u128::MAX, -1); + assert_eq!(a.as_usize(), usize::MAX); + assert!(a.to_usize().is_none()); + assert_eq!(a.to_isize().unwrap(), -1); + } +} diff --git a/arrow/src/util/bit_chunk_iterator.rs b/arrow-buffer/src/util/bit_chunk_iterator.rs similarity index 94% rename from arrow/src/util/bit_chunk_iterator.rs rename to arrow-buffer/src/util/bit_chunk_iterator.rs index f0127ed2267f..9e4fb8268dff 100644 --- a/arrow/src/util/bit_chunk_iterator.rs +++ b/arrow-buffer/src/util/bit_chunk_iterator.rs @@ -60,8 +60,7 @@ impl<'a> UnalignedBitChunk<'a> { // If less than 8 bytes, read into prefix if buffer.len() <= 8 { - let (suffix_mask, trailing_padding) = - compute_suffix_mask(len, offset_padding); + let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding); let prefix = read_u64(buffer) & suffix_mask & prefix_mask; return Self { @@ -75,8 +74,7 @@ impl<'a> UnalignedBitChunk<'a> { // If less than 16 bytes, read into prefix and suffix if buffer.len() <= 16 { - let (suffix_mask, trailing_padding) = - compute_suffix_mask(len, offset_padding); + let (suffix_mask, trailing_padding) = compute_suffix_mask(len, offset_padding); let prefix = read_u64(&buffer[..8]) & prefix_mask; let suffix = read_u64(&buffer[8..]) & suffix_mask; @@ -153,11 +151,11 @@ impl<'a> UnalignedBitChunk<'a> { self.chunks } - pub(crate) fn iter(&self) -> UnalignedBitChunkIterator<'a> { + pub fn iter(&self) -> UnalignedBitChunkIterator<'a> { self.prefix .into_iter() .chain(self.chunks.iter().cloned()) - .chain(self.suffix.into_iter()) + .chain(self.suffix) } /// Counts the number of ones @@ -166,11 +164,8 @@ impl<'a> UnalignedBitChunk<'a> { } } -pub(crate) type UnalignedBitChunkIterator<'a> = std::iter::Chain< - std::iter::Chain< - std::option::IntoIter, - std::iter::Cloned>, - >, +pub type UnalignedBitChunkIterator<'a> = std::iter::Chain< + std::iter::Chain, std::iter::Cloned>>, std::option::IntoIter, >; @@ -178,7 +173,7 @@ pub(crate) type UnalignedBitChunkIterator<'a> = std::iter::Chain< fn read_u64(input: &[u8]) -> u64 { let len = input.len().min(8); let mut buf = [0_u8; 8]; - (&mut buf[..len]).copy_from_slice(input); + buf[..len].copy_from_slice(input); u64::from_le_bytes(buf) } @@ -296,6 +291,12 @@ impl<'a> BitChunks<'a> { index: 0, } } + + /// Returns an iterator over chunks of 64 bits, with the remaining bits zero padded to 64-bits + #[inline] + pub fn iter_padded(&self) -> impl Iterator + 'a { + self.iter().chain(std::iter::once(self.remainder_bits())) + } } impl<'a> IntoIterator for BitChunks<'a> { @@ -332,9 +333,8 @@ impl Iterator for BitChunkIterator<'_> { } else { // the constructor ensures that bit_offset is in 0..8 // that means we need to read at most one additional byte to fill in the high bits - let next = unsafe { - std::ptr::read_unaligned(raw_data.add(index + 1) as *const u8) as u64 - }; + let next = + unsafe { std::ptr::read_unaligned(raw_data.add(index + 1) as *const u8) as u64 }; (current >> bit_offset) | (next << (64 - bit_offset)) }; @@ -381,8 +381,8 @@ mod tests { #[test] fn test_iter_unaligned() { let input: &[u8] = &[ - 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, - 0b00100000, 0b01000000, 0b11111111, + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -402,8 +402,8 @@ mod tests { #[test] fn test_iter_unaligned_remainder_1_byte() { let input: &[u8] = &[ - 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, - 0b00100000, 0b01000000, 0b11111111, + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -436,8 +436,8 @@ mod tests { #[test] fn test_iter_unaligned_remainder_bits_large() { let input: &[u8] = &[ - 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, - 0b11111111, 0b00000000, 0b11111111, + 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, 0b00000000, 0b11111111, + 0b00000000, 0b11111111, ]; let buffer: Buffer = Buffer::from(input); @@ -631,11 +631,8 @@ mod tests { let max_truncate = 128.min(mask_len - offset); let truncate = rng.gen::().checked_rem(max_truncate).unwrap_or(0); - let unaligned = UnalignedBitChunk::new( - buffer.as_slice(), - offset, - mask_len - offset - truncate, - ); + let unaligned = + UnalignedBitChunk::new(buffer.as_slice(), offset, mask_len - offset - truncate); let bool_slice = &bools[offset..mask_len - truncate]; diff --git a/arrow/src/util/bit_iterator.rs b/arrow-buffer/src/util/bit_iterator.rs similarity index 53% rename from arrow/src/util/bit_iterator.rs rename to arrow-buffer/src/util/bit_iterator.rs index bba9dac60a4b..df40a8fbaccb 100644 --- a/arrow/src/util/bit_iterator.rs +++ b/arrow-buffer/src/util/bit_iterator.rs @@ -15,7 +15,73 @@ // specific language governing permissions and limitations // under the License. -use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator}; +//! Types for iterating over packed bitmasks + +use crate::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator}; +use crate::bit_util::{ceil, get_bit_raw}; + +/// Iterator over the bits within a packed bitmask +/// +/// To efficiently iterate over just the set bits see [`BitIndexIterator`] and [`BitSliceIterator`] +pub struct BitIterator<'a> { + buffer: &'a [u8], + current_offset: usize, + end_offset: usize, +} + +impl<'a> BitIterator<'a> { + /// Create a new [`BitIterator`] from the provided `buffer`, + /// and `offset` and `len` in bits + /// + /// # Panic + /// + /// Panics if `buffer` is too short for the provided offset and length + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { + let end_offset = offset.checked_add(len).unwrap(); + let required_len = ceil(end_offset, 8); + assert!( + buffer.len() >= required_len, + "BitIterator buffer too small, expected {required_len} got {}", + buffer.len() + ); + + Self { + buffer, + current_offset: offset, + end_offset, + } + } +} + +impl<'a> Iterator for BitIterator<'a> { + type Item = bool; + + fn next(&mut self) -> Option { + if self.current_offset == self.end_offset { + return None; + } + // Safety: + // offsets in bounds + let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.current_offset) }; + self.current_offset += 1; + Some(v) + } +} + +impl<'a> ExactSizeIterator for BitIterator<'a> {} + +impl<'a> DoubleEndedIterator for BitIterator<'a> { + fn next_back(&mut self) -> Option { + if self.current_offset == self.end_offset { + return None; + } + self.end_offset -= 1; + // Safety: + // offsets in bounds + let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.end_offset) }; + Some(v) + } +} /// Iterator of contiguous ranges of set bits within a provided packed bitmask /// @@ -31,7 +97,7 @@ pub struct BitSliceIterator<'a> { } impl<'a> BitSliceIterator<'a> { - /// Create a new [`BitSliceIterator`] from the provide `buffer`, + /// Create a new [`BitSliceIterator`] from the provided `buffer`, /// and `offset` and `len` in bits pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { let chunk = UnalignedBitChunk::new(buffer, offset, len); @@ -157,4 +223,72 @@ impl<'a> Iterator for BitIndexIterator<'a> { } } -// Note: tests located in filter module +/// Calls the provided closure for each index in the provided null mask that is set, +/// using an adaptive strategy based on the null count +/// +/// Ideally this would be encapsulated in an [`Iterator`] that would determine the optimal +/// strategy up front, and then yield indexes based on this. +/// +/// Unfortunately, external iteration based on the resulting [`Iterator`] would match the strategy +/// variant on each call to [`Iterator::next`], and LLVM generally cannot eliminate this. +/// +/// One solution to this might be internal iteration, e.g. [`Iterator::try_fold`], however, +/// it is currently [not possible] to override this for custom iterators in stable Rust. +/// +/// As such this is the next best option +/// +/// [not possible]: https://github.com/rust-lang/rust/issues/69595 +#[inline] +pub fn try_for_each_valid_idx Result<(), E>>( + len: usize, + offset: usize, + null_count: usize, + nulls: Option<&[u8]>, + f: F, +) -> Result<(), E> { + let valid_count = len - null_count; + + if valid_count == len { + (0..len).try_for_each(f) + } else if null_count != len { + BitIndexIterator::new(nulls.unwrap(), offset, len).try_for_each(f) + } else { + Ok(()) + } +} + +// Note: further tests located in arrow_select::filter module + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bit_iterator() { + let mask = &[0b00010010, 0b00100011, 0b00000101, 0b00010001, 0b10010011]; + let actual: Vec<_> = BitIterator::new(mask, 0, 5).collect(); + assert_eq!(actual, &[false, true, false, false, true]); + + let actual: Vec<_> = BitIterator::new(mask, 4, 5).collect(); + assert_eq!(actual, &[true, false, false, false, true]); + + let actual: Vec<_> = BitIterator::new(mask, 12, 14).collect(); + assert_eq!( + actual, + &[ + false, true, false, false, true, false, true, false, false, false, false, false, + true, false + ] + ); + + assert_eq!(BitIterator::new(mask, 0, 0).count(), 0); + assert_eq!(BitIterator::new(mask, 40, 0).count(), 0); + } + + #[test] + #[should_panic(expected = "BitIterator buffer too small, expected 3 got 2")] + fn test_bit_iterator_bounds() { + let mask = &[223, 23]; + BitIterator::new(mask, 17, 0); + } +} diff --git a/arrow/src/util/bit_mask.rs b/arrow-buffer/src/util/bit_mask.rs similarity index 85% rename from arrow/src/util/bit_mask.rs rename to arrow-buffer/src/util/bit_mask.rs index da542a2bb1f9..8f81cb7d0469 100644 --- a/arrow/src/util/bit_mask.rs +++ b/arrow-buffer/src/util/bit_mask.rs @@ -17,8 +17,8 @@ //! Utils for working with packed bit masks -use crate::util::bit_chunk_iterator::BitChunks; -use crate::util::bit_util::{ceil, get_bit, set_bit}; +use crate::bit_chunk_iterator::BitChunks; +use crate::bit_util::{ceil, get_bit, set_bit}; /// Sets all bits on `write_data` in the range `[offset_write..offset_write+len]` to be equal to the /// bits in `data` in the range `[offset_read..offset_read+len]` @@ -42,8 +42,7 @@ pub fn set_bits( let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align); chunks.iter().for_each(|chunk| { null_count += chunk.count_zeros(); - write_data[write_byte_index..write_byte_index + 8] - .copy_from_slice(&chunk.to_le_bytes()); + write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes()); write_byte_index += 8; }); @@ -70,8 +69,8 @@ mod tests { fn test_set_bits_aligned() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, ]; let destination_offset = 8; @@ -80,8 +79,8 @@ mod tests { let len = 64; let expected_data: &[u8] = &[ - 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0, + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, 0, ]; let expected_null_count = 24; let result = set_bits( @@ -100,8 +99,8 @@ mod tests { fn test_set_bits_unaligned_destination_start() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, ]; let destination_offset = 3; @@ -110,8 +109,8 @@ mod tests { let len = 64; let expected_data: &[u8] = &[ - 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, - 0b00111110, 0b00101111, 0b00000101, 0b00000000, + 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110, + 0b00101111, 0b00000101, 0b00000000, ]; let expected_null_count = 24; let result = set_bits( @@ -130,8 +129,8 @@ mod tests { fn test_set_bits_unaligned_destination_end() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, ]; let destination_offset = 8; @@ -140,8 +139,8 @@ mod tests { let len = 62; let expected_data: &[u8] = &[ - 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b00100101, 0, + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b00100101, 0, ]; let expected_null_count = 23; let result = set_bits( @@ -160,9 +159,9 @@ mod tests { fn test_set_bits_unaligned() { let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, + 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101, + 0b10011001, 0b11011011, 0b11101011, 0b11000011, ]; let destination_offset = 3; @@ -171,9 +170,8 @@ mod tests { let len = 95; let expected_data: &[u8] = &[ - 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, - 0b01111001, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, - 0b00000001, + 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001, + 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001, ]; let expected_null_count = 35; let result = set_bits( diff --git a/arrow/src/util/bit_util.rs b/arrow-buffer/src/util/bit_util.rs similarity index 82% rename from arrow/src/util/bit_util.rs rename to arrow-buffer/src/util/bit_util.rs index 5752c5df972e..d2dbf3c84882 100644 --- a/arrow/src/util/bit_util.rs +++ b/arrow-buffer/src/util/bit_util.rs @@ -17,10 +17,6 @@ //! Utils for working with bits -use num::Integer; -#[cfg(feature = "simd")] -use packed_simd::u8x64; - const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; const UNSET_BIT_MASK: [u8; 8] = [ 255 - 1, @@ -102,34 +98,16 @@ pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) { pub fn ceil(value: usize, divisor: usize) -> usize { // Rewrite as `value.div_ceil(&divisor)` after // https://github.com/rust-lang/rust/issues/88581 is merged. - Integer::div_ceil(&value, &divisor) + value / divisor + (0 != value % divisor) as usize } -/// Performs SIMD bitwise binary operations. -/// -/// # Safety -/// -/// Note that each slice should be 64 bytes and it is the callers responsibility to ensure -/// that this is the case. If passed slices larger than 64 bytes the operation will only -/// be performed on the first 64 bytes. Slices less than 64 bytes will panic. -#[cfg(feature = "simd")] -pub unsafe fn bitwise_bin_op_simd(left: &[u8], right: &[u8], result: &mut [u8], op: F) -where - F: Fn(u8x64, u8x64) -> u8x64, -{ - let left_simd = u8x64::from_slice_unaligned_unchecked(left); - let right_simd = u8x64::from_slice_unaligned_unchecked(right); - let simd_result = op(left_simd, right_simd); - simd_result.write_to_slice_unaligned_unchecked(result); -} - -#[cfg(all(test, feature = "test_utils"))] +#[cfg(test)] mod tests { use std::collections::HashSet; use super::*; - use crate::util::test_util::seedable_rng; - use rand::Rng; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; #[test] fn test_round_upto_multiple_of_64() { @@ -168,10 +146,14 @@ mod tests { assert!(!get_bit(&[0b01001001, 0b01010010], 15)); } + pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) + } + #[test] fn test_get_bit_raw() { const NUM_BYTE: usize = 10; - let mut buf = vec![0; NUM_BYTE]; + let mut buf = [0; NUM_BYTE]; let mut expected = vec![]; let mut rng = seedable_rng(); for i in 0..8 * NUM_BYTE { @@ -279,7 +261,6 @@ mod tests { } #[test] - #[cfg(all(any(target_arch = "x86", target_arch = "x86_64")))] fn test_ceil() { assert_eq!(ceil(0, 1), 0); assert_eq!(ceil(1, 1), 1); @@ -293,28 +274,4 @@ mod tests { assert_eq!(ceil(10, 10000000000), 1); assert_eq!(ceil(10000000000, 1000000000), 10); } - - #[test] - #[cfg(feature = "simd")] - fn test_bitwise_and_simd() { - let buf1 = [0b00110011u8; 64]; - let buf2 = [0b11110000u8; 64]; - let mut buf3 = [0b00000000; 64]; - unsafe { bitwise_bin_op_simd(&buf1, &buf2, &mut buf3, |a, b| a & b) }; - for i in buf3.iter() { - assert_eq!(&0b00110000u8, i); - } - } - - #[test] - #[cfg(feature = "simd")] - fn test_bitwise_or_simd() { - let buf1 = [0b00110011u8; 64]; - let buf2 = [0b11110000u8; 64]; - let mut buf3 = [0b00000000; 64]; - unsafe { bitwise_bin_op_simd(&buf1, &buf2, &mut buf3, |a, b| a | b) }; - for i in buf3.iter() { - assert_eq!(&0b11110011u8, i); - } - } } diff --git a/arrow-buffer/src/util/mod.rs b/arrow-buffer/src/util/mod.rs new file mode 100644 index 000000000000..9023fe4a035d --- /dev/null +++ b/arrow-buffer/src/util/mod.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod bit_chunk_iterator; +pub mod bit_iterator; +pub mod bit_mask; +pub mod bit_util; diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml new file mode 100644 index 000000000000..19b857297d14 --- /dev/null +++ b/arrow-cast/Cargo.toml @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-cast" +version = { workspace = true } +description = "Cast kernel and utilities for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_cast" +path = "src/lib.rs" +bench = false + +[package.metadata.docs.rs] +features = ["prettyprint"] + +[features] +prettyprint = ["comfy-table"] + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +arrow-select = { workspace = true } +chrono = { workspace = true } +half = { version = "2.1", default-features = false } +num = { version = "0.4", default-features = false, features = ["std"] } +lexical-core = { version = "^0.8", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } +comfy-table = { version = "7.0", optional = true, default-features = false } +base64 = "0.21" + +[dev-dependencies] +criterion = { version = "0.5", default-features = false } +half = { version = "2.1", default-features = false } +rand = "0.8" + +[build-dependencies] + +[[bench]] +name = "parse_timestamp" +harness = false + +[[bench]] +name = "parse_time" +harness = false + +[[bench]] +name = "parse_date" +harness = false + +[[bench]] +name = "parse_decimal" +harness = false diff --git a/arrow-cast/benches/parse_date.rs b/arrow-cast/benches/parse_date.rs new file mode 100644 index 000000000000..e05d38d2f853 --- /dev/null +++ b/arrow-cast/benches/parse_date.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Date32Type; +use arrow_cast::parse::Parser; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let timestamps = ["2020-09-08", "2020-9-8", "2020-09-8", "2020-9-08"]; + + for timestamp in timestamps { + let t = black_box(timestamp); + c.bench_function(t, |b| { + b.iter(|| Date32Type::parse(t).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/benches/parse_decimal.rs b/arrow-cast/benches/parse_decimal.rs new file mode 100644 index 000000000000..5682859dd25a --- /dev/null +++ b/arrow-cast/benches/parse_decimal.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Decimal256Type; +use arrow_cast::parse::parse_decimal; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let decimals = [ + "123.123", + "123.1234", + "123.1", + "123", + "-123.123", + "-123.1234", + "-123.1", + "-123", + "0.0000123", + "12.", + "-12.", + "00.1", + "-00.1", + "12345678912345678.1234", + "-12345678912345678.1234", + "99999999999999999.999", + "-99999999999999999.999", + ".123", + "-.123", + "123.", + "-123.", + ]; + + for decimal in decimals { + let d = black_box(decimal); + c.bench_function(d, |b| { + b.iter(|| parse_decimal::(d, 20, 3).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/benches/parse_time.rs b/arrow-cast/benches/parse_time.rs new file mode 100644 index 000000000000..d28b9c7c613d --- /dev/null +++ b/arrow-cast/benches/parse_time.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_cast::parse::string_to_time_nanoseconds; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let timestamps = [ + "9:50", + "09:50", + "09:50 PM", + "9:50:12 AM", + "09:50:12 PM", + "09:50:12.123456789", + "9:50:12.123456789", + "09:50:12.123456789 PM", + ]; + + for timestamp in timestamps { + let t = black_box(timestamp); + c.bench_function(t, |b| { + b.iter(|| string_to_time_nanoseconds(t).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow/src/util/serialization.rs b/arrow-cast/benches/parse_timestamp.rs similarity index 51% rename from arrow/src/util/serialization.rs rename to arrow-cast/benches/parse_timestamp.rs index 14d67ca117c4..d3ab41863e70 100644 --- a/arrow/src/util/serialization.rs +++ b/arrow-cast/benches/parse_timestamp.rs @@ -15,19 +15,30 @@ // specific language governing permissions and limitations // under the License. -/// Converts numeric type to a `String` -pub fn lexical_to_string(n: N) -> String { - let mut buf = Vec::::with_capacity(N::FORMATTED_SIZE_DECIMAL); - unsafe { - // JUSTIFICATION - // Benefit - // Allows using the faster serializer lexical core and convert to string - // Soundness - // Length of buf is set as written length afterwards. lexical_core - // creates a valid string, so doesn't need to be checked. - let slice = std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); - let len = lexical_core::write(n, slice).len(); - buf.set_len(len); - String::from_utf8_unchecked(buf) +use arrow_cast::parse::string_to_timestamp_nanos; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let timestamps = [ + "2020-09-08", + "2020-09-08T13:42:29", + "2020-09-08T13:42:29.190", + "2020-09-08T13:42:29.190855", + "2020-09-08T13:42:29.190855999", + "2020-09-08T13:42:29+00:00", + "2020-09-08T13:42:29.190+00:00", + "2020-09-08T13:42:29.190855+00:00", + "2020-09-08T13:42:29.190855999-05:00", + "2020-09-08T13:42:29.190855Z", + ]; + + for timestamp in timestamps { + let t = black_box(timestamp); + c.bench_function(t, |b| { + b.iter(|| string_to_timestamp_nanos(t).unwrap()); + }); } } + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/src/base64.rs b/arrow-cast/src/base64.rs new file mode 100644 index 000000000000..e109c8112480 --- /dev/null +++ b/arrow-cast/src/base64.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functions for Base64 encoding/decoding + +use arrow_array::{Array, GenericBinaryArray, GenericStringArray, OffsetSizeTrait}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::ArrowError; +use base64::encoded_len; +use base64::engine::Config; + +pub use base64::prelude::*; + +/// Bas64 encode each element of `array` with the provided `engine` +pub fn b64_encode( + engine: &E, + array: &GenericBinaryArray, +) -> GenericStringArray { + let lengths = array.offsets().windows(2).map(|w| { + let len = w[1].as_usize() - w[0].as_usize(); + encoded_len(len, engine.config().encode_padding()).unwrap() + }); + let offsets = OffsetBuffer::::from_lengths(lengths); + let buffer_len = offsets.last().unwrap().as_usize(); + let mut buffer = vec![0_u8; buffer_len]; + let mut offset = 0; + + for i in 0..array.len() { + let len = engine + .encode_slice(array.value(i), &mut buffer[offset..]) + .unwrap(); + offset += len; + } + assert_eq!(offset, buffer_len); + + // Safety: Base64 is valid UTF-8 + unsafe { GenericStringArray::new_unchecked(offsets, buffer.into(), array.nulls().cloned()) } +} + +/// Base64 decode each element of `array` with the provided `engine` +pub fn b64_decode( + engine: &E, + array: &GenericBinaryArray, +) -> Result, ArrowError> { + let estimated_len = array.values().len(); // This is an overestimate + let mut buffer = vec![0; estimated_len]; + + let mut offsets = Vec::with_capacity(array.len() + 1); + offsets.push(O::usize_as(0)); + let mut offset = 0; + + for v in array.iter() { + if let Some(v) = v { + let len = engine.decode_slice(v, &mut buffer[offset..]).unwrap(); + // This cannot overflow as `len` is less than `v.len()` and `a` is valid + offset += len; + } + offsets.push(O::usize_as(offset)); + } + + // Safety: offsets monotonically increasing by construction + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Ok(GenericBinaryArray::new( + offsets, + buffer.into(), + array.nulls().cloned(), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::BinaryArray; + use base64::prelude::{BASE64_STANDARD, BASE64_STANDARD_NO_PAD}; + use rand::{thread_rng, Rng}; + + fn test_engine(e: &E, a: &BinaryArray) { + let encoded = b64_encode(e, a); + encoded.to_data().validate_full().unwrap(); + + let to_decode = encoded.into(); + let decoded = b64_decode(e, &to_decode).unwrap(); + decoded.to_data().validate_full().unwrap(); + + assert_eq!(&decoded, a); + } + + #[test] + fn test_b64() { + let mut rng = thread_rng(); + let len = rng.gen_range(1024..1050); + let data: BinaryArray = (0..len) + .map(|_| { + let len = rng.gen_range(0..16); + Some((0..len).map(|_| rng.gen()).collect::>()) + }) + .collect(); + + test_engine(&BASE64_STANDARD, &data); + test_engine(&BASE64_STANDARD_NO_PAD, &data); + } +} diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs new file mode 100644 index 000000000000..a75354cf9b35 --- /dev/null +++ b/arrow-cast/src/cast.rs @@ -0,0 +1,9450 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines cast kernels for `ArrayRef`, to convert `Array`s between +//! supported datatypes. +//! +//! Example: +//! +//! ``` +//! use arrow_array::*; +//! use arrow_cast::cast; +//! use arrow_schema::DataType; +//! use std::sync::Arc; +//! use arrow_array::types::Float64Type; +//! use arrow_array::cast::AsArray; +//! +//! let a = Int32Array::from(vec![5, 6, 7]); +//! let array = Arc::new(a) as ArrayRef; +//! let b = cast(&array, &DataType::Float64).unwrap(); +//! let c = b.as_primitive::(); +//! assert_eq!(5.0, c.value(0)); +//! assert_eq!(6.0, c.value(1)); +//! assert_eq!(7.0, c.value(2)); +//! ``` + +use chrono::{NaiveTime, Offset, TimeZone, Utc}; +use std::cmp::Ordering; +use std::sync::Arc; + +use crate::display::{ArrayFormatter, FormatOptions}; +use crate::parse::{ + parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, + string_to_datetime, Parser, +}; +use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *}; +use arrow_buffer::{i256, ArrowNativeType, OffsetBuffer}; +use arrow_data::transform::MutableArrayData; +use arrow_data::ArrayData; +use arrow_schema::*; +use arrow_select::take::take; +use num::cast::AsPrimitive; +use num::{NumCast, ToPrimitive}; + +/// CastOptions provides a way to override the default cast behaviors +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct CastOptions<'a> { + /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) + pub safe: bool, + /// Formatting options when casting from temporal types to string + pub format_options: FormatOptions<'a>, +} + +impl<'a> Default for CastOptions<'a> { + fn default() -> Self { + Self { + safe: true, + format_options: FormatOptions::default(), + } + } +} + +/// Return true if a value of type `from_type` can be cast into a value of `to_type`. +/// +/// See [`cast_with_options`] for more information +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + use self::IntervalUnit::*; + use self::TimeUnit::*; + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + ) => true, + // Dictionary/List conditions should be put in front of others + (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { + can_cast_types(from_value_type, to_value_type) + } + (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), + (List(list_from) | LargeList(list_from), List(list_to) | LargeList(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) => { + can_cast_types(list_from.data_type(), to_type) + } + (List(list_from) | LargeList(list_from), FixedSizeList(list_to, _)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (List(_), _) => false, + (FixedSizeList(list_from,_), List(list_to)) => { + list_from.data_type() == list_to.data_type() + } + (FixedSizeList(list_from,_), LargeList(list_to)) => { + list_from.data_type() == list_to.data_type() + } + (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), + (_, LargeList(list_to)) => can_cast_types(from_type, list_to.data_type()), + // cast one decimal type to another decimal type + (Decimal128(_, _), Decimal128(_, _)) => true, + (Decimal256(_, _), Decimal256(_, _)) => true, + (Decimal128(_, _), Decimal256(_, _)) => true, + (Decimal256(_, _), Decimal128(_, _)) => true, + // unsigned integer to decimal + (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) | + (UInt8 | UInt16 | UInt32 | UInt64, Decimal256(_, _)) | + // signed numeric to decimal + (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) | + (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) | + // decimal to unsigned numeric + (Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) | + // decimal to signed numeric + (Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true, + // decimal to Utf8 + (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true, + // Utf8 to decimal + (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, + (Struct(_), _) => false, + (_, Struct(_)) => false, + (_, Boolean) => { + DataType::is_integer(from_type) || + DataType::is_floating(from_type) + || from_type == &Utf8 + || from_type == &LargeUtf8 + } + (Boolean, _) => { + DataType::is_integer(to_type) || DataType::is_floating(to_type) || to_type == &Utf8 || to_type == &LargeUtf8 + } + + (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_)) => true, + (LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_)) => true, + (FixedSizeBinary(_), Binary | LargeBinary) => true, + ( + Utf8 | LargeUtf8, + Binary + | LargeBinary + | Utf8 + | LargeUtf8 + | Date32 + | Date64 + | Time32(Second) + | Time32(Millisecond) + | Time64(Microsecond) + | Time64(Nanosecond) + | Timestamp(Second, _) + | Timestamp(Millisecond, _) + | Timestamp(Microsecond, _) + | Timestamp(Nanosecond, _) + | Interval(_), + ) => true, + (Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, + (_, Utf8 | LargeUtf8) => from_type.is_primitive(), + + (_, Binary | LargeBinary) => from_type.is_integer(), + + // start numeric casts + ( + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, + ) => true, + // end numeric casts + + // temporal casts + (Int32, Date32 | Date64 | Time32(_)) => true, + (Date32, Int32 | Int64) => true, + (Time32(_), Int32) => true, + (Int64, Date64 | Date32 | Time64(_)) => true, + (Date64, Int64 | Int32) => true, + (Time64(_), Int64) => true, + (Date32 | Date64, Date32 | Date64) => true, + // time casts + (Time32(_), Time32(_)) => true, + (Time32(_), Time64(_)) => true, + (Time64(_), Time64(_)) => true, + (Time64(_), Time32(to_unit)) => { + matches!(to_unit, Second | Millisecond) + } + (Timestamp(_, _), _) if to_type.is_numeric() => true, + (_, Timestamp(_, _)) if from_type.is_numeric() => true, + (Date64, Timestamp(_, None)) => true, + (Date32, Timestamp(_, None)) => true, + ( + Timestamp(_, _), + Timestamp(_, _) + | Date32 + | Date64 + | Time32(Second) + | Time32(Millisecond) + | Time64(Microsecond) + | Time64(Nanosecond), + ) => true, + (Int64, Duration(_)) => true, + (Duration(_), Int64) => true, + (Interval(from_type), Int64) => { + match from_type { + YearMonth => true, + DayTime => true, + MonthDayNano => false, // Native type is i128 + } + } + (Int32, Interval(to_type)) => match to_type { + YearMonth => true, + DayTime => false, + MonthDayNano => false, + }, + (Int64, Interval(to_type)) => match to_type { + YearMonth => false, + DayTime => true, + MonthDayNano => false, + }, + (Duration(_), Interval(MonthDayNano)) => true, + (Interval(MonthDayNano), Duration(_)) => true, + (Interval(YearMonth), Interval(MonthDayNano)) => true, + (Interval(DayTime), Interval(MonthDayNano)) => true, + (_, _) => false, + } +} + +/// Cast `array` to the provided data type and return a new Array with type `to_type`, if possible. +/// +/// See [`cast_with_options`] for more information +pub fn cast(array: &dyn Array, to_type: &DataType) -> Result { + cast_with_options(array, to_type, &CastOptions::default()) +} + +fn cast_integer_to_decimal< + T: ArrowPrimitiveType, + D: DecimalType + ArrowPrimitiveType, + M, +>( + array: &PrimitiveArray, + precision: u8, + scale: i8, + base: M, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, + M: ArrowNativeTypeOp, +{ + let scale_factor = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). The scale causes overflow.", + D::PREFIX, + precision, + scale, + )) + })?; + + let array = if scale < 0 { + match cast_options.safe { + true => array.unary_opt::<_, D>(|v| { + v.as_().div_checked(scale_factor).ok().and_then(|v| { + (D::validate_decimal_precision(v, precision).is_ok()).then_some(v) + }) + }), + false => array.try_unary::<_, D, _>(|v| { + v.as_() + .div_checked(scale_factor) + .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) + })?, + } + } else { + match cast_options.safe { + true => array.unary_opt::<_, D>(|v| { + v.as_().mul_checked(scale_factor).ok().and_then(|v| { + (D::validate_decimal_precision(v, precision).is_ok()).then_some(v) + }) + }), + false => array.try_unary::<_, D, _>(|v| { + v.as_() + .mul_checked(scale_factor) + .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) + })?, + } + }; + + Ok(Arc::new(array.with_precision_and_scale(precision, scale)?)) +} + +fn cast_floating_point_to_decimal128( + array: &PrimitiveArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + if cast_options.safe { + array + .unary_opt::<_, Decimal128Type>(|v| { + (mul * v.as_()) + .round() + .to_i128() + .filter(|v| Decimal128Type::validate_decimal_precision(*v, precision).is_ok()) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal128Type, _>(|v| { + (mul * v.as_()) + .round() + .to_i128() + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal128Type::validate_decimal_precision(v, precision).map(|_| v) + }) + })? + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } +} + +fn cast_floating_point_to_decimal256( + array: &PrimitiveArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + if cast_options.safe { + array + .unary_opt::<_, Decimal256Type>(|v| { + i256::from_f64((v.as_() * mul).round()) + .filter(|v| Decimal256Type::validate_decimal_precision(*v, precision).is_ok()) + }) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal256Type, _>(|v| { + i256::from_f64((v.as_() * mul).round()) + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + precision, + scale, + v + )) + }) + .and_then(|v| { + Decimal256Type::validate_decimal_precision(v, precision).map(|_| v) + }) + })? + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } +} + +/// Cast the array from interval year month to month day nano +fn cast_interval_year_month_to_interval_month_day_nano( + array: &dyn Array, + _cast_options: &CastOptions, +) -> Result { + let array = array.as_primitive::(); + + Ok(Arc::new(array.unary::<_, IntervalMonthDayNanoType>(|v| { + let months = IntervalYearMonthType::to_months(v); + IntervalMonthDayNanoType::make_value(months, 0, 0) + }))) +} + +/// Cast the array from interval day time to month day nano +fn cast_interval_day_time_to_interval_month_day_nano( + array: &dyn Array, + _cast_options: &CastOptions, +) -> Result { + let array = array.as_primitive::(); + let mul = 1_000_000; + + Ok(Arc::new(array.unary::<_, IntervalMonthDayNanoType>(|v| { + let (days, ms) = IntervalDayTimeType::to_parts(v); + IntervalMonthDayNanoType::make_value(0, days, ms as i64 * mul) + }))) +} + +/// Cast the array from interval to duration +fn cast_month_day_nano_to_duration>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array.as_primitive::(); + let scale = match D::DATA_TYPE { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + if cast_options.safe { + let iter = array + .iter() + .map(|v| v.and_then(|v| (v >> 64 == 0).then_some((v as i64) / scale))); + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + })) + } else { + let vec = array + .iter() + .map(|v| { + v.map(|v| match v >> 64 { + 0 => Ok((v as i64) / scale), + _ => Err(ArrowError::ComputeError( + "Cannot convert interval containing non-zero months or days to duration" + .to_string(), + )), + }) + .transpose() + }) + .collect::, _>>()?; + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + })) + } +} + +/// Cast the array from duration and interval +fn cast_duration_to_interval>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast duration to DurationArray of expected type" + .to_string(), + ) + })?; + + let scale = match array.data_type() { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + if cast_options.safe { + let iter = array + .iter() + .map(|v| v.and_then(|v| v.checked_mul(scale).map(|v| v as i128))); + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + })) + } else { + let vec = array + .iter() + .map(|v| { + v.map(|v| { + if let Ok(v) = v.mul_checked(scale) { + Ok(v as i128) + } else { + Err(ArrowError::ComputeError(format!( + "Cannot cast to {:?}. Overflowing on {:?}", + IntervalMonthDayNanoType::DATA_TYPE, + v + ))) + } + }) + .transpose() + }) + .collect::, _>>()?; + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + })) + } +} + +/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] +fn cast_reinterpret_arrays>( + array: &dyn Array, +) -> Result { + Ok(Arc::new(array.as_primitive::().reinterpret_cast::())) +} + +fn cast_decimal_to_integer( + array: &dyn Array, + base: D::Native, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: ArrowPrimitiveType, + ::Native: NumCast, + D: DecimalType + ArrowPrimitiveType, + ::Native: ArrowNativeTypeOp + ToPrimitive, +{ + let array = array.as_primitive::(); + + let div: D::Native = base.pow_checked(scale as u32).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast to {:?}. The scale {} causes overflow.", + D::PREFIX, + scale, + )) + })?; + + let mut value_builder = PrimitiveBuilder::::with_capacity(array.len()); + + if cast_options.safe { + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null(); + } else { + let v = array + .value(i) + .div_checked(div) + .ok() + .and_then(::from::); + + value_builder.append_option(v); + } + } + } else { + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null(); + } else { + let v = array.value(i).div_checked(div)?; + + let value = ::from::(v).ok_or_else(|| { + ArrowError::CastError(format!( + "value of {:?} is out of range {}", + v, + T::DATA_TYPE + )) + })?; + + value_builder.append_value(value); + } + } + } + Ok(Arc::new(value_builder.finish())) +} + +// cast the decimal array to floating-point array +fn cast_decimal_to_float( + array: &dyn Array, + op: F, +) -> Result +where + F: Fn(D::Native) -> T::Native, +{ + let array = array.as_primitive::(); + let array = array.unary::<_, T>(op); + Ok(Arc::new(array)) +} + +fn make_timestamp_array( + array: &PrimitiveArray, + unit: TimeUnit, + tz: Option>, +) -> ArrayRef { + match unit { + TimeUnit::Second => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + TimeUnit::Millisecond => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + TimeUnit::Microsecond => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + TimeUnit::Nanosecond => Arc::new( + array + .reinterpret_cast::() + .with_timezone_opt(tz), + ), + } +} + +fn as_time_res_with_timezone( + v: i64, + tz: Option, +) -> Result { + let time = match tz { + Some(tz) => as_datetime_with_timezone::(v, tz).map(|d| d.time()), + None => as_datetime::(v).map(|d| d.time()), + }; + + time.ok_or_else(|| { + ArrowError::CastError(format!( + "Failed to create naive time with {} {}", + std::any::type_name::(), + v + )) + }) +} + +/// Cast `array` to the provided data type and return a new Array with type `to_type`, if possible. +/// +/// Accepts [`CastOptions`] to specify cast behavior. +/// +/// ## Behavior +/// * Boolean to Utf8: `true` => '1', `false` => `0` +/// * Utf8 to boolean: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`, +/// short variants are accepted, other strings return null or error +/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings +/// in integer casts return null +/// * Numeric to boolean: 0 returns `false`, any other value returns `true` +/// * List to List: the underlying data type is cast +/// * List to FixedSizeList: the underlying data type is cast. If safe is true and a list element +/// has the wrong length it will be replaced with NULL, otherwise an error will be returned +/// * Primitive to List: a list array with 1 value per slot is created +/// * Date32 and Date64: precision lost when going to higher interval +/// * Time32 and Time64: precision lost when going to higher interval +/// * Timestamp and Date{32|64}: precision lost when going to higher interval +/// * Temporal to/from backing primitive: zero-copy with data type change +/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals +/// (i.e. casting `6.4999` to Decimal(10, 1) becomes `6.5`). Prior to version `26.0.0`, +/// casting would truncate instead (i.e. outputs `6.4` instead) +/// +/// Unsupported Casts +/// * To or from `StructArray` +/// * List to primitive +/// * Interval and duration +pub fn cast_with_options( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + let from_type = array.data_type(); + // clone array if types are the same + if from_type == to_type { + return Ok(make_array(array.to_data())); + } + match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + ) => Ok(new_null_array(to_type, array.len())), + (Dictionary(index_type, _), _) => match **index_type { + Int8 => dictionary_cast::(array, to_type, cast_options), + Int16 => dictionary_cast::(array, to_type, cast_options), + Int32 => dictionary_cast::(array, to_type, cast_options), + Int64 => dictionary_cast::(array, to_type, cast_options), + UInt8 => dictionary_cast::(array, to_type, cast_options), + UInt16 => dictionary_cast::(array, to_type, cast_options), + UInt32 => dictionary_cast::(array, to_type, cast_options), + UInt64 => dictionary_cast::(array, to_type, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from dictionary type {from_type:?} to {to_type:?} not supported", + ))), + }, + (_, Dictionary(index_type, value_type)) => match **index_type { + Int8 => cast_to_dictionary::(array, value_type, cast_options), + Int16 => cast_to_dictionary::(array, value_type, cast_options), + Int32 => cast_to_dictionary::(array, value_type, cast_options), + Int64 => cast_to_dictionary::(array, value_type, cast_options), + UInt8 => cast_to_dictionary::(array, value_type, cast_options), + UInt16 => cast_to_dictionary::(array, value_type, cast_options), + UInt32 => cast_to_dictionary::(array, value_type, cast_options), + UInt64 => cast_to_dictionary::(array, value_type, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from type {from_type:?} to dictionary type {to_type:?} not supported", + ))), + }, + (List(_), List(to)) => cast_list_values::(array, to, cast_options), + (LargeList(_), LargeList(to)) => cast_list_values::(array, to, cast_options), + (List(_), LargeList(list_to)) => cast_list::(array, list_to, cast_options), + (LargeList(_), List(list_to)) => cast_list::(array, list_to, cast_options), + (List(_), FixedSizeList(field, size)) => { + let array = array.as_list::(); + cast_list_to_fixed_size_list::(array, field, *size, cast_options) + } + (LargeList(_), FixedSizeList(field, size)) => { + let array = array.as_list::(); + cast_list_to_fixed_size_list::(array, field, *size, cast_options) + } + (List(_) | LargeList(_), _) => match to_type { + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + _ => Err(ArrowError::CastError( + "Cannot cast list to non-list data types".to_string(), + )), + }, + (FixedSizeList(list_from, _), List(list_to)) => { + if list_to.data_type() != list_from.data_type() { + Err(ArrowError::CastError( + "cannot cast fixed-size-list to list with different child data".into(), + )) + } else { + cast_fixed_size_list_to_list::(array) + } + } + (FixedSizeList(list_from, _), LargeList(list_to)) => { + if list_to.data_type() != list_from.data_type() { + Err(ArrowError::CastError( + "cannot cast fixed-size-list to largelist with different child data".into(), + )) + } else { + cast_fixed_size_list_to_list::(array) + } + } + (_, List(ref to)) => cast_values_to_list::(array, to, cast_options), + (_, LargeList(ref to)) => cast_values_to_list::(array, to, cast_options), + (Decimal128(_, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal256(_, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(_, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal256(_, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(_, scale), _) if !to_type.is_temporal() => { + // cast decimal to other type + match to_type { + UInt8 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + UInt16 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + UInt32 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + UInt64 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int8 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int16 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int32 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Int64 => cast_decimal_to_integer::( + array, + 10_i128, + *scale, + cast_options, + ), + Float32 => cast_decimal_to_float::(array, |x| { + (x as f64 / 10_f64.powi(*scale as i32)) as f32 + }), + Float64 => cast_decimal_to_float::(array, |x| { + x as f64 / 10_f64.powi(*scale as i32) + }), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (Decimal256(_, scale), _) if !to_type.is_temporal() => { + // cast decimal to other type + match to_type { + UInt8 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + UInt16 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + UInt32 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + UInt64 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int8 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int16 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int32 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Int64 => cast_decimal_to_integer::( + array, + i256::from_i128(10_i128), + *scale, + cast_options, + ), + Float32 => cast_decimal_to_float::(array, |x| { + (x.to_f64().unwrap() / 10_f64.powi(*scale as i32)) as f32 + }), + Float64 => cast_decimal_to_float::(array, |x| { + x.to_f64().unwrap() / 10_f64.powi(*scale as i32) + }), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (_, Decimal128(precision, scale)) if !from_type.is_temporal() => { + // cast data to decimal + match from_type { + UInt8 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + UInt16 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + UInt32 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + UInt64 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>( + array.as_primitive::(), + *precision, + *scale, + 10_i128, + cast_options, + ), + Float32 => cast_floating_point_to_decimal128( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Float64 => cast_floating_point_to_decimal128( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Utf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + LargeUtf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (_, Decimal256(precision, scale)) if !from_type.is_temporal() => { + // cast data to decimal + match from_type { + UInt8 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + UInt16 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + UInt32 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + UInt64 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>( + array.as_primitive::(), + *precision, + *scale, + i256::from_i128(10_i128), + cast_options, + ), + Float32 => cast_floating_point_to_decimal256( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Float64 => cast_floating_point_to_decimal256( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Utf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + LargeUtf8 => cast_string_to_decimal::( + array, + *precision, + *scale, + cast_options, + ), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } + } + (Struct(_), _) => Err(ArrowError::CastError( + "Cannot cast from struct to other types".to_string(), + )), + (_, Struct(_)) => Err(ArrowError::CastError( + "Cannot cast to struct from other types".to_string(), + )), + (_, Boolean) => match from_type { + UInt8 => cast_numeric_to_bool::(array), + UInt16 => cast_numeric_to_bool::(array), + UInt32 => cast_numeric_to_bool::(array), + UInt64 => cast_numeric_to_bool::(array), + Int8 => cast_numeric_to_bool::(array), + Int16 => cast_numeric_to_bool::(array), + Int32 => cast_numeric_to_bool::(array), + Int64 => cast_numeric_to_bool::(array), + Float16 => cast_numeric_to_bool::(array), + Float32 => cast_numeric_to_bool::(array), + Float64 => cast_numeric_to_bool::(array), + Utf8 => cast_utf8_to_boolean::(array, cast_options), + LargeUtf8 => cast_utf8_to_boolean::(array, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Boolean, _) => match to_type { + UInt8 => cast_bool_to_numeric::(array, cast_options), + UInt16 => cast_bool_to_numeric::(array, cast_options), + UInt32 => cast_bool_to_numeric::(array, cast_options), + UInt64 => cast_bool_to_numeric::(array, cast_options), + Int8 => cast_bool_to_numeric::(array, cast_options), + Int16 => cast_bool_to_numeric::(array, cast_options), + Int32 => cast_bool_to_numeric::(array, cast_options), + Int64 => cast_bool_to_numeric::(array, cast_options), + Float16 => cast_bool_to_numeric::(array, cast_options), + Float32 => cast_bool_to_numeric::(array, cast_options), + Float64 => cast_bool_to_numeric::(array, cast_options), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Utf8, _) => match to_type { + UInt8 => parse_string::(array, cast_options), + UInt16 => parse_string::(array, cast_options), + UInt32 => parse_string::(array, cast_options), + UInt64 => parse_string::(array, cast_options), + Int8 => parse_string::(array, cast_options), + Int16 => parse_string::(array, cast_options), + Int32 => parse_string::(array, cast_options), + Int64 => parse_string::(array, cast_options), + Float32 => parse_string::(array, cast_options), + Float64 => parse_string::(array, cast_options), + Date32 => parse_string::(array, cast_options), + Date64 => parse_string::(array, cast_options), + Binary => Ok(Arc::new(BinaryArray::from( + array.as_string::().clone(), + ))), + LargeBinary => { + let binary = BinaryArray::from(array.as_string::().clone()); + cast_byte_container::(&binary) + } + LargeUtf8 => cast_byte_container::(array), + Time32(TimeUnit::Second) => parse_string::(array, cast_options), + Time32(TimeUnit::Millisecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Microsecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Nanosecond) => { + parse_string::(array, cast_options) + } + Timestamp(TimeUnit::Second, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Millisecond, to_tz) => cast_string_to_timestamp::< + i32, + TimestampMillisecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Microsecond, to_tz) => cast_string_to_timestamp::< + i32, + TimestampMicrosecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Nanosecond, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Interval(IntervalUnit::YearMonth) => { + cast_string_to_year_month_interval::(array, cast_options) + } + Interval(IntervalUnit::DayTime) => { + cast_string_to_day_time_interval::(array, cast_options) + } + Interval(IntervalUnit::MonthDayNano) => { + cast_string_to_month_day_nano_interval::(array, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (LargeUtf8, _) => match to_type { + UInt8 => parse_string::(array, cast_options), + UInt16 => parse_string::(array, cast_options), + UInt32 => parse_string::(array, cast_options), + UInt64 => parse_string::(array, cast_options), + Int8 => parse_string::(array, cast_options), + Int16 => parse_string::(array, cast_options), + Int32 => parse_string::(array, cast_options), + Int64 => parse_string::(array, cast_options), + Float32 => parse_string::(array, cast_options), + Float64 => parse_string::(array, cast_options), + Date32 => parse_string::(array, cast_options), + Date64 => parse_string::(array, cast_options), + Utf8 => cast_byte_container::(array), + Binary => { + let large_binary = LargeBinaryArray::from(array.as_string::().clone()); + cast_byte_container::(&large_binary) + } + LargeBinary => Ok(Arc::new(LargeBinaryArray::from( + array.as_string::().clone(), + ))), + Time32(TimeUnit::Second) => parse_string::(array, cast_options), + Time32(TimeUnit::Millisecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Microsecond) => { + parse_string::(array, cast_options) + } + Time64(TimeUnit::Nanosecond) => { + parse_string::(array, cast_options) + } + Timestamp(TimeUnit::Second, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Timestamp(TimeUnit::Millisecond, to_tz) => cast_string_to_timestamp::< + i64, + TimestampMillisecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Microsecond, to_tz) => cast_string_to_timestamp::< + i64, + TimestampMicrosecondType, + >(array, to_tz, cast_options), + Timestamp(TimeUnit::Nanosecond, to_tz) => { + cast_string_to_timestamp::(array, to_tz, cast_options) + } + Interval(IntervalUnit::YearMonth) => { + cast_string_to_year_month_interval::(array, cast_options) + } + Interval(IntervalUnit::DayTime) => { + cast_string_to_day_time_interval::(array, cast_options) + } + Interval(IntervalUnit::MonthDayNano) => { + cast_string_to_month_day_nano_interval::(array, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Binary, _) => match to_type { + Utf8 => cast_binary_to_string::(array, cast_options), + LargeUtf8 => { + let array = cast_binary_to_string::(array, cast_options)?; + cast_byte_container::(array.as_ref()) + } + LargeBinary => cast_byte_container::(array), + FixedSizeBinary(size) => { + cast_binary_to_fixed_size_binary::(array, *size, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (LargeBinary, _) => match to_type { + Utf8 => { + let array = cast_binary_to_string::(array, cast_options)?; + cast_byte_container::(array.as_ref()) + } + LargeUtf8 => cast_binary_to_string::(array, cast_options), + Binary => cast_byte_container::(array), + FixedSizeBinary(size) => { + cast_binary_to_fixed_size_binary::(array, *size, cast_options) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (FixedSizeBinary(size), _) => match to_type { + Binary => cast_fixed_size_binary_to_binary::(array, *size), + LargeBinary => cast_fixed_size_binary_to_binary::(array, *size), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (from_type, LargeUtf8) if from_type.is_primitive() => { + value_to_string::(array, cast_options) + } + (from_type, Utf8) if from_type.is_primitive() => { + value_to_string::(array, cast_options) + } + (from_type, Binary) if from_type.is_integer() => match from_type { + UInt8 => cast_numeric_to_binary::(array), + UInt16 => cast_numeric_to_binary::(array), + UInt32 => cast_numeric_to_binary::(array), + UInt64 => cast_numeric_to_binary::(array), + Int8 => cast_numeric_to_binary::(array), + Int16 => cast_numeric_to_binary::(array), + Int32 => cast_numeric_to_binary::(array), + Int64 => cast_numeric_to_binary::(array), + _ => unreachable!(), + }, + (from_type, LargeBinary) if from_type.is_integer() => match from_type { + UInt8 => cast_numeric_to_binary::(array), + UInt16 => cast_numeric_to_binary::(array), + UInt32 => cast_numeric_to_binary::(array), + UInt64 => cast_numeric_to_binary::(array), + Int8 => cast_numeric_to_binary::(array), + Int16 => cast_numeric_to_binary::(array), + Int32 => cast_numeric_to_binary::(array), + Int64 => cast_numeric_to_binary::(array), + _ => unreachable!(), + }, + // start numeric casts + (UInt8, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt8, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt8, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt8, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt8, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt16, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt16, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt16, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt16, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt16, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt32, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt32, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt32, UInt64) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt32, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt32, Float64) => cast_numeric_arrays::(array, cast_options), + + (UInt64, UInt8) => cast_numeric_arrays::(array, cast_options), + (UInt64, UInt16) => cast_numeric_arrays::(array, cast_options), + (UInt64, UInt32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int8) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int16) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Int64) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float16) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float32) => cast_numeric_arrays::(array, cast_options), + (UInt64, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int8, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int8, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int8, Int16) => cast_numeric_arrays::(array, cast_options), + (Int8, Int32) => cast_numeric_arrays::(array, cast_options), + (Int8, Int64) => cast_numeric_arrays::(array, cast_options), + (Int8, Float16) => cast_numeric_arrays::(array, cast_options), + (Int8, Float32) => cast_numeric_arrays::(array, cast_options), + (Int8, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int16, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int16, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int16, Int8) => cast_numeric_arrays::(array, cast_options), + (Int16, Int32) => cast_numeric_arrays::(array, cast_options), + (Int16, Int64) => cast_numeric_arrays::(array, cast_options), + (Int16, Float16) => cast_numeric_arrays::(array, cast_options), + (Int16, Float32) => cast_numeric_arrays::(array, cast_options), + (Int16, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int32, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int32, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int32, Int8) => cast_numeric_arrays::(array, cast_options), + (Int32, Int16) => cast_numeric_arrays::(array, cast_options), + (Int32, Int64) => cast_numeric_arrays::(array, cast_options), + (Int32, Float16) => cast_numeric_arrays::(array, cast_options), + (Int32, Float32) => cast_numeric_arrays::(array, cast_options), + (Int32, Float64) => cast_numeric_arrays::(array, cast_options), + + (Int64, UInt8) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt16) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt32) => cast_numeric_arrays::(array, cast_options), + (Int64, UInt64) => cast_numeric_arrays::(array, cast_options), + (Int64, Int8) => cast_numeric_arrays::(array, cast_options), + (Int64, Int16) => cast_numeric_arrays::(array, cast_options), + (Int64, Int32) => cast_numeric_arrays::(array, cast_options), + (Int64, Float16) => cast_numeric_arrays::(array, cast_options), + (Int64, Float32) => cast_numeric_arrays::(array, cast_options), + (Int64, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float16, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float16, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float16, Int8) => cast_numeric_arrays::(array, cast_options), + (Float16, Int16) => cast_numeric_arrays::(array, cast_options), + (Float16, Int32) => cast_numeric_arrays::(array, cast_options), + (Float16, Int64) => cast_numeric_arrays::(array, cast_options), + (Float16, Float32) => cast_numeric_arrays::(array, cast_options), + (Float16, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float32, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float32, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float32, Int8) => cast_numeric_arrays::(array, cast_options), + (Float32, Int16) => cast_numeric_arrays::(array, cast_options), + (Float32, Int32) => cast_numeric_arrays::(array, cast_options), + (Float32, Int64) => cast_numeric_arrays::(array, cast_options), + (Float32, Float16) => cast_numeric_arrays::(array, cast_options), + (Float32, Float64) => cast_numeric_arrays::(array, cast_options), + + (Float64, UInt8) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt16) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt32) => cast_numeric_arrays::(array, cast_options), + (Float64, UInt64) => cast_numeric_arrays::(array, cast_options), + (Float64, Int8) => cast_numeric_arrays::(array, cast_options), + (Float64, Int16) => cast_numeric_arrays::(array, cast_options), + (Float64, Int32) => cast_numeric_arrays::(array, cast_options), + (Float64, Int64) => cast_numeric_arrays::(array, cast_options), + (Float64, Float16) => cast_numeric_arrays::(array, cast_options), + (Float64, Float32) => cast_numeric_arrays::(array, cast_options), + // end numeric casts + + // temporal casts + (Int32, Date32) => cast_reinterpret_arrays::(array), + (Int32, Date64) => cast_with_options( + &cast_with_options(array, &Date32, cast_options)?, + &Date64, + cast_options, + ), + (Int32, Time32(TimeUnit::Second)) => { + cast_reinterpret_arrays::(array) + } + (Int32, Time32(TimeUnit::Millisecond)) => { + cast_reinterpret_arrays::(array) + } + // No support for microsecond/nanosecond with i32 + (Date32, Int32) => cast_reinterpret_arrays::(array), + (Date32, Int64) => cast_with_options( + &cast_with_options(array, &Int32, cast_options)?, + &Int64, + cast_options, + ), + (Time32(TimeUnit::Second), Int32) => { + cast_reinterpret_arrays::(array) + } + (Time32(TimeUnit::Millisecond), Int32) => { + cast_reinterpret_arrays::(array) + } + (Int64, Date64) => cast_reinterpret_arrays::(array), + (Int64, Date32) => cast_with_options( + &cast_with_options(array, &Int32, cast_options)?, + &Date32, + cast_options, + ), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => { + cast_reinterpret_arrays::(array) + } + (Int64, Time64(TimeUnit::Nanosecond)) => { + cast_reinterpret_arrays::(array) + } + + (Date64, Int64) => cast_reinterpret_arrays::(array), + (Date64, Int32) => cast_with_options( + &cast_with_options(array, &Int64, cast_options)?, + &Int32, + cast_options, + ), + (Time64(TimeUnit::Microsecond), Int64) => { + cast_reinterpret_arrays::(array) + } + (Time64(TimeUnit::Nanosecond), Int64) => { + cast_reinterpret_arrays::(array) + } + (Date32, Date64) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date64Type>(|x| x as i64 * MILLISECONDS_IN_DAY), + )), + (Date64, Date32) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date32Type>(|x| (x / MILLISECONDS_IN_DAY) as i32), + )), + + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| x * MILLISECONDS as i32), + )), + (Time32(TimeUnit::Second), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64MicrosecondType>(|x| x as i64 * MICROSECONDS), + )), + (Time32(TimeUnit::Second), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64NanosecondType>(|x| x as i64 * NANOSECONDS), + )), + + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32SecondType>(|x| x / MILLISECONDS as i32), + )), + (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64MicrosecondType>(|x| x as i64 * (MICROSECONDS / MILLISECONDS)), + )), + (Time32(TimeUnit::Millisecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64NanosecondType>(|x| x as i64 * (MICROSECONDS / NANOSECONDS)), + )), + + (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Second)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32SecondType>(|x| (x / MICROSECONDS) as i32), + )), + (Time64(TimeUnit::Microsecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| (x / (MICROSECONDS / MILLISECONDS)) as i32), + )), + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64NanosecondType>(|x| x * (NANOSECONDS / MICROSECONDS)), + )), + + (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Second)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32SecondType>(|x| (x / NANOSECONDS) as i32), + )), + (Time64(TimeUnit::Nanosecond), Time32(TimeUnit::Millisecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time32MillisecondType>(|x| (x / (NANOSECONDS / MILLISECONDS)) as i32), + )), + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Time64MicrosecondType>(|x| x / (NANOSECONDS / MICROSECONDS)), + )), + + // Timestamp to integer/floating/decimals + (Timestamp(TimeUnit::Second, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Timestamp(TimeUnit::Millisecond, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Timestamp(TimeUnit::Microsecond, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + (Timestamp(TimeUnit::Nanosecond, _), _) if to_type.is_numeric() => { + let array = cast_reinterpret_arrays::(array)?; + cast_with_options(&array, to_type, cast_options) + } + + (_, Timestamp(unit, tz)) if from_type.is_numeric() => { + let array = cast_with_options(array, &Int64, cast_options)?; + Ok(make_timestamp_array( + array.as_primitive(), + unit.clone(), + tz.clone(), + )) + } + + (Timestamp(from_unit, from_tz), Timestamp(to_unit, to_tz)) => { + let array = cast_with_options(array, &Int64, cast_options)?; + let time_array = array.as_primitive::(); + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + // we either divide or multiply, depending on size of each unit + // units are never the same when the types are the same + let converted = match from_size.cmp(&to_size) { + Ordering::Greater => { + let divisor = from_size / to_size; + time_array.unary::<_, Int64Type>(|o| o / divisor) + } + Ordering::Equal => time_array.clone(), + Ordering::Less => { + let mul = to_size / from_size; + if cast_options.safe { + time_array.unary_opt::<_, Int64Type>(|o| o.checked_mul(mul)) + } else { + time_array.try_unary::<_, Int64Type, _>(|o| o.mul_checked(mul))? + } + } + }; + // Normalize timezone + let adjusted = match (from_tz, to_tz) { + // Only this case needs to be adjusted because we're casting from + // unknown time offset to some time offset, we want the time to be + // unchanged. + // + // i.e. Timestamp('2001-01-01T00:00', None) -> Timestamp('2001-01-01T00:00', '+0700') + (None, Some(to_tz)) => { + let to_tz: Tz = to_tz.parse()?; + match to_unit { + TimeUnit::Second => adjust_timestamp_to_timezone::( + converted, + &to_tz, + cast_options, + )?, + TimeUnit::Millisecond => adjust_timestamp_to_timezone::< + TimestampMillisecondType, + >( + converted, &to_tz, cast_options + )?, + TimeUnit::Microsecond => adjust_timestamp_to_timezone::< + TimestampMicrosecondType, + >( + converted, &to_tz, cast_options + )?, + TimeUnit::Nanosecond => adjust_timestamp_to_timezone::< + TimestampNanosecondType, + >( + converted, &to_tz, cast_options + )?, + } + } + _ => converted, + }; + Ok(make_timestamp_array( + &adjusted, + to_unit.clone(), + to_tz.clone(), + )) + } + (Timestamp(from_unit, _), Date32) => { + let array = cast_with_options(array, &Int64, cast_options)?; + let time_array = array.as_primitive::(); + let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY; + + let mut b = Date32Builder::with_capacity(array.len()); + + for i in 0..array.len() { + if time_array.is_null(i) { + b.append_null(); + } else { + b.append_value( + num::integer::div_floor::(time_array.value(i), from_size) as i32, + ); + } + } + + Ok(Arc::new(b.finish()) as ArrayRef) + } + (Timestamp(TimeUnit::Second, _), Date64) => Ok(Arc::new(match cast_options.safe { + true => { + // change error to None + array + .as_primitive::() + .unary_opt::<_, Date64Type>(|x| x.checked_mul(MILLISECONDS)) + } + false => array + .as_primitive::() + .try_unary::<_, Date64Type, _>(|x| x.mul_checked(MILLISECONDS))?, + })), + (Timestamp(TimeUnit::Millisecond, _), Date64) => { + cast_reinterpret_arrays::(array) + } + (Timestamp(TimeUnit::Microsecond, _), Date64) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date64Type>(|x| x / (MICROSECONDS / MILLISECONDS)), + )), + (Timestamp(TimeUnit::Nanosecond, _), Date64) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, Date64Type>(|x| x / (NANOSECONDS / MILLISECONDS)), + )), + (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Second, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Microsecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64MicrosecondType, ArrowError>(|x| { + Ok(time_to_time64us(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time64(TimeUnit::Nanosecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time64NanosecondType, ArrowError>(|x| { + Ok(time_to_time64ns(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Second, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampSecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Millisecond, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampMillisecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Microsecond, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampMicrosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Second)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32SecondType, ArrowError>(|x| { + Ok(time_to_time32s(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + (Timestamp(TimeUnit::Nanosecond, tz), Time32(TimeUnit::Millisecond)) => { + let tz = tz.as_ref().map(|tz| tz.parse()).transpose()?; + Ok(Arc::new( + array + .as_primitive::() + .try_unary::<_, Time32MillisecondType, ArrowError>(|x| { + Ok(time_to_time32ms(as_time_res_with_timezone::< + TimestampNanosecondType, + >(x, tz)?)) + })?, + )) + } + + (Date64, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampSecondType>(|x| x / MILLISECONDS), + )), + (Date64, Timestamp(TimeUnit::Millisecond, None)) => { + cast_reinterpret_arrays::(array) + } + (Date64, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampMicrosecondType>(|x| x * (MICROSECONDS / MILLISECONDS)), + )), + (Date64, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampNanosecondType>(|x| x * (NANOSECONDS / MILLISECONDS)), + )), + (Date32, Timestamp(TimeUnit::Second, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampSecondType>(|x| (x as i64) * SECONDS_IN_DAY), + )), + (Date32, Timestamp(TimeUnit::Millisecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampMillisecondType>(|x| (x as i64) * MILLISECONDS_IN_DAY), + )), + (Date32, Timestamp(TimeUnit::Microsecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampMicrosecondType>(|x| (x as i64) * MICROSECONDS_IN_DAY), + )), + (Date32, Timestamp(TimeUnit::Nanosecond, None)) => Ok(Arc::new( + array + .as_primitive::() + .unary::<_, TimestampNanosecondType>(|x| (x as i64) * NANOSECONDS_IN_DAY), + )), + (Int64, Duration(TimeUnit::Second)) => { + cast_reinterpret_arrays::(array) + } + (Int64, Duration(TimeUnit::Millisecond)) => { + cast_reinterpret_arrays::(array) + } + (Int64, Duration(TimeUnit::Microsecond)) => { + cast_reinterpret_arrays::(array) + } + (Int64, Duration(TimeUnit::Nanosecond)) => { + cast_reinterpret_arrays::(array) + } + + (Duration(TimeUnit::Second), Int64) => { + cast_reinterpret_arrays::(array) + } + (Duration(TimeUnit::Millisecond), Int64) => { + cast_reinterpret_arrays::(array) + } + (Duration(TimeUnit::Microsecond), Int64) => { + cast_reinterpret_arrays::(array) + } + (Duration(TimeUnit::Nanosecond), Int64) => { + cast_reinterpret_arrays::(array) + } + (Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Millisecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Microsecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Nanosecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Second)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Millisecond)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Microsecond)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::MonthDayNano), Duration(TimeUnit::Nanosecond)) => { + cast_month_day_nano_to_duration::(array, cast_options) + } + (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::MonthDayNano)) => { + cast_interval_year_month_to_interval_month_day_nano(array, cast_options) + } + (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::MonthDayNano)) => { + cast_interval_day_time_to_interval_month_day_nano(array, cast_options) + } + (Interval(IntervalUnit::YearMonth), Int64) => { + cast_numeric_arrays::(array, cast_options) + } + (Interval(IntervalUnit::DayTime), Int64) => { + cast_reinterpret_arrays::(array) + } + (Int32, Interval(IntervalUnit::YearMonth)) => { + cast_reinterpret_arrays::(array) + } + (Int64, Interval(IntervalUnit::DayTime)) => { + cast_reinterpret_arrays::(array) + } + (_, _) => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + } +} + +/// Get the time unit as a multiple of a second +const fn time_unit_multiple(unit: &TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + +/// A utility trait that provides checked conversions between +/// decimal types inspired by [`NumCast`] +trait DecimalCast: Sized { + fn to_i128(self) -> Option; + + fn to_i256(self) -> Option; + + fn from_decimal(n: T) -> Option; +} + +impl DecimalCast for i128 { + fn to_i128(self) -> Option { + Some(self) + } + + fn to_i256(self) -> Option { + Some(i256::from_i128(self)) + } + + fn from_decimal(n: T) -> Option { + n.to_i128() + } +} + +impl DecimalCast for i256 { + fn to_i128(self) -> Option { + self.to_i128() + } + + fn to_i256(self) -> Option { + Some(self) + } + + fn from_decimal(n: T) -> Option { + n.to_i256() + } +} + +fn cast_decimal_to_decimal_error( + output_precision: u8, + output_scale: i8, +) -> impl Fn(::Native) -> ArrowError +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + move |x: I::Native| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + O::PREFIX, + output_precision, + output_scale, + x + )) + } +} + +fn convert_to_smaller_scale_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); + let div = I::Native::from_decimal(10_i128) + .unwrap() + .pow_checked((input_scale - output_scale) as u32)?; + + let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); + let half_neg = half.neg_wrapping(); + + let f = |x: I::Native| { + // div is >= 10 and so this cannot overflow + let d = x.div_wrapping(div); + let r = x.mod_wrapping(div); + + // Round result + let adjusted = match x >= I::Native::ZERO { + true if r >= half => d.add_wrapping(I::Native::ONE), + false if r <= half_neg => d.sub_wrapping(I::Native::ONE), + _ => d, + }; + O::Native::from_decimal(adjusted) + }; + + Ok(match cast_options.safe { + true => array.unary_opt(f), + false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + }) +} + +fn convert_to_bigger_or_equal_scale_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); + let mul = O::Native::from_decimal(10_i128) + .unwrap() + .pow_checked((output_scale - input_scale) as u32)?; + + let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); + + Ok(match cast_options.safe { + true => array.unary_opt(f), + false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?, + }) +} + +// Only support one type of decimal cast operations +fn cast_decimal_to_decimal_same_type( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array: PrimitiveArray = match input_scale.cmp(&output_scale) { + Ordering::Equal => { + // the scale doesn't change, the native value don't need to be changed + array.clone() + } + Ordering::Greater => convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )?, + Ordering::Less => { + // input_scale < output_scale + convert_to_bigger_or_equal_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + } + }; + + Ok(Arc::new(array.with_precision_and_scale( + output_precision, + output_scale, + )?)) +} + +// Support two different types of decimal cast operations +fn cast_decimal_to_decimal( + array: &PrimitiveArray, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array: PrimitiveArray = if input_scale > output_scale { + convert_to_smaller_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + } else { + convert_to_bigger_or_equal_scale_decimal::( + array, + input_scale, + output_precision, + output_scale, + cast_options, + )? + }; + + Ok(Arc::new(array.with_precision_and_scale( + output_precision, + output_scale, + )?)) +} + +/// Convert Array into a PrimitiveArray of type, and apply numeric cast +fn cast_numeric_arrays( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + FROM: ArrowPrimitiveType, + TO: ArrowPrimitiveType, + FROM::Native: NumCast, + TO::Native: NumCast, +{ + if cast_options.safe { + // If the value can't be casted to the `TO::Native`, return null + Ok(Arc::new(numeric_cast::( + from.as_primitive::(), + ))) + } else { + // If the value can't be casted to the `TO::Native`, return error + Ok(Arc::new(try_numeric_cast::( + from.as_primitive::(), + )?)) + } +} + +// Natural cast between numeric types +// If the value of T can't be casted to R, will throw error +fn try_numeric_cast(from: &PrimitiveArray) -> Result, ArrowError> +where + T: ArrowPrimitiveType, + R: ArrowPrimitiveType, + T::Native: NumCast, + R::Native: NumCast, +{ + from.try_unary(|value| { + num::cast::cast::(value).ok_or_else(|| { + ArrowError::CastError(format!( + "Can't cast value {:?} to type {}", + value, + R::DATA_TYPE + )) + }) + }) +} + +// Natural cast between numeric types +// If the value of T can't be casted to R, it will be converted to null +fn numeric_cast(from: &PrimitiveArray) -> PrimitiveArray +where + T: ArrowPrimitiveType, + R: ArrowPrimitiveType, + T::Native: NumCast, + R::Native: NumCast, +{ + from.unary_opt::<_, R>(num::cast::cast::) +} + +fn value_to_string( + array: &dyn Array, + options: &CastOptions, +) -> Result { + let mut builder = GenericStringBuilder::::new(); + let formatter = ArrayFormatter::try_new(array, &options.format_options)?; + let nulls = array.nulls(); + for i in 0..array.len() { + match nulls.map(|x| x.is_null(i)).unwrap_or_default() { + true => builder.append_null(), + false => { + formatter.value(i).write(&mut builder)?; + // tell the builder the row is finished + builder.append_value(""); + } + } + } + Ok(Arc::new(builder.finish())) +} + +fn cast_numeric_to_binary( + array: &dyn Array, +) -> Result { + let array = array.as_primitive::(); + let size = std::mem::size_of::(); + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(size).take(array.len())); + Ok(Arc::new(GenericBinaryArray::::new( + offsets, + array.values().inner().clone(), + array.nulls().cloned(), + ))) +} + +/// Parse UTF-8 +fn parse_string( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let string_array = array.as_string::(); + let array = if cast_options.safe { + let iter = string_array.iter().map(|x| x.and_then(P::parse)); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { PrimitiveArray::

::from_trusted_len_iter(iter) } + } else { + let v = string_array + .iter() + .map(|x| match x { + Some(v) => P::parse(v).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + P::DATA_TYPE + )) + }), + None => Ok(P::Native::default()), + }) + .collect::, ArrowError>>()?; + PrimitiveArray::new(v.into(), string_array.nulls().cloned()) + }; + + Ok(Arc::new(array) as ArrayRef) +} + +/// Casts generic string arrays to an ArrowTimestampType (TimeStampNanosecondArray, etc.) +fn cast_string_to_timestamp( + array: &dyn Array, + to_tz: &Option>, + cast_options: &CastOptions, +) -> Result { + let array = array.as_string::(); + let out: PrimitiveArray = match to_tz { + Some(tz) => { + let tz: Tz = tz.as_ref().parse()?; + cast_string_to_timestamp_impl(array, &tz, cast_options)? + } + None => cast_string_to_timestamp_impl(array, &Utc, cast_options)?, + }; + Ok(Arc::new(out.with_timezone_opt(to_tz.clone()))) +} + +fn cast_string_to_timestamp_impl( + array: &GenericStringArray, + tz: &Tz, + cast_options: &CastOptions, +) -> Result, ArrowError> { + if cast_options.safe { + let iter = array.iter().map(|v| { + v.and_then(|v| { + let naive = string_to_datetime(tz, v).ok()?.naive_utc(); + T::make_value(naive) + }) + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + + Ok(unsafe { PrimitiveArray::from_trusted_len_iter(iter) }) + } else { + let vec = array + .iter() + .map(|v| { + v.map(|v| { + let naive = string_to_datetime(tz, v)?.naive_utc(); + T::make_value(naive).ok_or_else(|| { + ArrowError::CastError(format!( + "Overflow converting {naive} to {:?}", + T::UNIT + )) + }) + }) + .transpose() + }) + .collect::>, _>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { PrimitiveArray::from_trusted_len_iter(vec.iter()) }) + } +} + +fn cast_string_to_interval( + array: &dyn Array, + cast_options: &CastOptions, + parse_function: F, +) -> Result +where + Offset: OffsetSizeTrait, + ArrowType: ArrowPrimitiveType, + F: Fn(&str) -> Result + Copy, +{ + let string_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let interval_array = if cast_options.safe { + let iter = string_array + .iter() + .map(|v| v.and_then(|v| parse_function(v).ok())); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { PrimitiveArray::::from_trusted_len_iter(iter) } + } else { + let vec = string_array + .iter() + .map(|v| v.map(parse_function).transpose()) + .collect::, ArrowError>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { PrimitiveArray::::from_trusted_len_iter(vec) } + }; + Ok(Arc::new(interval_array) as ArrayRef) +} + +fn cast_string_to_year_month_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_string_to_interval::( + array, + cast_options, + parse_interval_year_month, + ) +} + +fn cast_string_to_day_time_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_string_to_interval::( + array, + cast_options, + parse_interval_day_time, + ) +} + +fn cast_string_to_month_day_nano_interval( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + cast_string_to_interval::( + array, + cast_options, + parse_interval_month_day_nano, + ) +} + +fn adjust_timestamp_to_timezone( + array: PrimitiveArray, + to_tz: &Tz, + cast_options: &CastOptions, +) -> Result, ArrowError> { + let adjust = |o| { + let local = as_datetime::(o)?; + let offset = to_tz.offset_from_local_datetime(&local).single()?; + T::make_value(local - offset.fix()) + }; + let adjusted = if cast_options.safe { + array.unary_opt::<_, Int64Type>(adjust) + } else { + array.try_unary::<_, Int64Type, _>(|o| { + adjust(o).ok_or_else(|| { + ArrowError::CastError("Cannot cast timezone to different timezone".to_string()) + }) + })? + }; + Ok(adjusted) +} + +/// Casts Utf8 to Boolean +fn cast_utf8_to_boolean( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => Ok(Some(true)), + "f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off" | "0" => { + Ok(Some(false)) + } + invalid_value => match cast_options.safe { + true => Ok(None), + false => Err(ArrowError::CastError(format!( + "Cannot cast value '{invalid_value}' to value of Boolean type", + ))), + }, + }, + None => Ok(None), + }) + .collect::>()?; + + Ok(Arc::new(output_array)) +} + +/// Parses given string to specified decimal native (i128/i256) based on given +/// scale. Returns an `Err` if it cannot parse given string. +fn parse_string_to_decimal_native( + value_str: &str, + scale: usize, +) -> Result +where + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + let value_str = value_str.trim(); + let parts: Vec<&str> = value_str.split('.').collect(); + if parts.len() > 2 { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + let (negative, first_part) = if parts[0].is_empty() { + (false, parts[0]) + } else { + match parts[0].as_bytes()[0] { + b'-' => (true, &parts[0][1..]), + b'+' => (false, &parts[0][1..]), + _ => (false, parts[0]), + } + }; + + let integers = first_part.trim_start_matches('0'); + let decimals = if parts.len() == 2 { parts[1] } else { "" }; + + if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } + + // Adjust decimal based on scale + let mut number_decimals = if decimals.len() > scale { + let decimal_number = i256::from_string(decimals).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) + })?; + + let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; + + let half = div.div_wrapping(i256::from_i128(2)); + let half_neg = half.neg_wrapping(); + + let d = decimal_number.div_wrapping(div); + let r = decimal_number.mod_wrapping(div); + + // Round result + let adjusted = match decimal_number >= i256::ZERO { + true if r >= half => d.add_wrapping(i256::ONE), + false if r <= half_neg => d.sub_wrapping(i256::ONE), + _ => d, + }; + + let integers = if !integers.is_empty() { + i256::from_string(integers) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Cannot parse decimal format: {value_str}" + )) + }) + .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))? + } else { + i256::ZERO + }; + + format!("{}", integers.add_wrapping(adjusted)) + } else { + let padding = if scale > decimals.len() { scale } else { 0 }; + + let decimals = format!("{decimals:0( + from: &GenericStringArray, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if cast_options.safe { + let iter = from.iter().map(|v| { + v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) + .and_then(|v| { + T::validate_decimal_precision(v, precision) + .is_ok() + .then_some(v) + }) + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + .with_precision_and_scale(precision, scale)? + }) + } else { + let vec = from + .iter() + .map(|v| { + v.map(|v| { + parse_string_to_decimal_native::(v, scale as usize) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + T::DATA_TYPE, + )) + }) + .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) + }) + .transpose() + }) + .collect::, _>>()?; + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + Ok(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + .with_precision_and_scale(precision, scale)? + }) + } +} + +/// Cast Utf8 to decimal +fn cast_string_to_decimal( + from: &dyn Array, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + T: DecimalType, + T::Native: DecimalCast + ArrowNativeTypeOp, +{ + if scale < 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal with negative scale {scale}" + ))); + } + + if scale > T::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot cast string to decimal greater than maximum scale {}", + T::MAX_SCALE + ))); + } + + Ok(Arc::new(string_to_decimal_cast::( + from.as_any() + .downcast_ref::>() + .unwrap(), + precision, + scale, + cast_options, + )?)) +} + +/// Cast numeric types to Boolean +/// +/// Any zero value returns `false` while non-zero returns `true` +fn cast_numeric_to_bool(from: &dyn Array) -> Result +where + FROM: ArrowPrimitiveType, +{ + numeric_to_bool_cast::(from.as_primitive::()).map(|to| Arc::new(to) as ArrayRef) +} + +fn numeric_to_bool_cast(from: &PrimitiveArray) -> Result +where + T: ArrowPrimitiveType + ArrowPrimitiveType, +{ + let mut b = BooleanBuilder::with_capacity(from.len()); + + for i in 0..from.len() { + if from.is_null(i) { + b.append_null(); + } else if from.value(i) != T::default_value() { + b.append_value(true); + } else { + b.append_value(false); + } + } + + Ok(b.finish()) +} + +/// Cast Boolean types to numeric +/// +/// `false` returns 0 while `true` returns 1 +fn cast_bool_to_numeric( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + TO: ArrowPrimitiveType, + TO::Native: num::cast::NumCast, +{ + Ok(Arc::new(bool_to_numeric_cast::( + from.as_any().downcast_ref::().unwrap(), + cast_options, + ))) +} + +fn bool_to_numeric_cast(from: &BooleanArray, _cast_options: &CastOptions) -> PrimitiveArray +where + T: ArrowPrimitiveType, + T::Native: num::NumCast, +{ + let iter = (0..from.len()).map(|i| { + if from.is_null(i) { + None + } else if from.value(i) { + // a workaround to cast a primitive to T::Native, infallible + num::cast::cast(1) + } else { + Some(T::default_value()) + } + }); + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from a Range + unsafe { PrimitiveArray::::from_trusted_len_iter(iter) } +} + +/// Attempts to cast an `ArrayDictionary` with index type K into +/// `to_type` for supported types. +/// +/// K is the key type +fn dictionary_cast( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + + match to_type { + Dictionary(to_index_type, to_value_type) => { + let dict_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), + ) + })?; + + let keys_array: ArrayRef = + Arc::new(PrimitiveArray::::from(dict_array.keys().to_data())); + let values_array = dict_array.values(); + let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; + let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > keys_array.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - keys_array.null_count(), + keys_array.data_type(), + to_index_type + ))); + } + + let data = cast_keys.into_data(); + let builder = data + .into_builder() + .data_type(to_type.clone()) + .child_data(vec![cast_values.into_data()]); + + // Safety + // Cast keys are still valid + let data = unsafe { builder.build_unchecked() }; + + // create the appropriate array type + let new_array: ArrayRef = match **to_index_type { + Int8 => Arc::new(DictionaryArray::::from(data)), + Int16 => Arc::new(DictionaryArray::::from(data)), + Int32 => Arc::new(DictionaryArray::::from(data)), + Int64 => Arc::new(DictionaryArray::::from(data)), + UInt8 => Arc::new(DictionaryArray::::from(data)), + UInt16 => Arc::new(DictionaryArray::::from(data)), + UInt32 => Arc::new(DictionaryArray::::from(data)), + UInt64 => Arc::new(DictionaryArray::::from(data)), + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported type {to_index_type:?} for dictionary index" + ))); + } + }; + + Ok(new_array) + } + _ => unpack_dictionary::(array, to_type, cast_options), + } +} + +// Unpack a dictionary where the keys are of type into a flattened array of type to_type +fn unpack_dictionary( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let dict_array = array.as_dictionary::(); + let cast_dict_values = cast_with_options(dict_array.values(), to_type, cast_options)?; + take(cast_dict_values.as_ref(), dict_array.keys(), None) +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +fn cast_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; + + match *dict_value_type { + Int8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Int64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt8 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Decimal128(_, _) => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Decimal256(_, _) => { + pack_numeric_to_dictionary::(array, dict_value_type, cast_options) + } + Utf8 => pack_byte_to_dictionary::>(array, cast_options), + LargeUtf8 => pack_byte_to_dictionary::>(array, cast_options), + Binary => pack_byte_to_dictionary::>(array, cast_options), + LargeBinary => pack_byte_to_dictionary::>(array, cast_options), + _ => Err(ArrowError::CastError(format!( + "Unsupported output type for dictionary packing: {dict_value_type:?}" + ))), + } +} + +// Packs the data from the primitive array of type to a +// DictionaryArray with keys of type K and values of value_type V +fn pack_numeric_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, + V: ArrowPrimitiveType, +{ + // attempt to cast the source array values to the target value type (the dictionary values type) + let cast_values = cast_with_options(array, dict_value_type, cast_options)?; + let values = cast_values.as_primitive::(); + + let mut b = PrimitiveDictionaryBuilder::::with_capacity(values.len(), values.len()); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null(); + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + +// Packs the data as a GenericByteDictionaryBuilder, if possible, with the +// key types of K +fn pack_byte_to_dictionary( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, + T: ByteArrayType, +{ + let cast_values = cast_with_options(array, &T::DATA_TYPE, cast_options)?; + let values = cast_values + .as_any() + .downcast_ref::>() + .unwrap(); + let mut b = GenericByteDictionaryBuilder::::with_capacity(values.len(), 1024, 1024); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null(); + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + +/// Helper function that takes a primitive array and casts to a (generic) list array. +fn cast_values_to_list( + array: &dyn Array, + to: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let values = cast_with_options(array, to.data_type(), cast_options)?; + let offsets = OffsetBuffer::from_lengths(std::iter::repeat(1).take(values.len())); + let list = GenericListArray::::new(to.clone(), offsets, values, None); + Ok(Arc::new(list)) +} + +/// A specified helper to cast from `GenericBinaryArray` to `GenericStringArray` when they have same +/// offset size so re-encoding offset is unnecessary. +fn cast_binary_to_string( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array + .as_any() + .downcast_ref::>>() + .unwrap(); + + match GenericStringArray::::try_from_binary(array.clone()) { + Ok(a) => Ok(Arc::new(a)), + Err(e) => match cast_options.safe { + true => { + // Fallback to slow method to convert invalid sequences to nulls + let mut builder = + GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); + + let iter = array + .iter() + .map(|v| v.and_then(|v| std::str::from_utf8(v).ok())); + + builder.extend(iter); + Ok(Arc::new(builder.finish())) + } + false => Err(e), + }, + } +} + +/// Helper function to cast from one `BinaryArray` or 'LargeBinaryArray' to 'FixedSizeBinaryArray'. +fn cast_binary_to_fixed_size_binary( + array: &dyn Array, + byte_width: i32, + cast_options: &CastOptions, +) -> Result { + let array = array.as_binary::(); + let mut builder = FixedSizeBinaryBuilder::with_capacity(array.len(), byte_width); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + match builder.append_value(array.value(i)) { + Ok(_) => {} + Err(e) => match cast_options.safe { + true => builder.append_null(), + false => return Err(e), + }, + } + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Helper function to cast from 'FixedSizeBinaryArray' to one `BinaryArray` or 'LargeBinaryArray'. +/// If the target one is too large for the source array it will return an Error. +fn cast_fixed_size_binary_to_binary( + array: &dyn Array, + byte_width: i32, +) -> Result { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + let offsets: i128 = byte_width as i128 * array.len() as i128; + + let is_binary = matches!(GenericBinaryType::::DATA_TYPE, DataType::Binary); + if is_binary && offsets > i32::MAX as i128 { + return Err(ArrowError::ComputeError( + "FixedSizeBinary array too large to cast to Binary array".to_string(), + )); + } else if !is_binary && offsets > i64::MAX as i128 { + return Err(ArrowError::ComputeError( + "FixedSizeBinary array too large to cast to LargeBinary array".to_string(), + )); + } + + let mut builder = GenericBinaryBuilder::::with_capacity(array.len(), array.len()); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(array.value(i)); + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Helper function to cast from one `ByteArrayType` to another and vice versa. +/// If the target one (e.g., `LargeUtf8`) is too large for the source array it will return an Error. +fn cast_byte_container(array: &dyn Array) -> Result +where + FROM: ByteArrayType, + TO: ByteArrayType, + FROM::Offset: OffsetSizeTrait + ToPrimitive, + TO::Offset: OffsetSizeTrait + NumCast, +{ + let data = array.to_data(); + assert_eq!(data.data_type(), &FROM::DATA_TYPE); + let str_values_buf = data.buffers()[1].clone(); + let offsets = data.buffers()[0].typed_data::(); + + let mut offset_builder = BufferBuilder::::new(offsets.len()); + offsets + .iter() + .try_for_each::<_, Result<_, ArrowError>>(|offset| { + let offset = + <::Offset as NumCast>::from(*offset).ok_or_else(|| { + ArrowError::ComputeError(format!( + "{}{} array too large to cast to {}{} array", + FROM::Offset::PREFIX, + FROM::PREFIX, + TO::Offset::PREFIX, + TO::PREFIX + )) + })?; + offset_builder.append(offset); + Ok(()) + })?; + + let offset_buffer = offset_builder.finish(); + + let dtype = TO::DATA_TYPE; + + let builder = ArrayData::builder(dtype) + .offset(array.offset()) + .len(array.len()) + .add_buffer(offset_buffer) + .add_buffer(str_values_buf) + .nulls(data.nulls().cloned()); + + let array_data = unsafe { builder.build_unchecked() }; + + Ok(Arc::new(GenericByteArray::::from(array_data))) +} + +fn cast_fixed_size_list_to_list(array: &dyn Array) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let fixed_size_list: &FixedSizeListArray = array.as_fixed_size_list(); + let list: GenericListArray = fixed_size_list.clone().into(); + Ok(Arc::new(list)) +} + +fn cast_list_to_fixed_size_list( + array: &GenericListArray, + field: &FieldRef, + size: i32, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let cap = array.len() * size as usize; + + let mut nulls = (cast_options.safe || array.null_count() != 0).then(|| { + let mut buffer = BooleanBufferBuilder::new(array.len()); + match array.nulls() { + Some(n) => buffer.append_buffer(n.inner()), + None => buffer.append_n(array.len(), true), + } + buffer + }); + + // Nulls in FixedSizeListArray take up space and so we must pad the values + let values = array.values().to_data(); + let mut mutable = MutableArrayData::new(vec![&values], cast_options.safe, cap); + // The end position in values of the last incorrectly-sized list slice + let mut last_pos = 0; + for (idx, w) in array.offsets().windows(2).enumerate() { + let start_pos = w[0].as_usize(); + let end_pos = w[1].as_usize(); + let len = end_pos - start_pos; + + if len != size as usize { + if cast_options.safe || array.is_null(idx) { + if last_pos != start_pos { + // Extend with valid slices + mutable.extend(0, last_pos, start_pos); + } + // Pad this slice with nulls + mutable.extend_nulls(size as _); + nulls.as_mut().unwrap().set_bit(idx, false); + // Set last_pos to the end of this slice's values + last_pos = end_pos + } else { + return Err(ArrowError::CastError(format!( + "Cannot cast to FixedSizeList({size}): value at index {idx} has length {len}", + ))); + } + } + } + + let values = match last_pos { + 0 => array.values().slice(0, cap), // All slices were the correct length + _ => { + if mutable.len() != cap { + // Remaining slices were all correct length + let remaining = cap - mutable.len(); + mutable.extend(0, last_pos, last_pos + remaining) + } + make_array(mutable.freeze()) + } + }; + + // Cast the inner values if necessary + let values = cast_with_options(values.as_ref(), field.data_type(), cast_options)?; + + // Construct the FixedSizeListArray + let nulls = nulls.map(|mut x| x.finish().into()); + let array = FixedSizeListArray::new(field.clone(), size, values, nulls); + Ok(Arc::new(array)) +} + +/// Helper function that takes an Generic list container and casts the inner datatype. +fn cast_list_values( + array: &dyn Array, + to: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list = array.as_list::(); + let values = cast_with_options(list.values(), to.data_type(), cast_options)?; + Ok(Arc::new(GenericListArray::::new( + to.clone(), + list.offsets().clone(), + values, + list.nulls().cloned(), + ))) +} + +/// Cast the container type of List/Largelist array along with the inner datatype +fn cast_list( + array: &dyn Array, + field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list = array.as_list::(); + let values = list.values(); + let offsets = list.offsets(); + let nulls = list.nulls().cloned(); + + if !O::IS_LARGE && values.len() > i32::MAX as usize { + return Err(ArrowError::ComputeError( + "LargeList too large to cast to List".into(), + )); + } + + // Recursively cast values + let values = cast_with_options(values, field.data_type(), cast_options)?; + let offsets: Vec<_> = offsets.iter().map(|x| O::usize_as(x.as_usize())).collect(); + + // Safety: valid offsets and checked for overflow + let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; + + Ok(Arc::new(GenericListArray::::new( + field.clone(), + offsets, + values, + nulls, + ))) +} + +#[cfg(test)] +mod tests { + use arrow_buffer::{Buffer, NullBuffer}; + use half::f16; + + use super::*; + + macro_rules! generate_cast_test_case { + ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { + let output = + $OUTPUT_TYPE_ARRAY::from($OUTPUT_VALUES).with_data_type($OUTPUT_TYPE.clone()); + + // assert cast type + let input_array_type = $INPUT_ARRAY.data_type(); + assert!(can_cast_types(input_array_type, $OUTPUT_TYPE)); + let result = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap(); + assert_eq!($OUTPUT_TYPE, result.data_type()); + assert_eq!(result.as_ref(), &output); + + let cast_option = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let result = cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); + assert_eq!($OUTPUT_TYPE, result.data_type()); + assert_eq!(result.as_ref(), &output); + }; + } + + fn create_decimal_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + fn create_decimal256_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + #[test] + #[cfg(not(feature = "force_validate"))] + #[should_panic( + expected = "Cannot cast to Decimal128(20, 3). Overflowing on 57896044618658097711785492504343953926634992332820282019728792003956564819967" + )] + fn test_cast_decimal_to_decimal_round_with_error() { + // decimal256 to decimal128 overflow + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + Some(i256::MAX), + Some(i256::MIN), + ]; + let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + let input_type = DataType::Decimal256(76, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None, + None, + None, + ] + ); + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_cast_decimal_to_decimal_round() { + let array = vec![ + Some(1123454), + Some(2123456), + Some(-3123453), + Some(-3123456), + None, + ]; + let array = create_decimal_array(array, 20, 4).unwrap(); + // decimal128 to decimal128 + let input_type = DataType::Decimal128(20, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None + ] + ); + + // decimal128 to decimal256 + let input_type = DataType::Decimal128(20, 4); + let output_type = DataType::Decimal256(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(112345_i128)), + Some(i256::from_i128(212346_i128)), + Some(i256::from_i128(-312345_i128)), + Some(i256::from_i128(-312346_i128)), + None + ] + ); + + // decimal256 + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + ]; + let array = create_decimal256_array(array, 20, 4).unwrap(); + + // decimal256 to decimal256 + let input_type = DataType::Decimal256(20, 4); + let output_type = DataType::Decimal256(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(112345_i128)), + Some(i256::from_i128(212346_i128)), + Some(i256::from_i128(-312345_i128)), + Some(i256::from_i128(-312346_i128)), + None + ] + ); + // decimal256 to decimal128 + let input_type = DataType::Decimal256(20, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None + ] + ); + } + + #[test] + fn test_cast_decimal128_to_decimal128() { + let input_type = DataType::Decimal128(20, 3); + let output_type = DataType::Decimal128(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(11234560_i128), + Some(21234560_i128), + Some(31234560_i128), + None + ] + ); + // negative test + let array = vec![Some(123456), None]; + let array = create_decimal_array(array, 10, 0).unwrap(); + let result = cast(&array, &DataType::Decimal128(2, 2)); + assert!(result.is_ok()); + let array = result.unwrap(); + let array: &Decimal128Array = array.as_primitive(); + let err = array.validate_decimal_precision(2); + assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99", + err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal128_to_decimal128_overflow() { + let input_type = DataType::Decimal128(38, 3); + let output_type = DataType::Decimal128(38, 38); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i128::MAX)]; + let array = create_decimal_array(array, 38, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal128(38, 38). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal128_to_decimal256_overflow() { + let input_type = DataType::Decimal128(38, 3); + let output_type = DataType::Decimal256(76, 76); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i128::MAX)]; + let array = create_decimal_array(array, 38, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal256(76, 76). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal128_to_decimal256() { + let input_type = DataType::Decimal128(20, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } + + #[test] + fn test_cast_decimal256_to_decimal128_overflow() { + let input_type = DataType::Decimal256(76, 5); + let output_type = DataType::Decimal128(38, 7); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(i256::from_i128(i128::MAX))]; + let array = create_decimal256_array(array, 76, 5).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal128(38, 7). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal256_to_decimal256_overflow() { + let input_type = DataType::Decimal256(76, 5); + let output_type = DataType::Decimal256(76, 55); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(i256::from_i128(i128::MAX))]; + let array = create_decimal256_array(array, 76, 5).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Cast error: Cannot cast to Decimal256(76, 55). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal256_to_decimal128() { + let input_type = DataType::Decimal256(20, 3); + let output_type = DataType::Decimal128(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![ + Some(i256::from_i128(1123456)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(3123456)), + None, + ]; + let array = create_decimal256_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(11234560_i128), + Some(21234560_i128), + Some(31234560_i128), + None + ] + ); + } + + #[test] + fn test_cast_decimal256_to_decimal256() { + let input_type = DataType::Decimal256(20, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![ + Some(i256::from_i128(1123456)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(3123456)), + None, + ]; + let array = create_decimal256_array(array, 20, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } + + #[test] + fn test_cast_decimal_to_numeric() { + let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + // u8 + generate_cast_test_case!( + &array, + UInt8Array, + &DataType::UInt8, + vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] + ); + // u16 + generate_cast_test_case!( + &array, + UInt16Array, + &DataType::UInt16, + vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] + ); + // u32 + generate_cast_test_case!( + &array, + UInt32Array, + &DataType::UInt32, + vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] + ); + // u64 + generate_cast_test_case!( + &array, + UInt64Array, + &DataType::UInt64, + vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] + ); + // i8 + generate_cast_test_case!( + &array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + &array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + &array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32) + ] + ); + // f64 + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64) + ] + ); + + // overflow test: out of range of max u8 + let value_array: Vec> = vec![Some(51300)]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + let casted_array = cast_with_options( + &array, + &DataType::UInt8, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: value of 513 is out of range UInt8".to_string(), + casted_array.unwrap_err().to_string() + ); + + let casted_array = cast_with_options( + &array, + &DataType::UInt8, + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + // overflow test: out of range of max i8 + let value_array: Vec> = vec![Some(24400)]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: value of 244 is out of range Int8".to_string(), + casted_array.unwrap_err().to_string() + ); + + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + // loss the precision: convert decimal to f32、f64 + // f32 + // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678), + Some(112345679), + ]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32), + Some(1_123_456.7_f32), + Some(1_123_456.7_f32) + ] + ); + + // f64 + // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678901234568), + Some(112345678901234560), + ]; + let array = create_decimal_array(value_array, 38, 2).unwrap(); + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64), + Some(1_123_456_789_012_345.6_f64), + Some(1_123_456_789_012_345.6_f64), + ] + ); + } + + #[test] + fn test_cast_decimal256_to_numeric() { + let value_array: Vec> = vec![ + Some(i256::from_i128(125)), + Some(i256::from_i128(225)), + Some(i256::from_i128(325)), + None, + Some(i256::from_i128(525)), + ]; + let array = create_decimal256_array(value_array, 38, 2).unwrap(); + // u8 + generate_cast_test_case!( + &array, + UInt8Array, + &DataType::UInt8, + vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] + ); + // u16 + generate_cast_test_case!( + &array, + UInt16Array, + &DataType::UInt16, + vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] + ); + // u32 + generate_cast_test_case!( + &array, + UInt32Array, + &DataType::UInt32, + vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] + ); + // u64 + generate_cast_test_case!( + &array, + UInt64Array, + &DataType::UInt64, + vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] + ); + // i8 + generate_cast_test_case!( + &array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + &array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + &array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32) + ] + ); + // f64 + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64) + ] + ); + + // overflow test: out of range of max i8 + let value_array: Vec> = vec![Some(i256::from_i128(24400))]; + let array = create_decimal256_array(value_array, 38, 2).unwrap(); + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: value of 244 is out of range Int8".to_string(), + casted_array.unwrap_err().to_string() + ); + + let casted_array = cast_with_options( + &array, + &DataType::Int8, + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + // loss the precision: convert decimal to f32、f64 + // f32 + // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision. + let value_array: Vec> = vec![ + Some(i256::from_i128(125)), + Some(i256::from_i128(225)), + Some(i256::from_i128(325)), + None, + Some(i256::from_i128(525)), + Some(i256::from_i128(112345678)), + Some(i256::from_i128(112345679)), + ]; + let array = create_decimal256_array(value_array, 76, 2).unwrap(); + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32), + Some(1_123_456.7_f32), + Some(1_123_456.7_f32) + ] + ); + + // f64 + // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision. + let value_array: Vec> = vec![ + Some(i256::from_i128(125)), + Some(i256::from_i128(225)), + Some(i256::from_i128(325)), + None, + Some(i256::from_i128(525)), + Some(i256::from_i128(112345678901234568)), + Some(i256::from_i128(112345678901234560)), + ]; + let array = create_decimal256_array(value_array, 76, 2).unwrap(); + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64), + Some(1_123_456_789_012_345.6_f64), + Some(1_123_456_789_012_345.6_f64), + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal128() { + let decimal_type = DataType::Decimal128(38, 6); + // u8, u16, u32, u64 + let input_datas = vec![ + Arc::new(UInt8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u8 + Arc::new(UInt16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u16 + Arc::new(UInt32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u32 + Arc::new(UInt64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u64 + ]; + + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1000000_i128), + Some(2000000_i128), + Some(3000000_i128), + None, + Some(5000000_i128) + ] + ); + } + + // i8, i16, i32, i64 + let input_datas = vec![ + Arc::new(Int8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i8 + Arc::new(Int16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i16 + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i32 + Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i64 + ]; + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1000000_i128), + Some(2000000_i128), + Some(3000000_i128), + None, + Some(5000000_i128) + ] + ); + } + + // test u8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = UInt8Array::from(vec![1, 2, 3, 4, 100]); + let casted_array = cast(&array, &DataType::Decimal128(3, 1)); + assert!(casted_array.is_ok()); + let array = casted_array.unwrap(); + let array: &Decimal128Array = array.as_primitive(); + assert!(array.is_null(4)); + + // test i8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = Int8Array::from(vec![1, 2, 3, 4, 100]); + let casted_array = cast(&array, &DataType::Decimal128(3, 1)); + assert!(casted_array.is_ok()); + let array = casted_array.unwrap(); + let array: &Decimal128Array = array.as_primitive(); + assert!(array.is_null(4)); + + // test f32 to decimal type + let array = Float32Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1100000_i128), + Some(2200000_i128), + Some(4400000_i128), + None, + Some(1123456_i128), // round down + Some(1123457_i128), // round up + ] + ); + + // test f64 to decimal type + let array = Float64Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_489_123_4), // round up + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up + ]); + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(1100000_i128), + Some(2200000_i128), + Some(4400000_i128), + None, + Some(1123456_i128), // round down + Some(1123457_i128), // round up + Some(1123456_i128), // round down + Some(1123457_i128), // round up + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal256() { + let decimal_type = DataType::Decimal256(76, 6); + // u8, u16, u32, u64 + let input_datas = vec![ + Arc::new(UInt8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u8 + Arc::new(UInt16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u16 + Arc::new(UInt32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u32 + Arc::new(UInt64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // u64 + ]; + + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1000000_i128)), + Some(i256::from_i128(2000000_i128)), + Some(i256::from_i128(3000000_i128)), + None, + Some(i256::from_i128(5000000_i128)) + ] + ); + } + + // i8, i16, i32, i64 + let input_datas = vec![ + Arc::new(Int8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i8 + Arc::new(Int16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i16 + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i32 + Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i64 + ]; + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1000000_i128)), + Some(i256::from_i128(2000000_i128)), + Some(i256::from_i128(3000000_i128)), + None, + Some(i256::from_i128(5000000_i128)) + ] + ); + } + + // test i8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = Int8Array::from(vec![1, 2, 3, 4, 100]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &DataType::Decimal256(3, 1)); + assert!(casted_array.is_ok()); + let array = casted_array.unwrap(); + let array: &Decimal256Array = array.as_primitive(); + assert!(array.is_null(4)); + + // test f32 to decimal type + let array = Float32Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_4), // round down + Some(1.123_456_7), // round up + ]); + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1100000_i128)), + Some(i256::from_i128(2200000_i128)), + Some(i256::from_i128(4400000_i128)), + None, + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + ] + ); + + // test f64 to decimal type + let array = Float64Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_489_123_4), // round down + Some(1.123_456_789_123_4), // round up + Some(1.123_456_489_012_345_6), // round down + Some(1.123_456_789_012_345_6), // round up + ]); + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1100000_i128)), + Some(i256::from_i128(2200000_i128)), + Some(i256::from_i128(4400000_i128)), + None, + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + Some(i256::from_i128(1123456_i128)), // round down + Some(i256::from_i128(1123457_i128)), // round up + ] + ); + } + + #[test] + fn test_cast_i32_to_f64() { + let array = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = cast(&array, &DataType::Float64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(5.0, c.value(0)); + assert_eq!(6.0, c.value(1)); + assert_eq!(7.0, c.value(2)); + assert_eq!(8.0, c.value(3)); + assert_eq!(9.0, c.value(4)); + } + + #[test] + fn test_cast_i32_to_u8() { + let array = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); + let b = cast(&array, &DataType::UInt8).unwrap(); + let c = b.as_primitive::(); + assert!(!c.is_valid(0)); + assert_eq!(6, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(8, c.value(3)); + // overflows return None + assert!(!c.is_valid(4)); + } + + #[test] + #[should_panic(expected = "Can't cast value -5 to type UInt8")] + fn test_cast_int32_to_u8_with_error() { + let array = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); + // overflow with the error + let cast_option = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let result = cast_with_options(&array, &DataType::UInt8, &cast_option); + assert!(result.is_err()); + result.unwrap(); + } + + #[test] + fn test_cast_i32_to_u8_sliced() { + let array = Int32Array::from(vec![-5, 6, -7, 8, 100000000]); + assert_eq!(0, array.offset()); + let array = array.slice(2, 3); + let b = cast(&array, &DataType::UInt8).unwrap(); + assert_eq!(3, b.len()); + let c = b.as_primitive::(); + assert!(!c.is_valid(0)); + assert_eq!(8, c.value(1)); + // overflows return None + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_i32_to_i32() { + let array = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert_eq!(7, c.value(2)); + assert_eq!(8, c.value(3)); + assert_eq!(9, c.value(4)); + } + + #[test] + fn test_cast_i32_to_list_i32() { + let array = Int32Array::from(vec![5, 6, 7, 8, 9]); + let b = cast( + &array, + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + ) + .unwrap(); + assert_eq!(5, b.len()); + let arr = b.as_list::(); + assert_eq!(&[0, 1, 2, 3, 4, 5], arr.value_offsets()); + assert_eq!(1, arr.value_length(0)); + assert_eq!(1, arr.value_length(1)); + assert_eq!(1, arr.value_length(2)); + assert_eq!(1, arr.value_length(3)); + assert_eq!(1, arr.value_length(4)); + let c = arr.values().as_primitive::(); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert_eq!(7, c.value(2)); + assert_eq!(8, c.value(3)); + assert_eq!(9, c.value(4)); + } + + #[test] + fn test_cast_i32_to_list_i32_nullable() { + let array = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]); + let b = cast( + &array, + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + ) + .unwrap(); + assert_eq!(5, b.len()); + assert_eq!(0, b.null_count()); + let arr = b.as_list::(); + assert_eq!(&[0, 1, 2, 3, 4, 5], arr.value_offsets()); + assert_eq!(1, arr.value_length(0)); + assert_eq!(1, arr.value_length(1)); + assert_eq!(1, arr.value_length(2)); + assert_eq!(1, arr.value_length(3)); + assert_eq!(1, arr.value_length(4)); + + let c = arr.values().as_primitive::(); + assert_eq!(1, c.null_count()); + assert_eq!(5, c.value(0)); + assert!(!c.is_valid(1)); + assert_eq!(7, c.value(2)); + assert_eq!(8, c.value(3)); + assert_eq!(9, c.value(4)); + } + + #[test] + fn test_cast_i32_to_list_f64_nullable_sliced() { + let array = Int32Array::from(vec![Some(5), None, Some(7), Some(8), None, Some(10)]); + let array = array.slice(2, 4); + let b = cast( + &array, + &DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + ) + .unwrap(); + assert_eq!(4, b.len()); + assert_eq!(0, b.null_count()); + let arr = b.as_list::(); + assert_eq!(&[0, 1, 2, 3, 4], arr.value_offsets()); + assert_eq!(1, arr.value_length(0)); + assert_eq!(1, arr.value_length(1)); + assert_eq!(1, arr.value_length(2)); + assert_eq!(1, arr.value_length(3)); + let c = arr.values().as_primitive::(); + assert_eq!(1, c.null_count()); + assert_eq!(7.0, c.value(0)); + assert_eq!(8.0, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(10.0, c.value(3)); + } + + #[test] + fn test_cast_utf8_to_i32() { + let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(5, c.value(0)); + assert_eq!(6, c.value(1)); + assert!(!c.is_valid(2)); + assert_eq!(8, c.value(3)); + assert!(!c.is_valid(4)); + } + + #[test] + fn test_cast_with_options_utf8_to_i32() { + let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]); + let result = cast_with_options( + &array, + &DataType::Int32, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + match result { + Ok(_) => panic!("expected error"), + Err(e) => { + assert!( + e.to_string() + .contains("Cast error: Cannot cast string 'seven' to value of Int32 type",), + "Error: {e}" + ) + } + } + } + + #[test] + fn test_cast_utf8_to_bool() { + let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); + let casted = cast(&strings, &DataType::Boolean).unwrap(); + let expected = BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]); + assert_eq!(*as_boolean_array(&casted), expected); + } + + #[test] + fn test_cast_with_options_utf8_to_bool() { + let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); + let casted = cast_with_options( + &strings, + &DataType::Boolean, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + match casted { + Ok(_) => panic!("expected error"), + Err(e) => { + assert!(e + .to_string() + .contains("Cast error: Cannot cast value 'invalid' to value of Boolean type")) + } + } + } + + #[test] + fn test_cast_bool_to_i32() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1, c.value(0)); + assert_eq!(0, c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_bool_to_utf8() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_bool_to_large_utf8() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::LargeUtf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_bool_to_f64() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Float64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1.0, c.value(0)); + assert_eq!(0.0, c.value(1)); + assert!(!c.is_valid(2)); + } + + #[test] + fn test_cast_integer_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Int8Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Int16Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Int32Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt8Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt16Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt32Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = UInt64Array::from(vec![Some(2), Some(10), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_integer() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast(&cast(&array, &DataType::Int8).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Int16).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Int32).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt8).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt16).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt32).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::UInt64).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_floating_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Float16Array::from(vec![ + Some(f16::from_f32(2.0)), + Some(f16::from_f32(10.6)), + None, + ]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Float32Array::from(vec![Some(2.0), Some(10.6), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Float64Array::from(vec![Some(2.1), Some(10.2), None]); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_floating() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast(&cast(&array, &DataType::Float16).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Float32).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast(&cast(&array, &DataType::Float64).unwrap(), &DataType::Int64).unwrap(); + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_decimal_to_timestamp() { + let array = Int64Array::from(vec![Some(2), Some(10), None]); + let expected = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + let array = Decimal128Array::from(vec![Some(200), Some(1000), None]) + .with_precision_and_scale(4, 2) + .unwrap(); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + + let array = Decimal256Array::from(vec![ + Some(i256::from_i128(2000)), + Some(i256::from_i128(10000)), + None, + ]) + .with_precision_and_scale(5, 3) + .unwrap(); + let actual = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_timestamp_to_decimal() { + let array = TimestampMillisecondArray::from(vec![Some(5), Some(1), None]) + .with_timezone("UTC".to_string()); + let expected = cast(&array, &DataType::Int64).unwrap(); + + let actual = cast( + &cast(&array, &DataType::Decimal128(5, 2)).unwrap(), + &DataType::Int64, + ) + .unwrap(); + assert_eq!(&actual, &expected); + + let actual = cast( + &cast(&array, &DataType::Decimal256(10, 5)).unwrap(), + &DataType::Int64, + ) + .unwrap(); + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_list_i32_to_list_u16() { + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).into_data(); + + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + + // Construct a list array from the above two + // [[0,0,0], [-1, -2, -1], [2, 100000000]] + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + let list_array = ListArray::from(list_data); + + let cast_array = cast( + &list_array, + &DataType::List(Arc::new(Field::new("item", DataType::UInt16, true))), + ) + .unwrap(); + + // For the ListArray itself, there are no null values (as there were no nulls when they went in) + // + // 3 negative values should get lost when casting to unsigned, + // 1 value should overflow + assert_eq!(0, cast_array.null_count()); + + // offsets should be the same + let array = cast_array.as_list::(); + assert_eq!(list_array.value_offsets(), array.value_offsets()); + + assert_eq!(DataType::UInt16, array.value_type()); + assert_eq!(3, array.value_length(0)); + assert_eq!(3, array.value_length(1)); + assert_eq!(2, array.value_length(2)); + + // expect 4 nulls: negative numbers and overflow + let u16arr = array.values().as_primitive::(); + assert_eq!(4, u16arr.null_count()); + + // expect 4 nulls: negative numbers and overflow + let expected: UInt16Array = + vec![Some(0), Some(0), Some(0), None, None, None, Some(2), None] + .into_iter() + .collect(); + + assert_eq!(u16arr, &expected); + } + + #[test] + fn test_cast_list_i32_to_list_timestamp() { + // Construct a value array + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 8, 100000000]).into_data(); + + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 9]); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + let actual = cast( + &list_array, + &DataType::List(Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ))), + ) + .unwrap(); + + let expected = cast( + &cast( + &list_array, + &DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + ) + .unwrap(), + &DataType::List(Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ))), + ) + .unwrap(); + + assert_eq!(&actual, &expected); + } + + #[test] + fn test_cast_date32_to_date64() { + let a = Date32Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000000, c.value(0)); + assert_eq!(1545696000000, c.value(1)); + } + + #[test] + fn test_cast_date64_to_date32() { + let a = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_string_to_timestamp() { + let a1 = Arc::new(StringArray::from(vec![ + Some("2020-09-08T12:00:00.123456789+00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2020-09-08T12:00:00.123456789+00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + for time_unit in &[ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + let to_type = DataType::Timestamp(time_unit.clone(), None); + let b = cast(array, &to_type).unwrap(); + + match time_unit { + TimeUnit::Second => { + let c = b.as_primitive::(); + assert_eq!(1599566400, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + TimeUnit::Millisecond => { + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1599566400123, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + TimeUnit::Microsecond => { + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1599566400123456, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + TimeUnit::Nanosecond => { + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1599566400123456789, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + } + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!( + err.to_string(), + "Parser error: Error parsing timestamp from 'Not a valid date': error parsing date" + ); + } + } + } + + #[test] + fn test_cast_string_to_timestamp_overflow() { + let array = StringArray::from(vec!["9800-09-08T12:00:00.123456789"]); + let result = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.values(), &[247112596800]); + } + + #[test] + fn test_cast_string_to_date32() { + let a1 = Arc::new(StringArray::from(vec![ + Some("2018-12-25"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2018-12-25"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Date32; + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(17890, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid date' to value of Date32 type" + ); + } + } + + #[test] + fn test_cast_string_format_yyyymmdd_to_date32() { + let a = Arc::new(StringArray::from(vec![ + Some("2020-12-25"), + Some("20201117"), + ])) as ArrayRef; + + let to_type = DataType::Date32; + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let result = cast_with_options(&a, &to_type, &options).unwrap(); + let c = result.as_primitive::(); + assert_eq!( + chrono::NaiveDate::from_ymd_opt(2020, 12, 25), + c.value_as_date(0) + ); + assert_eq!( + chrono::NaiveDate::from_ymd_opt(2020, 11, 17), + c.value_as_date(1) + ); + } + + #[test] + fn test_cast_string_to_time32second() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time32(TimeUnit::Second); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315, c.value(0)); + assert_eq!(29340, c.value(1)); + assert!(c.is_null(2)); + assert!(c.is_null(3)); + assert!(c.is_null(4)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Second) type"); + } + } + + #[test] + fn test_cast_string_to_time32millisecond() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time32(TimeUnit::Millisecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315091, c.value(0)); + assert_eq!(29340091, c.value(1)); + assert!(c.is_null(2)); + assert!(c.is_null(3)); + assert!(c.is_null(4)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Millisecond) type"); + } + } + + #[test] + fn test_cast_string_to_time64microsecond() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time64(TimeUnit::Microsecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315091323, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Microsecond) type"); + } + } + + #[test] + fn test_cast_string_to_time64nanosecond() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time64(TimeUnit::Nanosecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(29315091323414, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Nanosecond) type"); + } + } + + #[test] + fn test_cast_string_to_date64() { + let a1 = Arc::new(StringArray::from(vec![ + Some("2020-09-08T12:00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2020-09-08T12:00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Date64; + let b = cast(array, &to_type).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1599566400000, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid date' to value of Date64 type" + ); + } + } + + macro_rules! test_safe_string_to_interval { + ($data_vec:expr, $interval_unit:expr, $array_ty:ty, $expect_vec:expr) => { + let source_string_array = Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; + + let options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + + let target_interval_array = cast_with_options( + &source_string_array.clone(), + &DataType::Interval($interval_unit), + &options, + ) + .unwrap() + .as_any() + .downcast_ref::<$array_ty>() + .unwrap() + .clone() as $array_ty; + + let target_string_array = + cast_with_options(&target_interval_array, &DataType::Utf8, &options) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + + let expect_string_array = StringArray::from($expect_vec); + + assert_eq!(target_string_array, expect_string_array); + + let target_large_string_array = + cast_with_options(&target_interval_array, &DataType::LargeUtf8, &options) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + + let expect_large_string_array = LargeStringArray::from($expect_vec); + + assert_eq!(target_large_string_array, expect_large_string_array); + }; + } + + #[test] + fn test_cast_string_to_interval_year_month() { + test_safe_string_to_interval!( + vec![ + Some("1 year 1 month"), + Some("1.5 years 13 month"), + Some("30 days"), + Some("31 days"), + Some("2 months 31 days"), + Some("2 months 31 days 1 second"), + Some("foobar"), + ], + IntervalUnit::YearMonth, + IntervalYearMonthArray, + vec![ + Some("1 years 1 mons 0 days 0 hours 0 mins 0.00 secs"), + Some("2 years 7 mons 0 days 0 hours 0 mins 0.00 secs"), + None, + None, + None, + None, + None, + ] + ); + } + + #[test] + fn test_cast_string_to_interval_day_time() { + test_safe_string_to_interval!( + vec![ + Some("1 year 1 month"), + Some("1.5 years 13 month"), + Some("30 days"), + Some("1 day 2 second 3.5 milliseconds"), + Some("foobar"), + ], + IntervalUnit::DayTime, + IntervalDayTimeArray, + vec![ + Some("0 years 0 mons 390 days 0 hours 0 mins 0.000 secs"), + Some("0 years 0 mons 930 days 0 hours 0 mins 0.000 secs"), + Some("0 years 0 mons 30 days 0 hours 0 mins 0.000 secs"), + None, + None, + ] + ); + } + + #[test] + fn test_cast_string_to_interval_month_day_nano() { + test_safe_string_to_interval!( + vec![ + Some("1 year 1 month 1 day"), + None, + Some("1.5 years 13 month 35 days 1.4 milliseconds"), + Some("3 days"), + Some("8 seconds"), + None, + Some("1 day 29800 milliseconds"), + Some("3 months 1 second"), + Some("6 minutes 120 second"), + Some("2 years 39 months 9 days 19 hours 1 minute 83 seconds 399222 milliseconds"), + Some("foobar"), + ], + IntervalUnit::MonthDayNano, + IntervalMonthDayNanoArray, + vec![ + Some("0 years 13 mons 1 days 0 hours 0 mins 0.000000000 secs"), + None, + Some("0 years 31 mons 35 days 0 hours 0 mins 0.001400000 secs"), + Some("0 years 0 mons 3 days 0 hours 0 mins 0.000000000 secs"), + Some("0 years 0 mons 0 days 0 hours 0 mins 8.000000000 secs"), + None, + Some("0 years 0 mons 1 days 0 hours 0 mins 29.800000000 secs"), + Some("0 years 3 mons 0 days 0 hours 0 mins 1.000000000 secs"), + Some("0 years 0 mons 0 days 0 hours 8 mins 0.000000000 secs"), + Some("0 years 63 mons 9 days 19 hours 9 mins 2.222000000 secs"), + None, + ] + ); + } + + macro_rules! test_unsafe_string_to_interval_err { + ($data_vec:expr, $interval_unit:expr, $error_msg:expr) => { + let string_array = Arc::new(StringArray::from($data_vec.clone())) as ArrayRef; + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let arrow_err = cast_with_options( + &string_array.clone(), + &DataType::Interval($interval_unit), + &options, + ) + .unwrap_err(); + assert_eq!($error_msg, arrow_err.to_string()); + }; + } + + #[test] + fn test_cast_string_to_interval_err() { + test_unsafe_string_to_interval_err!( + vec![Some("foobar")], + IntervalUnit::YearMonth, + r#"Not yet implemented: Unsupported Interval Expression with value "foobar""# + ); + test_unsafe_string_to_interval_err!( + vec![Some("foobar")], + IntervalUnit::DayTime, + r#"Not yet implemented: Unsupported Interval Expression with value "foobar""# + ); + test_unsafe_string_to_interval_err!( + vec![Some("foobar")], + IntervalUnit::MonthDayNano, + r#"Not yet implemented: Unsupported Interval Expression with value "foobar""# + ); + test_unsafe_string_to_interval_err!( + vec![Some("2 months 31 days 1 second")], + IntervalUnit::YearMonth, + r#"Cast error: Cannot cast 2 months 31 days 1 second to IntervalYearMonth. Only year and month fields are allowed."# + ); + test_unsafe_string_to_interval_err!( + vec![Some("1 day 1.5 milliseconds")], + IntervalUnit::DayTime, + r#"Cast error: Cannot cast 1 day 1.5 milliseconds to IntervalDayTime because the nanos part isn't multiple of milliseconds"# + ); + + // overflow + test_unsafe_string_to_interval_err!( + vec![Some(format!( + "{} century {} year {} month", + i64::MAX - 2, + i64::MAX - 2, + i64::MAX - 2 + ))], + IntervalUnit::DayTime, + format!( + "Compute error: Overflow happened on: {} * 100", + i64::MAX - 2 + ) + ); + test_unsafe_string_to_interval_err!( + vec![Some(format!( + "{} year {} month {} day", + i64::MAX - 2, + i64::MAX - 2, + i64::MAX - 2 + ))], + IntervalUnit::MonthDayNano, + format!("Compute error: Overflow happened on: {} * 12", i64::MAX - 2) + ); + } + + #[test] + fn test_cast_binary_to_fixed_size_binary() { + let bytes_1 = "Hiiii".as_bytes(); + let bytes_2 = "Hello".as_bytes(); + + let binary_data = vec![Some(bytes_1), Some(bytes_2), None]; + let a1 = Arc::new(BinaryArray::from(binary_data.clone())) as ArrayRef; + let a2 = Arc::new(LargeBinaryArray::from(binary_data)) as ArrayRef; + + let array_ref = cast(&a1, &DataType::FixedSizeBinary(5)).unwrap(); + let down_cast = array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let array_ref = cast(&a2, &DataType::FixedSizeBinary(5)).unwrap(); + let down_cast = array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + + // test error cases when the length of binary are not same + let bytes_1 = "Hi".as_bytes(); + let bytes_2 = "Hello".as_bytes(); + + let binary_data = vec![Some(bytes_1), Some(bytes_2), None]; + let a1 = Arc::new(BinaryArray::from(binary_data.clone())) as ArrayRef; + let a2 = Arc::new(LargeBinaryArray::from(binary_data)) as ArrayRef; + + let array_ref = cast_with_options( + &a1, + &DataType::FixedSizeBinary(5), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(array_ref.is_err()); + + let array_ref = cast_with_options( + &a2, + &DataType::FixedSizeBinary(5), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(array_ref.is_err()); + } + + #[test] + fn test_fixed_size_binary_to_binary() { + let bytes_1 = "Hiiii".as_bytes(); + let bytes_2 = "Hello".as_bytes(); + + let binary_data = vec![Some(bytes_1), Some(bytes_2), None]; + let a1 = Arc::new(FixedSizeBinaryArray::from(binary_data.clone())) as ArrayRef; + + let array_ref = cast(&a1, &DataType::Binary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let array_ref = cast(&a1, &DataType::LargeBinary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + } + + #[test] + fn test_numeric_to_binary() { + let a = Int16Array::from(vec![Some(1), Some(511), None]); + + let array_ref = cast(&a, &DataType::Binary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&1_i16.to_le_bytes(), down_cast.value(0)); + assert_eq!(&511_i16.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let a = Int64Array::from(vec![Some(-1), Some(123456789), None]); + + let array_ref = cast(&a, &DataType::Binary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&(-1_i64).to_le_bytes(), down_cast.value(0)); + assert_eq!(&123456789_i64.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + } + + #[test] + fn test_numeric_to_large_binary() { + let a = Int16Array::from(vec![Some(1), Some(511), None]); + + let array_ref = cast(&a, &DataType::LargeBinary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&1_i16.to_le_bytes(), down_cast.value(0)); + assert_eq!(&511_i16.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + + let a = Int64Array::from(vec![Some(-1), Some(123456789), None]); + + let array_ref = cast(&a, &DataType::LargeBinary).unwrap(); + let down_cast = array_ref.as_binary::(); + assert_eq!(&(-1_i64).to_le_bytes(), down_cast.value(0)); + assert_eq!(&123456789_i64.to_le_bytes(), down_cast.value(1)); + assert!(down_cast.is_null(2)); + } + + #[test] + fn test_cast_date32_to_int32() { + let array = Date32Array::from(vec![10000, 17890]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + } + + #[test] + fn test_cast_int32_to_date32() { + let array = Int32Array::from(vec![10000, 17890]); + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + } + + #[test] + fn test_cast_timestamp_to_date32() { + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]) + .with_timezone("UTC".to_string()); + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + assert_eq!(10000, c.value(0)); + assert_eq!(17890, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_timestamp_to_date64() { + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + + let array = TimestampSecondArray::from(vec![Some(864000000005), Some(1545696000001)]); + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000005000, c.value(0)); + assert_eq!(1545696000001000, c.value(1)); + + // test overflow, safe cast + let array = TimestampSecondArray::from(vec![Some(i64::MAX)]); + let b = cast(&array, &DataType::Date64).unwrap(); + assert!(b.is_null(0)); + // test overflow, unsafe cast + let array = TimestampSecondArray::from(vec![Some(i64::MAX)]); + let options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let b = cast_with_options(&array, &DataType::Date64, &options); + assert!(b.is_err()); + } + + #[test] + fn test_cast_timestamp_to_time64() { + // test timestamp secs + let array = TimestampSecondArray::from(vec![Some(86405), Some(1), None]) + .with_timezone("+01:00".to_string()); + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp milliseconds + let a = TimestampMillisecondArray::from(vec![Some(86405000), Some(1000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp microseconds + let a = TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp nanoseconds + let a = TimestampNanosecondArray::from(vec![Some(86405000000000), Some(1000000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000, c.value(0)); + assert_eq!(3601000000, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000000000, c.value(0)); + assert_eq!(3601000000000, c.value(1)); + assert!(c.is_null(2)); + + // test overflow + let a = + TimestampSecondArray::from(vec![Some(i64::MAX)]).with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time64(TimeUnit::Microsecond)); + assert!(b.is_err()); + let b = cast(&array, &DataType::Time64(TimeUnit::Nanosecond)); + assert!(b.is_err()); + let b = cast(&array, &DataType::Time64(TimeUnit::Millisecond)); + assert!(b.is_err()); + } + + #[test] + fn test_cast_timestamp_to_time32() { + // test timestamp secs + let a = TimestampSecondArray::from(vec![Some(86405), Some(1), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp milliseconds + let a = TimestampMillisecondArray::from(vec![Some(86405000), Some(1000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp microseconds + let a = TimestampMicrosecondArray::from(vec![Some(86405000000), Some(1000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test timestamp nanoseconds + let a = TimestampNanosecondArray::from(vec![Some(86405000000000), Some(1000000000), None]) + .with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605, c.value(0)); + assert_eq!(3601, c.value(1)); + assert!(c.is_null(2)); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(3605000, c.value(0)); + assert_eq!(3601000, c.value(1)); + assert!(c.is_null(2)); + + // test overflow + let a = + TimestampSecondArray::from(vec![Some(i64::MAX)]).with_timezone("+01:00".to_string()); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Time32(TimeUnit::Second)); + assert!(b.is_err()); + let b = cast(&array, &DataType::Time32(TimeUnit::Millisecond)); + assert!(b.is_err()); + } + + // Cast Timestamp(_, None) -> Timestamp(_, Some(timezone)) + #[test] + fn test_cast_timestamp_with_timezone_1() { + let string_array: Arc = Arc::new(StringArray::from(vec![ + Some("2000-01-01T00:00:00.123456789"), + Some("2010-01-01T00:00:00.123456789"), + None, + ])); + let to_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let timestamp_array = cast(&string_array, &to_type).unwrap(); + + let to_type = DataType::Timestamp(TimeUnit::Microsecond, Some("+0700".into())); + let timestamp_array = cast(×tamp_array, &to_type).unwrap(); + + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T00:00:00.123456+07:00", result.value(0)); + assert_eq!("2010-01-01T00:00:00.123456+07:00", result.value(1)); + assert!(result.is_null(2)); + } + + // Cast Timestamp(_, Some(timezone)) -> Timestamp(_, None) + #[test] + fn test_cast_timestamp_with_timezone_2() { + let string_array: Arc = Arc::new(StringArray::from(vec![ + Some("2000-01-01T07:00:00.123456789"), + Some("2010-01-01T07:00:00.123456789"), + None, + ])); + let to_type = DataType::Timestamp(TimeUnit::Millisecond, Some("+0700".into())); + let timestamp_array = cast(&string_array, &to_type).unwrap(); + + // Check intermediate representation is correct + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T07:00:00.123+07:00", result.value(0)); + assert_eq!("2010-01-01T07:00:00.123+07:00", result.value(1)); + assert!(result.is_null(2)); + + let to_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let timestamp_array = cast(×tamp_array, &to_type).unwrap(); + + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T00:00:00.123", result.value(0)); + assert_eq!("2010-01-01T00:00:00.123", result.value(1)); + assert!(result.is_null(2)); + } + + // Cast Timestamp(_, Some(timezone)) -> Timestamp(_, Some(timezone)) + #[test] + fn test_cast_timestamp_with_timezone_3() { + let string_array: Arc = Arc::new(StringArray::from(vec![ + Some("2000-01-01T07:00:00.123456789"), + Some("2010-01-01T07:00:00.123456789"), + None, + ])); + let to_type = DataType::Timestamp(TimeUnit::Microsecond, Some("+0700".into())); + let timestamp_array = cast(&string_array, &to_type).unwrap(); + + // Check intermediate representation is correct + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("2000-01-01T07:00:00.123456+07:00", result.value(0)); + assert_eq!("2010-01-01T07:00:00.123456+07:00", result.value(1)); + assert!(result.is_null(2)); + + let to_type = DataType::Timestamp(TimeUnit::Second, Some("-08:00".into())); + let timestamp_array = cast(×tamp_array, &to_type).unwrap(); + + let string_array = cast(×tamp_array, &DataType::Utf8).unwrap(); + let result = string_array.as_string::(); + assert_eq!("1999-12-31T16:00:00-08:00", result.value(0)); + assert_eq!("2009-12-31T16:00:00-08:00", result.value(1)); + assert!(result.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000000, c.value(0)); + assert_eq!(1545696000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp_ms() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Millisecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp_us() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(864000000005000, c.value(0)); + assert_eq!(1545696000001000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date64_to_timestamp_ns() { + let array = Date64Array::from(vec![Some(864000000005), Some(1545696000001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(864000000005000000, c.value(0)); + assert_eq!(1545696000001000000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_timestamp_to_i64() { + let array = + TimestampMillisecondArray::from(vec![Some(864000000005), Some(1545696000001), None]) + .with_timezone("UTC".to_string()); + let b = cast(&array, &DataType::Int64).unwrap(); + let c = b.as_primitive::(); + assert_eq!(&DataType::Int64, c.data_type()); + assert_eq!(864000000005, c.value(0)); + assert_eq!(1545696000001, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_string() { + let array = Date32Array::from(vec![10000, 17890]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(&DataType::Utf8, c.data_type()); + assert_eq!("1997-05-19", c.value(0)); + assert_eq!("2018-12-25", c.value(1)); + } + + #[test] + fn test_cast_date64_to_string() { + let array = Date64Array::from(vec![10000 * 86400000, 17890 * 86400000]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(&DataType::Utf8, c.data_type()); + assert_eq!("1997-05-19T00:00:00", c.value(0)); + assert_eq!("2018-12-25T00:00:00", c.value(1)); + } + + #[test] + fn test_cast_timestamp_to_strings() { + // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None + let array = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); + let out = cast(&array, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19T00:00:03.005"), + Some("2018-12-25T00:00:02.001"), + None + ] + ); + let out = cast(&array, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19T00:00:03.005"), + Some("2018-12-25T00:00:02.001"), + None + ] + ); + } + + #[test] + fn test_cast_timestamp_to_strings_opt() { + let ts_format = "%Y-%m-%d %H:%M:%S%.6f"; + let tz = "+0545"; // UTC + 0545 is Asia/Kathmandu + let cast_options = CastOptions { + safe: true, + format_options: FormatOptions::default() + .with_timestamp_format(Some(ts_format)) + .with_timestamp_tz_format(Some(ts_format)), + }; + // "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None + let array_without_tz = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); + let out = cast_with_options(&array_without_tz, &DataType::Utf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 00:00:03.005000"), + Some("2018-12-25 00:00:02.001000"), + None + ] + ); + let out = + cast_with_options(&array_without_tz, &DataType::LargeUtf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 00:00:03.005000"), + Some("2018-12-25 00:00:02.001000"), + None + ] + ); + + let array_with_tz = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]) + .with_timezone(tz.to_string()); + let out = cast_with_options(&array_with_tz, &DataType::Utf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 05:45:03.005000"), + Some("2018-12-25 05:45:02.001000"), + None + ] + ); + let out = cast_with_options(&array_with_tz, &DataType::LargeUtf8, &cast_options).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + out, + vec![ + Some("1997-05-19 05:45:03.005000"), + Some("2018-12-25 05:45:02.001000"), + None + ] + ); + } + + #[test] + fn test_cast_between_timestamps() { + let array = + TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]); + let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(864000003, c.value(0)); + assert_eq!(1545696002, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_duration_to_i64() { + let base = vec![5, 6, 7, 8, 100000000]; + + let duration_arrays = vec![ + Arc::new(DurationNanosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMicrosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMillisecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationSecondArray::from(base.clone())) as ArrayRef, + ]; + + for arr in duration_arrays { + assert!(can_cast_types(arr.data_type(), &DataType::Int64)); + let result = cast(&arr, &DataType::Int64).unwrap(); + let result = result.as_primitive::(); + assert_eq!(base.as_slice(), result.values()); + } + } + + #[test] + fn test_cast_interval_to_i64() { + let base = vec![5, 6, 7, 8]; + + let interval_arrays = vec![ + Arc::new(IntervalDayTimeArray::from(base.clone())) as ArrayRef, + Arc::new(IntervalYearMonthArray::from( + base.iter().map(|x| *x as i32).collect::>(), + )) as ArrayRef, + ]; + + for arr in interval_arrays { + assert!(can_cast_types(arr.data_type(), &DataType::Int64)); + let result = cast(&arr, &DataType::Int64).unwrap(); + let result = result.as_primitive::(); + assert_eq!(base.as_slice(), result.values()); + } + } + + #[test] + fn test_cast_to_strings() { + let a = Int32Array::from(vec![1, 2, 3]); + let out = cast(&a, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(out, vec![Some("1"), Some("2"), Some("3")]); + let out = cast(&a, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(out, vec![Some("1"), Some("2"), Some("3")]); + } + + #[test] + fn test_str_to_str_casts() { + for data in [ + vec![Some("foo"), Some("bar"), Some("ham")], + vec![Some("foo"), None, Some("bar")], + ] { + let a = LargeStringArray::from(data.clone()); + let to = cast(&a, &DataType::Utf8).unwrap(); + let expect = a + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + let out = to + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(expect, out); + + let a = StringArray::from(data); + let to = cast(&a, &DataType::LargeUtf8).unwrap(); + let expect = a + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + let out = to + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(expect, out); + } + } + + #[test] + fn test_cast_from_f64() { + let f64_values: Vec = vec![ + i64::MIN as f64, + i32::MIN as f64, + i16::MIN as f64, + i8::MIN as f64, + 0_f64, + u8::MAX as f64, + u16::MAX as f64, + u32::MAX as f64, + u64::MAX as f64, + ]; + let f64_array: ArrayRef = Arc::new(Float64Array::from(f64_values)); + + let f64_expected = vec![ + -9223372036854776000.0, + -2147483648.0, + -32768.0, + -128.0, + 0.0, + 255.0, + 65535.0, + 4294967295.0, + 18446744073709552000.0, + ]; + assert_eq!( + f64_expected, + get_cast_values::(&f64_array, &DataType::Float64) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f32_expected = vec![ + -9223372000000000000.0, + -2147483600.0, + -32768.0, + -128.0, + 0.0, + 255.0, + 65535.0, + 4294967300.0, + 18446744000000000000.0, + ]; + assert_eq!( + f32_expected, + get_cast_values::(&f64_array, &DataType::Float32) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f16_expected = vec![ + f16::from_f64(-9223372000000000000.0), + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(255.0), + f16::from_f64(65535.0), + f16::from_f64(4294967300.0), + f16::from_f64(18446744000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&f64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec![ + "-9223372036854775808", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&f64_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "null", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "null", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&f64_array, &DataType::Int32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "255", "null", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&f64_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "null", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&f64_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&f64_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967295", + "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&f64_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&f64_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "255", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&f64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_f32() { + let f32_values: Vec = vec![ + i32::MIN as f32, + i32::MIN as f32, + i16::MIN as f32, + i8::MIN as f32, + 0_f32, + u8::MAX as f32, + u16::MAX as f32, + u32::MAX as f32, + u32::MAX as f32, + ]; + let f32_array: ArrayRef = Arc::new(Float32Array::from(f32_values)); + + let f64_expected = vec![ + "-2147483648.0", + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967296.0", + "4294967296.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&f32_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-2147483600.0", + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "255.0", + "65535.0", + "4294967300.0", + "4294967300.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&f32_array, &DataType::Float32) + ); + + let f16_expected = vec![ + "-inf", "-inf", "-32768.0", "-128.0", "0.0", "255.0", "inf", "inf", "inf", + ]; + assert_eq!( + f16_expected, + get_cast_values::(&f32_array, &DataType::Float16) + ); + + let i64_expected = vec![ + "-2147483648", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "4294967296", + "4294967296", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&f32_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "-2147483648", + "-2147483648", + "-32768", + "-128", + "0", + "255", + "65535", + "null", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&f32_array, &DataType::Int32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "255", "null", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&f32_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "null", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&f32_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "255", + "65535", + "4294967296", + "4294967296", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&f32_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&f32_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "255", "65535", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&f32_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "255", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&f32_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint64() { + let u64_values: Vec = vec![ + 0, + u8::MAX as u64, + u16::MAX as u64, + u32::MAX as u64, + u64::MAX, + ]; + let u64_array: ArrayRef = Arc::new(UInt64Array::from(u64_values)); + + let f64_expected = vec![0.0, 255.0, 65535.0, 4294967295.0, 18446744073709552000.0]; + assert_eq!( + f64_expected, + get_cast_values::(&u64_array, &DataType::Float64) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f32_expected = vec![0.0, 255.0, 65535.0, 4294967300.0, 18446744000000000000.0]; + assert_eq!( + f32_expected, + get_cast_values::(&u64_array, &DataType::Float32) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f16_expected = vec![ + f16::from_f64(0.0), + f16::from_f64(255.0), + f16::from_f64(65535.0), + f16::from_f64(4294967300.0), + f16::from_f64(18446744000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&u64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec!["0", "255", "65535", "4294967295", "null"]; + assert_eq!( + i64_expected, + get_cast_values::(&u64_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535", "null", "null"]; + assert_eq!( + i32_expected, + get_cast_values::(&u64_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null", "null", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u64_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u64_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535", "4294967295", "18446744073709551615"]; + assert_eq!( + u64_expected, + get_cast_values::(&u64_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535", "4294967295", "null"]; + assert_eq!( + u32_expected, + get_cast_values::(&u64_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535", "null", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&u64_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint32() { + let u32_values: Vec = vec![0, u8::MAX as u32, u16::MAX as u32, u32::MAX]; + let u32_array: ArrayRef = Arc::new(UInt32Array::from(u32_values)); + + let f64_expected = vec!["0.0", "255.0", "65535.0", "4294967295.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u32_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0", "65535.0", "4294967300.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u32_array, &DataType::Float32) + ); + + let f16_expected = vec!["0.0", "255.0", "inf", "inf"]; + assert_eq!( + f16_expected, + get_cast_values::(&u32_array, &DataType::Float16) + ); + + let i64_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + i64_expected, + get_cast_values::(&u32_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535", "null"]; + assert_eq!( + i32_expected, + get_cast_values::(&u32_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u32_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u32_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + u64_expected, + get_cast_values::(&u32_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535", "4294967295"]; + assert_eq!( + u32_expected, + get_cast_values::(&u32_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&u32_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u32_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint16() { + let u16_values: Vec = vec![0, u8::MAX as u16, u16::MAX]; + let u16_array: ArrayRef = Arc::new(UInt16Array::from(u16_values)); + + let f64_expected = vec!["0.0", "255.0", "65535.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u16_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0", "65535.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u16_array, &DataType::Float32) + ); + + let f16_expected = vec!["0.0", "255.0", "inf"]; + assert_eq!( + f16_expected, + get_cast_values::(&u16_array, &DataType::Float16) + ); + + let i64_expected = vec!["0", "255", "65535"]; + assert_eq!( + i64_expected, + get_cast_values::(&u16_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255", "65535"]; + assert_eq!( + i32_expected, + get_cast_values::(&u16_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&u16_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u16_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255", "65535"]; + assert_eq!( + u64_expected, + get_cast_values::(&u16_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255", "65535"]; + assert_eq!( + u32_expected, + get_cast_values::(&u16_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255", "65535"]; + assert_eq!( + u16_expected, + get_cast_values::(&u16_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&u16_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_uint8() { + let u8_values: Vec = vec![0, u8::MAX]; + let u8_array: ArrayRef = Arc::new(UInt8Array::from(u8_values)); + + let f64_expected = vec!["0.0", "255.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&u8_array, &DataType::Float64) + ); + + let f32_expected = vec!["0.0", "255.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&u8_array, &DataType::Float32) + ); + + let f16_expected = vec!["0.0", "255.0"]; + assert_eq!( + f16_expected, + get_cast_values::(&u8_array, &DataType::Float16) + ); + + let i64_expected = vec!["0", "255"]; + assert_eq!( + i64_expected, + get_cast_values::(&u8_array, &DataType::Int64) + ); + + let i32_expected = vec!["0", "255"]; + assert_eq!( + i32_expected, + get_cast_values::(&u8_array, &DataType::Int32) + ); + + let i16_expected = vec!["0", "255"]; + assert_eq!( + i16_expected, + get_cast_values::(&u8_array, &DataType::Int16) + ); + + let i8_expected = vec!["0", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&u8_array, &DataType::Int8) + ); + + let u64_expected = vec!["0", "255"]; + assert_eq!( + u64_expected, + get_cast_values::(&u8_array, &DataType::UInt64) + ); + + let u32_expected = vec!["0", "255"]; + assert_eq!( + u32_expected, + get_cast_values::(&u8_array, &DataType::UInt32) + ); + + let u16_expected = vec!["0", "255"]; + assert_eq!( + u16_expected, + get_cast_values::(&u8_array, &DataType::UInt16) + ); + + let u8_expected = vec!["0", "255"]; + assert_eq!( + u8_expected, + get_cast_values::(&u8_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int64() { + let i64_values: Vec = vec![ + i64::MIN, + i32::MIN as i64, + i16::MIN as i64, + i8::MIN as i64, + 0, + i8::MAX as i64, + i16::MAX as i64, + i32::MAX as i64, + i64::MAX, + ]; + let i64_array: ArrayRef = Arc::new(Int64Array::from(i64_values)); + + let f64_expected = vec![ + -9223372036854776000.0, + -2147483648.0, + -32768.0, + -128.0, + 0.0, + 127.0, + 32767.0, + 2147483647.0, + 9223372036854776000.0, + ]; + assert_eq!( + f64_expected, + get_cast_values::(&i64_array, &DataType::Float64) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f32_expected = vec![ + -9223372000000000000.0, + -2147483600.0, + -32768.0, + -128.0, + 0.0, + 127.0, + 32767.0, + 2147483600.0, + 9223372000000000000.0, + ]; + assert_eq!( + f32_expected, + get_cast_values::(&i64_array, &DataType::Float32) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let f16_expected = vec![ + f16::from_f64(-9223372000000000000.0), + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + f16::from_f64(2147483600.0), + f16::from_f64(9223372000000000000.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i64_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec![ + "-9223372036854775808", + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + "9223372036854775807", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&i64_array, &DataType::Int64) + ); + + let i32_expected = vec![ + "null", + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + "null", + ]; + assert_eq!( + i32_expected, + get_cast_values::(&i64_array, &DataType::Int32) + ); + + assert_eq!( + i32_expected, + get_cast_values::(&i64_array, &DataType::Date32) + ); + + let i16_expected = vec![ + "null", "null", "-32768", "-128", "0", "127", "32767", "null", "null", + ]; + assert_eq!( + i16_expected, + get_cast_values::(&i64_array, &DataType::Int16) + ); + + let i8_expected = vec![ + "null", "null", "null", "-128", "0", "127", "null", "null", "null", + ]; + assert_eq!( + i8_expected, + get_cast_values::(&i64_array, &DataType::Int8) + ); + + let u64_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "127", + "32767", + "2147483647", + "9223372036854775807", + ]; + assert_eq!( + u64_expected, + get_cast_values::(&i64_array, &DataType::UInt64) + ); + + let u32_expected = vec![ + "null", + "null", + "null", + "null", + "0", + "127", + "32767", + "2147483647", + "null", + ]; + assert_eq!( + u32_expected, + get_cast_values::(&i64_array, &DataType::UInt32) + ); + + let u16_expected = vec![ + "null", "null", "null", "null", "0", "127", "32767", "null", "null", + ]; + assert_eq!( + u16_expected, + get_cast_values::(&i64_array, &DataType::UInt16) + ); + + let u8_expected = vec![ + "null", "null", "null", "null", "0", "127", "null", "null", "null", + ]; + assert_eq!( + u8_expected, + get_cast_values::(&i64_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_int32() { + let i32_values: Vec = vec![ + i32::MIN, + i16::MIN as i32, + i8::MIN as i32, + 0, + i8::MAX as i32, + i16::MAX as i32, + i32::MAX, + ]; + let i32_array: ArrayRef = Arc::new(Int32Array::from(i32_values)); + + let f64_expected = vec![ + "-2147483648.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483647.0", + ]; + assert_eq!( + f64_expected, + get_cast_values::(&i32_array, &DataType::Float64) + ); + + let f32_expected = vec![ + "-2147483600.0", + "-32768.0", + "-128.0", + "0.0", + "127.0", + "32767.0", + "2147483600.0", + ]; + assert_eq!( + f32_expected, + get_cast_values::(&i32_array, &DataType::Float32) + ); + + let f16_expected = vec![ + f16::from_f64(-2147483600.0), + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + f16::from_f64(2147483600.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i32_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i16_expected = vec!["null", "-32768", "-128", "0", "127", "32767", "null"]; + assert_eq!( + i16_expected, + get_cast_values::(&i32_array, &DataType::Int16) + ); + + let i8_expected = vec!["null", "null", "-128", "0", "127", "null", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&i32_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + assert_eq!( + u64_expected, + get_cast_values::(&i32_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "null", "null", "0", "127", "32767", "2147483647"]; + assert_eq!( + u32_expected, + get_cast_values::(&i32_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "null", "null", "0", "127", "32767", "null"]; + assert_eq!( + u16_expected, + get_cast_values::(&i32_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "null", "null", "0", "127", "null", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&i32_array, &DataType::UInt8) + ); + + // The date32 to date64 cast increases the numerical values in order to keep the same dates. + let i64_expected = vec![ + "-185542587187200000", + "-2831155200000", + "-11059200000", + "0", + "10972800000", + "2831068800000", + "185542587100800000", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&i32_array, &DataType::Date64) + ); + } + + #[test] + fn test_cast_from_int16() { + let i16_values: Vec = vec![i16::MIN, i8::MIN as i16, 0, i8::MAX as i16, i16::MAX]; + let i16_array: ArrayRef = Arc::new(Int16Array::from(i16_values)); + + let f64_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&i16_array, &DataType::Float64) + ); + + let f32_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&i16_array, &DataType::Float32) + ); + + let f16_expected = vec![ + f16::from_f64(-32768.0), + f16::from_f64(-128.0), + f16::from_f64(0.0), + f16::from_f64(127.0), + f16::from_f64(32767.0), + ]; + assert_eq!( + f16_expected, + get_cast_values::(&i16_array, &DataType::Float16) + .iter() + .map(|i| i.parse::().unwrap()) + .collect::>() + ); + + let i64_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i64_expected, + get_cast_values::(&i16_array, &DataType::Int64) + ); + + let i32_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i32_expected, + get_cast_values::(&i16_array, &DataType::Int32) + ); + + let i16_expected = vec!["-32768", "-128", "0", "127", "32767"]; + assert_eq!( + i16_expected, + get_cast_values::(&i16_array, &DataType::Int16) + ); + + let i8_expected = vec!["null", "-128", "0", "127", "null"]; + assert_eq!( + i8_expected, + get_cast_values::(&i16_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u64_expected, + get_cast_values::(&i16_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u32_expected, + get_cast_values::(&i16_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "null", "0", "127", "32767"]; + assert_eq!( + u16_expected, + get_cast_values::(&i16_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "null", "0", "127", "null"]; + assert_eq!( + u8_expected, + get_cast_values::(&i16_array, &DataType::UInt8) + ); + } + + #[test] + fn test_cast_from_date32() { + let i32_values: Vec = vec![ + i32::MIN, + i16::MIN as i32, + i8::MIN as i32, + 0, + i8::MAX as i32, + i16::MAX as i32, + i32::MAX, + ]; + let date32_array: ArrayRef = Arc::new(Date32Array::from(i32_values)); + + let i64_expected = vec![ + "-2147483648", + "-32768", + "-128", + "0", + "127", + "32767", + "2147483647", + ]; + assert_eq!( + i64_expected, + get_cast_values::(&date32_array, &DataType::Int64) + ); + } + + #[test] + fn test_cast_from_int8() { + let i8_values: Vec = vec![i8::MIN, 0, i8::MAX]; + let i8_array = Int8Array::from(i8_values); + + let f64_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f64_expected, + get_cast_values::(&i8_array, &DataType::Float64) + ); + + let f32_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f32_expected, + get_cast_values::(&i8_array, &DataType::Float32) + ); + + let f16_expected = vec!["-128.0", "0.0", "127.0"]; + assert_eq!( + f16_expected, + get_cast_values::(&i8_array, &DataType::Float16) + ); + + let i64_expected = vec!["-128", "0", "127"]; + assert_eq!( + i64_expected, + get_cast_values::(&i8_array, &DataType::Int64) + ); + + let i32_expected = vec!["-128", "0", "127"]; + assert_eq!( + i32_expected, + get_cast_values::(&i8_array, &DataType::Int32) + ); + + let i16_expected = vec!["-128", "0", "127"]; + assert_eq!( + i16_expected, + get_cast_values::(&i8_array, &DataType::Int16) + ); + + let i8_expected = vec!["-128", "0", "127"]; + assert_eq!( + i8_expected, + get_cast_values::(&i8_array, &DataType::Int8) + ); + + let u64_expected = vec!["null", "0", "127"]; + assert_eq!( + u64_expected, + get_cast_values::(&i8_array, &DataType::UInt64) + ); + + let u32_expected = vec!["null", "0", "127"]; + assert_eq!( + u32_expected, + get_cast_values::(&i8_array, &DataType::UInt32) + ); + + let u16_expected = vec!["null", "0", "127"]; + assert_eq!( + u16_expected, + get_cast_values::(&i8_array, &DataType::UInt16) + ); + + let u8_expected = vec!["null", "0", "127"]; + assert_eq!( + u8_expected, + get_cast_values::(&i8_array, &DataType::UInt8) + ); + } + + /// Convert `array` into a vector of strings by casting to data type dt + fn get_cast_values(array: &dyn Array, dt: &DataType) -> Vec + where + T: ArrowPrimitiveType, + { + let c = cast(array, dt).unwrap(); + let a = c.as_primitive::(); + let mut v: Vec = vec![]; + for i in 0..array.len() { + if a.is_null(i) { + v.push("null".to_string()) + } else { + v.push(format!("{:?}", a.value(i))); + } + } + v + } + + #[test] + fn test_cast_utf8_dict() { + // FROM a dictionary with of Utf8 values + use DataType::*; + + let mut builder = StringDictionaryBuilder::::new(); + builder.append("one").unwrap(); + builder.append_null(); + builder.append("three").unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["one", "null", "three"]; + + // Test casting TO StringArray + let cast_type = Utf8; + let cast_array = cast(&array, &cast_type).expect("cast to UTF-8 failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Test casting TO Dictionary (with different index sizes) + + let cast_type = Dictionary(Box::new(Int16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(Int64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt16), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt32), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + let cast_type = Dictionary(Box::new(UInt64), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_primitive() { + use DataType::*; + // test converting from an array that has indexes of a type + // that are out of bounds for a particular other kind of + // index. + + let mut builder = PrimitiveDictionaryBuilder::::new(); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + builder.append(i).unwrap(); + } + let array: ArrayRef = Arc::new(builder.finish()); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{res:?}"); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{actual_error}' in actual error '{expected_error}'" + ); + } + + #[test] + fn test_cast_dict_to_dict_bad_index_value_utf8() { + use DataType::*; + // Same test as test_cast_dict_to_dict_bad_index_value but use + // string values (and encode the expected behavior here); + + let mut builder = StringDictionaryBuilder::::new(); + + // add 200 distinct values (which can be stored by a + // dictionary indexed by int32, but not a dictionary indexed + // with int8) + for i in 0..200 { + let val = format!("val{i}"); + builder.append(&val).unwrap(); + } + let array = builder.finish(); + + let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let res = cast(&array, &cast_type); + assert!(res.is_err()); + let actual_error = format!("{res:?}"); + let expected_error = "Could not convert 72 dictionary indexes from Int32 to Int8"; + assert!( + actual_error.contains(expected_error), + "did not find expected error '{actual_error}' in actual error '{expected_error}'" + ); + } + + #[test] + fn test_cast_primitive_dict() { + // FROM a dictionary with of INT32 values + use DataType::*; + + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(1).unwrap(); + builder.append_null(); + builder.append(3).unwrap(); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Test casting TO PrimitiveArray, different dictionary type + let cast_array = cast(&array, &Utf8).expect("cast to UTF-8 failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Utf8); + + let cast_array = cast(&array, &Int64).expect("cast to int64 failed"); + assert_eq!(array_to_strings(&cast_array), expected); + assert_eq!(cast_array.data_type(), &Int64); + } + + #[test] + fn test_cast_primitive_array_to_dict() { + use DataType::*; + + let mut builder = PrimitiveBuilder::::new(); + builder.append_value(1); + builder.append_null(); + builder.append_value(3); + let array: ArrayRef = Arc::new(builder.finish()); + + let expected = vec!["1", "null", "3"]; + + // Cast to a dictionary (same value type, Int32) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int32)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + + // Cast to a dictionary (different value type, Int8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Int8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_string_array_to_dict() { + use DataType::*; + + let array = Arc::new(StringArray::from(vec![Some("one"), None, Some("three")])) as ArrayRef; + + let expected = vec!["one", "null", "three"]; + + // Cast to a dictionary (same value type, Utf8) + let cast_type = Dictionary(Box::new(UInt8), Box::new(Utf8)); + let cast_array = cast(&array, &cast_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(array_to_strings(&cast_array), expected); + } + + #[test] + fn test_cast_null_array_to_from_decimal_array() { + let data_type = DataType::Decimal128(12, 4); + let array = new_null_array(&DataType::Null, 4); + assert_eq!(array.data_type(), &DataType::Null); + let cast_array = cast(&array, &data_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), &data_type); + for i in 0..4 { + assert!(cast_array.is_null(i)); + } + + let array = new_null_array(&data_type, 4); + assert_eq!(array.data_type(), &data_type); + let cast_array = cast(&array, &DataType::Null).expect("cast failed"); + assert_eq!(cast_array.data_type(), &DataType::Null); + assert_eq!(cast_array.len(), 4); + assert_eq!(cast_array.logical_nulls().unwrap().null_count(), 4); + } + + #[test] + fn test_cast_null_array_from_and_to_primitive_array() { + macro_rules! typed_test { + ($ARR_TYPE:ident, $DATATYPE:ident, $TYPE:tt) => {{ + { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + let expected = $ARR_TYPE::from(vec![None; 6]); + let cast_type = DataType::$DATATYPE; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = cast_array.as_primitive::<$TYPE>(); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + }}; + } + + typed_test!(Int16Array, Int16, Int16Type); + typed_test!(Int32Array, Int32, Int32Type); + typed_test!(Int64Array, Int64, Int64Type); + + typed_test!(UInt16Array, UInt16, UInt16Type); + typed_test!(UInt32Array, UInt32, UInt32Type); + typed_test!(UInt64Array, UInt64, UInt64Type); + + typed_test!(Float32Array, Float32, Float32Type); + typed_test!(Float64Array, Float64, Float64Type); + + typed_test!(Date32Array, Date32, Date32Type); + typed_test!(Date64Array, Date64, Date64Type); + } + + fn cast_from_null_to_other(data_type: &DataType) { + // Cast from null to data_type + { + let array = new_null_array(&DataType::Null, 4); + assert_eq!(array.data_type(), &DataType::Null); + let cast_array = cast(&array, data_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), data_type); + for i in 0..4 { + assert!(cast_array.is_null(i)); + } + } + } + + #[test] + fn test_cast_null_from_and_to_variable_sized() { + cast_from_null_to_other(&DataType::Utf8); + cast_from_null_to_other(&DataType::LargeUtf8); + cast_from_null_to_other(&DataType::Binary); + cast_from_null_to_other(&DataType::LargeBinary); + } + + #[test] + fn test_cast_null_from_and_to_nested_type() { + // Cast null from and to map + let data_type = DataType::Map( + Arc::new(Field::new_struct( + "entry", + vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ], + false, + )), + false, + ); + cast_from_null_to_other(&data_type); + + // Cast null from and to list + let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + cast_from_null_to_other(&data_type); + let data_type = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + cast_from_null_to_other(&data_type); + let data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); + cast_from_null_to_other(&data_type); + + // Cast null from and to dictionary + let values = vec![None, None, None, None] as Vec>; + let array: DictionaryArray = values.into_iter().collect(); + let array = Arc::new(array) as ArrayRef; + let data_type = array.data_type().to_owned(); + cast_from_null_to_other(&data_type); + + // Cast null from and to struct + let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); + cast_from_null_to_other(&data_type); + } + + /// Print the `DictionaryArray` `array` as a vector of strings + fn array_to_strings(array: &ArrayRef) -> Vec { + let options = FormatOptions::new().with_null("null"); + let formatter = ArrayFormatter::try_new(array.as_ref(), &options).unwrap(); + (0..array.len()) + .map(|i| formatter.value(i).to_string()) + .collect() + } + + #[test] + fn test_cast_utf8_to_date32() { + use chrono::NaiveDate; + let from_ymd = chrono::NaiveDate::from_ymd_opt; + let since = chrono::NaiveDate::signed_duration_since; + + let a = StringArray::from(vec![ + "2000-01-01", // valid date with leading 0s + "2000-2-2", // valid date without leading 0s + "2000-00-00", // invalid month and day + "2000-01-01T12:00:00", // date + time is invalid + "2000", // just a year is invalid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date32).unwrap(); + let c = b.as_primitive::(); + + // test valid inputs + let date_value = since( + NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(), + from_ymd(1970, 1, 1).unwrap(), + ) + .num_days() as i32; + assert!(c.is_valid(0)); // "2000-01-01" + assert_eq!(date_value, c.value(0)); + + let date_value = since( + NaiveDate::from_ymd_opt(2000, 2, 2).unwrap(), + from_ymd(1970, 1, 1).unwrap(), + ) + .num_days() as i32; + assert!(c.is_valid(1)); // "2000-2-2" + assert_eq!(date_value, c.value(1)); + + // test invalid inputs + assert!(!c.is_valid(2)); // "2000-00-00" + assert!(!c.is_valid(3)); // "2000-01-01T12:00:00" + assert!(!c.is_valid(4)); // "2000" + } + + #[test] + fn test_cast_utf8_to_date64() { + let a = StringArray::from(vec![ + "2000-01-01T12:00:00", // date + time valid + "2020-12-15T12:34:56", // date + time valid + "2020-2-2T12:34:56", // valid date time without leading 0s + "2000-00-00T12:00:00", // invalid month and day + "2000-01-01 12:00:00", // missing the 'T' + "2000-01-01", // just a date is invalid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Date64).unwrap(); + let c = b.as_primitive::(); + + // test valid inputs + assert!(c.is_valid(0)); // "2000-01-01T12:00:00" + assert_eq!(946728000000, c.value(0)); + assert!(c.is_valid(1)); // "2020-12-15T12:34:56" + assert_eq!(1608035696000, c.value(1)); + assert!(!c.is_valid(2)); // "2020-2-2T12:34:56" + + assert!(!c.is_valid(3)); // "2000-00-00T12:00:00" + assert!(c.is_valid(4)); // "2000-01-01 12:00:00" + assert_eq!(946728000000, c.value(4)); + assert!(c.is_valid(5)); // "2000-01-01" + assert_eq!(946684800000, c.value(5)); + } + + #[test] + fn test_can_cast_types_fixed_size_list_to_list() { + // DataType::List + let array1 = Arc::new(make_fixed_size_list_array()) as ArrayRef; + assert!(can_cast_types( + array1.data_type(), + &DataType::List(Arc::new(Field::new("", DataType::Int32, false))) + )); + + // DataType::LargeList + let array2 = Arc::new(make_fixed_size_list_array_for_large_list()) as ArrayRef; + assert!(can_cast_types( + array2.data_type(), + &DataType::LargeList(Arc::new(Field::new("", DataType::Int64, false))) + )); + } + + #[test] + fn test_cast_fixed_size_list_to_list() { + // DataType::List + let array1 = Arc::new(make_fixed_size_list_array()) as ArrayRef; + let list_array1 = cast( + &array1, + &DataType::List(Arc::new(Field::new("", DataType::Int32, false))), + ) + .unwrap(); + let actual = list_array1.as_any().downcast_ref::().unwrap(); + let expected = array1 + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(expected.values(), actual.values()); + assert_eq!(expected.len(), actual.len()); + + // DataType::LargeList + let array2 = Arc::new(make_fixed_size_list_array_for_large_list()) as ArrayRef; + let list_array2 = cast( + &array2, + &DataType::LargeList(Arc::new(Field::new("", DataType::Int64, false))), + ) + .unwrap(); + let actual = list_array2 + .as_any() + .downcast_ref::() + .unwrap(); + let expected = array2 + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(expected.values(), actual.values()); + assert_eq!(expected.len(), actual.len()); + + // Cast previous LargeList to List + let array3 = Arc::new(actual.clone()) as ArrayRef; + let list_array3 = cast( + &array3, + &DataType::List(Arc::new(Field::new("", DataType::Int64, false))), + ) + .unwrap(); + let actual = list_array3.as_any().downcast_ref::().unwrap(); + let expected = array3.as_any().downcast_ref::().unwrap(); + assert_eq!(expected.values(), actual.values()); + } + + #[test] + fn test_cast_list_containers() { + // large-list to list + let array = Arc::new(make_large_list_array()) as ArrayRef; + let list_array = cast( + &array, + &DataType::List(Arc::new(Field::new("", DataType::Int32, false))), + ) + .unwrap(); + let actual = list_array.as_any().downcast_ref::().unwrap(); + let expected = array.as_any().downcast_ref::().unwrap(); + + assert_eq!(&expected.value(0), &actual.value(0)); + assert_eq!(&expected.value(1), &actual.value(1)); + assert_eq!(&expected.value(2), &actual.value(2)); + + // list to large-list + let array = Arc::new(make_list_array()) as ArrayRef; + let large_list_array = cast( + &array, + &DataType::LargeList(Arc::new(Field::new("", DataType::Int32, false))), + ) + .unwrap(); + let actual = large_list_array + .as_any() + .downcast_ref::() + .unwrap(); + let expected = array.as_any().downcast_ref::().unwrap(); + + assert_eq!(&expected.value(0), &actual.value(0)); + assert_eq!(&expected.value(1), &actual.value(1)); + assert_eq!(&expected.value(2), &actual.value(2)); + } + + #[test] + fn test_cast_list_to_fsl() { + // There four noteworthy cases we should handle: + // 1. No nulls + // 2. Nulls that are always empty + // 3. Nulls that have varying lengths + // 4. Nulls that are correctly sized (same as target list size) + + // Non-null case + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let values = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5), Some(6)]), + ]; + let array = Arc::new(ListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + values, 3, + )) as ArrayRef; + let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + + // Null cases + // Array is [[1, 2, 3], null, [4, 5, 6], null] + let cases = [ + ( + // Zero-length nulls + vec![1, 2, 3, 4, 5, 6], + vec![3, 0, 3, 0], + ), + ( + // Varying-length nulls + vec![1, 2, 3, 0, 0, 4, 5, 6, 0], + vec![3, 2, 3, 1], + ), + ( + // Correctly-sized nulls + vec![1, 2, 3, 0, 0, 0, 4, 5, 6, 0, 0, 0], + vec![3, 3, 3, 3], + ), + ( + // Mixed nulls + vec![1, 2, 3, 4, 5, 6, 0, 0, 0], + vec![3, 0, 3, 3], + ), + ]; + let null_buffer = NullBuffer::from(vec![true, false, true, false]); + + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5), Some(6)]), + None, + ], + 3, + )) as ArrayRef; + + for (values, lengths) in cases.iter() { + let array = Arc::new(ListArray::new( + field.clone(), + OffsetBuffer::from_lengths(lengths.clone()), + Arc::new(Int32Array::from(values.clone())), + Some(null_buffer.clone()), + )) as ArrayRef; + let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + } + + #[test] + fn test_cast_list_to_fsl_safety() { + let values = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6), Some(7), Some(8), Some(9)]), + Some(vec![Some(3), Some(4), Some(5)]), + ]; + let array = Arc::new(ListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + + let res = cast_with_options( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + &CastOptions { + safe: false, + ..Default::default() + }, + ); + assert!(res.is_err()); + assert!(format!("{:?}", res) + .contains("Cannot cast to FixedSizeList(3): value at index 1 has length 2")); + + // When safe=true (default), the cast will fill nulls for lists that are + // too short and truncate lists that are too long. + let res = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + ) + .unwrap(); + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, // Too short -> replaced with null + None, // Too long -> replaced with null + Some(vec![Some(3), Some(4), Some(5)]), + ], + 3, + )) as ArrayRef; + assert_eq!(expected.as_ref(), res.as_ref()); + } + + #[test] + fn test_cast_large_list_to_fsl() { + let values = vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3), Some(4)])]; + let array = Arc::new(LargeListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + values, 2, + )) as ArrayRef; + let actual = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 2), + ) + .unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + + #[test] + fn test_cast_list_to_fsl_subcast() { + let array = Arc::new(LargeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(i32::MAX)]), + ], + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(i32::MAX as i64)]), + ], + 2, + )) as ArrayRef; + let actual = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 2), + ) + .unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + + let res = cast_with_options( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int16, true)), 2), + &CastOptions { + safe: false, + ..Default::default() + }, + ); + assert!(res.is_err()); + assert!(format!("{:?}", res).contains("Can't cast value 2147483647 to type Int16")); + } + + #[test] + fn test_cast_list_to_fsl_empty() { + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let array = new_empty_array(&DataType::List(field.clone())); + + let target_type = DataType::FixedSizeList(field.clone(), 3); + let expected = new_empty_array(&target_type); + + let actual = cast(array.as_ref(), &target_type).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + + fn make_list_array() -> ListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + ListArray::from(list_data) + } + + fn make_large_list_array() -> LargeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8]); + + // Construct a list array from the above two + let list_data_type = + DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + LargeListArray::from(list_data) + } + + fn make_fixed_size_list_array() -> FixedSizeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 4); + let list_data = ArrayData::builder(list_data_type) + .len(2) + .add_child_data(value_data) + .build() + .unwrap(); + FixedSizeListArray::from(list_data) + } + + fn make_fixed_size_list_array_for_large_list() -> FixedSizeListArray { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int64) + .len(8) + .add_buffer(Buffer::from_slice_ref([0i64, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + let list_data_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 4); + let list_data = ArrayData::builder(list_data_type) + .len(2) + .add_child_data(value_data) + .build() + .unwrap(); + FixedSizeListArray::from(list_data) + } + + #[test] + fn test_utf8_cast_offsets() { + // test if offset of the array is taken into account during cast + let str_array = StringArray::from(vec!["a", "b", "c"]); + let str_array = str_array.slice(1, 2); + + let out = cast(&str_array, &DataType::LargeUtf8).unwrap(); + + let large_str_array = out.as_any().downcast_ref::().unwrap(); + let strs = large_str_array.into_iter().flatten().collect::>(); + assert_eq!(strs, &["b", "c"]) + } + + #[test] + fn test_list_cast_offsets() { + // test if offset of the array is taken into account during cast + let array1 = make_list_array().slice(1, 2); + let array2 = Arc::new(make_list_array()) as ArrayRef; + + let dt = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + let out1 = cast(&array1, &dt).unwrap(); + let out2 = cast(&array2, &dt).unwrap(); + + assert_eq!(&out1, &out2.slice(1, 2)) + } + + #[test] + fn test_list_to_string() { + let str_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g", "h"]); + let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); + let value_data = str_array.into_data(); + + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + let array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + let out = cast(&array, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[a, b, c]", "[d, e, f]", "[g, h]"]); + + let out = cast(&array, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[a, b, c]", "[d, e, f]", "[g, h]"]); + + let array = Arc::new(make_list_array()) as ArrayRef; + let out = cast(&array, &DataType::Utf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]); + + let array = Arc::new(make_large_list_array()) as ArrayRef; + let out = cast(&array, &DataType::LargeUtf8).unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]); + } + + #[test] + fn test_cast_f64_to_decimal128() { + // to reproduce https://github.com/apache/arrow-rs/issues/2997 + + let decimal_type = DataType::Decimal128(18, 2); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0650000000), + Some(0.0649999999), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(7_i128), // round up + Some(7_i128), // round up + Some(7_i128), // round up + Some(6_i128), // round down + ] + ); + + let decimal_type = DataType::Decimal128(18, 3); + let array = Float64Array::from(vec![ + Some(0.0699999999), + Some(0.0659999999), + Some(0.0650000000), + Some(0.0649999999), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &decimal_type, + vec![ + Some(70_i128), // round up + Some(66_i128), // round up + Some(65_i128), // round down + Some(65_i128), // round up + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal128_overflow() { + let array = Int64Array::from(vec![i64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + } + + #[test] + fn test_cast_numeric_to_decimal256_overflow() { + let array = Int64Array::from(vec![i64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 76), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 76), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + } + + #[test] + fn test_cast_floating_point_to_decimal128_precision_overflow() { + let array = Float64Array::from(vec![1.1]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(2, 2), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(2, 2), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Invalid argument error: 110 is too large to store in a Decimal128 of precision 2. Max is 99"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_floating_point_to_decimal256_precision_overflow() { + let array = Float64Array::from(vec![1.1]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(2, 2), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(2, 2), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Invalid argument error: 110 is too large to store in a Decimal256 of precision 2. Max is 99"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_floating_point_to_decimal128_overflow() { + let array = Float64Array::from(vec![f64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(38, 30), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Cast error: Cannot cast to Decimal128(38, 30)"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_floating_point_to_decimal256_overflow() { + let array = Float64Array::from(vec![f64::MAX]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 50), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(76, 50), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + let err = casted_array.unwrap_err().to_string(); + let expected_error = "Cast error: Cannot cast to Decimal256(76, 50)"; + assert!( + err.contains(expected_error), + "did not find expected error '{expected_error}' in actual error '{err}'" + ); + } + + #[test] + fn test_cast_decimal128_to_decimal128_negative_scale() { + let input_type = DataType::Decimal128(20, 0); + let output_type = DataType::Decimal128(20, -1); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123450), Some(2123455), Some(3123456), None]; + let input_decimal_array = create_decimal_array(array, 20, 0).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(312346_i128), + None + ] + ); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1123450", decimal_arr.value_as_string(0)); + assert_eq!("2123460", decimal_arr.value_as_string(1)); + assert_eq!("3123460", decimal_arr.value_as_string(2)); + } + + #[test] + fn test_cast_numeric_to_decimal128_negative() { + let decimal_type = DataType::Decimal128(38, -1); + let array = Arc::new(Int32Array::from(vec![ + Some(1123456), + Some(2123456), + Some(3123456), + ])) as ArrayRef; + + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1123450", decimal_arr.value_as_string(0)); + assert_eq!("2123450", decimal_arr.value_as_string(1)); + assert_eq!("3123450", decimal_arr.value_as_string(2)); + + let array = Arc::new(Float32Array::from(vec![ + Some(1123.456), + Some(2123.456), + Some(3123.456), + ])) as ArrayRef; + + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1120", decimal_arr.value_as_string(0)); + assert_eq!("2120", decimal_arr.value_as_string(1)); + assert_eq!("3120", decimal_arr.value_as_string(2)); + } + + #[test] + fn test_cast_decimal128_to_decimal128_negative() { + let input_type = DataType::Decimal128(10, -1); + let output_type = DataType::Decimal128(10, -2); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(123)]; + let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(12_i128),]); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1200", decimal_arr.value_as_string(0)); + + let array = vec![Some(125)]; + let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(13_i128),]); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("1300", decimal_arr.value_as_string(0)); + } + + #[test] + fn test_cast_decimal128_to_decimal256_negative() { + let input_type = DataType::Decimal128(10, 3); + let output_type = DataType::Decimal256(10, 5); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(i128::MAX), Some(i128::MIN)]; + let input_decimal_array = create_decimal_array(array, 10, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + + let hundred = i256::from_i128(100); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(i128::MAX).mul_wrapping(hundred)), + Some(i256::from_i128(i128::MIN).mul_wrapping(hundred)) + ] + ); + } + + #[test] + fn test_parse_string_to_decimal() { + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("123.45", 2).unwrap(), + 38, + 2, + ), + "123.45" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("12345", 2).unwrap(), + 38, + 2, + ), + "12345.00" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::("0.12345", 2).unwrap(), + 38, + 2, + ), + "0.12" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".12345", 2).unwrap(), + 38, + 2, + ), + "0.12" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".1265", 2).unwrap(), + 38, + 2, + ), + "0.13" + ); + assert_eq!( + Decimal128Type::format_decimal( + parse_string_to_decimal_native::(".1265", 2).unwrap(), + 38, + 2, + ), + "0.13" + ); + + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("123.45", 3).unwrap(), + 38, + 3, + ), + "123.450" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("12345", 3).unwrap(), + 38, + 3, + ), + "12345.000" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::("0.12345", 3).unwrap(), + 38, + 3, + ), + "0.123" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::(".12345", 3).unwrap(), + 38, + 3, + ), + "0.123" + ); + assert_eq!( + Decimal256Type::format_decimal( + parse_string_to_decimal_native::(".1265", 3).unwrap(), + 38, + 3, + ), + "0.127" + ); + } + + fn test_cast_string_to_decimal(array: ArrayRef) { + // Decimal128 + let output_type = DataType::Decimal128(38, 2); + assert!(can_cast_types(array.data_type(), &output_type)); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("123.45", decimal_arr.value_as_string(0)); + assert_eq!("1.23", decimal_arr.value_as_string(1)); + assert_eq!("0.12", decimal_arr.value_as_string(2)); + assert_eq!("0.13", decimal_arr.value_as_string(3)); + assert_eq!("1.26", decimal_arr.value_as_string(4)); + assert_eq!("12345.00", decimal_arr.value_as_string(5)); + assert_eq!("12345.00", decimal_arr.value_as_string(6)); + assert_eq!("0.12", decimal_arr.value_as_string(7)); + assert_eq!("12.23", decimal_arr.value_as_string(8)); + assert!(decimal_arr.is_null(9)); + assert_eq!("0.00", decimal_arr.value_as_string(10)); + assert_eq!("0.00", decimal_arr.value_as_string(11)); + assert!(decimal_arr.is_null(12)); + assert_eq!("-1.23", decimal_arr.value_as_string(13)); + assert_eq!("-1.24", decimal_arr.value_as_string(14)); + assert_eq!("0.00", decimal_arr.value_as_string(15)); + assert_eq!("-123.00", decimal_arr.value_as_string(16)); + assert_eq!("-123.23", decimal_arr.value_as_string(17)); + assert_eq!("-0.12", decimal_arr.value_as_string(18)); + assert_eq!("1.23", decimal_arr.value_as_string(19)); + assert_eq!("1.24", decimal_arr.value_as_string(20)); + assert_eq!("0.00", decimal_arr.value_as_string(21)); + assert_eq!("123.00", decimal_arr.value_as_string(22)); + assert_eq!("123.23", decimal_arr.value_as_string(23)); + assert_eq!("0.12", decimal_arr.value_as_string(24)); + assert!(decimal_arr.is_null(25)); + assert!(decimal_arr.is_null(26)); + assert!(decimal_arr.is_null(27)); + + // Decimal256 + let output_type = DataType::Decimal256(76, 3); + assert!(can_cast_types(array.data_type(), &output_type)); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!("123.450", decimal_arr.value_as_string(0)); + assert_eq!("1.235", decimal_arr.value_as_string(1)); + assert_eq!("0.123", decimal_arr.value_as_string(2)); + assert_eq!("0.127", decimal_arr.value_as_string(3)); + assert_eq!("1.263", decimal_arr.value_as_string(4)); + assert_eq!("12345.000", decimal_arr.value_as_string(5)); + assert_eq!("12345.000", decimal_arr.value_as_string(6)); + assert_eq!("0.123", decimal_arr.value_as_string(7)); + assert_eq!("12.234", decimal_arr.value_as_string(8)); + assert!(decimal_arr.is_null(9)); + assert_eq!("0.000", decimal_arr.value_as_string(10)); + assert_eq!("0.000", decimal_arr.value_as_string(11)); + assert!(decimal_arr.is_null(12)); + assert_eq!("-1.235", decimal_arr.value_as_string(13)); + assert_eq!("-1.236", decimal_arr.value_as_string(14)); + assert_eq!("0.000", decimal_arr.value_as_string(15)); + assert_eq!("-123.000", decimal_arr.value_as_string(16)); + assert_eq!("-123.234", decimal_arr.value_as_string(17)); + assert_eq!("-0.123", decimal_arr.value_as_string(18)); + assert_eq!("1.235", decimal_arr.value_as_string(19)); + assert_eq!("1.236", decimal_arr.value_as_string(20)); + assert_eq!("0.000", decimal_arr.value_as_string(21)); + assert_eq!("123.000", decimal_arr.value_as_string(22)); + assert_eq!("123.234", decimal_arr.value_as_string(23)); + assert_eq!("0.123", decimal_arr.value_as_string(24)); + assert!(decimal_arr.is_null(25)); + assert!(decimal_arr.is_null(26)); + assert!(decimal_arr.is_null(27)); + } + + #[test] + fn test_cast_utf8_to_decimal() { + let str_array = StringArray::from(vec![ + Some("123.45"), + Some("1.2345"), + Some("0.12345"), + Some("0.1267"), + Some("1.263"), + Some("12345.0"), + Some("12345"), + Some("000.123"), + Some("12.234000"), + None, + Some(""), + Some(" "), + None, + Some("-1.23499999"), + Some("-1.23599999"), + Some("-0.00001"), + Some("-123"), + Some("-123.234000"), + Some("-000.123"), + Some("+1.23499999"), + Some("+1.23599999"), + Some("+0.00001"), + Some("+123"), + Some("+123.234000"), + Some("+000.123"), + Some("1.-23499999"), + Some("-1.-23499999"), + Some("--1.23499999"), + ]); + let array = Arc::new(str_array) as ArrayRef; + + test_cast_string_to_decimal(array); + } + + #[test] + fn test_cast_large_utf8_to_decimal() { + let str_array = LargeStringArray::from(vec![ + Some("123.45"), + Some("1.2345"), + Some("0.12345"), + Some("0.1267"), + Some("1.263"), + Some("12345.0"), + Some("12345"), + Some("000.123"), + Some("12.234000"), + None, + Some(""), + Some(" "), + None, + Some("-1.23499999"), + Some("-1.23599999"), + Some("-0.00001"), + Some("-123"), + Some("-123.234000"), + Some("-000.123"), + Some("+1.23499999"), + Some("+1.23599999"), + Some("+0.00001"), + Some("+123"), + Some("+123.234000"), + Some("+000.123"), + Some("1.-23499999"), + Some("-1.-23499999"), + Some("--1.23499999"), + ]); + let array = Arc::new(str_array) as ArrayRef; + + test_cast_string_to_decimal(array); + } + + #[test] + fn test_cast_invalid_utf8_to_decimal() { + let str_array = StringArray::from(vec!["4.4.5", ". 0.123"]); + let array = Arc::new(str_array) as ArrayRef; + + // Safe cast + let output_type = DataType::Decimal128(38, 2); + let casted_array = cast(&array, &output_type).unwrap(); + assert!(casted_array.is_null(0)); + assert!(casted_array.is_null(1)); + + let output_type = DataType::Decimal256(76, 2); + let casted_array = cast(&array, &output_type).unwrap(); + assert!(casted_array.is_null(0)); + assert!(casted_array.is_null(1)); + + // Non-safe cast + let output_type = DataType::Decimal128(38, 2); + let str_array = StringArray::from(vec!["4.4.5"]); + let array = Arc::new(str_array) as ArrayRef; + let option = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); + assert!(casted_err + .to_string() + .contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type")); + + let str_array = StringArray::from(vec![". 0.123"]); + let array = Arc::new(str_array) as ArrayRef; + let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); + assert!(casted_err + .to_string() + .contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type")); + } + + fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { + let output_type = DataType::Decimal128(38, 2); + let casted_array = cast(&overflow_array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert!(decimal_arr.is_null(0)); + assert!(decimal_arr.is_null(1)); + assert!(decimal_arr.is_null(2)); + assert_eq!( + "999999999999999999999999999999999999.99", + decimal_arr.value_as_string(3) + ); + assert_eq!( + "100000000000000000000000000000000000.00", + decimal_arr.value_as_string(4) + ); + } + + #[test] + fn test_cast_string_to_decimal128_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal128(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_utf8_to_decimal128_overflow() { + let overflow_str_array = StringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal128_overflow(overflow_array); + } + + #[test] + fn test_cast_large_utf8_to_decimal128_overflow() { + let overflow_str_array = LargeStringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal128_overflow(overflow_array); + } + + fn test_cast_string_to_decimal256_overflow(overflow_array: ArrayRef) { + let output_type = DataType::Decimal256(76, 2); + let casted_array = cast(&overflow_array, &output_type).unwrap(); + let decimal_arr = casted_array.as_primitive::(); + + assert_eq!( + "170141183460469231731687303715884105727.00", + decimal_arr.value_as_string(0) + ); + assert_eq!( + "-170141183460469231731687303715884105728.00", + decimal_arr.value_as_string(1) + ); + assert_eq!( + "99999999999999999999999999999999999999.00", + decimal_arr.value_as_string(2) + ); + assert_eq!( + "999999999999999999999999999999999999.99", + decimal_arr.value_as_string(3) + ); + assert_eq!( + "100000000000000000000000000000000000.00", + decimal_arr.value_as_string(4) + ); + assert!(decimal_arr.is_null(5)); + assert!(decimal_arr.is_null(6)); + } + + #[test] + fn test_cast_string_to_decimal256_precision_overflow() { + let array = StringArray::from(vec!["1000".to_string()]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal256(10, 8), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_utf8_to_decimal256_overflow() { + let overflow_str_array = StringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + i256::MAX.to_string(), + i256::MIN.to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal256_overflow(overflow_array); + } + + #[test] + fn test_cast_large_utf8_to_decimal256_overflow() { + let overflow_str_array = LargeStringArray::from(vec![ + i128::MAX.to_string(), + i128::MIN.to_string(), + "99999999999999999999999999999999999999".to_string(), + "999999999999999999999999999999999999.99".to_string(), + "99999999999999999999999999999999999.999".to_string(), + i256::MAX.to_string(), + i256::MIN.to_string(), + ]); + let overflow_array = Arc::new(overflow_str_array) as ArrayRef; + + test_cast_string_to_decimal256_overflow(overflow_array); + } + + #[test] + fn test_cast_date32_to_timestamp() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Second, None)).unwrap(); + let c = b.as_primitive::(); + assert_eq!(1609459200, c.value(0)); + assert_eq!(1640995200, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_timestamp_ms() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Millisecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1609459200000, c.value(0)); + assert_eq!(1640995200000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_timestamp_us() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1609459200000000, c.value(0)); + assert_eq!(1640995200000000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_cast_date32_to_timestamp_ns() { + let a = Date32Array::from(vec![Some(18628), Some(18993), None]); // 2021-1-1, 2022-1-1 + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1609459200000000000, c.value(0)); + assert_eq!(1640995200000000000, c.value(1)); + assert!(c.is_null(2)); + } + + #[test] + fn test_timezone_cast() { + let a = StringArray::from(vec![ + "2000-01-01T12:00:00", // date + time valid + "2020-12-15T12:34:56", // date + time valid + ]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); + let v = b.as_primitive::(); + + assert_eq!(v.value(0), 946728000000000000); + assert_eq!(v.value(1), 1608035696000000000); + + let b = cast( + &b, + &DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + ) + .unwrap(); + let v = b.as_primitive::(); + + assert_eq!(v.value(0), 946728000000000000); + assert_eq!(v.value(1), 1608035696000000000); + + let b = cast( + &b, + &DataType::Timestamp(TimeUnit::Millisecond, Some("+02:00".into())), + ) + .unwrap(); + let v = b.as_primitive::(); + + assert_eq!(v.value(0), 946728000000); + assert_eq!(v.value(1), 1608035696000); + } + + #[test] + fn test_cast_utf8_to_timestamp() { + fn test_tz(tz: Arc) { + let valid = StringArray::from(vec![ + "2023-01-01 04:05:06.789000-08:00", + "2023-01-01 04:05:06.789000-07:00", + "2023-01-01 04:05:06.789 -0800", + "2023-01-01 04:05:06.789 -08:00", + "2023-01-01 040506 +0730", + "2023-01-01 040506 +07:30", + "2023-01-01 04:05:06.789", + "2023-01-01 04:05:06", + "2023-01-01", + ]); + + let array = Arc::new(valid) as ArrayRef; + let b = cast_with_options( + &array, + &DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.clone())), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap(); + + let tz = tz.as_ref().parse().unwrap(); + + let as_tz = + |v: i64| as_datetime_with_timezone::(v, tz).unwrap(); + + let as_utc = |v: &i64| as_tz(*v).naive_utc().to_string(); + let as_local = |v: &i64| as_tz(*v).naive_local().to_string(); + + let values = b.as_primitive::().values(); + let utc_results: Vec<_> = values.iter().map(as_utc).collect(); + let local_results: Vec<_> = values.iter().map(as_local).collect(); + + // Absolute timestamps should be parsed preserving the same UTC instant + assert_eq!( + &utc_results[..6], + &[ + "2023-01-01 12:05:06.789".to_string(), + "2023-01-01 11:05:06.789".to_string(), + "2023-01-01 12:05:06.789".to_string(), + "2023-01-01 12:05:06.789".to_string(), + "2022-12-31 20:35:06".to_string(), + "2022-12-31 20:35:06".to_string(), + ] + ); + // Non-absolute timestamps should be parsed preserving the same local instant + assert_eq!( + &local_results[6..], + &[ + "2023-01-01 04:05:06.789".to_string(), + "2023-01-01 04:05:06".to_string(), + "2023-01-01 00:00:00".to_string() + ] + ) + } + + test_tz("+00:00".into()); + test_tz("+02:00".into()); + } + + #[test] + fn test_cast_invalid_utf8() { + let v1: &[u8] = b"\xFF invalid"; + let v2: &[u8] = b"\x00 Foo"; + let s = BinaryArray::from(vec![v1, v2]); + let options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + let array = cast_with_options(&s, &DataType::Utf8, &options).unwrap(); + let a = array.as_string::(); + a.to_data().validate_full().unwrap(); + + assert_eq!(a.null_count(), 1); + assert_eq!(a.len(), 2); + assert!(a.is_null(0)); + assert_eq!(a.value(0), ""); + assert_eq!(a.value(1), "\x00 Foo"); + } + + #[test] + fn test_cast_utf8_to_timestamptz() { + let valid = StringArray::from(vec!["2023-01-01"]); + + let array = Arc::new(valid) as ArrayRef; + let b = cast( + &array, + &DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + ) + .unwrap(); + + let expect = DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())); + + assert_eq!(b.data_type(), &expect); + let c = b + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(1672531200000000000, c.value(0)); + } + + #[test] + fn test_cast_decimal_to_utf8() { + fn test_decimal_to_string( + output_type: DataType, + array: PrimitiveArray, + ) { + let b = cast(&array, &output_type).unwrap(); + + assert_eq!(b.data_type(), &output_type); + let c = b.as_string::(); + + assert_eq!("1123.454", c.value(0)); + assert_eq!("2123.456", c.value(1)); + assert_eq!("-3123.453", c.value(2)); + assert_eq!("-3123.456", c.value(3)); + assert_eq!("0.000", c.value(4)); + assert_eq!("0.123", c.value(5)); + assert_eq!("1234.567", c.value(6)); + assert_eq!("-1234.567", c.value(7)); + assert!(c.is_null(8)); + } + let array128: Vec> = vec![ + Some(1123454), + Some(2123456), + Some(-3123453), + Some(-3123456), + Some(0), + Some(123), + Some(123456789), + Some(-123456789), + None, + ]; + + let array256: Vec> = array128.iter().map(|v| v.map(i256::from_i128)).collect(); + + test_decimal_to_string::( + DataType::Utf8, + create_decimal_array(array128.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal_array(array128, 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::Utf8, + create_decimal256_array(array256.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal256_array(array256, 7, 3).unwrap(), + ); + } + + #[test] + fn test_cast_numeric_to_decimal128_precision_overflow() { + let array = Int64Array::from(vec![1234567]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal128(7, 3), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal128(7, 3), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal128 of precision 7. Max is 9999999", err.unwrap_err().to_string()); + } + + #[test] + fn test_cast_numeric_to_decimal256_precision_overflow() { + let array = Int64Array::from(vec![1234567]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Decimal256(7, 3), + &CastOptions { + safe: true, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_ok()); + assert!(casted_array.unwrap().is_null(0)); + + let err = cast_with_options( + &array, + &DataType::Decimal256(7, 3), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal256 of precision 7. Max is 9999999", err.unwrap_err().to_string()); + } + + /// helper function to test casting from duration to interval + fn cast_from_duration_to_interval>( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let array = PrimitiveArray::::new(array.into(), None); + let array = Arc::new(array) as ArrayRef; + let interval = DataType::Interval(IntervalUnit::MonthDayNano); + let out = cast_with_options(&array, &interval, cast_options)?; + let out = out.as_primitive::().clone(); + Ok(out) + } + + #[test] + fn test_cast_from_duration_to_interval() { + // from duration second to interval month day nano + let array = vec![1234567]; + let casted_array = + cast_from_duration_to_interval::(array, &CastOptions::default()) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000000000); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &CastOptions::default(), + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + + // from duration millisecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000000); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &CastOptions::default(), + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + + // from duration microsecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &CastOptions::default(), + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!(casted_array.is_err()); + + // from duration nanosecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap(); + assert_eq!(casted_array.value(0), 9223372036854775807); + } + + /// helper function to test casting from interval to duration + fn cast_from_interval_to_duration( + array: &IntervalMonthDayNanoArray, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let casted_array = cast_with_options(&array, &T::DATA_TYPE, cast_options)?; + casted_array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError(format!("Failed to downcast to {}", T::DATA_TYPE)) + }) + .cloned() + } + + #[test] + fn test_cast_from_interval_to_duration() { + let nullable = CastOptions::default(); + let fallible = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + // from interval month day nano to duration second + let array = vec![1234567].into(); + let casted_array: DurationSecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 0); + + let array = vec![i128::MAX].into(); + let casted_array: DurationSecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let res = cast_from_interval_to_duration::(&array, &fallible); + assert!(res.is_err()); + + // from interval month day nano to duration millisecond + let array = vec![1234567].into(); + let casted_array: DurationMillisecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 1); + + let array = vec![i128::MAX].into(); + let casted_array: DurationMillisecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let res = cast_from_interval_to_duration::(&array, &fallible); + assert!(res.is_err()); + + // from interval month day nano to duration microsecond + let array = vec![1234567].into(); + let casted_array: DurationMicrosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 1234); + + let array = vec![i128::MAX].into(); + let casted_array = + cast_from_interval_to_duration::(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = + cast_from_interval_to_duration::(&array, &fallible); + assert!(casted_array.is_err()); + + // from interval month day nano to duration nanosecond + let array = vec![1234567].into(); + let casted_array: DurationNanosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert_eq!(casted_array.value(0), 1234567); + + let array = vec![i128::MAX].into(); + let casted_array: DurationNanosecondArray = + cast_from_interval_to_duration(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = + cast_from_interval_to_duration::(&array, &fallible); + assert!(casted_array.is_err()); + + let array = vec![ + IntervalMonthDayNanoType::make_value(0, 1, 0), + IntervalMonthDayNanoType::make_value(-1, 0, 0), + IntervalMonthDayNanoType::make_value(1, 1, 0), + IntervalMonthDayNanoType::make_value(1, 0, 1), + IntervalMonthDayNanoType::make_value(0, 0, -1), + ] + .into(); + let casted_array = + cast_from_interval_to_duration::(&array, &nullable).unwrap(); + assert!(!casted_array.is_valid(0)); + assert!(!casted_array.is_valid(1)); + assert!(!casted_array.is_valid(2)); + assert!(!casted_array.is_valid(3)); + assert!(casted_array.is_valid(4)); + assert_eq!(casted_array.value(4), -1); + } + + /// helper function to test casting from interval year month to interval month day nano + fn cast_from_interval_year_month_to_interval_month_day_nano( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let array = PrimitiveArray::::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Interval(IntervalUnit::MonthDayNano), + cast_options, + )?; + casted_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Failed to downcast to IntervalMonthDayNanoArray".to_string(), + ) + }) + .cloned() + } + + #[test] + fn test_cast_from_interval_year_month_to_interval_month_day_nano() { + // from interval year month to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_interval_year_month_to_interval_month_day_nano( + array, + &CastOptions::default(), + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 97812474910747780469848774134464512); + } + + /// helper function to test casting from interval day time to interval month day nano + fn cast_from_interval_day_time_to_interval_month_day_nano( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let array = PrimitiveArray::::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Interval(IntervalUnit::MonthDayNano), + cast_options, + )?; + Ok(casted_array + .as_primitive::() + .clone()) + } + + #[test] + fn test_cast_from_interval_day_time_to_interval_month_day_nano() { + // from interval day time to interval month day nano + let array = vec![123]; + let casted_array = + cast_from_interval_day_time_to_interval_month_day_nano(array, &CastOptions::default()) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 123000000); + } + + #[test] + fn test_cast_below_unixtimestamp() { + let valid = StringArray::from(vec![ + "1900-01-03 23:59:59", + "1969-12-31 00:00:01", + "1989-12-31 00:00:01", + ]); + + let array = Arc::new(valid) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap(); + + let ts_array = casted_array + .as_primitive::() + .values() + .iter() + .map(|ts| ts / 1_000_000) + .collect::>(); + + let array = TimestampMillisecondArray::from(ts_array).with_timezone("UTC".to_string()); + let casted_array = cast(&array, &DataType::Date32).unwrap(); + let date_array = casted_array.as_primitive::(); + let casted_array = cast(&date_array, &DataType::Utf8).unwrap(); + let string_array = casted_array.as_string::(); + assert_eq!("1900-01-03", string_array.value(0)); + assert_eq!("1969-12-31", string_array.value(1)); + assert_eq!("1989-12-31", string_array.value(2)); + } + + #[test] + fn test_nested_list() { + let mut list = ListBuilder::new(Int32Builder::new()); + list.append_value([Some(1), Some(2), Some(3)]); + list.append_value([Some(4), None, Some(6)]); + let list = list.finish(); + + let to_field = Field::new("nested", list.data_type().clone(), false); + let to = DataType::List(Arc::new(to_field)); + let out = cast(&list, &to).unwrap(); + let opts = FormatOptions::default().with_null("null"); + let formatted = ArrayFormatter::try_new(out.as_ref(), &opts).unwrap(); + + assert_eq!(formatted.value(0).to_string(), "[[1], [2], [3]]"); + assert_eq!(formatted.value(1).to_string(), "[[4], [null], [6]]"); + } + + #[test] + fn test_nested_list_cast() { + let mut builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + builder.append_value([Some([Some(1), Some(2), None]), None]); + builder.append_value([None, Some([]), None]); + builder.append_null(); + builder.append_value([Some([Some(2), Some(3)])]); + let start = builder.finish(); + + let mut builder = LargeListBuilder::new(LargeListBuilder::new(Int8Builder::new())); + builder.append_value([Some([Some(1), Some(2), None]), None]); + builder.append_value([None, Some([]), None]); + builder.append_null(); + builder.append_value([Some([Some(2), Some(3)])]); + let expected = builder.finish(); + + let actual = cast(&start, expected.data_type()).unwrap(); + assert_eq!(actual.as_ref(), &expected); + } + + const CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: true, + format_options: FormatOptions::new(), + }; + + #[test] + #[allow(clippy::assertions_on_constants)] + fn test_const_options() { + assert!(CAST_OPTIONS.safe) + } + + #[test] + fn test_list_format_options() { + let options = CastOptions { + safe: false, + format_options: FormatOptions::default().with_null("null"), + }; + let array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(0), None, Some(2)]), + ]); + let a = cast_with_options(&array, &DataType::Utf8, &options).unwrap(); + let r: Vec<_> = a.as_string::().iter().map(|x| x.unwrap()).collect(); + assert_eq!(r, &["[0, 1, 2]", "[0, null, 2]"]); + } + #[test] + fn test_cast_string_to_timestamp_invalid_tz() { + // content after Z should be ignored + let bad_timestamp = "2023-12-05T21:58:10.45ZZTOP"; + let array = StringArray::from(vec![Some(bad_timestamp)]); + + let data_types = [ + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ]; + + let cast_options = CastOptions { + safe: false, + ..Default::default() + }; + + for dt in data_types { + assert_eq!( + cast_with_options(&array, &dt, &cast_options).unwrap_err().to_string(), + "Parser error: Invalid timezone \"ZZTOP\": only offset based timezones supported without chrono-tz feature" + ); + } + } +} diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs new file mode 100644 index 000000000000..edf7c9394c88 --- /dev/null +++ b/arrow-cast/src/display.rs @@ -0,0 +1,1118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functions for printing array values, as strings, for debugging +//! purposes. See the `pretty` crate for additional functions for +//! record batch pretty printing. + +use std::fmt::{Display, Formatter, Write}; +use std::ops::Range; + +use arrow_array::cast::*; +use arrow_array::temporal_conversions::*; +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_schema::*; +use chrono::{NaiveDate, NaiveDateTime, SecondsFormat, TimeZone, Utc}; +use lexical_core::FormattedSize; + +type TimeFormat<'a> = Option<&'a str>; + +/// Format for displaying durations +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum DurationFormat { + /// ISO 8601 - `P198DT72932.972880S` + ISO8601, + /// A human readable representation - `198 days 16 hours 34 mins 15.407810000 secs` + Pretty, +} + +/// Options for formatting arrays +/// +/// By default nulls are formatted as `""` and temporal types formatted +/// according to RFC3339 +/// +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct FormatOptions<'a> { + /// If set to `true` any formatting errors will be written to the output + /// instead of being converted into a [`std::fmt::Error`] + safe: bool, + /// Format string for nulls + null: &'a str, + /// Date format for date arrays + date_format: TimeFormat<'a>, + /// Format for DateTime arrays + datetime_format: TimeFormat<'a>, + /// Timestamp format for timestamp arrays + timestamp_format: TimeFormat<'a>, + /// Timestamp format for timestamp with timezone arrays + timestamp_tz_format: TimeFormat<'a>, + /// Time format for time arrays + time_format: TimeFormat<'a>, + /// Duration format + duration_format: DurationFormat, +} + +impl<'a> Default for FormatOptions<'a> { + fn default() -> Self { + Self::new() + } +} + +impl<'a> FormatOptions<'a> { + pub const fn new() -> Self { + Self { + safe: true, + null: "", + date_format: None, + datetime_format: None, + timestamp_format: None, + timestamp_tz_format: None, + time_format: None, + duration_format: DurationFormat::ISO8601, + } + } + + /// If set to `true` any formatting errors will be written to the output + /// instead of being converted into a [`std::fmt::Error`] + pub const fn with_display_error(mut self, safe: bool) -> Self { + self.safe = safe; + self + } + + /// Overrides the string used to represent a null + /// + /// Defaults to `""` + pub const fn with_null(self, null: &'a str) -> Self { + Self { null, ..self } + } + + /// Overrides the format used for [`DataType::Date32`] columns + pub const fn with_date_format(self, date_format: Option<&'a str>) -> Self { + Self { + date_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Date64`] columns + pub const fn with_datetime_format(self, datetime_format: Option<&'a str>) -> Self { + Self { + datetime_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Timestamp`] columns without a timezone + pub const fn with_timestamp_format(self, timestamp_format: Option<&'a str>) -> Self { + Self { + timestamp_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Timestamp`] columns with a timezone + pub const fn with_timestamp_tz_format(self, timestamp_tz_format: Option<&'a str>) -> Self { + Self { + timestamp_tz_format, + ..self + } + } + + /// Overrides the format used for [`DataType::Time32`] and [`DataType::Time64`] columns + pub const fn with_time_format(self, time_format: Option<&'a str>) -> Self { + Self { + time_format, + ..self + } + } + + /// Overrides the format used for duration columns + /// + /// Defaults to [`DurationFormat::ISO8601`] + pub const fn with_duration_format(self, duration_format: DurationFormat) -> Self { + Self { + duration_format, + ..self + } + } +} + +/// Implements [`Display`] for a specific array value +pub struct ValueFormatter<'a> { + idx: usize, + formatter: &'a ArrayFormatter<'a>, +} + +impl<'a> ValueFormatter<'a> { + /// Writes this value to the provided [`Write`] + /// + /// Note: this ignores [`FormatOptions::with_display_error`] and + /// will return an error on formatting issue + pub fn write(&self, s: &mut dyn Write) -> Result<(), ArrowError> { + match self.formatter.format.write(self.idx, s) { + Ok(_) => Ok(()), + Err(FormatError::Arrow(e)) => Err(e), + Err(FormatError::Format(_)) => Err(ArrowError::CastError("Format error".to_string())), + } + } + + /// Fallibly converts this to a string + pub fn try_to_string(&self) -> Result { + let mut s = String::new(); + self.write(&mut s)?; + Ok(s) + } +} + +impl<'a> Display for ValueFormatter<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.formatter.format.write(self.idx, f) { + Ok(()) => Ok(()), + Err(FormatError::Arrow(e)) if self.formatter.safe => { + write!(f, "ERROR: {e}") + } + Err(_) => Err(std::fmt::Error), + } + } +} + +/// A string formatter for an [`Array`] +/// +/// This can be used with [`std::write`] to write type-erased `dyn Array` +/// +/// ``` +/// # use std::fmt::{Display, Formatter, Write}; +/// # use arrow_array::{Array, ArrayRef, Int32Array}; +/// # use arrow_cast::display::{ArrayFormatter, FormatOptions}; +/// # use arrow_schema::ArrowError; +/// struct MyContainer { +/// values: ArrayRef, +/// } +/// +/// impl Display for MyContainer { +/// fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { +/// let options = FormatOptions::default(); +/// let formatter = ArrayFormatter::try_new(self.values.as_ref(), &options) +/// .map_err(|_| std::fmt::Error)?; +/// +/// let mut iter = 0..self.values.len(); +/// if let Some(idx) = iter.next() { +/// write!(f, "{}", formatter.value(idx))?; +/// } +/// for idx in iter { +/// write!(f, ", {}", formatter.value(idx))?; +/// } +/// Ok(()) +/// } +/// } +/// ``` +/// +/// [`ValueFormatter::write`] can also be used to get a semantic error, instead of the +/// opaque [`std::fmt::Error`] +/// +/// ``` +/// # use std::fmt::Write; +/// # use arrow_array::Array; +/// # use arrow_cast::display::{ArrayFormatter, FormatOptions}; +/// # use arrow_schema::ArrowError; +/// fn format_array( +/// f: &mut dyn Write, +/// array: &dyn Array, +/// options: &FormatOptions, +/// ) -> Result<(), ArrowError> { +/// let formatter = ArrayFormatter::try_new(array, options)?; +/// for i in 0..array.len() { +/// formatter.value(i).write(f)? +/// } +/// Ok(()) +/// } +/// ``` +/// +pub struct ArrayFormatter<'a> { + format: Box, + safe: bool, +} + +impl<'a> ArrayFormatter<'a> { + /// Returns an [`ArrayFormatter`] that can be used to format `array` + /// + /// This returns an error if an array of the given data type cannot be formatted + pub fn try_new(array: &'a dyn Array, options: &FormatOptions<'a>) -> Result { + Ok(Self { + format: make_formatter(array, options)?, + safe: options.safe, + }) + } + + /// Returns a [`ValueFormatter`] that implements [`Display`] for + /// the value of the array at `idx` + pub fn value(&self, idx: usize) -> ValueFormatter<'_> { + ValueFormatter { + formatter: self, + idx, + } + } +} + +fn make_formatter<'a>( + array: &'a dyn Array, + options: &FormatOptions<'a>, +) -> Result, ArrowError> { + downcast_primitive_array! { + array => array_format(array, options), + DataType::Null => array_format(as_null_array(array), options), + DataType::Boolean => array_format(as_boolean_array(array), options), + DataType::Utf8 => array_format(array.as_string::(), options), + DataType::LargeUtf8 => array_format(array.as_string::(), options), + DataType::Binary => array_format(array.as_binary::(), options), + DataType::LargeBinary => array_format(array.as_binary::(), options), + DataType::FixedSizeBinary(_) => { + let a = array.as_any().downcast_ref::().unwrap(); + array_format(a, options) + } + DataType::Dictionary(_, _) => downcast_dictionary_array! { + array => array_format(array, options), + _ => unreachable!() + } + DataType::List(_) => array_format(as_generic_list_array::(array), options), + DataType::LargeList(_) => array_format(as_generic_list_array::(array), options), + DataType::FixedSizeList(_, _) => { + let a = array.as_any().downcast_ref::().unwrap(); + array_format(a, options) + } + DataType::Struct(_) => array_format(as_struct_array(array), options), + DataType::Map(_, _) => array_format(as_map_array(array), options), + DataType::Union(_, _) => array_format(as_union_array(array), options), + DataType::RunEndEncoded(_, _) => downcast_run_array! { + array => array_format(array, options), + _ => unreachable!() + }, + d => Err(ArrowError::NotYetImplemented(format!("formatting {d} is not yet supported"))), + } +} + +/// Either an [`ArrowError`] or [`std::fmt::Error`] +enum FormatError { + Format(std::fmt::Error), + Arrow(ArrowError), +} + +type FormatResult = Result<(), FormatError>; + +impl From for FormatError { + fn from(value: std::fmt::Error) -> Self { + Self::Format(value) + } +} + +impl From for FormatError { + fn from(value: ArrowError) -> Self { + Self::Arrow(value) + } +} + +/// [`Display`] but accepting an index +trait DisplayIndex { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult; +} + +/// [`DisplayIndex`] with additional state +trait DisplayIndexState<'a> { + type State; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result; + + fn write(&self, state: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult; +} + +impl<'a, T: DisplayIndex> DisplayIndexState<'a> for T { + type State = (); + + fn prepare(&self, _options: &FormatOptions<'a>) -> Result { + Ok(()) + } + + fn write(&self, _: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + DisplayIndex::write(self, idx, f) + } +} + +struct ArrayFormat<'a, F: DisplayIndexState<'a>> { + state: F::State, + array: F, + null: &'a str, +} + +fn array_format<'a, F>( + array: F, + options: &FormatOptions<'a>, +) -> Result, ArrowError> +where + F: DisplayIndexState<'a> + Array + 'a, +{ + let state = array.prepare(options)?; + Ok(Box::new(ArrayFormat { + state, + array, + null: options.null, + })) +} + +impl<'a, F: DisplayIndexState<'a> + Array> DisplayIndex for ArrayFormat<'a, F> { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + if self.array.is_null(idx) { + if !self.null.is_empty() { + f.write_str(self.null)? + } + return Ok(()); + } + DisplayIndexState::write(&self.array, &self.state, idx, f) + } +} + +impl<'a> DisplayIndex for &'a BooleanArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", self.value(idx))?; + Ok(()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a NullArray { + type State = &'a str; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok(options.null) + } + + fn write(&self, state: &Self::State, _idx: usize, f: &mut dyn Write) -> FormatResult { + f.write_str(state)?; + Ok(()) + } +} + +macro_rules! primitive_display { + ($($t:ty),+) => { + $(impl<'a> DisplayIndex for &'a PrimitiveArray<$t> + { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let mut buffer = [0u8; <$t as ArrowPrimitiveType>::Native::FORMATTED_SIZE]; + // SAFETY: + // buffer is T::FORMATTED_SIZE + let b = unsafe { lexical_core::write_unchecked(value, &mut buffer) }; + // Lexical core produces valid UTF-8 + let s = unsafe { std::str::from_utf8_unchecked(b) }; + f.write_str(s)?; + Ok(()) + } + })+ + }; +} + +primitive_display!(Int8Type, Int16Type, Int32Type, Int64Type); +primitive_display!(UInt8Type, UInt16Type, UInt32Type, UInt64Type); +primitive_display!(Float32Type, Float64Type); + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", self.value(idx))?; + Ok(()) + } +} + +macro_rules! decimal_display { + ($($t:ty),+) => { + $(impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = (u8, i8); + + fn prepare(&self, _options: &FormatOptions<'a>) -> Result { + Ok((self.precision(), self.scale())) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", <$t>::format_decimal(self.values()[idx], s.0, s.1))?; + Ok(()) + } + })+ + }; +} + +decimal_display!(Decimal128Type, Decimal256Type); + +fn write_timestamp( + f: &mut dyn Write, + naive: NaiveDateTime, + timezone: Option, + format: Option<&str>, +) -> FormatResult { + match timezone { + Some(tz) => { + let date = Utc.from_utc_datetime(&naive).with_timezone(&tz); + match format { + Some(s) => write!(f, "{}", date.format(s))?, + None => write!(f, "{}", date.to_rfc3339_opts(SecondsFormat::AutoSi, true))?, + } + } + None => match format { + Some(s) => write!(f, "{}", naive.format(s))?, + None => write!(f, "{naive:?}")?, + }, + } + Ok(()) +} + +macro_rules! timestamp_display { + ($($t:ty),+) => { + $(impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = (Option, TimeFormat<'a>); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + match self.data_type() { + DataType::Timestamp(_, Some(tz)) => Ok((Some(tz.parse()?), options.timestamp_tz_format)), + DataType::Timestamp(_, None) => Ok((None, options.timestamp_format)), + _ => unreachable!(), + } + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let naive = as_datetime::<$t>(value).ok_or_else(|| { + ArrowError::CastError(format!( + "Failed to convert {} to datetime for {}", + value, + self.data_type() + )) + })?; + + write_timestamp(f, naive, s.0, s.1.clone()) + } + })+ + }; +} + +timestamp_display!( + TimestampSecondType, + TimestampMillisecondType, + TimestampMicrosecondType, + TimestampNanosecondType +); + +macro_rules! temporal_display { + ($convert:ident, $format:ident, $t:ty) => { + impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = TimeFormat<'a>; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok(options.$format) + } + + fn write(&self, fmt: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value = self.value(idx); + let naive = $convert(value as _).ok_or_else(|| { + ArrowError::CastError(format!( + "Failed to convert {} to temporal for {}", + value, + self.data_type() + )) + })?; + + match fmt { + Some(s) => write!(f, "{}", naive.format(s))?, + None => write!(f, "{naive:?}")?, + } + Ok(()) + } + } + }; +} + +#[inline] +fn date32_to_date(value: i32) -> Option { + Some(date32_to_datetime(value)?.date()) +} + +temporal_display!(date32_to_date, date_format, Date32Type); +temporal_display!(date64_to_datetime, datetime_format, Date64Type); +temporal_display!(time32s_to_time, time_format, Time32SecondType); +temporal_display!(time32ms_to_time, time_format, Time32MillisecondType); +temporal_display!(time64us_to_time, time_format, Time64MicrosecondType); +temporal_display!(time64ns_to_time, time_format, Time64NanosecondType); + +macro_rules! duration_display { + ($convert:ident, $t:ty, $scale:tt) => { + impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { + type State = DurationFormat; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok(options.duration_format) + } + + fn write(&self, fmt: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let v = self.value(idx); + match fmt { + DurationFormat::ISO8601 => write!(f, "{}", $convert(v))?, + DurationFormat::Pretty => duration_fmt!(f, v, $scale)?, + } + Ok(()) + } + } + }; +} + +macro_rules! duration_fmt { + ($f:ident, $v:expr, 0) => {{ + let secs = $v; + let mins = secs / 60; + let hours = mins / 60; + let days = hours / 24; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + let hours = hours - (days * 24); + write!($f, "{days} days {hours} hours {mins} mins {secs} secs") + }}; + ($f:ident, $v:expr, $scale:tt) => {{ + let subsec = $v; + let secs = subsec / 10_i64.pow($scale); + let mins = secs / 60; + let hours = mins / 60; + let days = hours / 24; + + let subsec = subsec - (secs * 10_i64.pow($scale)); + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + let hours = hours - (days * 24); + match subsec.is_negative() { + true => { + write!( + $f, + concat!("{} days {} hours {} mins -{}.{:0", $scale, "} secs"), + days, + hours, + mins, + secs.abs(), + subsec.abs() + ) + } + false => { + write!( + $f, + concat!("{} days {} hours {} mins {}.{:0", $scale, "} secs"), + days, hours, mins, secs, subsec + ) + } + } + }}; +} + +duration_display!(duration_s_to_duration, DurationSecondType, 0); +duration_display!(duration_ms_to_duration, DurationMillisecondType, 3); +duration_display!(duration_us_to_duration, DurationMicrosecondType, 6); +duration_display!(duration_ns_to_duration, DurationNanosecondType, 9); + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let interval = self.value(idx) as f64; + let years = (interval / 12_f64).floor(); + let month = interval - (years * 12_f64); + + write!( + f, + "{years} years {month} mons 0 days 0 hours 0 mins 0.00 secs", + )?; + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let value: u64 = self.value(idx) as u64; + + let days_parts: i32 = ((value & 0xFFFFFFFF00000000) >> 32) as i32; + let milliseconds_part: i32 = (value & 0xFFFFFFFF) as i32; + + let secs = milliseconds_part / 1_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let milliseconds = milliseconds_part % 1_000; + + let secs_sign = if secs < 0 || milliseconds < 0 { + "-" + } else { + "" + }; + + write!( + f, + "0 years 0 mons {} days {} hours {} mins {}{}.{:03} secs", + days_parts, + hours, + mins, + secs_sign, + secs.abs(), + milliseconds.abs(), + )?; + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a PrimitiveArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let value: u128 = self.value(idx) as u128; + + let months_part: i32 = ((value & 0xFFFFFFFF000000000000000000000000) >> 96) as i32; + let days_part: i32 = ((value & 0xFFFFFFFF0000000000000000) >> 64) as i32; + let nanoseconds_part: i64 = (value & 0xFFFFFFFFFFFFFFFF) as i64; + + let secs = nanoseconds_part / 1_000_000_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let nanoseconds = nanoseconds_part % 1_000_000_000; + + let secs_sign = if secs < 0 || nanoseconds < 0 { "-" } else { "" }; + + write!( + f, + "0 years {} mons {} days {} hours {} mins {}{}.{:09} secs", + months_part, + days_part, + hours, + mins, + secs_sign, + secs.abs(), + nanoseconds.abs(), + )?; + Ok(()) + } +} + +impl<'a, O: OffsetSizeTrait> DisplayIndex for &'a GenericStringArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + write!(f, "{}", self.value(idx))?; + Ok(()) + } +} + +impl<'a, O: OffsetSizeTrait> DisplayIndex for &'a GenericBinaryArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let v = self.value(idx); + for byte in v { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl<'a> DisplayIndex for &'a FixedSizeBinaryArray { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + let v = self.value(idx); + for byte in v { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl<'a, K: ArrowDictionaryKeyType> DisplayIndexState<'a> for &'a DictionaryArray { + type State = Box; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + make_formatter(self.values().as_ref(), options) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value_idx = self.keys().values()[idx].as_usize(); + s.as_ref().write(value_idx, f) + } +} + +impl<'a, K: RunEndIndexType> DisplayIndexState<'a> for &'a RunArray { + type State = Box; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + make_formatter(self.values().as_ref(), options) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let value_idx = self.get_physical_index(idx); + s.as_ref().write(value_idx, f) + } +} + +fn write_list( + f: &mut dyn Write, + mut range: Range, + values: &dyn DisplayIndex, +) -> FormatResult { + f.write_char('[')?; + if let Some(idx) = range.next() { + values.write(idx, f)?; + } + for idx in range { + write!(f, ", ")?; + values.write(idx, f)?; + } + f.write_char(']')?; + Ok(()) +} + +impl<'a, O: OffsetSizeTrait> DisplayIndexState<'a> for &'a GenericListArray { + type State = Box; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + make_formatter(self.values().as_ref(), options) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let offsets = self.value_offsets(); + let end = offsets[idx + 1].as_usize(); + let start = offsets[idx].as_usize(); + write_list(f, start..end, s.as_ref()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a FixedSizeListArray { + type State = (usize, Box); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let values = make_formatter(self.values().as_ref(), options)?; + let length = self.value_length(); + Ok((length as usize, values)) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let start = idx * s.0; + let end = start + s.0; + write_list(f, start..end, s.1.as_ref()) + } +} + +/// Pairs a boxed [`DisplayIndex`] with its field name +type FieldDisplay<'a> = (&'a str, Box); + +impl<'a> DisplayIndexState<'a> for &'a StructArray { + type State = Vec>; + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let fields = match (*self).data_type() { + DataType::Struct(f) => f, + _ => unreachable!(), + }; + + self.columns() + .iter() + .zip(fields) + .map(|(a, f)| { + let format = make_formatter(a.as_ref(), options)?; + Ok((f.name().as_str(), format)) + }) + .collect() + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let mut iter = s.iter(); + f.write_char('{')?; + if let Some((name, display)) = iter.next() { + write!(f, "{name}: ")?; + display.as_ref().write(idx, f)?; + } + for (name, display) in iter { + write!(f, ", {name}: ")?; + display.as_ref().write(idx, f)?; + } + f.write_char('}')?; + Ok(()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a MapArray { + type State = (Box, Box); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let keys = make_formatter(self.keys().as_ref(), options)?; + let values = make_formatter(self.values().as_ref(), options)?; + Ok((keys, values)) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let offsets = self.value_offsets(); + let end = offsets[idx + 1].as_usize(); + let start = offsets[idx].as_usize(); + let mut iter = start..end; + + f.write_char('{')?; + if let Some(idx) = iter.next() { + s.0.write(idx, f)?; + write!(f, ": ")?; + s.1.write(idx, f)?; + } + + for idx in iter { + write!(f, ", ")?; + s.0.write(idx, f)?; + write!(f, ": ")?; + s.1.write(idx, f)?; + } + + f.write_char('}')?; + Ok(()) + } +} + +impl<'a> DisplayIndexState<'a> for &'a UnionArray { + type State = ( + Vec)>>, + UnionMode, + ); + + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + let (fields, mode) = match (*self).data_type() { + DataType::Union(fields, mode) => (fields, mode), + _ => unreachable!(), + }; + + let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default() as usize; + let mut out: Vec> = (0..max_id + 1).map(|_| None).collect(); + for (i, field) in fields.iter() { + let formatter = make_formatter(self.child(i).as_ref(), options)?; + out[i as usize] = Some((field.name().as_str(), formatter)) + } + Ok((out, *mode)) + } + + fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { + let id = self.type_id(idx); + let idx = match s.1 { + UnionMode::Dense => self.value_offset(idx), + UnionMode::Sparse => idx, + }; + let (name, field) = s.0[id as usize].as_ref().unwrap(); + + write!(f, "{{{name}=")?; + field.write(idx, f)?; + f.write_char('}')?; + Ok(()) + } +} + +/// Get the value at the given row in an array as a String. +/// +/// Note this function is quite inefficient and is unlikely to be +/// suitable for converting large arrays or record batches. +/// +/// Please see [`ArrayFormatter`] for a more performant interface +pub fn array_value_to_string(column: &dyn Array, row: usize) -> Result { + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(column, &options)?; + Ok(formatter.value(row).to_string()) +} + +/// Converts numeric type to a `String` +pub fn lexical_to_string(n: N) -> String { + let mut buf = Vec::::with_capacity(N::FORMATTED_SIZE_DECIMAL); + unsafe { + // JUSTIFICATION + // Benefit + // Allows using the faster serializer lexical core and convert to string + // Soundness + // Length of buf is set as written length afterwards. lexical_core + // creates a valid string, so doesn't need to be checked. + let slice = std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); + let len = lexical_core::write(n, slice).len(); + buf.set_len(len); + String::from_utf8_unchecked(buf) + } +} + +#[cfg(test)] +mod tests { + use arrow_array::builder::StringRunBuilder; + + use super::*; + + /// Test to verify options can be constant. See #4580 + const TEST_CONST_OPTIONS: FormatOptions<'static> = FormatOptions::new() + .with_date_format(Some("foo")) + .with_timestamp_format(Some("404")); + + #[test] + fn test_const_options() { + assert_eq!(TEST_CONST_OPTIONS.date_format, Some("foo")); + } + + #[test] + fn test_map_array_to_string() { + let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; + let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]); + + // Construct a buffer for value offsets, for the nested array: + // [[a, b, c], [d, e, f], [g, h]] + let entry_offsets = [0, 3, 6, 8]; + + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + assert_eq!( + "{d: 30, e: 40, f: 50}", + array_value_to_string(&map_array, 1).unwrap() + ); + } + + fn format_array(array: &dyn Array, fmt: &FormatOptions) -> Vec { + let fmt = ArrayFormatter::try_new(array, fmt).unwrap(); + (0..array.len()).map(|x| fmt.value(x).to_string()).collect() + } + + #[test] + fn test_array_value_to_string_duration() { + let iso_fmt = FormatOptions::new(); + let pretty_fmt = FormatOptions::new().with_duration_format(DurationFormat::Pretty); + + let array = DurationNanosecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000_000 + 123456789, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000_000 - 123456789, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.000000001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.000000001 secs"); + assert_eq!(iso[1], "-PT0.000000001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.000000001 secs"); + assert_eq!(iso[2], "PT0.000001S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 0.000001000 secs"); + assert_eq!(iso[3], "-PT0.000001S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -0.000001000 secs"); + assert_eq!(iso[4], "P45DT50554.123456789S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123456789 secs"); + assert_eq!(iso[5], "-P45DT50554.123456789S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123456789 secs"); + + let array = DurationMicrosecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000 + 123456, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000_000 - 123456, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.000001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.000001 secs"); + assert_eq!(iso[1], "-PT0.000001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.000001 secs"); + assert_eq!(iso[2], "PT0.001S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 0.001000 secs"); + assert_eq!(iso[3], "-PT0.001S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -0.001000 secs"); + assert_eq!(iso[4], "P45DT50554.123456S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123456 secs"); + assert_eq!(iso[5], "-P45DT50554.123456S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123456 secs"); + + let array = DurationMillisecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + (45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000 + 123, + -(45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34) * 1_000 - 123, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT0.001S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 0.001 secs"); + assert_eq!(iso[1], "-PT0.001S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -0.001 secs"); + assert_eq!(iso[2], "PT1S"); + assert_eq!(pretty[2], "0 days 0 hours 0 mins 1.000 secs"); + assert_eq!(iso[3], "-PT1S"); + assert_eq!(pretty[3], "0 days 0 hours 0 mins -1.000 secs"); + assert_eq!(iso[4], "P45DT50554.123S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34.123 secs"); + assert_eq!(iso[5], "-P45DT50554.123S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34.123 secs"); + + let array = DurationSecondArray::from(vec![ + 1, + -1, + 1000, + -1000, + 45 * 60 * 60 * 24 + 14 * 60 * 60 + 2 * 60 + 34, + -45 * 60 * 60 * 24 - 14 * 60 * 60 - 2 * 60 - 34, + ]); + let iso = format_array(&array, &iso_fmt); + let pretty = format_array(&array, &pretty_fmt); + + assert_eq!(iso[0], "PT1S"); + assert_eq!(pretty[0], "0 days 0 hours 0 mins 1 secs"); + assert_eq!(iso[1], "-PT1S"); + assert_eq!(pretty[1], "0 days 0 hours 0 mins -1 secs"); + assert_eq!(iso[2], "PT1000S"); + assert_eq!(pretty[2], "0 days 0 hours 16 mins 40 secs"); + assert_eq!(iso[3], "-PT1000S"); + assert_eq!(pretty[3], "0 days 0 hours -16 mins -40 secs"); + assert_eq!(iso[4], "P45DT50554S"); + assert_eq!(pretty[4], "45 days 14 hours 2 mins 34 secs"); + assert_eq!(iso[5], "-P45DT50554S"); + assert_eq!(pretty[5], "-45 days -14 hours -2 mins -34 secs"); + } + + #[test] + fn test_null() { + let array = NullArray::new(2); + let options = FormatOptions::new().with_null("NULL"); + let formatted = format_array(&array, &options); + assert_eq!(formatted, &["NULL".to_string(), "NULL".to_string()]) + } + + #[test] + fn test_string_run_arry_to_string() { + let mut builder = StringRunBuilder::::new(); + + builder.append_value("input_value"); + builder.append_value("input_value"); + builder.append_value("input_value"); + builder.append_value("input_value1"); + + let map_array = builder.finish(); + assert_eq!("input_value", array_value_to_string(&map_array, 1).unwrap()); + assert_eq!( + "input_value1", + array_value_to_string(&map_array, 3).unwrap() + ); + } +} diff --git a/arrow-cast/src/lib.rs b/arrow-cast/src/lib.rs new file mode 100644 index 000000000000..71ebe6c0ed8b --- /dev/null +++ b/arrow-cast/src/lib.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Cast kernel for [Apache Arrow](https://docs.rs/arrow) + +pub mod cast; +pub use cast::*; +pub mod display; +pub mod parse; +#[cfg(feature = "prettyprint")] +pub mod pretty; + +pub mod base64; diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs new file mode 100644 index 000000000000..3d2e47ed95a4 --- /dev/null +++ b/arrow-cast/src/parse.rs @@ -0,0 +1,2288 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::timezone::Tz; +use arrow_array::types::*; +use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; +use arrow_buffer::ArrowNativeType; +use arrow_schema::ArrowError; +use chrono::prelude::*; +use half::f16; +use std::str::FromStr; + +/// Parse nanoseconds from the first `N` values in digits, subtracting the offset `O` +#[inline] +fn parse_nanos(digits: &[u8]) -> u32 { + digits[..N] + .iter() + .fold(0_u32, |acc, v| acc * 10 + v.wrapping_sub(O) as u32) + * 10_u32.pow((9 - N) as _) +} + +/// Helper for parsing RFC3339 timestamps +struct TimestampParser { + /// The timestamp bytes to parse minus `b'0'` + /// + /// This makes interpretation as an integer inexpensive + digits: [u8; 32], + /// A mask containing a `1` bit where the corresponding byte is a valid ASCII digit + mask: u32, +} + +impl TimestampParser { + fn new(bytes: &[u8]) -> Self { + let mut digits = [0; 32]; + let mut mask = 0; + + // Treating all bytes the same way, helps LLVM vectorise this correctly + for (idx, (o, i)) in digits.iter_mut().zip(bytes).enumerate() { + *o = i.wrapping_sub(b'0'); + mask |= ((*o < 10) as u32) << idx + } + + Self { digits, mask } + } + + /// Returns true if the byte at `idx` in the original string equals `b` + fn test(&self, idx: usize, b: u8) -> bool { + self.digits[idx] == b.wrapping_sub(b'0') + } + + /// Parses a date of the form `1997-01-31` + fn date(&self) -> Option { + if self.mask & 0b1111111111 != 0b1101101111 || !self.test(4, b'-') || !self.test(7, b'-') { + return None; + } + + let year = self.digits[0] as u16 * 1000 + + self.digits[1] as u16 * 100 + + self.digits[2] as u16 * 10 + + self.digits[3] as u16; + + let month = self.digits[5] * 10 + self.digits[6]; + let day = self.digits[8] * 10 + self.digits[9]; + + NaiveDate::from_ymd_opt(year as _, month as _, day as _) + } + + /// Parses a time of any of forms + /// - `09:26:56` + /// - `09:26:56.123` + /// - `09:26:56.123456` + /// - `09:26:56.123456789` + /// - `092656` + /// + /// Returning the end byte offset + fn time(&self) -> Option<(NaiveTime, usize)> { + // Make a NaiveTime handling leap seconds + let time = |hour, min, sec, nano| match sec { + 60 => { + let nano = 1_000_000_000 + nano; + NaiveTime::from_hms_nano_opt(hour as _, min as _, 59, nano) + } + _ => NaiveTime::from_hms_nano_opt(hour as _, min as _, sec as _, nano), + }; + + match (self.mask >> 11) & 0b11111111 { + // 09:26:56 + 0b11011011 if self.test(13, b':') && self.test(16, b':') => { + let hour = self.digits[11] * 10 + self.digits[12]; + let minute = self.digits[14] * 10 + self.digits[15]; + let second = self.digits[17] * 10 + self.digits[18]; + + match self.test(19, b'.') { + true => { + let digits = (self.mask >> 20).trailing_ones(); + let nanos = match digits { + 0 => return None, + 1 => parse_nanos::<1, 0>(&self.digits[20..21]), + 2 => parse_nanos::<2, 0>(&self.digits[20..22]), + 3 => parse_nanos::<3, 0>(&self.digits[20..23]), + 4 => parse_nanos::<4, 0>(&self.digits[20..24]), + 5 => parse_nanos::<5, 0>(&self.digits[20..25]), + 6 => parse_nanos::<6, 0>(&self.digits[20..26]), + 7 => parse_nanos::<7, 0>(&self.digits[20..27]), + 8 => parse_nanos::<8, 0>(&self.digits[20..28]), + _ => parse_nanos::<9, 0>(&self.digits[20..29]), + }; + Some((time(hour, minute, second, nanos)?, 20 + digits as usize)) + } + false => Some((time(hour, minute, second, 0)?, 19)), + } + } + // 092656 + 0b111111 => { + let hour = self.digits[11] * 10 + self.digits[12]; + let minute = self.digits[13] * 10 + self.digits[14]; + let second = self.digits[15] * 10 + self.digits[16]; + let time = time(hour, minute, second, 0)?; + Some((time, 17)) + } + _ => None, + } + } +} + +/// Accepts a string and parses it relative to the provided `timezone` +/// +/// In addition to RFC3339 / ISO8601 standard timestamps, it also +/// accepts strings that use a space ` ` to separate the date and time +/// as well as strings that have no explicit timezone offset. +/// +/// Examples of accepted inputs: +/// * `1997-01-31T09:26:56.123Z` # RCF3339 +/// * `1997-01-31T09:26:56.123-05:00` # RCF3339 +/// * `1997-01-31 09:26:56.123-05:00` # close to RCF3339 but with a space rather than T +/// * `2023-01-01 04:05:06.789 -08` # close to RCF3339, no fractional seconds or time separator +/// * `1997-01-31T09:26:56.123` # close to RCF3339 but no timezone offset specified +/// * `1997-01-31 09:26:56.123` # close to RCF3339 but uses a space and no timezone offset +/// * `1997-01-31 09:26:56` # close to RCF3339, no fractional seconds +/// * `1997-01-31 092656` # close to RCF3339, no fractional seconds +/// * `1997-01-31 092656+04:00` # close to RCF3339, no fractional seconds or time separator +/// * `1997-01-31` # close to RCF3339, only date no time +/// +/// [IANA timezones] are only supported if the `arrow-array/chrono-tz` feature is enabled +/// +/// * `2023-01-01 040506 America/Los_Angeles` +/// +/// If a timestamp is ambiguous, for example as a result of daylight-savings time, an error +/// will be returned +/// +/// Some formats supported by PostgresSql +/// are not supported, like +/// +/// * "2023-01-01 04:05:06.789 +07:30:00", +/// * "2023-01-01 040506 +07:30:00", +/// * "2023-01-01 04:05:06.789 PST", +/// +/// [IANA timezones]: https://www.iana.org/time-zones +pub fn string_to_datetime(timezone: &T, s: &str) -> Result, ArrowError> { + let err = + |ctx: &str| ArrowError::ParseError(format!("Error parsing timestamp from '{s}': {ctx}")); + + let bytes = s.as_bytes(); + if bytes.len() < 10 { + return Err(err("timestamp must contain at least 10 characters")); + } + + let parser = TimestampParser::new(bytes); + let date = parser.date().ok_or_else(|| err("error parsing date"))?; + if bytes.len() == 10 { + let datetime = date.and_time(NaiveTime::from_hms_opt(0, 0, 0).unwrap()); + return timezone + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset")); + } + + if !parser.test(10, b'T') && !parser.test(10, b't') && !parser.test(10, b' ') { + return Err(err("invalid timestamp separator")); + } + + let (time, mut tz_offset) = parser.time().ok_or_else(|| err("error parsing time"))?; + let datetime = date.and_time(time); + + if tz_offset == 32 { + // Decimal overrun + while tz_offset < bytes.len() && bytes[tz_offset].is_ascii_digit() { + tz_offset += 1; + } + } + + if bytes.len() <= tz_offset { + return timezone + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset")); + } + + if (bytes[tz_offset] == b'z' || bytes[tz_offset] == b'Z') && tz_offset == bytes.len() - 1 { + return Ok(timezone.from_utc_datetime(&datetime)); + } + + // Parse remainder of string as timezone + let parsed_tz: Tz = s[tz_offset..].trim_start().parse()?; + let parsed = parsed_tz + .from_local_datetime(&datetime) + .single() + .ok_or_else(|| err("error computing timezone offset"))?; + + Ok(parsed.with_timezone(timezone)) +} + +/// Accepts a string in RFC3339 / ISO8601 standard format and some +/// variants and converts it to a nanosecond precision timestamp. +/// +/// See [`string_to_datetime`] for the full set of supported formats +/// +/// Implements the `to_timestamp` function to convert a string to a +/// timestamp, following the model of spark SQL’s to_`timestamp`. +/// +/// Internally, this function uses the `chrono` library for the +/// datetime parsing +/// +/// We hope to extend this function in the future with a second +/// parameter to specifying the format string. +/// +/// ## Timestamp Precision +/// +/// Function uses the maximum precision timestamps supported by +/// Arrow (nanoseconds stored as a 64-bit integer) timestamps. This +/// means the range of dates that timestamps can represent is ~1677 AD +/// to 2262 AM +/// +/// ## Timezone / Offset Handling +/// +/// Numerical values of timestamps are stored compared to offset UTC. +/// +/// This function interprets string without an explicit time zone as timestamps +/// relative to UTC, see [`string_to_datetime`] for alternative semantics +/// +/// In particular: +/// +/// ``` +/// # use arrow_cast::parse::string_to_timestamp_nanos; +/// // Note all three of these timestamps are parsed as the same value +/// let a = string_to_timestamp_nanos("1997-01-31 09:26:56.123Z").unwrap(); +/// let b = string_to_timestamp_nanos("1997-01-31T09:26:56.123").unwrap(); +/// let c = string_to_timestamp_nanos("1997-01-31T14:26:56.123+05:00").unwrap(); +/// +/// assert_eq!(a, b); +/// assert_eq!(b, c); +/// ``` +/// +#[inline] +pub fn string_to_timestamp_nanos(s: &str) -> Result { + to_timestamp_nanos(string_to_datetime(&Utc, s)?.naive_utc()) +} + +/// Fallible conversion of [`NaiveDateTime`] to `i64` nanoseconds +#[inline] +fn to_timestamp_nanos(dt: NaiveDateTime) -> Result { + dt.timestamp_nanos_opt() + .ok_or_else(|| ArrowError::ParseError(ERR_NANOSECONDS_NOT_SUPPORTED.to_string())) +} + +/// Accepts a string in ISO8601 standard format and some +/// variants and converts it to nanoseconds since midnight. +/// +/// Examples of accepted inputs: +/// * `09:26:56.123 AM` +/// * `23:59:59` +/// * `6:00 pm` +// +/// Internally, this function uses the `chrono` library for the +/// time parsing +/// +/// ## Timezone / Offset Handling +/// +/// This function does not support parsing strings with a timezone +/// or offset specified, as it considers only time since midnight. +pub fn string_to_time_nanoseconds(s: &str) -> Result { + let nt = string_to_time(s) + .ok_or_else(|| ArrowError::ParseError(format!("Failed to parse \'{s}\' as time")))?; + Ok(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) +} + +fn string_to_time(s: &str) -> Option { + let bytes = s.as_bytes(); + if bytes.len() < 4 { + return None; + } + + let (am, bytes) = match bytes.get(bytes.len() - 3..) { + Some(b" AM" | b" am" | b" Am" | b" aM") => (Some(true), &bytes[..bytes.len() - 3]), + Some(b" PM" | b" pm" | b" pM" | b" Pm") => (Some(false), &bytes[..bytes.len() - 3]), + _ => (None, bytes), + }; + + if bytes.len() < 4 { + return None; + } + + let mut digits = [b'0'; 6]; + + // Extract hour + let bytes = match (bytes[1], bytes[2]) { + (b':', _) => { + digits[1] = bytes[0]; + &bytes[2..] + } + (_, b':') => { + digits[0] = bytes[0]; + digits[1] = bytes[1]; + &bytes[3..] + } + _ => return None, + }; + + if bytes.len() < 2 { + return None; // Minutes required + } + + // Extract minutes + digits[2] = bytes[0]; + digits[3] = bytes[1]; + + let nanoseconds = match bytes.get(2) { + Some(b':') => { + if bytes.len() < 5 { + return None; + } + + // Extract seconds + digits[4] = bytes[3]; + digits[5] = bytes[4]; + + // Extract sub-seconds if any + match bytes.get(5) { + Some(b'.') => { + let decimal = &bytes[6..]; + if decimal.iter().any(|x| !x.is_ascii_digit()) { + return None; + } + match decimal.len() { + 0 => return None, + 1 => parse_nanos::<1, b'0'>(decimal), + 2 => parse_nanos::<2, b'0'>(decimal), + 3 => parse_nanos::<3, b'0'>(decimal), + 4 => parse_nanos::<4, b'0'>(decimal), + 5 => parse_nanos::<5, b'0'>(decimal), + 6 => parse_nanos::<6, b'0'>(decimal), + 7 => parse_nanos::<7, b'0'>(decimal), + 8 => parse_nanos::<8, b'0'>(decimal), + _ => parse_nanos::<9, b'0'>(decimal), + } + } + Some(_) => return None, + None => 0, + } + } + Some(_) => return None, + None => 0, + }; + + digits.iter_mut().for_each(|x| *x = x.wrapping_sub(b'0')); + if digits.iter().any(|x| *x > 9) { + return None; + } + + let hour = match (digits[0] * 10 + digits[1], am) { + (12, Some(true)) => 0, // 12:00 AM -> 00:00 + (h @ 1..=11, Some(true)) => h, // 1:00 AM -> 01:00 + (12, Some(false)) => 12, // 12:00 PM -> 12:00 + (h @ 1..=11, Some(false)) => h + 12, // 1:00 PM -> 13:00 + (_, Some(_)) => return None, + (h, None) => h, + }; + + // Handle leap second + let (second, nanoseconds) = match digits[4] * 10 + digits[5] { + 60 => (59, nanoseconds + 1_000_000_000), + s => (s, nanoseconds), + }; + + NaiveTime::from_hms_nano_opt( + hour as _, + (digits[2] * 10 + digits[3]) as _, + second as _, + nanoseconds, + ) +} + +/// Specialized parsing implementations +/// used by csv and json reader +pub trait Parser: ArrowPrimitiveType { + fn parse(string: &str) -> Option; + + fn parse_formatted(string: &str, _format: &str) -> Option { + Self::parse(string) + } +} + +impl Parser for Float16Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()) + .ok() + .map(f16::from_f32) + } +} + +impl Parser for Float32Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +impl Parser for Float64Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +macro_rules! parser_primitive { + ($t:ty) => { + impl Parser for $t { + fn parse(string: &str) -> Option { + lexical_core::parse::(string.as_bytes()).ok() + } + } + }; +} +parser_primitive!(UInt64Type); +parser_primitive!(UInt32Type); +parser_primitive!(UInt16Type); +parser_primitive!(UInt8Type); +parser_primitive!(Int64Type); +parser_primitive!(Int32Type); +parser_primitive!(Int16Type); +parser_primitive!(Int8Type); + +impl Parser for TimestampNanosecondType { + fn parse(string: &str) -> Option { + string_to_timestamp_nanos(string).ok() + } +} + +impl Parser for TimestampMicrosecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1000) + } +} + +impl Parser for TimestampMillisecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000) + } +} + +impl Parser for TimestampSecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000_000) + } +} + +impl Parser for Time64NanosecondType { + // Will truncate any fractions of a nanosecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i64 * 1_000_000_000 + nt.nanosecond() as i64) + } +} + +impl Parser for Time64MicrosecondType { + // Will truncate any fractions of a microsecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| nanos / 1_000) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i64 * 1_000_000 + nt.nanosecond() as i64 / 1_000) + } +} + +impl Parser for Time32MillisecondType { + // Will truncate any fractions of a millisecond + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| (nanos / 1_000_000) as i32) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i32 * 1_000 + nt.nanosecond() as i32 / 1_000_000) + } +} + +impl Parser for Time32SecondType { + // Will truncate any fractions of a second + fn parse(string: &str) -> Option { + string_to_time_nanoseconds(string) + .ok() + .map(|nanos| (nanos / 1_000_000_000) as i32) + .or_else(|| string.parse::().ok()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let nt = NaiveTime::parse_from_str(string, format).ok()?; + Some(nt.num_seconds_from_midnight() as i32 + nt.nanosecond() as i32 / 1_000_000_000) + } +} + +/// Number of days between 0001-01-01 and 1970-01-01 +const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// Error message if nanosecond conversion request beyond supported interval +const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; + +fn parse_date(string: &str) -> Option { + if string.len() > 10 { + return None; + } + let mut digits = [0; 10]; + let mut mask = 0; + + // Treating all bytes the same way, helps LLVM vectorise this correctly + for (idx, (o, i)) in digits.iter_mut().zip(string.bytes()).enumerate() { + *o = i.wrapping_sub(b'0'); + mask |= ((*o < 10) as u16) << idx + } + + const HYPHEN: u8 = b'-'.wrapping_sub(b'0'); + + // refer to https://www.rfc-editor.org/rfc/rfc3339#section-3 + if digits[4] != HYPHEN { + let (year, month, day) = match (mask, string.len()) { + (0b11111111, 8) => ( + digits[0] as u16 * 1000 + + digits[1] as u16 * 100 + + digits[2] as u16 * 10 + + digits[3] as u16, + digits[4] * 10 + digits[5], + digits[6] * 10 + digits[7], + ), + _ => return None, + }; + return NaiveDate::from_ymd_opt(year as _, month as _, day as _); + } + + let (month, day) = match mask { + 0b1101101111 => { + if digits[7] != HYPHEN { + return None; + } + (digits[5] * 10 + digits[6], digits[8] * 10 + digits[9]) + } + 0b101101111 => { + if digits[7] != HYPHEN { + return None; + } + (digits[5] * 10 + digits[6], digits[8]) + } + 0b110101111 => { + if digits[6] != HYPHEN { + return None; + } + (digits[5], digits[7] * 10 + digits[8]) + } + 0b10101111 => { + if digits[6] != HYPHEN { + return None; + } + (digits[5], digits[7]) + } + _ => return None, + }; + + let year = + digits[0] as u16 * 1000 + digits[1] as u16 * 100 + digits[2] as u16 * 10 + digits[3] as u16; + + NaiveDate::from_ymd_opt(year as _, month as _, day as _) +} + +impl Parser for Date32Type { + fn parse(string: &str) -> Option { + let date = parse_date(string)?; + Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + let date = NaiveDate::parse_from_str(string, format).ok()?; + Some(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } +} + +impl Parser for Date64Type { + fn parse(string: &str) -> Option { + if string.len() <= 10 { + let date = parse_date(string)?; + Some(NaiveDateTime::new(date, NaiveTime::default()).timestamp_millis()) + } else { + let date_time = string_to_datetime(&Utc, string).ok()?; + Some(date_time.timestamp_millis()) + } + } + + fn parse_formatted(string: &str, format: &str) -> Option { + use chrono::format::Fixed; + use chrono::format::StrftimeItems; + let fmt = StrftimeItems::new(format); + let has_zone = fmt.into_iter().any(|item| match item { + chrono::format::Item::Fixed(fixed_item) => matches!( + fixed_item, + Fixed::RFC2822 + | Fixed::RFC3339 + | Fixed::TimezoneName + | Fixed::TimezoneOffsetColon + | Fixed::TimezoneOffsetColonZ + | Fixed::TimezoneOffset + | Fixed::TimezoneOffsetZ + ), + _ => false, + }); + if has_zone { + let date_time = chrono::DateTime::parse_from_str(string, format).ok()?; + Some(date_time.timestamp_millis()) + } else { + let date_time = NaiveDateTime::parse_from_str(string, format).ok()?; + Some(date_time.timestamp_millis()) + } + } +} + +/// Parse the string format decimal value to i128/i256 format and checking the precision and scale. +/// The result value can't be out of bounds. +pub fn parse_decimal( + s: &str, + precision: u8, + scale: i8, +) -> Result { + let mut result = T::Native::usize_as(0); + let mut fractionals = 0; + let mut digits = 0; + let base = T::Native::usize_as(10); + + let bs = s.as_bytes(); + let (bs, negative) = match bs.first() { + Some(b'-') => (&bs[1..], true), + Some(b'+') => (&bs[1..], false), + _ => (bs, false), + }; + + if bs.is_empty() { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + + let mut bs = bs.iter(); + // Overflow checks are not required if 10^(precision - 1) <= T::MAX holds. + // Thus, if we validate the precision correctly, we can skip overflow checks. + while let Some(b) = bs.next() { + match b { + b'0'..=b'9' => { + if digits == 0 && *b == b'0' { + // Ignore leading zeros. + continue; + } + digits += 1; + result = result.mul_wrapping(base); + result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + } + b'.' => { + for b in bs.by_ref() { + if !b.is_ascii_digit() { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + if fractionals == scale { + // We have processed all the digits that we need. All that + // is left is to validate that the rest of the string contains + // valid digits. + continue; + } + fractionals += 1; + digits += 1; + result = result.mul_wrapping(base); + result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); + } + + // Fail on "." + if digits == 0 { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + } + } + + if fractionals < scale { + let exp = scale - fractionals; + if exp as u8 + digits > precision { + return Err(ArrowError::ParseError("parse decimal overflow".to_string())); + } + let mul = base.pow_wrapping(exp as _); + result = result.mul_wrapping(mul); + } else if digits > precision { + return Err(ArrowError::ParseError("parse decimal overflow".to_string())); + } + + Ok(if negative { + result.neg_wrapping() + } else { + result + }) +} + +pub fn parse_interval_year_month( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Year); + let interval = Interval::parse(value, &config)?; + + let months = interval.to_year_months().map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast {value} to IntervalYearMonth. Only year and month fields are allowed." + )) + })?; + + Ok(IntervalYearMonthType::make_value(0, months)) +} + +pub fn parse_interval_day_time( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Day); + let interval = Interval::parse(value, &config)?; + + let (days, millis) = interval.to_day_time().map_err(|_| ArrowError::CastError(format!( + "Cannot cast {value} to IntervalDayTime because the nanos part isn't multiple of milliseconds" + )))?; + + Ok(IntervalDayTimeType::make_value(days, millis)) +} + +pub fn parse_interval_month_day_nano( + value: &str, +) -> Result<::Native, ArrowError> { + let config = IntervalParseConfig::new(IntervalUnit::Month); + let interval = Interval::parse(value, &config)?; + + let (months, days, nanos) = interval.to_month_day_nanos(); + + Ok(IntervalMonthDayNanoType::make_value(months, days, nanos)) +} + +const NANOS_PER_MILLIS: i64 = 1_000_000; +const NANOS_PER_SECOND: i64 = 1_000 * NANOS_PER_MILLIS; +const NANOS_PER_MINUTE: i64 = 60 * NANOS_PER_SECOND; +const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MINUTE; +#[cfg(test)] +const NANOS_PER_DAY: i64 = 24 * NANOS_PER_HOUR; + +#[rustfmt::skip] +#[derive(Clone, Copy)] +#[repr(u16)] +enum IntervalUnit { + Century = 0b_0000_0000_0001, + Decade = 0b_0000_0000_0010, + Year = 0b_0000_0000_0100, + Month = 0b_0000_0000_1000, + Week = 0b_0000_0001_0000, + Day = 0b_0000_0010_0000, + Hour = 0b_0000_0100_0000, + Minute = 0b_0000_1000_0000, + Second = 0b_0001_0000_0000, + Millisecond = 0b_0010_0000_0000, + Microsecond = 0b_0100_0000_0000, + Nanosecond = 0b_1000_0000_0000, +} + +impl FromStr for IntervalUnit { + type Err = ArrowError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "century" | "centuries" => Ok(Self::Century), + "decade" | "decades" => Ok(Self::Decade), + "year" | "years" => Ok(Self::Year), + "month" | "months" => Ok(Self::Month), + "week" | "weeks" => Ok(Self::Week), + "day" | "days" => Ok(Self::Day), + "hour" | "hours" => Ok(Self::Hour), + "minute" | "minutes" => Ok(Self::Minute), + "second" | "seconds" => Ok(Self::Second), + "millisecond" | "milliseconds" => Ok(Self::Millisecond), + "microsecond" | "microseconds" => Ok(Self::Microsecond), + "nanosecond" | "nanoseconds" => Ok(Self::Nanosecond), + _ => Err(ArrowError::NotYetImplemented(format!( + "Unknown interval type: {s}" + ))), + } + } +} + +pub type MonthDayNano = (i32, i32, i64); + +/// Chosen based on the number of decimal digits in 1 week in nanoseconds +const INTERVAL_PRECISION: u32 = 15; + +#[derive(Clone, Copy, Debug, PartialEq)] +struct IntervalAmount { + /// The integer component of the interval amount + integer: i64, + /// The fractional component multiplied by 10^INTERVAL_PRECISION + frac: i64, +} + +#[cfg(test)] +impl IntervalAmount { + fn new(integer: i64, frac: i64) -> Self { + Self { integer, frac } + } +} + +impl FromStr for IntervalAmount { + type Err = ArrowError; + + fn from_str(s: &str) -> Result { + match s.split_once('.') { + Some((integer, frac)) + if frac.len() <= INTERVAL_PRECISION as usize + && !frac.is_empty() + && !frac.starts_with('-') => + { + // integer will be "" for values like ".5" + // and "-" for values like "-.5" + let explicit_neg = integer.starts_with('-'); + let integer = if integer.is_empty() || integer == "-" { + Ok(0) + } else { + integer.parse::().map_err(|_| { + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) + }) + }?; + + let frac_unscaled = frac.parse::().map_err(|_| { + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) + })?; + + // scale fractional part by interval precision + let frac = frac_unscaled * 10_i64.pow(INTERVAL_PRECISION - frac.len() as u32); + + // propagate the sign of the integer part to the fractional part + let frac = if integer < 0 || explicit_neg { + -frac + } else { + frac + }; + + let result = Self { integer, frac }; + + Ok(result) + } + Some((_, frac)) if frac.starts_with('-') => Err(ArrowError::ParseError(format!( + "Failed to parse {s} as interval amount" + ))), + Some((_, frac)) if frac.len() > INTERVAL_PRECISION as usize => { + Err(ArrowError::ParseError(format!( + "{s} exceeds the precision available for interval amount" + ))) + } + Some(_) | None => { + let integer = s.parse::().map_err(|_| { + ArrowError::ParseError(format!("Failed to parse {s} as interval amount")) + })?; + + let result = Self { integer, frac: 0 }; + Ok(result) + } + } + } +} + +#[derive(Debug, Default, PartialEq)] +struct Interval { + months: i32, + days: i32, + nanos: i64, +} + +impl Interval { + fn new(months: i32, days: i32, nanos: i64) -> Self { + Self { + months, + days, + nanos, + } + } + + fn to_year_months(&self) -> Result { + match (self.months, self.days, self.nanos) { + (months, days, nanos) if days == 0 && nanos == 0 => Ok(months), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unable to represent interval with days and nanos as year-months: {:?}", + self + ))), + } + } + + fn to_day_time(&self) -> Result<(i32, i32), ArrowError> { + let days = self.months.mul_checked(30)?.add_checked(self.days)?; + + match self.nanos { + nanos if nanos % NANOS_PER_MILLIS == 0 => { + let millis = (self.nanos / 1_000_000).try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Unable to represent {} nanos as milliseconds in a signed 32-bit integer", + self.nanos + )) + })?; + + Ok((days, millis)) + } + nanos => Err(ArrowError::InvalidArgumentError(format!( + "Unable to represent {nanos} as milliseconds" + ))), + } + } + + fn to_month_day_nanos(&self) -> (i32, i32, i64) { + (self.months, self.days, self.nanos) + } + + /// Parse string value in traditional Postgres format such as + /// `1 year 2 months 3 days 4 hours 5 minutes 6 seconds` + fn parse(value: &str, config: &IntervalParseConfig) -> Result { + let components = parse_interval_components(value, config)?; + + components + .into_iter() + .try_fold(Self::default(), |result, (amount, unit)| { + result.add(amount, unit) + }) + } + + /// Interval addition following Postgres behavior. Fractional units will be spilled into smaller units. + /// When the interval unit is larger than months, the result is rounded to total months and not spilled to days/nanos. + /// Fractional parts of weeks and days are represented using days and nanoseconds. + /// e.g. INTERVAL '0.5 MONTH' = 15 days, INTERVAL '1.5 MONTH' = 1 month 15 days + /// e.g. INTERVAL '0.5 DAY' = 12 hours, INTERVAL '1.5 DAY' = 1 day 12 hours + /// [Postgres reference](https://www.postgresql.org/docs/15/datatype-datetime.html#DATATYPE-INTERVAL-INPUT:~:text=Field%20values%20can,fractional%20on%20output.) + fn add(&self, amount: IntervalAmount, unit: IntervalUnit) -> Result { + let result = match unit { + IntervalUnit::Century => { + let months_int = amount.integer.mul_checked(100)?.mul_checked(12)?; + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 2); + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} centuries as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Decade => { + let months_int = amount.integer.mul_checked(10)?.mul_checked(12)?; + + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION - 1); + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} decades as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Year => { + let months_int = amount.integer.mul_checked(12)?; + let month_frac = amount.frac * 12 / 10_i64.pow(INTERVAL_PRECISION); + let months = months_int + .add_checked(month_frac)? + .try_into() + .map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} years as months in a signed 32-bit integer", + &amount.integer + )) + })?; + + Self::new(self.months.add_checked(months)?, self.days, self.nanos) + } + IntervalUnit::Month => { + let months = amount.integer.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} months in a signed 32-bit integer", + &amount.integer + )) + })?; + + let days = amount.frac * 3 / 10_i64.pow(INTERVAL_PRECISION - 1); + let days = days.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} months as days in a signed 32-bit integer", + amount.frac / 10_i64.pow(INTERVAL_PRECISION) + )) + })?; + + Self::new( + self.months.add_checked(months)?, + self.days.add_checked(days)?, + self.nanos, + ) + } + IntervalUnit::Week => { + let days = amount.integer.mul_checked(7)?.try_into().map_err(|_| { + ArrowError::ParseError(format!( + "Unable to represent {} weeks as days in a signed 32-bit integer", + &amount.integer + )) + })?; + + let nanos = amount.frac * 7 * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + + Self::new( + self.months, + self.days.add_checked(days)?, + self.nanos.add_checked(nanos)?, + ) + } + IntervalUnit::Day => { + let days = amount.integer.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Unable to represent {} days in a signed 32-bit integer", + amount.integer + )) + })?; + + let nanos = amount.frac * 24 * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + + Self::new( + self.months, + self.days.add_checked(days)?, + self.nanos.add_checked(nanos)?, + ) + } + IntervalUnit::Hour => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_HOUR)?; + let nanos_frac = amount.frac * 6 * 6 / 10_i64.pow(INTERVAL_PRECISION - 11); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Minute => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_MINUTE)?; + let nanos_frac = amount.frac * 6 / 10_i64.pow(INTERVAL_PRECISION - 10); + + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Second => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_SECOND)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 9); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Millisecond => { + let nanos_int = amount.integer.mul_checked(NANOS_PER_MILLIS)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 6); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Microsecond => { + let nanos_int = amount.integer.mul_checked(1_000)?; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION - 3); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + IntervalUnit::Nanosecond => { + let nanos_int = amount.integer; + let nanos_frac = amount.frac / 10_i64.pow(INTERVAL_PRECISION); + let nanos = nanos_int.add_checked(nanos_frac)?; + + Interval::new(self.months, self.days, self.nanos.add_checked(nanos)?) + } + }; + + Ok(result) + } +} + +struct IntervalParseConfig { + /// The default unit to use if none is specified + /// e.g. `INTERVAL 1` represents `INTERVAL 1 SECOND` when default_unit = IntervalType::Second + default_unit: IntervalUnit, +} + +impl IntervalParseConfig { + fn new(default_unit: IntervalUnit) -> Self { + Self { default_unit } + } +} + +/// parse the string into a vector of interval components i.e. (amount, unit) tuples +fn parse_interval_components( + value: &str, + config: &IntervalParseConfig, +) -> Result, ArrowError> { + let parts = value.split_whitespace(); + + let raw_amounts = parts.clone().step_by(2); + let raw_units = parts.skip(1).step_by(2); + + // parse amounts + let (amounts, invalid_amounts) = raw_amounts + .map(IntervalAmount::from_str) + .partition::, _>(Result::is_ok); + + // invalid amounts? + if !invalid_amounts.is_empty() { + return Err(ArrowError::NotYetImplemented(format!( + "Unsupported Interval Expression with value {value:?}" + ))); + } + + // parse units + let (units, invalid_units): (Vec<_>, Vec<_>) = raw_units + .clone() + .map(IntervalUnit::from_str) + .partition(Result::is_ok); + + // invalid units? + if !invalid_units.is_empty() { + return Err(ArrowError::ParseError(format!( + "Invalid input syntax for type interval: {value:?}" + ))); + } + + // collect parsed results + let amounts = amounts.into_iter().map(Result::unwrap).collect::>(); + let units = units.into_iter().map(Result::unwrap).collect::>(); + + // if only an amount is specified, use the default unit + if amounts.len() == 1 && units.is_empty() { + return Ok(vec![(amounts[0], config.default_unit)]); + }; + + // duplicate units? + let mut observed_interval_types = 0; + for (unit, raw_unit) in units.iter().zip(raw_units) { + if observed_interval_types & (*unit as u16) != 0 { + return Err(ArrowError::ParseError(format!( + "Invalid input syntax for type interval: {value:?}. Repeated type '{raw_unit}'", + ))); + } + + observed_interval_types |= *unit as u16; + } + + let result = amounts.iter().copied().zip(units.iter().copied()); + + Ok(result.collect::>()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::temporal_conversions::date32_to_datetime; + use arrow_array::timezone::Tz; + use arrow_buffer::i256; + + #[test] + fn test_parse_nanos() { + assert_eq!(parse_nanos::<3, 0>(&[1, 2, 3]), 123_000_000); + assert_eq!(parse_nanos::<5, 0>(&[1, 2, 3, 4, 5]), 123_450_000); + assert_eq!(parse_nanos::<6, b'0'>(b"123456"), 123_456_000); + } + + #[test] + fn string_to_timestamp_timezone() { + // Explicit timezone + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08T13:42:29.190855+00:00").unwrap() + ); + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08T13:42:29.190855Z").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp("2020-09-08T13:42:29Z").unwrap() + ); // no fractional part + assert_eq!( + 1599590549190855000, + parse_timestamp("2020-09-08T13:42:29.190855-05:00").unwrap() + ); + } + + #[test] + fn string_to_timestamp_timezone_space() { + // Ensure space rather than T between time and date is accepted + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08 13:42:29.190855+00:00").unwrap() + ); + assert_eq!( + 1599572549190855000, + parse_timestamp("2020-09-08 13:42:29.190855Z").unwrap() + ); + assert_eq!( + 1599572549000000000, + parse_timestamp("2020-09-08 13:42:29Z").unwrap() + ); // no fractional part + assert_eq!( + 1599590549190855000, + parse_timestamp("2020-09-08 13:42:29.190855-05:00").unwrap() + ); + } + + #[test] + #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function: mktime + fn string_to_timestamp_no_timezone() { + // This test is designed to succeed in regardless of the local + // timezone the test machine is running. Thus it is still + // somewhat susceptible to bugs in the use of chrono + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29.190855").unwrap() + ); + + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29.190855").unwrap() + ); + + // Also ensure that parsing timestamps with no fractional + // second part works as well + let naive_datetime_whole_secs = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_opt(13, 42, 29).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime_whole_secs.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29").unwrap() + ); + + assert_eq!( + naive_datetime_whole_secs.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29").unwrap() + ); + + // ensure without time work + // no time, should be the nano second at + // 2020-09-08 0:0:0 + let naive_datetime_no_time = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_opt(0, 0, 0).unwrap(), + ); + + assert_eq!( + naive_datetime_no_time.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08").unwrap() + ) + } + + #[test] + fn string_to_timestamp_chrono() { + let cases = [ + "2020-09-08T13:42:29Z", + "1969-01-01T00:00:00.1Z", + "2020-09-08T12:00:12.12345678+00:00", + "2020-09-08T12:00:12+00:00", + "2020-09-08T12:00:12.1+00:00", + "2020-09-08T12:00:12.12+00:00", + "2020-09-08T12:00:12.123+00:00", + "2020-09-08T12:00:12.1234+00:00", + "2020-09-08T12:00:12.12345+00:00", + "2020-09-08T12:00:12.123456+00:00", + "2020-09-08T12:00:12.1234567+00:00", + "2020-09-08T12:00:12.12345678+00:00", + "2020-09-08T12:00:12.123456789+00:00", + "2020-09-08T12:00:12.12345678912z", + "2020-09-08T12:00:12.123456789123Z", + "2020-09-08T12:00:12.123456789123+02:00", + "2020-09-08T12:00:12.12345678912345Z", + "2020-09-08T12:00:12.1234567891234567+02:00", + "2020-09-08T12:00:60Z", + "2020-09-08T12:00:60.123Z", + "2020-09-08T12:00:60.123456+02:00", + "2020-09-08T12:00:60.1234567891234567+02:00", + "2020-09-08T12:00:60.999999999+02:00", + "2020-09-08t12:00:12.12345678+00:00", + "2020-09-08t12:00:12+00:00", + "2020-09-08t12:00:12Z", + ]; + + for case in cases { + let chrono = DateTime::parse_from_rfc3339(case).unwrap(); + let chrono_utc = chrono.with_timezone(&Utc); + + let custom = string_to_datetime(&Utc, case).unwrap(); + assert_eq!(chrono_utc, custom) + } + } + + #[test] + fn string_to_timestamp_naive() { + let cases = [ + "2018-11-13T17:11:10.011375885995", + "2030-12-04T17:11:10.123", + "2030-12-04T17:11:10.1234", + "2030-12-04T17:11:10.123456", + ]; + for case in cases { + let chrono = NaiveDateTime::parse_from_str(case, "%Y-%m-%dT%H:%M:%S%.f").unwrap(); + let custom = string_to_datetime(&Utc, case).unwrap(); + assert_eq!(chrono, custom.naive_utc()) + } + } + + #[test] + fn string_to_timestamp_invalid() { + // Test parsing invalid formats + let cases = [ + ("", "timestamp must contain at least 10 characters"), + ("SS", "timestamp must contain at least 10 characters"), + ("Wed, 18 Feb 2015 23:16:09 GMT", "error parsing date"), + ("1997-01-31H09:26:56.123Z", "invalid timestamp separator"), + ("1997-01-31 09:26:56.123Z", "error parsing time"), + ("1997:01:31T09:26:56.123Z", "error parsing date"), + ("1997:1:31T09:26:56.123Z", "error parsing date"), + ("1997-01-32T09:26:56.123Z", "error parsing date"), + ("1997-13-32T09:26:56.123Z", "error parsing date"), + ("1997-02-29T09:26:56.123Z", "error parsing date"), + ("2015-02-30T17:35:20-08:00", "error parsing date"), + ("1997-01-10T9:26:56.123Z", "error parsing time"), + ("2015-01-20T25:35:20-08:00", "error parsing time"), + ("1997-01-10T09:61:56.123Z", "error parsing time"), + ("1997-01-10T09:61:90.123Z", "error parsing time"), + ("1997-01-10T12:00:6.123Z", "error parsing time"), + ("1997-01-31T092656.123Z", "error parsing time"), + ("1997-01-10T12:00:06.", "error parsing time"), + ("1997-01-10T12:00:06. ", "error parsing time"), + ]; + + for (s, ctx) in cases { + let expected = format!("Parser error: Error parsing timestamp from '{s}': {ctx}"); + let actual = string_to_datetime(&Utc, s).unwrap_err().to_string(); + assert_eq!(actual, expected) + } + } + + // Parse a timestamp to timestamp int with a useful human readable error message + fn parse_timestamp(s: &str) -> Result { + let result = string_to_timestamp_nanos(s); + if let Err(e) = &result { + eprintln!("Error parsing timestamp '{s}': {e:?}"); + } + result + } + + #[test] + fn string_without_timezone_to_timestamp() { + // string without timezone should always output the same regardless the local or session timezone + + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 190855000).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29.190855").unwrap() + ); + + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29.190855").unwrap() + ); + + let naive_datetime = NaiveDateTime::new( + NaiveDate::from_ymd_opt(2020, 9, 8).unwrap(), + NaiveTime::from_hms_nano_opt(13, 42, 29, 0).unwrap(), + ); + + // Ensure both T and ' ' variants work + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08T13:42:29").unwrap() + ); + + assert_eq!( + naive_datetime.timestamp_nanos_opt().unwrap(), + parse_timestamp("2020-09-08 13:42:29").unwrap() + ); + + let tz: Tz = "+02:00".parse().unwrap(); + let date = string_to_datetime(&tz, "2020-09-08 13:42:29").unwrap(); + let utc = date.naive_utc().to_string(); + assert_eq!(utc, "2020-09-08 11:42:29"); + let local = date.naive_local().to_string(); + assert_eq!(local, "2020-09-08 13:42:29"); + + let date = string_to_datetime(&tz, "2020-09-08 13:42:29Z").unwrap(); + let utc = date.naive_utc().to_string(); + assert_eq!(utc, "2020-09-08 13:42:29"); + let local = date.naive_local().to_string(); + assert_eq!(local, "2020-09-08 15:42:29"); + + let dt = + NaiveDateTime::parse_from_str("2020-09-08T13:42:29Z", "%Y-%m-%dT%H:%M:%SZ").unwrap(); + let local: Tz = "+08:00".parse().unwrap(); + + // Parsed as offset from UTC + let date = string_to_datetime(&local, "2020-09-08T13:42:29Z").unwrap(); + assert_eq!(dt, date.naive_utc()); + assert_ne!(dt, date.naive_local()); + + // Parsed as offset from local + let date = string_to_datetime(&local, "2020-09-08 13:42:29").unwrap(); + assert_eq!(dt, date.naive_local()); + assert_ne!(dt, date.naive_utc()); + } + + #[test] + fn parse_date32() { + let cases = [ + "2020-09-08", + "2020-9-8", + "2020-09-8", + "2020-9-08", + "2020-12-1", + "1690-2-5", + ]; + for case in cases { + let v = date32_to_datetime(Date32Type::parse(case).unwrap()).unwrap(); + let expected: NaiveDate = case.parse().unwrap(); + assert_eq!(v.date(), expected); + } + + let err_cases = [ + "", + "80-01-01", + "342", + "Foo", + "2020-09-08-03", + "2020--04-03", + "2020--", + ]; + for case in err_cases { + assert_eq!(Date32Type::parse(case), None); + } + } + + #[test] + fn parse_time64_nanos() { + assert_eq!( + Time64NanosecondType::parse("02:10:01.1234567899999999"), + Some(7_801_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("02:10:01.1234567"), + Some(7_801_123_456_700) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.1234567"), + Some(7_801_123_456_700) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 AM"), + Some(601_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 am"), + Some(601_123_456_789) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.12345678 PM"), + Some(51_001_123_456_780) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01.12345678 pm"), + Some(51_001_123_456_780) + ); + assert_eq!( + Time64NanosecondType::parse("02:10:01"), + Some(7_801_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01"), + Some(7_801_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01 AM"), + Some(601_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10:01 am"), + Some(601_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01 PM"), + Some(51_001_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10:01 pm"), + Some(51_001_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("02:10"), + Some(7_800_000_000_000) + ); + assert_eq!(Time64NanosecondType::parse("2:10"), Some(7_800_000_000_000)); + assert_eq!( + Time64NanosecondType::parse("12:10 AM"), + Some(600_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("12:10 am"), + Some(600_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10 PM"), + Some(51_000_000_000_000) + ); + assert_eq!( + Time64NanosecondType::parse("2:10 pm"), + Some(51_000_000_000_000) + ); + + // parse directly as nanoseconds + assert_eq!(Time64NanosecondType::parse("1"), Some(1)); + + // leap second + assert_eq!( + Time64NanosecondType::parse("23:59:60"), + Some(86_400_000_000_000) + ); + + // custom format + assert_eq!( + Time64NanosecondType::parse_formatted("02 - 10 - 01 - .1234567", "%H - %M - %S - %.f"), + Some(7_801_123_456_700) + ); + } + + #[test] + fn parse_time64_micros() { + // expected formats + assert_eq!( + Time64MicrosecondType::parse("02:10:01.1234"), + Some(7_801_123_400) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.1234"), + Some(7_801_123_400) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 AM"), + Some(601_123_456) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 am"), + Some(601_123_456) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.12345 PM"), + Some(51_001_123_450) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01.12345 pm"), + Some(51_001_123_450) + ); + assert_eq!( + Time64MicrosecondType::parse("02:10:01"), + Some(7_801_000_000) + ); + assert_eq!(Time64MicrosecondType::parse("2:10:01"), Some(7_801_000_000)); + assert_eq!( + Time64MicrosecondType::parse("12:10:01 AM"), + Some(601_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01 am"), + Some(601_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01 PM"), + Some(51_001_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10:01 pm"), + Some(51_001_000_000) + ); + assert_eq!(Time64MicrosecondType::parse("02:10"), Some(7_800_000_000)); + assert_eq!(Time64MicrosecondType::parse("2:10"), Some(7_800_000_000)); + assert_eq!(Time64MicrosecondType::parse("12:10 AM"), Some(600_000_000)); + assert_eq!(Time64MicrosecondType::parse("12:10 am"), Some(600_000_000)); + assert_eq!( + Time64MicrosecondType::parse("2:10 PM"), + Some(51_000_000_000) + ); + assert_eq!( + Time64MicrosecondType::parse("2:10 pm"), + Some(51_000_000_000) + ); + + // parse directly as microseconds + assert_eq!(Time64MicrosecondType::parse("1"), Some(1)); + + // leap second + assert_eq!( + Time64MicrosecondType::parse("23:59:60"), + Some(86_400_000_000) + ); + + // custom format + assert_eq!( + Time64MicrosecondType::parse_formatted("02 - 10 - 01 - .1234", "%H - %M - %S - %.f"), + Some(7_801_123_400) + ); + } + + #[test] + fn parse_time32_millis() { + // expected formats + assert_eq!(Time32MillisecondType::parse("02:10:01.1"), Some(7_801_100)); + assert_eq!(Time32MillisecondType::parse("2:10:01.1"), Some(7_801_100)); + assert_eq!( + Time32MillisecondType::parse("12:10:01.123 AM"), + Some(601_123) + ); + assert_eq!( + Time32MillisecondType::parse("12:10:01.123 am"), + Some(601_123) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 PM"), + Some(51_001_120) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 pm"), + Some(51_001_120) + ); + assert_eq!(Time32MillisecondType::parse("02:10:01"), Some(7_801_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01"), Some(7_801_000)); + assert_eq!(Time32MillisecondType::parse("12:10:01 AM"), Some(601_000)); + assert_eq!(Time32MillisecondType::parse("12:10:01 am"), Some(601_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01 PM"), Some(51_001_000)); + assert_eq!(Time32MillisecondType::parse("2:10:01 pm"), Some(51_001_000)); + assert_eq!(Time32MillisecondType::parse("02:10"), Some(7_800_000)); + assert_eq!(Time32MillisecondType::parse("2:10"), Some(7_800_000)); + assert_eq!(Time32MillisecondType::parse("12:10 AM"), Some(600_000)); + assert_eq!(Time32MillisecondType::parse("12:10 am"), Some(600_000)); + assert_eq!(Time32MillisecondType::parse("2:10 PM"), Some(51_000_000)); + assert_eq!(Time32MillisecondType::parse("2:10 pm"), Some(51_000_000)); + + // parse directly as milliseconds + assert_eq!(Time32MillisecondType::parse("1"), Some(1)); + + // leap second + assert_eq!(Time32MillisecondType::parse("23:59:60"), Some(86_400_000)); + + // custom format + assert_eq!( + Time32MillisecondType::parse_formatted("02 - 10 - 01 - .1", "%H - %M - %S - %.f"), + Some(7_801_100) + ); + } + + #[test] + fn parse_time32_secs() { + // expected formats + assert_eq!(Time32SecondType::parse("02:10:01.1"), Some(7_801)); + assert_eq!(Time32SecondType::parse("02:10:01"), Some(7_801)); + assert_eq!(Time32SecondType::parse("2:10:01"), Some(7_801)); + assert_eq!(Time32SecondType::parse("12:10:01 AM"), Some(601)); + assert_eq!(Time32SecondType::parse("12:10:01 am"), Some(601)); + assert_eq!(Time32SecondType::parse("2:10:01 PM"), Some(51_001)); + assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001)); + assert_eq!(Time32SecondType::parse("02:10"), Some(7_800)); + assert_eq!(Time32SecondType::parse("2:10"), Some(7_800)); + assert_eq!(Time32SecondType::parse("12:10 AM"), Some(600)); + assert_eq!(Time32SecondType::parse("12:10 am"), Some(600)); + assert_eq!(Time32SecondType::parse("2:10 PM"), Some(51_000)); + assert_eq!(Time32SecondType::parse("2:10 pm"), Some(51_000)); + + // parse directly as seconds + assert_eq!(Time32SecondType::parse("1"), Some(1)); + + // leap second + assert_eq!(Time32SecondType::parse("23:59:60"), Some(86400)); + + // custom format + assert_eq!( + Time32SecondType::parse_formatted("02 - 10 - 01", "%H - %M - %S"), + Some(7_801) + ); + } + + #[test] + fn test_string_to_time_invalid() { + let cases = [ + "25:00", + "9:00:", + "009:00", + "09:0:00", + "25:00:00", + "13:00 AM", + "13:00 PM", + "12:00. AM", + "09:0:00", + "09:01:0", + "09:01:1", + "9:1:0", + "09:01:0", + "1:00.123", + "1:00:00.123f", + " 9:00:00", + ":09:00", + "T9:00:00", + "AM", + ]; + for case in cases { + assert!(string_to_time(case).is_none(), "{case}"); + } + } + + #[test] + fn test_string_to_time_chrono() { + let cases = [ + ("1:00", "%H:%M"), + ("12:00", "%H:%M"), + ("13:00", "%H:%M"), + ("24:00", "%H:%M"), + ("1:00:00", "%H:%M:%S"), + ("12:00:30", "%H:%M:%S"), + ("13:00:59", "%H:%M:%S"), + ("24:00:60", "%H:%M:%S"), + ("09:00:00", "%H:%M:%S%.f"), + ("0:00:30.123456", "%H:%M:%S%.f"), + ("0:00 AM", "%I:%M %P"), + ("1:00 AM", "%I:%M %P"), + ("12:00 AM", "%I:%M %P"), + ("13:00 AM", "%I:%M %P"), + ("0:00 PM", "%I:%M %P"), + ("1:00 PM", "%I:%M %P"), + ("12:00 PM", "%I:%M %P"), + ("13:00 PM", "%I:%M %P"), + ("1:00 pM", "%I:%M %P"), + ("1:00 Pm", "%I:%M %P"), + ("1:00 aM", "%I:%M %P"), + ("1:00 Am", "%I:%M %P"), + ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789123 PM", "%I:%M:%S%.f %P"), + ("1:00:30.1234 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456 PM", "%I:%M:%S%.f %P"), + ("1:00:30.123456789123456789 PM", "%I:%M:%S%.f %P"), + ("1:00:30.12F456 PM", "%I:%M:%S%.f %P"), + ]; + for (s, format) in cases { + let chrono = NaiveTime::parse_from_str(s, format).ok(); + let custom = string_to_time(s); + assert_eq!(chrono, custom, "{s}"); + } + } + + #[test] + fn test_parse_interval() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + assert_eq!( + Interval::new(1i32, 0i32, 0i64), + Interval::parse("1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(2i32, 0i32, 0i64), + Interval::parse("2 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-1i32, -18i32, -(NANOS_PER_DAY / 5)), + Interval::parse("-1.5 months -3.2 days", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 15i32, 0), + Interval::parse("0.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 15i32, 0), + Interval::parse(".5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -15i32, 0), + Interval::parse("-0.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -15i32, 0), + Interval::parse("-.5 months", &config).unwrap(), + ); + + assert_eq!( + Interval::new(2i32, 10i32, 9 * NANOS_PER_HOUR), + Interval::parse("2.1 months 7.25 days 3 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::parse("1 centurys 1 month", &config) + .unwrap_err() + .to_string(), + r#"Parser error: Invalid input syntax for type interval: "1 centurys 1 month""# + ); + + assert_eq!( + Interval::new(37i32, 0i32, 0i64), + Interval::parse("3 year 1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(35i32, 0i32, 0i64), + Interval::parse("3 year -1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-37i32, 0i32, 0i64), + Interval::parse("-3 year -1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(-35i32, 0i32, 0i64), + Interval::parse("-3 year 1 month", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 5i32, 0i64), + Interval::parse("5 days", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, 3 * NANOS_PER_HOUR), + Interval::parse("7 days 3 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, 5 * NANOS_PER_MINUTE), + Interval::parse("7 days 5 minutes", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, 7i32, -5 * NANOS_PER_MINUTE), + Interval::parse("7 days -5 minutes", &config).unwrap(), + ); + + assert_eq!( + Interval::new(0i32, -7i32, 5 * NANOS_PER_HOUR), + Interval::parse("-7 days 5 hours", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + 0i32, + -7i32, + -5 * NANOS_PER_HOUR - 5 * NANOS_PER_MINUTE - 5 * NANOS_PER_SECOND + ), + Interval::parse("-7 days -5 hours -5 minutes -5 seconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 0i32, 25 * NANOS_PER_MILLIS), + Interval::parse("1 year 25 millisecond", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + 12i32, + 1i32, + (NANOS_PER_SECOND as f64 * 0.000000001_f64) as i64 + ), + Interval::parse("1 year 1 day 0.000000001 seconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, NANOS_PER_MILLIS / 10), + Interval::parse("1 year 1 day 0.1 milliseconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, 1000i64), + Interval::parse("1 year 1 day 1 microsecond", &config).unwrap(), + ); + + assert_eq!( + Interval::new(12i32, 1i32, 1i64), + Interval::parse("1 year 1 day 1 nanoseconds", &config).unwrap(), + ); + + assert_eq!( + Interval::new(1i32, 0i32, -NANOS_PER_SECOND), + Interval::parse("1 month -1 second", &config).unwrap(), + ); + + assert_eq!( + Interval::new( + -13i32, + -8i32, + -NANOS_PER_HOUR + - NANOS_PER_MINUTE + - NANOS_PER_SECOND + - (1.11_f64 * NANOS_PER_MILLIS as f64) as i64 + ), + Interval::parse( + "-1 year -1 month -1 week -1 day -1 hour -1 minute -1 second -1.11 millisecond", + &config + ) + .unwrap(), + ); + } + + #[test] + fn test_duplicate_interval_type() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + let err = Interval::parse("1 month 1 second 1 second", &config) + .expect_err("parsing interval should have failed"); + assert_eq!( + r#"ParseError("Invalid input syntax for type interval: \"1 month 1 second 1 second\". Repeated type 'second'")"#, + format!("{err:?}") + ); + + // test with singular and plural forms + let err = Interval::parse("1 century 2 centuries", &config) + .expect_err("parsing interval should have failed"); + assert_eq!( + r#"ParseError("Invalid input syntax for type interval: \"1 century 2 centuries\". Repeated type 'centuries'")"#, + format!("{err:?}") + ); + } + + #[test] + fn test_interval_amount_parsing() { + // integer + let result = IntervalAmount::from_str("123").unwrap(); + let expected = IntervalAmount::new(123, 0); + + assert_eq!(result, expected); + + // positive w/ fractional + let result = IntervalAmount::from_str("0.3").unwrap(); + let expected = IntervalAmount::new(0, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)); + + assert_eq!(result, expected); + + // negative w/ fractional + let result = IntervalAmount::from_str("-3.5").unwrap(); + let expected = IntervalAmount::new(-3, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)); + + assert_eq!(result, expected); + + // invalid: missing fractional + let result = IntervalAmount::from_str("3."); + assert!(result.is_err()); + + // invalid: sign in fractional + let result = IntervalAmount::from_str("3.-5"); + assert!(result.is_err()); + } + + #[test] + fn test_interval_precision() { + let config = IntervalParseConfig::new(IntervalUnit::Month); + + let result = Interval::parse("100000.1 days", &config).unwrap(); + let expected = Interval::new(0_i32, 100_000_i32, NANOS_PER_DAY / 10); + + assert_eq!(result, expected); + } + + #[test] + fn test_interval_addition() { + // add 4.1 centuries + let start = Interval::new(1, 2, 3); + let expected = Interval::new(4921, 2, 3); + + let result = start + .add( + IntervalAmount::new(4, 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Century, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 10.25 decades + let start = Interval::new(1, 2, 3); + let expected = Interval::new(1231, 2, 3); + + let result = start + .add( + IntervalAmount::new(10, 25 * 10_i64.pow(INTERVAL_PRECISION - 2)), + IntervalUnit::Decade, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 30.3 years (reminder: Postgres logic does not spill to days/nanos when interval is larger than a month) + let start = Interval::new(1, 2, 3); + let expected = Interval::new(364, 2, 3); + + let result = start + .add( + IntervalAmount::new(30, 3 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Year, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 1.5 months + let start = Interval::new(1, 2, 3); + let expected = Interval::new(2, 17, 3); + + let result = start + .add( + IntervalAmount::new(1, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Month, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add -2 weeks + let start = Interval::new(1, 25, 3); + let expected = Interval::new(1, 11, 3); + + let result = start + .add(IntervalAmount::new(-2, 0), IntervalUnit::Week) + .unwrap(); + + assert_eq!(result, expected); + + // add 2.2 days + let start = Interval::new(12, 15, 3); + let expected = Interval::new(12, 17, 3 + 17_280 * NANOS_PER_SECOND); + + let result = start + .add( + IntervalAmount::new(2, 2 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Day, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add 12.5 hours + let start = Interval::new(1, 2, 3); + let expected = Interval::new(1, 2, 3 + 45_000 * NANOS_PER_SECOND); + + let result = start + .add( + IntervalAmount::new(12, 5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Hour, + ) + .unwrap(); + + assert_eq!(result, expected); + + // add -1.5 minutes + let start = Interval::new(0, 0, -3); + let expected = Interval::new(0, 0, -90_000_000_000 - 3); + + let result = start + .add( + IntervalAmount::new(-1, -5 * 10_i64.pow(INTERVAL_PRECISION - 1)), + IntervalUnit::Minute, + ) + .unwrap(); + + assert_eq!(result, expected); + } + + #[test] + fn string_to_timestamp_old() { + parse_timestamp("1677-06-14T07:29:01.256") + .map_err(|e| assert!(e.to_string().ends_with(ERR_NANOSECONDS_NOT_SUPPORTED))) + .unwrap_err(); + } + + #[test] + fn test_parse_decimal_with_parameter() { + let tests = [ + ("0", 0i128), + ("123.123", 123123i128), + ("123.1234", 123123i128), + ("123.1", 123100i128), + ("123", 123000i128), + ("-123.123", -123123i128), + ("-123.1234", -123123i128), + ("-123.1", -123100i128), + ("-123", -123000i128), + ("0.0000123", 0i128), + ("12.", 12000i128), + ("-12.", -12000i128), + ("00.1", 100i128), + ("-00.1", -100i128), + ("12345678912345678.1234", 12345678912345678123i128), + ("-12345678912345678.1234", -12345678912345678123i128), + ("99999999999999999.999", 99999999999999999999i128), + ("-99999999999999999.999", -99999999999999999999i128), + (".123", 123i128), + ("-.123", -123i128), + ("123.", 123000i128), + ("-123.", -123000i128), + ]; + for (s, i) in tests { + let result_128 = parse_decimal::(s, 20, 3); + assert_eq!(i, result_128.unwrap()); + let result_256 = parse_decimal::(s, 20, 3); + assert_eq!(i256::from_i128(i), result_256.unwrap()); + } + let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+", "-"]; + for s in can_not_parse_tests { + let result_128 = parse_decimal::(s, 20, 3); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result_128.unwrap_err().to_string() + ); + let result_256 = parse_decimal::(s, 20, 3); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result_256.unwrap_err().to_string() + ); + } + let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"]; + for s in overflow_parse_tests { + let result_128 = parse_decimal::(s, 10, 3); + let expected_128 = "Parser error: parse decimal overflow"; + let actual_128 = result_128.unwrap_err().to_string(); + + assert!( + actual_128.contains(expected_128), + "actual: '{actual_128}', expected: '{expected_128}'" + ); + + let result_256 = parse_decimal::(s, 10, 3); + let expected_256 = "Parser error: parse decimal overflow"; + let actual_256 = result_256.unwrap_err().to_string(); + + assert!( + actual_256.contains(expected_256), + "actual: '{actual_256}', expected: '{expected_256}'" + ); + } + + let edge_tests_128 = [ + ( + "99999999999999999999999999999999999999", + 99999999999999999999999999999999999999i128, + 0, + ), + ( + "999999999999999999999999999999999999.99", + 99999999999999999999999999999999999999i128, + 2, + ), + ( + "9999999999999999999999999.9999999999999", + 99999999999999999999999999999999999999i128, + 13, + ), + ( + "9999999999999999999999999", + 99999999999999999999999990000000000000i128, + 13, + ), + ( + "0.99999999999999999999999999999999999999", + 99999999999999999999999999999999999999i128, + 38, + ), + ]; + for (s, i, scale) in edge_tests_128 { + let result_128 = parse_decimal::(s, 38, scale); + assert_eq!(i, result_128.unwrap()); + } + let edge_tests_256 = [ + ( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), + 0, + ), + ( + "999999999999999999999999999999999999999999999999999999999999999999999999.9999", + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), + 4, + ), + ( + "99999999999999999999999999999999999999999999999999.99999999999999999999999999", + i256::from_string( + "9999999999999999999999999999999999999999999999999999999999999999999999999999", + ) + .unwrap(), + 26, + ), + ( + "99999999999999999999999999999999999999999999999999", + i256::from_string( + "9999999999999999999999999999999999999999999999999900000000000000000000000000", + ) + .unwrap(), + 26, + ), + ]; + for (s, i, scale) in edge_tests_256 { + let result = parse_decimal::(s, 76, scale); + assert_eq!(i, result.unwrap()); + } + } +} diff --git a/arrow/src/util/pretty.rs b/arrow-cast/src/pretty.rs similarity index 56% rename from arrow/src/util/pretty.rs rename to arrow-cast/src/pretty.rs index b0013619b50c..550afa9f739d 100644 --- a/arrow/src/util/pretty.rs +++ b/arrow-cast/src/pretty.rs @@ -15,44 +15,60 @@ // specific language governing permissions and limitations // under the License. -//! Utilities for printing record batches. Note this module is not +//! Utilities for pretty printing record batches. Note this module is not //! available unless `feature = "prettyprint"` is enabled. -use crate::{array::ArrayRef, record_batch::RecordBatch}; +use crate::display::{ArrayFormatter, FormatOptions}; +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::ArrowError; use comfy_table::{Cell, Table}; use std::fmt::Display; -use crate::error::Result; - -use super::display::array_value_to_string; +/// Create a visual representation of record batches +pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { + let options = FormatOptions::default().with_display_error(true); + pretty_format_batches_with_options(results, &options) +} -///! Create a visual representation of record batches -pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { - create_table(results) +/// Create a visual representation of record batches +pub fn pretty_format_batches_with_options( + results: &[RecordBatch], + options: &FormatOptions, +) -> Result { + create_table(results, options) } -///! Create a visual representation of columns +/// Create a visual representation of columns pub fn pretty_format_columns( col_name: &str, results: &[ArrayRef], -) -> Result { - create_column(col_name, results) +) -> Result { + let options = FormatOptions::default().with_display_error(true); + pretty_format_columns_with_options(col_name, results, &options) +} + +pub fn pretty_format_columns_with_options( + col_name: &str, + results: &[ArrayRef], + options: &FormatOptions, +) -> Result { + create_column(col_name, results, options) } -///! Prints a visual representation of record batches to stdout -pub fn print_batches(results: &[RecordBatch]) -> Result<()> { - println!("{}", create_table(results)?); +/// Prints a visual representation of record batches to stdout +pub fn print_batches(results: &[RecordBatch]) -> Result<(), ArrowError> { + println!("{}", pretty_format_batches(results)?); Ok(()) } -///! Prints a visual representation of a list of column to stdout -pub fn print_columns(col_name: &str, results: &[ArrayRef]) -> Result<()> { - println!("{}", create_column(col_name, results)?); +/// Prints a visual representation of a list of column to stdout +pub fn print_columns(col_name: &str, results: &[ArrayRef]) -> Result<(), ArrowError> { + println!("{}", pretty_format_columns(col_name, results)?); Ok(()) } -///! Convert a series of record batches into a table -fn create_table(results: &[RecordBatch]) -> Result { +/// Convert a series of record batches into a table +fn create_table(results: &[RecordBatch], options: &FormatOptions) -> Result { let mut table = Table::new(); table.load_preset("||--+-++| ++++++"); @@ -64,16 +80,21 @@ fn create_table(results: &[RecordBatch]) -> Result
{ let mut header = Vec::new(); for field in schema.fields() { - header.push(Cell::new(&field.name())); + header.push(Cell::new(field.name())); } table.set_header(header); for batch in results { + let formatters = batch + .columns() + .iter() + .map(|c| ArrayFormatter::try_new(c.as_ref(), options)) + .collect::, ArrowError>>()?; + for row in 0..batch.num_rows() { let mut cells = Vec::new(); - for col in 0..batch.num_columns() { - let column = batch.column(col); - cells.push(Cell::new(&array_value_to_string(column, row)?)); + for formatter in &formatters { + cells.push(Cell::new(formatter.value(row))); } table.add_row(cells); } @@ -82,7 +103,11 @@ fn create_table(results: &[RecordBatch]) -> Result
{ Ok(table) } -fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ +fn create_column( + field: &str, + columns: &[ArrayRef], + options: &FormatOptions, +) -> Result { let mut table = Table::new(); table.load_preset("||--+-++| ++++++"); @@ -94,8 +119,9 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ table.set_header(header); for col in columns { + let formatter = ArrayFormatter::try_new(col.as_ref(), options)?; for row in 0..col.len() { - let cells = vec![Cell::new(&array_value_to_string(col, row)?)]; + let cells = vec![Cell::new(formatter.value(row))]; table.add_row(cells); } } @@ -105,28 +131,20 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ #[cfg(test)] mod tests { - use crate::{ - array::{ - self, new_null_array, Array, Date32Array, Date64Array, - FixedSizeBinaryBuilder, Float16Array, Int32Array, PrimitiveBuilder, - StringArray, StringBuilder, StringDictionaryBuilder, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UnionArray, UnionBuilder, - }, - buffer::Buffer, - datatypes::{DataType, Field, Float64Type, Int32Type, Schema, UnionMode}, - }; use super::*; - use crate::array::{Decimal128Array, FixedSizeListBuilder}; + use crate::display::array_value_to_string; + use arrow_array::builder::*; + use arrow_array::types::*; + use arrow_array::*; + use arrow_buffer::Buffer; + use arrow_schema::*; + use half::f16; use std::fmt::Write; use std::sync::Arc; - use half::f16; - #[test] - fn test_pretty_format_batches() -> Result<()> { + fn test_pretty_format_batches() { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, true), @@ -150,9 +168,10 @@ mod tests { Some(100), ])), ], - )?; + ) + .unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+---+-----+", @@ -167,13 +186,11 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_columns() -> Result<()> { + fn test_pretty_format_columns() { let columns = vec![ Arc::new(array::StringArray::from(vec![ Some("a"), @@ -184,18 +201,16 @@ mod tests { Arc::new(array::StringArray::from(vec![Some("e"), None, Some("g")])), ]; - let table = pretty_format_columns("a", &columns)?.to_string(); + let table = pretty_format_columns("a", &columns).unwrap().to_string(); let expected = vec![ - "+---+", "| a |", "+---+", "| a |", "| b |", "| |", "| d |", "| e |", - "| |", "| g |", "+---+", + "+---+", "| a |", "+---+", "| a |", "| b |", "| |", "| d |", "| e |", "| |", + "| g |", "+---+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] @@ -231,28 +246,25 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{:#?}", table); + assert_eq!(expected, actual, "Actual result:\n{table:#?}"); } #[test] - fn test_pretty_format_dictionary() -> Result<()> { + fn test_pretty_format_dictionary() { // define a schema. - let field_type = - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); - let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + let field = Field::new_dictionary("d1", DataType::Int32, DataType::Utf8, true); + let schema = Arc::new(Schema::new(vec![field])); - let keys_builder = PrimitiveBuilder::::with_capacity(10); - let values_builder = StringBuilder::new(); - let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + let mut builder = StringDictionaryBuilder::::new(); - builder.append("one")?; + builder.append_value("one"); builder.append_null(); - builder.append("three")?; + builder.append_value("three"); let array = Arc::new(builder.finish()); - let batch = RecordBatch::try_new(schema, vec![array])?; + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+-------+", @@ -266,18 +278,14 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_fixed_size_list() -> Result<()> { + fn test_pretty_format_fixed_size_list() { // define a schema. - let field_type = DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Int32, true)), - 3, - ); + let field_type = + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let keys_builder = Int32Array::builder(3); @@ -292,8 +300,8 @@ mod tests { let array = Arc::new(builder.finish()); - let batch = RecordBatch::try_new(schema, vec![array])?; - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+-----------+", "| d1 |", @@ -306,27 +314,25 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_fixed_size_binary() -> Result<()> { + fn test_pretty_format_fixed_size_binary() { // define a schema. let field_type = DataType::FixedSizeBinary(3); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 3); - builder.append_value(&[1, 2, 3]).unwrap(); + builder.append_value([1, 2, 3]).unwrap(); builder.append_null(); - builder.append_value(&[7, 8, 9]).unwrap(); + builder.append_value([7, 8, 9]).unwrap(); let array = Arc::new(builder.finish()); - let batch = RecordBatch::try_new(schema, vec![array])?; - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+--------+", "| d1 |", @@ -339,9 +345,7 @@ mod tests { let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } /// Generate an array with type $ARRAYTYPE with a numeric value of @@ -368,17 +372,49 @@ mod tests { let expected = $EXPECTED_RESULT; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n\n{:#?}\n\n", actual); + assert_eq!(expected, actual, "Actual result:\n\n{actual:#?}\n\n"); }; } + fn timestamp_batch(timezone: &str, value: T::Native) -> RecordBatch { + let mut builder = PrimitiveBuilder::::with_capacity(10); + builder.append_value(value); + builder.append_null(); + let array = builder.finish(); + let array = array.with_timezone(timezone); + + let schema = Arc::new(Schema::new(vec![Field::new( + "f", + array.data_type().clone(), + true, + )])); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() + } + + #[test] + fn test_pretty_format_timestamp_second_with_fixed_offset_timezone() { + let batch = timestamp_batch::("+08:00", 11111111); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + + let expected = vec![ + "+---------------------------+", + "| f |", + "+---------------------------+", + "| 1970-05-09T22:25:11+08:00 |", + "| |", + "+---------------------------+", + ]; + let actual: Vec<&str> = table.lines().collect(); + assert_eq!(expected, actual, "Actual result:\n\n{actual:#?}\n\n"); + } + #[test] fn test_pretty_format_timestamp_second() { let expected = vec![ "+---------------------+", "| f |", "+---------------------+", - "| 1970-05-09 14:25:11 |", + "| 1970-05-09T14:25:11 |", "| |", "+---------------------+", ]; @@ -391,7 +427,7 @@ mod tests { "+-------------------------+", "| f |", "+-------------------------+", - "| 1970-01-01 03:05:11.111 |", + "| 1970-01-01T03:05:11.111 |", "| |", "+-------------------------+", ]; @@ -404,7 +440,7 @@ mod tests { "+----------------------------+", "| f |", "+----------------------------+", - "| 1970-01-01 00:00:11.111111 |", + "| 1970-01-01T00:00:11.111111 |", "| |", "+----------------------------+", ]; @@ -417,7 +453,7 @@ mod tests { "+-------------------------------+", "| f |", "+-------------------------------+", - "| 1970-01-01 00:00:00.011111111 |", + "| 1970-01-01T00:00:00.011111111 |", "| |", "+-------------------------------+", ]; @@ -440,12 +476,12 @@ mod tests { #[test] fn test_pretty_format_date_64() { let expected = vec![ - "+------------+", - "| f |", - "+------------+", - "| 2005-03-18 |", - "| |", - "+------------+", + "+---------------------+", + "| f |", + "+---------------------+", + "| 2005-03-18T01:58:20 |", + "| |", + "+---------------------+", ]; check_datetime!(Date64Array, 1111111100000, expected); } @@ -503,7 +539,7 @@ mod tests { } #[test] - fn test_int_display() -> Result<()> { + fn test_int_display() { let array = Arc::new(Int32Array::from(vec![6, 3])) as ArrayRef; let actual_one = array_value_to_string(&array, 0).unwrap(); let expected_one = "6"; @@ -512,11 +548,10 @@ mod tests { let expected_two = "3"; assert_eq!(actual_one, expected_one); assert_eq!(actual_two, expected_two); - Ok(()) } #[test] - fn test_decimal_display() -> Result<()> { + fn test_decimal_display() { let precision = 10; let scale = 2; @@ -534,9 +569,9 @@ mod tests { true, )])); - let batch = RecordBatch::try_new(schema, vec![dm])?; + let batch = RecordBatch::try_new(schema, vec![dm]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+-------+", @@ -550,13 +585,11 @@ mod tests { ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_decimal_display_zero_scale() -> Result<()> { + fn test_decimal_display_zero_scale() { let precision = 5; let scale = 0; @@ -574,33 +607,31 @@ mod tests { true, )])); - let batch = RecordBatch::try_new(schema, vec![dm])?; + let batch = RecordBatch::try_new(schema, vec![dm]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - "+------+", "| f |", "+------+", "| 101 |", "| |", "| 200 |", - "| 3040 |", "+------+", + "+------+", "| f |", "+------+", "| 101 |", "| |", "| 200 |", "| 3040 |", + "+------+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_struct() -> Result<()> { + fn test_pretty_format_struct() { let schema = Schema::new(vec![ - Field::new( + Field::new_struct( "c1", - DataType::Struct(vec![ - Field::new("c11", DataType::Int32, false), - Field::new( + vec![ + Field::new("c11", DataType::Int32, true), + Field::new_struct( "c12", - DataType::Struct(vec![Field::new("c121", DataType::Utf8, false)]), + vec![Field::new("c121", DataType::Utf8, false)], false, ), - ]), + ], false, ), Field::new("c2", DataType::Utf8, false), @@ -608,47 +639,43 @@ mod tests { let c1 = StructArray::from(vec![ ( - Field::new("c11", DataType::Int32, false), + Arc::new(Field::new("c11", DataType::Int32, true)), Arc::new(Int32Array::from(vec![Some(1), None, Some(5)])) as ArrayRef, ), ( - Field::new( + Arc::new(Field::new_struct( "c12", - DataType::Struct(vec![Field::new("c121", DataType::Utf8, false)]), + vec![Field::new("c121", DataType::Utf8, false)], false, - ), + )), Arc::new(StructArray::from(vec![( - Field::new("c121", DataType::Utf8, false), - Arc::new(StringArray::from(vec![Some("e"), Some("f"), Some("g")])) - as ArrayRef, + Arc::new(Field::new("c121", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![Some("e"), Some("f"), Some("g")])) as ArrayRef, )])) as ArrayRef, ), ]); let c2 = StringArray::from(vec![Some("a"), Some("b"), Some("c")]); let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]) - .unwrap(); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - r#"+-------------------------------------+----+"#, - r#"| c1 | c2 |"#, - r#"+-------------------------------------+----+"#, - r#"| {"c11": 1, "c12": {"c121": "e"}} | a |"#, - r#"| {"c11": null, "c12": {"c121": "f"}} | b |"#, - r#"| {"c11": 5, "c12": {"c121": "g"}} | c |"#, - r#"+-------------------------------------+----+"#, + "+--------------------------+----+", + "| c1 | c2 |", + "+--------------------------+----+", + "| {c11: 1, c12: {c121: e}} | a |", + "| {c11: , c12: {c121: f}} | b |", + "| {c11: 5, c12: {c121: g}} | c |", + "+--------------------------+----+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); - - Ok(()) + assert_eq!(expected, actual, "Actual result:\n{table}"); } #[test] - fn test_pretty_format_dense_union() -> Result<()> { + fn test_pretty_format_dense_union() { let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 3.2234).unwrap(); @@ -656,22 +683,18 @@ mod tests { builder.append_null::("a").unwrap(); let union = builder.build().unwrap(); - let schema = Schema::new(vec![Field::new( + let schema = Schema::new(vec![Field::new_union( "Teamsters", - DataType::Union( - vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Float64, false), - ], - vec![0, 1], - UnionMode::Dense, - ), - false, + vec![0, 1], + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ], + UnionMode::Dense, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ "+------------+", @@ -685,11 +708,10 @@ mod tests { ]; assert_eq!(expected, actual); - Ok(()) } #[test] - fn test_pretty_format_sparse_union() -> Result<()> { + fn test_pretty_format_sparse_union() { let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 3.2234).unwrap(); @@ -697,22 +719,18 @@ mod tests { builder.append_null::("a").unwrap(); let union = builder.build().unwrap(); - let schema = Schema::new(vec![Field::new( + let schema = Schema::new(vec![Field::new_union( "Teamsters", - DataType::Union( - vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Float64, false), - ], - vec![0, 1], - UnionMode::Sparse, - ), - false, + vec![0, 1], + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ], + UnionMode::Sparse, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ "+------------+", @@ -726,11 +744,10 @@ mod tests { ]; assert_eq!(expected, actual); - Ok(()) } #[test] - fn test_pretty_format_nested_union() -> Result<()> { + fn test_pretty_format_nested_union() { //Inner UnionArray let mut builder = UnionBuilder::new_dense(); builder.append::("b", 1).unwrap(); @@ -740,22 +757,19 @@ mod tests { builder.append_null::("c").unwrap(); let inner = builder.build().unwrap(); - let inner_field = Field::new( + let inner_field = Field::new_union( "European Union", - DataType::Union( - vec![ - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Float64, false), - ], - vec![0, 1], - UnionMode::Dense, - ), - false, + vec![0, 1], + vec![ + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Float64, false), + ], + UnionMode::Dense, ); // Can't use UnionBuilder with non-primitive types, so manually build outer UnionArray let a_array = Int32Array::from(vec![None, None, None, Some(1234), Some(23)]); - let type_ids = Buffer::from_slice_ref(&[1_i8, 1, 0, 0, 1]); + let type_ids = Buffer::from_slice_ref([1_i8, 1, 0, 0, 1]); let children: Vec<(Field, Arc)> = vec![ (Field::new("a", DataType::Int32, true), Arc::new(a_array)), @@ -764,19 +778,15 @@ mod tests { let outer = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap(); - let schema = Schema::new(vec![Field::new( + let schema = Schema::new(vec![Field::new_union( "Teamsters", - DataType::Union( - vec![Field::new("a", DataType::Int32, true), inner_field], - vec![0, 1], - UnionMode::Sparse, - ), - false, + vec![0, 1], + vec![Field::new("a", DataType::Int32, true), inner_field], + UnionMode::Sparse, )]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(outer)]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(outer)]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let actual: Vec<&str> = table.lines().collect(); let expected = vec![ "+-----------------------------+", @@ -790,11 +800,10 @@ mod tests { "+-----------------------------+", ]; assert_eq!(expected, actual); - Ok(()) } #[test] - fn test_writing_formatted_batches() -> Result<()> { + fn test_writing_formatted_batches() { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, true), @@ -818,12 +827,13 @@ mod tests { Some(100), ])), ], - )?; + ) + .unwrap(); let mut buf = String::new(); - write!(&mut buf, "{}", pretty_format_batches(&[batch])?).unwrap(); + write!(&mut buf, "{}", pretty_format_batches(&[batch]).unwrap()).unwrap(); - let s = vec![ + let s = [ "+---+-----+", "| a | b |", "+---+-----+", @@ -835,12 +845,10 @@ mod tests { ]; let expected = s.join("\n"); assert_eq!(expected, buf); - - Ok(()) } #[test] - fn test_float16_display() -> Result<()> { + fn test_float16_display() { let values = vec![ Some(f16::from_f32(f32::NAN)), Some(f16::from_f32(4.0)), @@ -854,18 +862,144 @@ mod tests { true, )])); - let batch = RecordBatch::try_new(schema, vec![array])?; + let batch = RecordBatch::try_new(schema, vec![array]).unwrap(); - let table = pretty_format_batches(&[batch])?.to_string(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ - "+------+", "| f16 |", "+------+", "| NaN |", "| 4 |", "| -inf |", - "+------+", + "+------+", "| f16 |", "+------+", "| NaN |", "| 4 |", "| -inf |", "+------+", ]; let actual: Vec<&str> = table.lines().collect(); - assert_eq!(expected, actual, "Actual result:\n{}", table); + assert_eq!(expected, actual, "Actual result:\n{table}"); + } + + #[test] + fn test_pretty_format_interval_day_time() { + let arr = Arc::new(arrow_array::IntervalDayTimeArray::from(vec![ + Some(-600000), + Some(4294966295), + Some(4294967295), + Some(1), + Some(10), + Some(100), + ])); + + let schema = Arc::new(Schema::new(vec![Field::new( + "IntervalDayTime", + arr.data_type().clone(), + true, + )])); + + let batch = RecordBatch::try_new(schema, vec![arr]).unwrap(); + + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + + let expected = vec![ + "+----------------------------------------------------+", + "| IntervalDayTime |", + "+----------------------------------------------------+", + "| 0 years 0 mons -1 days 0 hours -10 mins 0.000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins -1.001 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins -0.001 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.001 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.010 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.100 secs |", + "+----------------------------------------------------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{table}"); + } + + #[test] + fn test_pretty_format_interval_month_day_nano_array() { + let arr = Arc::new(arrow_array::IntervalMonthDayNanoArray::from(vec![ + Some(-600000000000), + Some(18446744072709551615), + Some(18446744073709551615), + Some(1), + Some(10), + Some(100), + Some(1_000), + Some(10_000), + Some(100_000), + Some(1_000_000), + Some(10_000_000), + Some(100_000_000), + Some(1_000_000_000), + ])); + + let schema = Arc::new(Schema::new(vec![Field::new( + "IntervalMonthDayNano", + arr.data_type().clone(), + true, + )])); + + let batch = RecordBatch::try_new(schema, vec![arr]).unwrap(); + + let table = pretty_format_batches(&[batch]).unwrap().to_string(); + + let expected = vec![ + "+-----------------------------------------------------------+", + "| IntervalMonthDayNano |", + "+-----------------------------------------------------------+", + "| 0 years -1 mons -1 days 0 hours -10 mins 0.000000000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins -1.000000001 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins -0.000000001 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.000000001 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.000000010 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.000000100 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.000001000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.000010000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.000100000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.001000000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.010000000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 0.100000000 secs |", + "| 0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs |", + "+-----------------------------------------------------------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{table}"); + } + + #[test] + fn test_format_options() { + let options = FormatOptions::default().with_null("null"); + let array = Int32Array::from(vec![Some(1), Some(2), None, Some(3), Some(4)]); + let batch = RecordBatch::try_from_iter([("my_column_name", Arc::new(array) as _)]).unwrap(); + + let column = pretty_format_columns_with_options( + "my_column_name", + &[batch.column(0).clone()], + &options, + ) + .unwrap() + .to_string(); + + let batch = pretty_format_batches_with_options(&[batch], &options) + .unwrap() + .to_string(); + + let expected = vec![ + "+----------------+", + "| my_column_name |", + "+----------------+", + "| 1 |", + "| 2 |", + "| null |", + "| 3 |", + "| 4 |", + "+----------------+", + ]; + + let actual: Vec<&str> = column.lines().collect(); + assert_eq!(expected, actual, "Actual result:\n{column}"); - Ok(()) + let actual: Vec<&str> = batch.lines().collect(); + assert_eq!(expected, actual, "Actual result:\n{batch}"); } } diff --git a/arrow-csv/Cargo.toml b/arrow-csv/Cargo.toml new file mode 100644 index 000000000000..d29c85c56cfd --- /dev/null +++ b/arrow-csv/Cargo.toml @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-csv" +version = { workspace = true } +description = "Support for parsing CSV format to and from the Arrow format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_csv" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +chrono = { workspace = true } +csv = { version = "1.1", default-features = false } +csv-core = { version = "0.1" } +lazy_static = { version = "1.4", default-features = false } +lexical-core = { version = "^0.8", default-features = false } +regex = { version = "1.7.0", default-features = false, features = ["std", "unicode", "perf"] } + +[dev-dependencies] +tempfile = "3.3" +futures = "0.3" +tokio = { version = "1.27", default-features = false, features = ["io-util"] } +bytes = "1.4" diff --git a/arrow-csv/examples/README.md b/arrow-csv/examples/README.md new file mode 100644 index 000000000000..340413e76d94 --- /dev/null +++ b/arrow-csv/examples/README.md @@ -0,0 +1,21 @@ + + +# Examples +- [`csv_calculation.rs`](csv_calculation.rs): performs a simple calculation using the CSV reader \ No newline at end of file diff --git a/arrow-csv/examples/csv_calculation.rs b/arrow-csv/examples/csv_calculation.rs new file mode 100644 index 000000000000..6ce963e2b012 --- /dev/null +++ b/arrow-csv/examples/csv_calculation.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::Int16Type; +use arrow_csv::ReaderBuilder; + +use arrow_schema::{DataType, Field, Schema}; +use std::fs::File; +use std::sync::Arc; + +fn main() { + // read csv from file + let file = File::open("arrow-csv/test/data/example.csv").unwrap(); + let csv_schema = Schema::new(vec![ + Field::new("c1", DataType::Int16, true), + Field::new("c2", DataType::Float32, true), + Field::new("c3", DataType::Utf8, true), + Field::new("c4", DataType::Boolean, true), + ]); + let mut reader = ReaderBuilder::new(Arc::new(csv_schema)) + .with_header(true) + .build(file) + .unwrap(); + + match reader.next() { + Some(r) => match r { + Ok(r) => { + // get the column(0) max value + let col = r.column(0).as_primitive::(); + let max = col.iter().max().flatten(); + println!("max value column(0): {max:?}") + } + Err(e) => { + println!("{e:?}"); + } + }, + None => { + println!("csv is empty"); + } + } +} diff --git a/arrow-csv/src/lib.rs b/arrow-csv/src/lib.rs new file mode 100644 index 000000000000..e6dc69935199 --- /dev/null +++ b/arrow-csv/src/lib.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Transfer data between the Arrow memory format and CSV (comma-separated values). + +pub mod reader; +pub mod writer; + +pub use self::reader::infer_schema_from_files; +pub use self::reader::Reader; +pub use self::reader::ReaderBuilder; +pub use self::writer::Writer; +pub use self::writer::WriterBuilder; +use arrow_schema::ArrowError; + +fn map_csv_error(error: csv::Error) -> ArrowError { + match error.kind() { + csv::ErrorKind::Io(error) => ArrowError::CsvError(error.to_string()), + csv::ErrorKind::Utf8 { pos: _, err } => ArrowError::CsvError(format!( + "Encountered UTF-8 error while reading CSV file: {err}" + )), + csv::ErrorKind::UnequalLengths { + expected_len, len, .. + } => ArrowError::CsvError(format!( + "Encountered unequal lengths between records on CSV file. Expected {len} \ + records, found {expected_len} records" + )), + _ => ArrowError::CsvError("Error reading CSV file".to_string()), + } +} diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs new file mode 100644 index 000000000000..83c8965fdf8a --- /dev/null +++ b/arrow-csv/src/reader/mod.rs @@ -0,0 +1,2359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CSV Reader +//! +//! # Basic Usage +//! +//! This CSV reader allows CSV files to be read into the Arrow memory model. Records are +//! loaded in batches and are then converted from row-based data to columnar data. +//! +//! Example: +//! +//! ``` +//! # use arrow_schema::*; +//! # use arrow_csv::{Reader, ReaderBuilder}; +//! # use std::fs::File; +//! # use std::sync::Arc; +//! +//! let schema = Schema::new(vec![ +//! Field::new("city", DataType::Utf8, false), +//! Field::new("lat", DataType::Float64, false), +//! Field::new("lng", DataType::Float64, false), +//! ]); +//! +//! let file = File::open("test/data/uk_cities.csv").unwrap(); +//! +//! let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap(); +//! let batch = csv.next().unwrap().unwrap(); +//! ``` +//! +//! # Async Usage +//! +//! The lower-level [`Decoder`] can be integrated with various forms of async data streams, +//! and is designed to be agnostic to the various different kinds of async IO primitives found +//! within the Rust ecosystem. +//! +//! For example, see below for how it can be used with an arbitrary `Stream` of `Bytes` +//! +//! ``` +//! # use std::task::{Poll, ready}; +//! # use bytes::{Buf, Bytes}; +//! # use arrow_schema::ArrowError; +//! # use futures::stream::{Stream, StreamExt}; +//! # use arrow_array::RecordBatch; +//! # use arrow_csv::reader::Decoder; +//! # +//! fn decode_stream + Unpin>( +//! mut decoder: Decoder, +//! mut input: S, +//! ) -> impl Stream> { +//! let mut buffered = Bytes::new(); +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! if buffered.is_empty() { +//! if let Some(b) = ready!(input.poll_next_unpin(cx)) { +//! buffered = b; +//! } +//! // Note: don't break on `None` as the decoder needs +//! // to be called with an empty array to delimit the +//! // final record +//! } +//! let decoded = match decoder.decode(buffered.as_ref()) { +//! Ok(0) => break, +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! buffered.advance(decoded); +//! } +//! +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! +//! ``` +//! +//! In a similar vein, it can also be used with tokio-based IO primitives +//! +//! ``` +//! # use std::pin::Pin; +//! # use std::task::{Poll, ready}; +//! # use futures::Stream; +//! # use tokio::io::AsyncBufRead; +//! # use arrow_array::RecordBatch; +//! # use arrow_csv::reader::Decoder; +//! # use arrow_schema::ArrowError; +//! fn decode_stream( +//! mut decoder: Decoder, +//! mut reader: R, +//! ) -> impl Stream> { +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) { +//! Ok(b) => b, +//! Err(e) => return Poll::Ready(Some(Err(e.into()))), +//! }; +//! let decoded = match decoder.decode(b) { +//! // Note: the decoder needs to be called with an empty +//! // array to delimit the final record +//! Ok(0) => break, +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! Pin::new(&mut reader).consume(decoded); +//! } +//! +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! + +mod records; + +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::types::*; +use arrow_array::*; +use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser}; +use arrow_schema::*; +use chrono::{TimeZone, Utc}; +use csv::StringRecord; +use lazy_static::lazy_static; +use regex::{Regex, RegexSet}; +use std::fmt::{self, Debug}; +use std::fs::File; +use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; +use std::sync::Arc; + +use crate::map_csv_error; +use crate::reader::records::{RecordDecoder, StringRecords}; +use arrow_array::timezone::Tz; + +lazy_static! { + /// Order should match [`InferredDataType`] + static ref REGEX_SET: RegexSet = RegexSet::new([ + r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN + r"^-?(\d+)$", //INTEGER + r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$", //DECIMAL + r"^\d{4}-\d\d-\d\d$", //DATE32 + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d(?:[^\d\.].*)?$", //Timestamp(Second) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,3}(?:[^\d].*)?$", //Timestamp(Millisecond) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,6}(?:[^\d].*)?$", //Timestamp(Microsecond) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}(?:[^\d].*)?$", //Timestamp(Nanosecond) + ]).unwrap(); +} + +/// A wrapper over `Option` to check if the value is `NULL`. +#[derive(Debug, Clone, Default)] +struct NullRegex(Option); + +impl NullRegex { + /// Returns true if the value should be considered as `NULL` according to + /// the provided regular expression. + #[inline] + fn is_null(&self, s: &str) -> bool { + match &self.0 { + Some(r) => r.is_match(s), + None => s.is_empty(), + } + } +} + +#[derive(Default, Copy, Clone)] +struct InferredDataType { + /// Packed booleans indicating type + /// + /// 0 - Boolean + /// 1 - Integer + /// 2 - Float64 + /// 3 - Date32 + /// 4 - Timestamp(Second) + /// 5 - Timestamp(Millisecond) + /// 6 - Timestamp(Microsecond) + /// 7 - Timestamp(Nanosecond) + /// 8 - Utf8 + packed: u16, +} + +impl InferredDataType { + /// Returns the inferred data type + fn get(&self) -> DataType { + match self.packed { + 0 => DataType::Null, + 1 => DataType::Boolean, + 2 => DataType::Int64, + 4 | 6 => DataType::Float64, // Promote Int64 to Float64 + b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() { + // Promote to highest precision temporal type + 8 => DataType::Timestamp(TimeUnit::Nanosecond, None), + 9 => DataType::Timestamp(TimeUnit::Microsecond, None), + 10 => DataType::Timestamp(TimeUnit::Millisecond, None), + 11 => DataType::Timestamp(TimeUnit::Second, None), + 12 => DataType::Date32, + _ => unreachable!(), + }, + _ => DataType::Utf8, + } + } + + /// Updates the [`InferredDataType`] with the given string + fn update(&mut self, string: &str) { + self.packed |= if string.starts_with('"') { + 1 << 8 // Utf8 + } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() { + 1 << m + } else { + 1 << 8 // Utf8 + } + } +} + +/// The format specification for the CSV file +#[derive(Debug, Clone, Default)] +pub struct Format { + header: bool, + delimiter: Option, + escape: Option, + quote: Option, + terminator: Option, + null_regex: NullRegex, +} + +impl Format { + pub fn with_header(mut self, has_header: bool) -> Self { + self.header = has_header; + self + } + + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.delimiter = Some(delimiter); + self + } + + pub fn with_escape(mut self, escape: u8) -> Self { + self.escape = Some(escape); + self + } + + pub fn with_quote(mut self, quote: u8) -> Self { + self.quote = Some(quote); + self + } + + pub fn with_terminator(mut self, terminator: u8) -> Self { + self.terminator = Some(terminator); + self + } + + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.null_regex = NullRegex(Some(null_regex)); + self + } + + /// Infer schema of CSV records from the provided `reader` + /// + /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` + /// records are read to infer the schema + /// + /// Returns inferred schema and number of records read + pub fn infer_schema( + &self, + reader: R, + max_records: Option, + ) -> Result<(Schema, usize), ArrowError> { + let mut csv_reader = self.build_reader(reader); + + // get or create header names + // when has_header is false, creates default column names with column_ prefix + let headers: Vec = if self.header { + let headers = &csv_reader.headers().map_err(map_csv_error)?.clone(); + headers.iter().map(|s| s.to_string()).collect() + } else { + let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len(); + (0..*first_record_count) + .map(|i| format!("column_{}", i + 1)) + .collect() + }; + + let header_length = headers.len(); + // keep track of inferred field types + let mut column_types: Vec = vec![Default::default(); header_length]; + + let mut records_count = 0; + + let mut record = StringRecord::new(); + let max_records = max_records.unwrap_or(usize::MAX); + while records_count < max_records { + if !csv_reader.read_record(&mut record).map_err(map_csv_error)? { + break; + } + records_count += 1; + + // Note since we may be looking at a sample of the data, we make the safe assumption that + // they could be nullable + for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) { + if let Some(string) = record.get(i) { + if !self.null_regex.is_null(string) { + column_type.update(string) + } + } + } + } + + // build schema from inference results + let fields: Fields = column_types + .iter() + .zip(&headers) + .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true)) + .collect(); + + Ok((Schema::new(fields), records_count)) + } + + /// Build a [`csv::Reader`] for this [`Format`] + fn build_reader(&self, reader: R) -> csv::Reader { + let mut builder = csv::ReaderBuilder::new(); + builder.has_headers(self.header); + + if let Some(c) = self.delimiter { + builder.delimiter(c); + } + builder.escape(self.escape); + if let Some(c) = self.quote { + builder.quote(c); + } + if let Some(t) = self.terminator { + builder.terminator(csv::Terminator::Any(t)); + } + builder.from_reader(reader) + } + + /// Build a [`csv_core::Reader`] for this [`Format`] + fn build_parser(&self) -> csv_core::Reader { + let mut builder = csv_core::ReaderBuilder::new(); + builder.escape(self.escape); + + if let Some(c) = self.delimiter { + builder.delimiter(c); + } + if let Some(c) = self.quote { + builder.quote(c); + } + if let Some(t) = self.terminator { + builder.terminator(csv_core::Terminator::Any(t)); + } + builder.build() + } +} + +/// Infer the schema of a CSV file by reading through the first n records of the file, +/// with `max_read_records` controlling the maximum number of records to read. +/// +/// If `max_read_records` is not set, the whole file is read to infer its schema. +/// +/// Return inferred schema and number of records used for inference. This function does not change +/// reader cursor offset. +/// +/// The inferred schema will always have each field set as nullable. +#[deprecated(note = "Use Format::infer_schema")] +#[allow(deprecated)] +pub fn infer_file_schema( + mut reader: R, + delimiter: u8, + max_read_records: Option, + has_header: bool, +) -> Result<(Schema, usize), ArrowError> { + let saved_offset = reader.stream_position()?; + let r = infer_reader_schema(&mut reader, delimiter, max_read_records, has_header)?; + // return the reader seek back to the start + reader.seek(SeekFrom::Start(saved_offset))?; + Ok(r) +} + +/// Infer schema of CSV records provided by struct that implements `Read` trait. +/// +/// `max_read_records` controlling the maximum number of records to read. If `max_read_records` is +/// not set, all records are read to infer the schema. +/// +/// Return inferred schema and number of records used for inference. +#[deprecated(note = "Use Format::infer_schema")] +pub fn infer_reader_schema( + reader: R, + delimiter: u8, + max_read_records: Option, + has_header: bool, +) -> Result<(Schema, usize), ArrowError> { + let format = Format { + delimiter: Some(delimiter), + header: has_header, + ..Default::default() + }; + format.infer_schema(reader, max_read_records) +} + +/// Infer schema from a list of CSV files by reading through first n records +/// with `max_read_records` controlling the maximum number of records to read. +/// +/// Files will be read in the given order until n records have been reached. +/// +/// If `max_read_records` is not set, all files will be read fully to infer the schema. +pub fn infer_schema_from_files( + files: &[String], + delimiter: u8, + max_read_records: Option, + has_header: bool, +) -> Result { + let mut schemas = vec![]; + let mut records_to_read = max_read_records.unwrap_or(usize::MAX); + let format = Format { + delimiter: Some(delimiter), + header: has_header, + ..Default::default() + }; + + for fname in files.iter() { + let f = File::open(fname)?; + let (schema, records_read) = format.infer_schema(f, Some(records_to_read))?; + if records_read == 0 { + continue; + } + schemas.push(schema.clone()); + records_to_read -= records_read; + if records_to_read == 0 { + break; + } + } + + Schema::try_merge(schemas) +} + +// optional bounds of the reader, of the form (min line, max line). +type Bounds = Option<(usize, usize)>; + +/// CSV file reader using [`std::io::BufReader`] +pub type Reader = BufReader>; + +/// CSV file reader +pub struct BufReader { + /// File reader + reader: R, + + /// The decoder + decoder: Decoder, +} + +impl fmt::Debug for BufReader +where + R: BufRead, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Reader") + .field("decoder", &self.decoder) + .finish() + } +} + +impl Reader { + /// Returns the schema of the reader, useful for getting the schema without reading + /// record batches + pub fn schema(&self) -> SchemaRef { + match &self.decoder.projection { + Some(projection) => { + let fields = self.decoder.schema.fields(); + let projected = projection.iter().map(|i| fields[*i].clone()); + Arc::new(Schema::new(projected.collect::())) + } + None => self.decoder.schema.clone(), + } + } +} + +impl BufReader { + fn read(&mut self) -> Result, ArrowError> { + loop { + let buf = self.reader.fill_buf()?; + let decoded = self.decoder.decode(buf)?; + self.reader.consume(decoded); + // Yield if decoded no bytes or the decoder is full + // + // The capacity check avoids looping around and potentially + // blocking reading data in fill_buf that isn't needed + // to flush the next batch + if decoded == 0 || self.decoder.capacity() == 0 { + break; + } + } + + self.decoder.flush() + } +} + +impl Iterator for BufReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.read().transpose() + } +} + +impl RecordBatchReader for BufReader { + fn schema(&self) -> SchemaRef { + self.decoder.schema.clone() + } +} + +/// A push-based interface for decoding CSV data from an arbitrary byte stream +/// +/// See [`Reader`] for a higher-level interface for interface with [`Read`] +/// +/// The push-based interface facilitates integration with sources that yield arbitrarily +/// delimited bytes ranges, such as [`BufRead`], or a chunked byte stream received from +/// object storage +/// +/// ``` +/// # use std::io::BufRead; +/// # use arrow_array::RecordBatch; +/// # use arrow_csv::ReaderBuilder; +/// # use arrow_schema::{ArrowError, SchemaRef}; +/// # +/// fn read_from_csv( +/// mut reader: R, +/// schema: SchemaRef, +/// batch_size: usize, +/// ) -> Result>, ArrowError> { +/// let mut decoder = ReaderBuilder::new(schema) +/// .with_batch_size(batch_size) +/// .build_decoder(); +/// +/// let mut next = move || { +/// loop { +/// let buf = reader.fill_buf()?; +/// let decoded = decoder.decode(buf)?; +/// if decoded == 0 { +/// break; +/// } +/// +/// // Consume the number of bytes read +/// reader.consume(decoded); +/// } +/// decoder.flush() +/// }; +/// Ok(std::iter::from_fn(move || next().transpose())) +/// } +/// ``` +#[derive(Debug)] +pub struct Decoder { + /// Explicit schema for the CSV file + schema: SchemaRef, + + /// Optional projection for which columns to load (zero-based column indices) + projection: Option>, + + /// Number of records per batch + batch_size: usize, + + /// Rows to skip + to_skip: usize, + + /// Current line number + line_number: usize, + + /// End line number + end: usize, + + /// A decoder for [`StringRecords`] + record_decoder: RecordDecoder, + + /// Check if the string matches this pattern for `NULL`. + null_regex: NullRegex, +} + +impl Decoder { + /// Decode records from `buf` returning the number of bytes read + /// + /// This method returns once `batch_size` objects have been parsed since the + /// last call to [`Self::flush`], or `buf` is exhausted. Any remaining bytes + /// should be included in the next call to [`Self::decode`] + /// + /// There is no requirement that `buf` contains a whole number of records, facilitating + /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] or + /// network sources such as object storage + pub fn decode(&mut self, buf: &[u8]) -> Result { + if self.to_skip != 0 { + // Skip in units of `to_read` to avoid over-allocating buffers + let to_skip = self.to_skip.min(self.batch_size); + let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?; + self.to_skip -= skipped; + self.record_decoder.clear(); + return Ok(bytes); + } + + let to_read = self.batch_size.min(self.end - self.line_number) - self.record_decoder.len(); + let (_, bytes) = self.record_decoder.decode(buf, to_read)?; + Ok(bytes) + } + + /// Flushes the currently buffered data to a [`RecordBatch`] + /// + /// This should only be called after [`Self::decode`] has returned `Ok(0)`, + /// otherwise may return an error if part way through decoding a record + /// + /// Returns `Ok(None)` if no buffered data + pub fn flush(&mut self) -> Result, ArrowError> { + if self.record_decoder.is_empty() { + return Ok(None); + } + + let rows = self.record_decoder.flush()?; + let batch = parse( + &rows, + self.schema.fields(), + Some(self.schema.metadata.clone()), + self.projection.as_ref(), + self.line_number, + &self.null_regex, + )?; + self.line_number += rows.len(); + Ok(Some(batch)) + } + + /// Returns the number of records that can be read before requiring a call to [`Self::flush`] + pub fn capacity(&self) -> usize { + self.batch_size - self.record_decoder.len() + } +} + +/// Parses a slice of [`StringRecords`] into a [RecordBatch] +fn parse( + rows: &StringRecords<'_>, + fields: &Fields, + metadata: Option>, + projection: Option<&Vec>, + line_number: usize, + null_regex: &NullRegex, +) -> Result { + let projection: Vec = match projection { + Some(v) => v.clone(), + None => fields.iter().enumerate().map(|(i, _)| i).collect(), + }; + + let arrays: Result, _> = projection + .iter() + .map(|i| { + let i = *i; + let field = &fields[i]; + match field.data_type() { + DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex), + DataType::Decimal128(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), + DataType::Decimal256(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), + DataType::Int8 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Int16 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Int32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Int64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt8 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt16 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::UInt64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Float32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Float64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Date32 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Date64 => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time32(TimeUnit::Second) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time32(TimeUnit::Millisecond) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time64(TimeUnit::Microsecond) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Time64(TimeUnit::Nanosecond) => { + build_primitive_array::(line_number, rows, i, null_regex) + } + DataType::Timestamp(TimeUnit::Second, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + build_timestamp_array::( + line_number, + rows, + i, + tz.as_deref(), + null_regex, + ) + } + DataType::Null => Ok(Arc::new(NullArray::builder(rows.len()).finish()) as ArrayRef), + DataType::Utf8 => Ok(Arc::new( + rows.iter() + .map(|row| { + let s = row.get(i); + (!null_regex.is_null(s)).then_some(s) + }) + .collect::(), + ) as ArrayRef), + DataType::Dictionary(key_type, value_type) + if value_type.as_ref() == &DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int16 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int32 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::Int64 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt8 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt16 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt32 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + DataType::UInt64 => Ok(Arc::new( + rows.iter() + .map(|row| row.get(i)) + .collect::>(), + ) as ArrayRef), + _ => Err(ArrowError::ParseError(format!( + "Unsupported dictionary key type {key_type:?}" + ))), + } + } + other => Err(ArrowError::ParseError(format!( + "Unsupported data type {other:?}" + ))), + } + }) + .collect(); + + let projected_fields: Fields = projection.iter().map(|i| fields[*i].clone()).collect(); + + let projected_schema = Arc::new(match metadata { + None => Schema::new(projected_fields), + Some(metadata) => Schema::new_with_metadata(projected_fields, metadata), + }); + + arrays.and_then(|arr| { + RecordBatch::try_new_with_options( + projected_schema, + arr, + &RecordBatchOptions::new() + .with_match_field_names(true) + .with_row_count(Some(rows.len())), + ) + }) +} + +fn parse_bool(string: &str) -> Option { + if string.eq_ignore_ascii_case("false") { + Some(false) + } else if string.eq_ignore_ascii_case("true") { + Some(true) + } else { + None + } +} + +// parse the column string to an Arrow Array +fn build_decimal_array( + _line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + precision: u8, + scale: i8, + null_regex: &NullRegex, +) -> Result { + let mut decimal_builder = PrimitiveBuilder::::with_capacity(rows.len()); + for row in rows.iter() { + let s = row.get(col_idx); + if null_regex.is_null(s) { + // append null + decimal_builder.append_null(); + } else { + let decimal_value: Result = parse_decimal::(s, precision, scale); + match decimal_value { + Ok(v) => { + decimal_builder.append_value(v); + } + Err(e) => { + return Err(e); + } + } + } + } + Ok(Arc::new( + decimal_builder + .finish() + .with_precision_and_scale(precision, scale)?, + )) +} + +// parses a specific column (col_idx) into an Arrow Array. +fn build_primitive_array( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + null_regex: &NullRegex, +) -> Result { + rows.iter() + .enumerate() + .map(|(row_index, row)| { + let s = row.get(col_idx); + if null_regex.is_null(s) { + return Ok(None); + } + + match T::parse(s) { + Some(e) => Ok(Some(e)), + None => Err(ArrowError::ParseError(format!( + // TODO: we should surface the underlying error here. + "Error while parsing value {} for column {} at line {}", + s, + col_idx, + line_number + row_index + ))), + } + }) + .collect::, ArrowError>>() + .map(|e| Arc::new(e) as ArrayRef) +} + +fn build_timestamp_array( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + timezone: Option<&str>, + null_regex: &NullRegex, +) -> Result { + Ok(Arc::new(match timezone { + Some(timezone) => { + let tz: Tz = timezone.parse()?; + build_timestamp_array_impl::(line_number, rows, col_idx, &tz, null_regex)? + .with_timezone(timezone) + } + None => build_timestamp_array_impl::(line_number, rows, col_idx, &Utc, null_regex)?, + })) +} + +fn build_timestamp_array_impl( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + timezone: &Tz, + null_regex: &NullRegex, +) -> Result, ArrowError> { + rows.iter() + .enumerate() + .map(|(row_index, row)| { + let s = row.get(col_idx); + if null_regex.is_null(s) { + return Ok(None); + } + + let date = string_to_datetime(timezone, s) + .and_then(|date| match T::UNIT { + TimeUnit::Second => Ok(date.timestamp()), + TimeUnit::Millisecond => Ok(date.timestamp_millis()), + TimeUnit::Microsecond => Ok(date.timestamp_micros()), + TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| { + ArrowError::ParseError(format!( + "{} would overflow 64-bit signed nanoseconds", + date.to_rfc3339(), + )) + }), + }) + .map_err(|e| { + ArrowError::ParseError(format!( + "Error parsing column {col_idx} at line {}: {}", + line_number + row_index, + e + )) + })?; + Ok(Some(date)) + }) + .collect() +} + +// parses a specific column (col_idx) into an Arrow Array. +fn build_boolean_array( + line_number: usize, + rows: &StringRecords<'_>, + col_idx: usize, + null_regex: &NullRegex, +) -> Result { + rows.iter() + .enumerate() + .map(|(row_index, row)| { + let s = row.get(col_idx); + if null_regex.is_null(s) { + return Ok(None); + } + let parsed = parse_bool(s); + match parsed { + Some(e) => Ok(Some(e)), + None => Err(ArrowError::ParseError(format!( + // TODO: we should surface the underlying error here. + "Error while parsing value {} for column {} at line {}", + s, + col_idx, + line_number + row_index + ))), + } + }) + .collect::>() + .map(|e| Arc::new(e) as ArrayRef) +} + +/// CSV file reader builder +#[derive(Debug)] +pub struct ReaderBuilder { + /// Schema of the CSV file + schema: SchemaRef, + /// Format of the CSV file + format: Format, + /// Batch size (number of records to load each time) + /// + /// The default batch size when using the `ReaderBuilder` is 1024 records + batch_size: usize, + /// The bounds over which to scan the reader. `None` starts from 0 and runs until EOF. + bounds: Bounds, + /// Optional projection for which columns to load (zero-based column indices) + projection: Option>, +} + +impl ReaderBuilder { + /// Create a new builder for configuring CSV parsing options. + /// + /// To convert a builder into a reader, call `ReaderBuilder::build` + /// + /// # Example + /// + /// ``` + /// # use arrow_csv::{Reader, ReaderBuilder}; + /// # use std::fs::File; + /// # use std::io::Seek; + /// # use std::sync::Arc; + /// # use arrow_csv::reader::Format; + /// # + /// let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + /// // Infer the schema with the first 100 records + /// let (schema, _) = Format::default().infer_schema(&mut file, Some(100)).unwrap(); + /// file.rewind().unwrap(); + /// + /// // create a builder + /// ReaderBuilder::new(Arc::new(schema)).build(file).unwrap(); + /// ``` + pub fn new(schema: SchemaRef) -> ReaderBuilder { + Self { + schema, + format: Format::default(), + batch_size: 1024, + bounds: None, + projection: None, + } + } + + /// Set whether the CSV file has headers + #[deprecated(note = "Use with_header")] + #[doc(hidden)] + pub fn has_header(mut self, has_header: bool) -> Self { + self.format.header = has_header; + self + } + + /// Set whether the CSV file has a header + pub fn with_header(mut self, has_header: bool) -> Self { + self.format.header = has_header; + self + } + + /// Overrides the [`Format`] of this [`ReaderBuilder] + pub fn with_format(mut self, format: Format) -> Self { + self.format = format; + self + } + + /// Set the CSV file's column delimiter as a byte character + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.format.delimiter = Some(delimiter); + self + } + + pub fn with_escape(mut self, escape: u8) -> Self { + self.format.escape = Some(escape); + self + } + + pub fn with_quote(mut self, quote: u8) -> Self { + self.format.quote = Some(quote); + self + } + + pub fn with_terminator(mut self, terminator: u8) -> Self { + self.format.terminator = Some(terminator); + self + } + + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.format.null_regex = NullRegex(Some(null_regex)); + self + } + + /// Set the batch size (number of records to load at one time) + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the bounds over which to scan the reader. + /// `start` and `end` are line numbers. + pub fn with_bounds(mut self, start: usize, end: usize) -> Self { + self.bounds = Some((start, end)); + self + } + + /// Set the reader's column projection + pub fn with_projection(mut self, projection: Vec) -> Self { + self.projection = Some(projection); + self + } + + /// Create a new `Reader` from a non-buffered reader + /// + /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional + /// buffering, as internally this method wraps `reader` in [`std::io::BufReader`] + pub fn build(self, reader: R) -> Result, ArrowError> { + self.build_buffered(StdBufReader::new(reader)) + } + + /// Create a new `BufReader` from a buffered reader + pub fn build_buffered(self, reader: R) -> Result, ArrowError> { + Ok(BufReader { + reader, + decoder: self.build_decoder(), + }) + } + + /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream + pub fn build_decoder(self) -> Decoder { + let delimiter = self.format.build_parser(); + let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len()); + + let header = self.format.header as usize; + + let (start, end) = match self.bounds { + Some((start, end)) => (start + header, end + header), + None => (header, usize::MAX), + }; + + Decoder { + schema: self.schema, + to_skip: start, + record_decoder, + line_number: start, + end, + projection: self.projection, + batch_size: self.batch_size, + null_regex: self.format.null_regex, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Cursor, Write}; + use tempfile::NamedTempFile; + + use arrow_array::cast::AsArray; + + #[test] + fn test_csv() { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap(); + assert_eq!(schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + // access data from a primitive array + let lat = batch.column(1).as_primitive::(); + assert_eq!(57.653484, lat.value(0)); + + // access data from a string array (ListArray) + let city = batch.column(0).as_string::(); + + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + } + + #[test] + fn test_csv_schema_metadata() { + let mut metadata = std::collections::HashMap::new(); + metadata.insert("foo".to_owned(), "bar".to_owned()); + let schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ], + metadata.clone(), + )); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema.clone()).build(file).unwrap(); + assert_eq!(schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + assert_eq!(&metadata, batch.schema().metadata()); + } + + #[test] + fn test_csv_reader_with_decimal() { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Decimal128(38, 6), false), + Field::new("lng", DataType::Decimal256(76, 6), false), + ])); + + let file = File::open("test/data/decimal_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema).build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("57.653484", lat.value_as_string(0)); + assert_eq!("53.002666", lat.value_as_string(1)); + assert_eq!("52.412811", lat.value_as_string(2)); + assert_eq!("51.481583", lat.value_as_string(3)); + assert_eq!("12.123456", lat.value_as_string(4)); + assert_eq!("50.760000", lat.value_as_string(5)); + assert_eq!("0.123000", lat.value_as_string(6)); + assert_eq!("123.000000", lat.value_as_string(7)); + assert_eq!("123.000000", lat.value_as_string(8)); + assert_eq!("-50.760000", lat.value_as_string(9)); + + let lng = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("-3.335724", lng.value_as_string(0)); + assert_eq!("-2.179404", lng.value_as_string(1)); + assert_eq!("-1.778197", lng.value_as_string(2)); + assert_eq!("-3.179090", lng.value_as_string(3)); + assert_eq!("-3.179090", lng.value_as_string(4)); + assert_eq!("0.290472", lng.value_as_string(5)); + assert_eq!("0.290472", lng.value_as_string(6)); + assert_eq!("0.290472", lng.value_as_string(7)); + assert_eq!("0.290472", lng.value_as_string(8)); + assert_eq!("0.290472", lng.value_as_string(9)); + } + + #[test] + fn test_csv_from_buf_reader() { + let schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ]); + + let file_with_headers = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + let file_without_headers = File::open("test/data/uk_cities.csv").unwrap(); + let both_files = file_with_headers + .chain(Cursor::new("\n".to_string())) + .chain(file_without_headers); + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_header(true) + .build(both_files) + .unwrap(); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(74, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + } + + #[test] + fn test_csv_with_schema_inference() { + let mut file = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + + let (schema, _) = Format::default() + .with_header(true) + .infer_schema(&mut file, None) + .unwrap(); + + file.rewind().unwrap(); + let builder = ReaderBuilder::new(Arc::new(schema)).with_header(true); + + let mut csv = builder.build(file).unwrap(); + let expected_schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, true), + Field::new("lat", DataType::Float64, true), + Field::new("lng", DataType::Float64, true), + ]); + assert_eq!(Arc::new(expected_schema), csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(57.653484, lat.value(0)); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + } + + #[test] + fn test_csv_with_schema_inference_no_headers() { + let mut file = File::open("test/data/uk_cities.csv").unwrap(); + + let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let mut csv = ReaderBuilder::new(Arc::new(schema)).build(file).unwrap(); + + // csv field names should be 'column_{number}' + let schema = csv.schema(); + assert_eq!("column_1", schema.field(0).name()); + assert_eq!("column_2", schema.field(1).name()); + assert_eq!("column_3", schema.field(2).name()); + let batch = csv.next().unwrap().unwrap(); + let batch_schema = batch.schema(); + + assert_eq!(schema, batch_schema); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(57.653484, lat.value(0)); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + } + + #[test] + fn test_csv_builder_with_bounds() { + let mut file = File::open("test/data/uk_cities.csv").unwrap(); + + // Set the bounds to the lines 0, 1 and 2. + let (schema, _) = Format::default().infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_bounds(0, 2) + .build(file) + .unwrap(); + let batch = csv.next().unwrap().unwrap(); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // The value on line 0 is within the bounds + assert_eq!("Elgin, Scotland, the UK", city.value(0)); + + // The value on line 13 is outside of the bounds. Therefore + // the call to .value() will panic. + let result = std::panic::catch_unwind(|| city.value(13)); + assert!(result.is_err()); + } + + #[test] + fn test_csv_with_projection() { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_projection(vec![0, 1]) + .build(file) + .unwrap(); + + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + ])); + assert_eq!(projected_schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(projected_schema, batch.schema()); + assert_eq!(37, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + } + + #[test] + fn test_csv_with_dictionary() { + let schema = Arc::new(Schema::new(vec![ + Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ])); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_projection(vec![0, 1]) + .build(file) + .unwrap(); + + let projected_schema = Arc::new(Schema::new(vec![ + Field::new_dictionary("city", DataType::Int32, DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + ])); + assert_eq!(projected_schema, csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(projected_schema, batch.schema()); + assert_eq!(37, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + + let strings = arrow_cast::cast(batch.column(0), &DataType::Utf8).unwrap(); + let strings = strings.as_string::(); + + assert_eq!(strings.value(0), "Elgin, Scotland, the UK"); + assert_eq!(strings.value(4), "Eastbourne, East Sussex, UK"); + assert_eq!(strings.value(29), "Uckfield, East Sussex, UK"); + } + + #[test] + fn test_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, false), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, false), + ])); + + let file = File::open("test/data/null_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(!batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_init_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + Field::new("c_null", DataType::Null, true), + ])); + let file = File::open("test/data/init_null_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_init_nulls_with_inference() { + let format = Format::default().with_header(true).with_delimiter(b','); + + let mut file = File::open("test/data/init_null_test.csv").unwrap(); + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c_int", DataType::Int64, true), + Field::new("c_float", DataType::Float64, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + Field::new("c_null", DataType::Null, true), + ]); + assert_eq!(schema, expected_schema); + + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + assert!(batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_custom_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ])); + + let file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .with_header(true) + .with_null_regex(null_regex) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + // "nil"s should be NULL + assert!(batch.column(0).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(batch.column(3).is_null(4)); + assert!(batch.column(2).is_null(3)); + assert!(!batch.column(2).is_null(4)); + } + + #[test] + fn test_nulls_with_inference() { + let mut file = File::open("test/data/various_types.csv").unwrap(); + let format = Format::default().with_header(true).with_delimiter(b'|'); + + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3, 4, 5]); + + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + assert_eq!(7, batch.num_rows()); + assert_eq!(6, batch.num_columns()); + + let schema = batch.schema(); + + assert_eq!(&DataType::Int64, schema.field(0).data_type()); + assert_eq!(&DataType::Float64, schema.field(1).data_type()); + assert_eq!(&DataType::Float64, schema.field(2).data_type()); + assert_eq!(&DataType::Boolean, schema.field(3).data_type()); + assert_eq!(&DataType::Date32, schema.field(4).data_type()); + assert_eq!( + &DataType::Timestamp(TimeUnit::Second, None), + schema.field(5).data_type() + ); + + let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect(); + assert_eq!( + names, + vec![ + "c_int", + "c_float", + "c_string", + "c_bool", + "c_date", + "c_datetime" + ] + ); + + assert!(schema.field(0).is_nullable()); + assert!(schema.field(1).is_nullable()); + assert!(schema.field(2).is_nullable()); + assert!(schema.field(3).is_nullable()); + assert!(schema.field(4).is_nullable()); + assert!(schema.field(5).is_nullable()); + + assert!(!batch.column(1).is_null(0)); + assert!(!batch.column(1).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(!batch.column(1).is_null(3)); + assert!(!batch.column(1).is_null(4)); + } + + #[test] + fn test_custom_nulls_with_inference() { + let mut file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let format = Format::default() + .with_header(true) + .with_null_regex(null_regex); + + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c_int", DataType::Int64, true), + Field::new("c_float", DataType::Float64, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ]); + + assert_eq!(schema, expected_schema); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3]); + + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + assert_eq!(5, batch.num_rows()); + assert_eq!(4, batch.num_columns()); + + assert_eq!(batch.schema().as_ref(), &expected_schema); + } + + #[test] + fn test_parse_invalid_csv() { + let file = File::open("test/data/various_types_invalid.csv").unwrap(); + + let schema = Schema::new(vec![ + Field::new("c_int", DataType::UInt64, false), + Field::new("c_float", DataType::Float32, false), + Field::new("c_string", DataType::Utf8, false), + Field::new("c_bool", DataType::Boolean, false), + ]); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(true) + .with_delimiter(b'|') + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3]); + + let mut csv = builder.build(file).unwrap(); + match csv.next() { + Some(e) => match e { + Err(e) => assert_eq!( + "ParseError(\"Error while parsing value 4.x4 for column 1 at line 4\")", + format!("{e:?}") + ), + Ok(_) => panic!("should have failed"), + }, + None => panic!("should have failed"), + } + } + + /// Infer the data type of a record + fn infer_field_schema(string: &str) -> DataType { + let mut v = InferredDataType::default(); + v.update(string); + v.get() + } + + #[test] + fn test_infer_field_schema() { + assert_eq!(infer_field_schema("A"), DataType::Utf8); + assert_eq!(infer_field_schema("\"123\""), DataType::Utf8); + assert_eq!(infer_field_schema("10"), DataType::Int64); + assert_eq!(infer_field_schema("10.2"), DataType::Float64); + assert_eq!(infer_field_schema(".2"), DataType::Float64); + assert_eq!(infer_field_schema("2."), DataType::Float64); + assert_eq!(infer_field_schema("true"), DataType::Boolean); + assert_eq!(infer_field_schema("trUe"), DataType::Boolean); + assert_eq!(infer_field_schema("false"), DataType::Boolean); + assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32); + assert_eq!( + infer_field_schema("2020-11-08T14:20:01"), + DataType::Timestamp(TimeUnit::Second, None) + ); + assert_eq!( + infer_field_schema("2020-11-08 14:20:01"), + DataType::Timestamp(TimeUnit::Second, None) + ); + assert_eq!( + infer_field_schema("2020-11-08 14:20:01"), + DataType::Timestamp(TimeUnit::Second, None) + ); + assert_eq!(infer_field_schema("-5.13"), DataType::Float64); + assert_eq!(infer_field_schema("0.1300"), DataType::Float64); + assert_eq!( + infer_field_schema("2021-12-19 13:12:30.921"), + DataType::Timestamp(TimeUnit::Millisecond, None) + ); + assert_eq!( + infer_field_schema("2021-12-19T13:12:30.123456789"), + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + } + + #[test] + fn parse_date32() { + assert_eq!(Date32Type::parse("1970-01-01").unwrap(), 0); + assert_eq!(Date32Type::parse("2020-03-15").unwrap(), 18336); + assert_eq!(Date32Type::parse("1945-05-08").unwrap(), -9004); + } + + #[test] + fn parse_time() { + assert_eq!( + Time64NanosecondType::parse("12:10:01.123456789 AM"), + Some(601_123_456_789) + ); + assert_eq!( + Time64MicrosecondType::parse("12:10:01.123456 am"), + Some(601_123_456) + ); + assert_eq!( + Time32MillisecondType::parse("2:10:01.12 PM"), + Some(51_001_120) + ); + assert_eq!(Time32SecondType::parse("2:10:01 pm"), Some(51_001)); + } + + #[test] + fn parse_date64() { + assert_eq!(Date64Type::parse("1970-01-01T00:00:00").unwrap(), 0); + assert_eq!( + Date64Type::parse("2018-11-13T17:11:10").unwrap(), + 1542129070000 + ); + assert_eq!( + Date64Type::parse("2018-11-13T17:11:10.011").unwrap(), + 1542129070011 + ); + assert_eq!( + Date64Type::parse("1900-02-28T12:34:56").unwrap(), + -2203932304000 + ); + assert_eq!( + Date64Type::parse_formatted("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S").unwrap(), + -2203932304000 + ); + assert_eq!( + Date64Type::parse_formatted("1900-02-28 12:34:56+0030", "%Y-%m-%d %H:%M:%S%z").unwrap(), + -2203932304000 - (30 * 60 * 1000) + ); + } + + fn test_parse_timestamp_impl( + timezone: Option>, + expected: &[i64], + ) { + let csv = [ + "1970-01-01T00:00:00", + "1970-01-01T00:00:00Z", + "1970-01-01T00:00:00+02:00", + ] + .join("\n"); + let schema = Arc::new(Schema::new(vec![Field::new( + "field", + DataType::Timestamp(T::UNIT, timezone.clone()), + true, + )])); + + let mut decoder = ReaderBuilder::new(schema).build_decoder(); + + let decoded = decoder.decode(csv.as_bytes()).unwrap(); + assert_eq!(decoded, csv.len()); + decoder.decode(&[]).unwrap(); + + let batch = decoder.flush().unwrap().unwrap(); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.num_rows(), 3); + let col = batch.column(0).as_primitive::(); + assert_eq!(col.values(), expected); + assert_eq!(col.data_type(), &DataType::Timestamp(T::UNIT, timezone)); + } + + #[test] + fn test_parse_timestamp() { + test_parse_timestamp_impl::(None, &[0, 0, -7_200_000_000_000]); + test_parse_timestamp_impl::( + Some("+00:00".into()), + &[0, 0, -7_200_000_000_000], + ); + test_parse_timestamp_impl::( + Some("-05:00".into()), + &[18_000_000_000_000, 0, -7_200_000_000_000], + ); + test_parse_timestamp_impl::( + Some("-03".into()), + &[10_800_000_000, 0, -7_200_000_000], + ); + test_parse_timestamp_impl::( + Some("-03".into()), + &[10_800_000, 0, -7_200_000], + ); + test_parse_timestamp_impl::(Some("-03".into()), &[10_800, 0, -7_200]); + } + + #[test] + fn test_infer_schema_from_multiple_files() { + let mut csv1 = NamedTempFile::new().unwrap(); + let mut csv2 = NamedTempFile::new().unwrap(); + let csv3 = NamedTempFile::new().unwrap(); // empty csv file should be skipped + let mut csv4 = NamedTempFile::new().unwrap(); + writeln!(csv1, "c1,c2,c3").unwrap(); + writeln!(csv1, "1,\"foo\",0.5").unwrap(); + writeln!(csv1, "3,\"bar\",1").unwrap(); + writeln!(csv1, "3,\"bar\",2e-06").unwrap(); + // reading csv2 will set c2 to optional + writeln!(csv2, "c1,c2,c3,c4").unwrap(); + writeln!(csv2, "10,,3.14,true").unwrap(); + // reading csv4 will set c3 to optional + writeln!(csv4, "c1,c2,c3").unwrap(); + writeln!(csv4, "10,\"foo\",").unwrap(); + + let schema = infer_schema_from_files( + &[ + csv3.path().to_str().unwrap().to_string(), + csv1.path().to_str().unwrap().to_string(), + csv2.path().to_str().unwrap().to_string(), + csv4.path().to_str().unwrap().to_string(), + ], + b',', + Some(4), // only csv1 and csv2 should be read + true, + ) + .unwrap(); + + assert_eq!(schema.fields().len(), 4); + assert!(schema.field(0).is_nullable()); + assert!(schema.field(1).is_nullable()); + assert!(schema.field(2).is_nullable()); + assert!(schema.field(3).is_nullable()); + + assert_eq!(&DataType::Int64, schema.field(0).data_type()); + assert_eq!(&DataType::Utf8, schema.field(1).data_type()); + assert_eq!(&DataType::Float64, schema.field(2).data_type()); + assert_eq!(&DataType::Boolean, schema.field(3).data_type()); + } + + #[test] + fn test_bounded() { + let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]); + let data = vec![ + vec!["0"], + vec!["1"], + vec!["2"], + vec!["3"], + vec!["4"], + vec!["5"], + vec!["6"], + ]; + + let data = data + .iter() + .map(|x| x.join(",")) + .collect::>() + .join("\n"); + let data = data.as_bytes(); + + let reader = std::io::Cursor::new(data); + + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_batch_size(2) + .with_projection(vec![0]) + .with_bounds(2, 6) + .build_buffered(reader) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + let a = batch.column(0); + let a = a.as_any().downcast_ref::().unwrap(); + assert_eq!(a, &UInt32Array::from(vec![2, 3])); + + let batch = csv.next().unwrap().unwrap(); + let a = batch.column(0); + let a = a.as_any().downcast_ref::().unwrap(); + assert_eq!(a, &UInt32Array::from(vec![4, 5])); + + assert!(csv.next().is_none()); + } + + #[test] + fn test_empty_projection() { + let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]); + let data = vec![vec!["0"], vec!["1"]]; + + let data = data + .iter() + .map(|x| x.join(",")) + .collect::>() + .join("\n"); + + let mut csv = ReaderBuilder::new(Arc::new(schema)) + .with_batch_size(2) + .with_projection(vec![]) + .build_buffered(Cursor::new(data.as_bytes())) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + assert_eq!(batch.columns().len(), 0); + assert_eq!(batch.num_rows(), 2); + + assert!(csv.next().is_none()); + } + + #[test] + fn test_parsing_bool() { + // Encode the expected behavior of boolean parsing + assert_eq!(Some(true), parse_bool("true")); + assert_eq!(Some(true), parse_bool("tRUe")); + assert_eq!(Some(true), parse_bool("True")); + assert_eq!(Some(true), parse_bool("TRUE")); + assert_eq!(None, parse_bool("t")); + assert_eq!(None, parse_bool("T")); + assert_eq!(None, parse_bool("")); + + assert_eq!(Some(false), parse_bool("false")); + assert_eq!(Some(false), parse_bool("fALse")); + assert_eq!(Some(false), parse_bool("False")); + assert_eq!(Some(false), parse_bool("FALSE")); + assert_eq!(None, parse_bool("f")); + assert_eq!(None, parse_bool("F")); + assert_eq!(None, parse_bool("")); + } + + #[test] + fn test_parsing_float() { + assert_eq!(Some(12.34), Float64Type::parse("12.34")); + assert_eq!(Some(-12.34), Float64Type::parse("-12.34")); + assert_eq!(Some(12.0), Float64Type::parse("12")); + assert_eq!(Some(0.0), Float64Type::parse("0")); + assert_eq!(Some(2.0), Float64Type::parse("2.")); + assert_eq!(Some(0.2), Float64Type::parse(".2")); + assert!(Float64Type::parse("nan").unwrap().is_nan()); + assert!(Float64Type::parse("NaN").unwrap().is_nan()); + assert!(Float64Type::parse("inf").unwrap().is_infinite()); + assert!(Float64Type::parse("inf").unwrap().is_sign_positive()); + assert!(Float64Type::parse("-inf").unwrap().is_infinite()); + assert!(Float64Type::parse("-inf").unwrap().is_sign_negative()); + assert_eq!(None, Float64Type::parse("")); + assert_eq!(None, Float64Type::parse("dd")); + assert_eq!(None, Float64Type::parse("12.34.56")); + } + + #[test] + fn test_non_std_quote() { + let schema = Schema::new(vec![ + Field::new("text1", DataType::Utf8, false), + Field::new("text2", DataType::Utf8, false), + ]); + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(false) + .with_quote(b'~'); // default is ", change to ~ + + let mut csv_text = Vec::new(); + let mut csv_writer = std::io::Cursor::new(&mut csv_text); + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value{index:}"); + csv_writer + .write_fmt(format_args!("~{text1}~,~{text2}~\r\n")) + .unwrap(); + } + let mut csv_reader = std::io::Cursor::new(&csv_text); + let mut reader = builder.build(&mut csv_reader).unwrap(); + let batch = reader.next().unwrap().unwrap(); + let col0 = batch.column(0); + assert_eq!(col0.len(), 10); + let col0_arr = col0.as_any().downcast_ref::().unwrap(); + assert_eq!(col0_arr.value(0), "id0"); + let col1 = batch.column(1); + assert_eq!(col1.len(), 10); + let col1_arr = col1.as_any().downcast_ref::().unwrap(); + assert_eq!(col1_arr.value(5), "value5"); + } + + #[test] + fn test_non_std_escape() { + let schema = Schema::new(vec![ + Field::new("text1", DataType::Utf8, false), + Field::new("text2", DataType::Utf8, false), + ]); + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(false) + .with_escape(b'\\'); // default is None, change to \ + + let mut csv_text = Vec::new(); + let mut csv_writer = std::io::Cursor::new(&mut csv_text); + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value\\\"{index:}"); + csv_writer + .write_fmt(format_args!("\"{text1}\",\"{text2}\"\r\n")) + .unwrap(); + } + let mut csv_reader = std::io::Cursor::new(&csv_text); + let mut reader = builder.build(&mut csv_reader).unwrap(); + let batch = reader.next().unwrap().unwrap(); + let col0 = batch.column(0); + assert_eq!(col0.len(), 10); + let col0_arr = col0.as_any().downcast_ref::().unwrap(); + assert_eq!(col0_arr.value(0), "id0"); + let col1 = batch.column(1); + assert_eq!(col1.len(), 10); + let col1_arr = col1.as_any().downcast_ref::().unwrap(); + assert_eq!(col1_arr.value(5), "value\"5"); + } + + #[test] + fn test_non_std_terminator() { + let schema = Schema::new(vec![ + Field::new("text1", DataType::Utf8, false), + Field::new("text2", DataType::Utf8, false), + ]); + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_header(false) + .with_terminator(b'\n'); // default is CRLF, change to LF + + let mut csv_text = Vec::new(); + let mut csv_writer = std::io::Cursor::new(&mut csv_text); + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value{index:}"); + csv_writer + .write_fmt(format_args!("\"{text1}\",\"{text2}\"\n")) + .unwrap(); + } + let mut csv_reader = std::io::Cursor::new(&csv_text); + let mut reader = builder.build(&mut csv_reader).unwrap(); + let batch = reader.next().unwrap().unwrap(); + let col0 = batch.column(0); + assert_eq!(col0.len(), 10); + let col0_arr = col0.as_any().downcast_ref::().unwrap(); + assert_eq!(col0_arr.value(0), "id0"); + let col1 = batch.column(1); + assert_eq!(col1.len(), 10); + let col1_arr = col1.as_any().downcast_ref::().unwrap(); + assert_eq!(col1_arr.value(5), "value5"); + } + + #[test] + fn test_header_bounds() { + let csv = "a,b\na,b\na,b\na,b\na,b\n"; + let tests = [ + (None, false, 5), + (None, true, 4), + (Some((0, 4)), false, 4), + (Some((1, 4)), false, 3), + (Some((0, 4)), true, 4), + (Some((1, 4)), true, 3), + ]; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("a", DataType::Utf8, false), + ])); + + for (idx, (bounds, has_header, expected)) in tests.into_iter().enumerate() { + let mut reader = ReaderBuilder::new(schema.clone()).with_header(has_header); + if let Some((start, end)) = bounds { + reader = reader.with_bounds(start, end); + } + let b = reader + .build_buffered(Cursor::new(csv.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + assert_eq!(b.num_rows(), expected, "{idx}"); + } + } + + #[test] + fn test_null_boolean() { + let csv = "true,false\nFalse,True\n,True\nFalse,"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Boolean, true), + Field::new("a", DataType::Boolean, true), + ])); + + let b = ReaderBuilder::new(schema) + .build_buffered(Cursor::new(csv.as_bytes())) + .unwrap() + .next() + .unwrap() + .unwrap(); + + assert_eq!(b.num_rows(), 4); + assert_eq!(b.num_columns(), 2); + + let c = b.column(0).as_boolean(); + assert_eq!(c.null_count(), 1); + assert!(c.value(0)); + assert!(!c.value(1)); + assert!(c.is_null(2)); + assert!(!c.value(3)); + + let c = b.column(1).as_boolean(); + assert_eq!(c.null_count(), 1); + assert!(!c.value(0)); + assert!(c.value(1)); + assert!(c.value(2)); + assert!(c.is_null(3)); + } + + #[test] + fn test_buffered() { + let tests = [ + ("test/data/uk_cities.csv", false, 37), + ("test/data/various_types.csv", true, 7), + ("test/data/decimal_test.csv", false, 10), + ]; + + for (path, has_header, expected_rows) in tests { + let (schema, _) = Format::default() + .infer_schema(File::open(path).unwrap(), None) + .unwrap(); + let schema = Arc::new(schema); + + for batch_size in [1, 4] { + for capacity in [1, 3, 7, 100] { + let reader = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .with_header(has_header) + .build(File::open(path).unwrap()) + .unwrap(); + + let expected = reader.collect::, _>>().unwrap(); + + assert_eq!( + expected.iter().map(|x| x.num_rows()).sum::(), + expected_rows + ); + + let buffered = + std::io::BufReader::with_capacity(capacity, File::open(path).unwrap()); + + let reader = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .with_header(has_header) + .build_buffered(buffered) + .unwrap(); + + let actual = reader.collect::, _>>().unwrap(); + assert_eq!(expected, actual) + } + } + } + } + + fn err_test(csv: &[u8], expected: &str) { + let schema = Arc::new(Schema::new(vec![ + Field::new("text1", DataType::Utf8, true), + Field::new("text2", DataType::Utf8, true), + ])); + let buffer = std::io::BufReader::with_capacity(2, Cursor::new(csv)); + let b = ReaderBuilder::new(schema) + .with_batch_size(2) + .build_buffered(buffer) + .unwrap(); + let err = b.collect::, _>>().unwrap_err().to_string(); + assert_eq!(err, expected) + } + + #[test] + fn test_invalid_utf8() { + err_test( + b"sdf,dsfg\ndfd,hgh\xFFue\n,sds\nFalhghse,", + "Csv error: Encountered invalid UTF-8 data for line 2 and field 2", + ); + + err_test( + b"sdf,dsfg\ndksdk,jf\nd\xFFfd,hghue\n,sds\nFalhghse,", + "Csv error: Encountered invalid UTF-8 data for line 3 and field 1", + ); + + err_test( + b"sdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF", + "Csv error: Encountered invalid UTF-8 data for line 5 and field 2", + ); + + err_test( + b"\xFFsdf,dsfg\ndksdk,jf\ndsdsfd,hghue\n,sds\nFalhghse,\xFF", + "Csv error: Encountered invalid UTF-8 data for line 1 and field 1", + ); + } + + struct InstrumentedRead { + r: R, + fill_count: usize, + fill_sizes: Vec, + } + + impl InstrumentedRead { + fn new(r: R) -> Self { + Self { + r, + fill_count: 0, + fill_sizes: vec![], + } + } + } + + impl Seek for InstrumentedRead { + fn seek(&mut self, pos: SeekFrom) -> std::io::Result { + self.r.seek(pos) + } + } + + impl Read for InstrumentedRead { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.r.read(buf) + } + } + + impl BufRead for InstrumentedRead { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + self.fill_count += 1; + let buf = self.r.fill_buf()?; + self.fill_sizes.push(buf.len()); + Ok(buf) + } + + fn consume(&mut self, amt: usize) { + self.r.consume(amt) + } + } + + #[test] + fn test_io() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let csv = "foo,bar\nbaz,foo\na,b\nc,d"; + let mut read = InstrumentedRead::new(Cursor::new(csv.as_bytes())); + let reader = ReaderBuilder::new(schema) + .with_batch_size(3) + .build_buffered(&mut read) + .unwrap(); + + let batches = reader.collect::, _>>().unwrap(); + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 3); + assert_eq!(batches[1].num_rows(), 1); + + // Expect 4 calls to fill_buf + // 1. Read first 3 rows + // 2. Read final row + // 3. Delimit and flush final row + // 4. Iterator finished + assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]); + assert_eq!(read.fill_count, 4); + } + + #[test] + fn test_inference() { + let cases: &[(&[&str], DataType)] = &[ + (&[], DataType::Null), + (&["false", "12"], DataType::Utf8), + (&["12", "cupcakes"], DataType::Utf8), + (&["12", "12.4"], DataType::Float64), + (&["14050", "24332"], DataType::Int64), + (&["14050.0", "true"], DataType::Utf8), + (&["14050", "2020-03-19 00:00:00"], DataType::Utf8), + (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8), + ( + &["2020-03-19 02:00:00", "2020-03-19 00:00:00"], + DataType::Timestamp(TimeUnit::Second, None), + ), + (&["2020-03-19", "2020-03-20"], DataType::Date32), + ( + &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"], + DataType::Timestamp(TimeUnit::Second, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00", + "2020-03-19 00:00:00.000", + ], + DataType::Timestamp(TimeUnit::Millisecond, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00", + "2020-03-19 00:00:00.000000", + ], + DataType::Timestamp(TimeUnit::Microsecond, None), + ), + ( + &["2020-03-19 02:00:00+02:00", "2020-03-19 02:00:00Z"], + DataType::Timestamp(TimeUnit::Second, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00+02:00", + "2020-03-19 02:00:00Z", + "2020-03-19 02:00:00.12Z", + ], + DataType::Timestamp(TimeUnit::Millisecond, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00.000000000", + "2020-03-19 00:00:00.000000", + ], + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ]; + + for (values, expected) in cases { + let mut t = InferredDataType::default(); + for v in *values { + t.update(v) + } + assert_eq!(&t.get(), expected, "{values:?}") + } + } +} diff --git a/arrow-csv/src/reader/records.rs b/arrow-csv/src/reader/records.rs new file mode 100644 index 000000000000..877cfb3ee653 --- /dev/null +++ b/arrow-csv/src/reader/records.rs @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ArrowError; +use csv_core::{ReadRecordResult, Reader}; + +/// The estimated length of a field in bytes +const AVERAGE_FIELD_SIZE: usize = 8; + +/// The minimum amount of data in a single read +const MIN_CAPACITY: usize = 1024; + +/// [`RecordDecoder`] provides a push-based interface to decoder [`StringRecords`] +#[derive(Debug)] +pub struct RecordDecoder { + delimiter: Reader, + + /// The expected number of fields per row + num_columns: usize, + + /// The current line number + line_number: usize, + + /// Offsets delimiting field start positions + offsets: Vec, + + /// The current offset into `self.offsets` + /// + /// We track this independently of Vec to avoid re-zeroing memory + offsets_len: usize, + + /// The number of fields read for the current record + current_field: usize, + + /// The number of rows buffered + num_rows: usize, + + /// Decoded field data + data: Vec, + + /// Offsets into data + /// + /// We track this independently of Vec to avoid re-zeroing memory + data_len: usize, +} + +impl RecordDecoder { + pub fn new(delimiter: Reader, num_columns: usize) -> Self { + Self { + delimiter, + num_columns, + line_number: 1, + offsets: vec![], + offsets_len: 1, // The first offset is always 0 + current_field: 0, + data_len: 0, + data: vec![], + num_rows: 0, + } + } + + /// Decodes records from `input` returning the number of records and bytes read + /// + /// Note: this expects to be called with an empty `input` to signal EOF + pub fn decode(&mut self, input: &[u8], to_read: usize) -> Result<(usize, usize), ArrowError> { + if to_read == 0 { + return Ok((0, 0)); + } + + // Reserve sufficient capacity in offsets + self.offsets + .resize(self.offsets_len + to_read * self.num_columns, 0); + + // The current offset into `input` + let mut input_offset = 0; + + // The number of rows decoded in this pass + let mut read = 0; + + loop { + // Reserve necessary space in output data based on best estimate + let remaining_rows = to_read - read; + let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE; + let estimated_data = capacity.max(MIN_CAPACITY); + self.data.resize(self.data_len + estimated_data, 0); + + // Try to read a record + loop { + let (result, bytes_read, bytes_written, end_positions) = + self.delimiter.read_record( + &input[input_offset..], + &mut self.data[self.data_len..], + &mut self.offsets[self.offsets_len..], + ); + + self.current_field += end_positions; + self.offsets_len += end_positions; + input_offset += bytes_read; + self.data_len += bytes_written; + + match result { + ReadRecordResult::End | ReadRecordResult::InputEmpty => { + // Reached end of input + return Ok((read, input_offset)); + } + // Need to allocate more capacity + ReadRecordResult::OutputFull => break, + ReadRecordResult::OutputEndsFull => { + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got more than {}", + self.line_number, self.num_columns, self.current_field + ))); + } + ReadRecordResult::Record => { + if self.current_field != self.num_columns { + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got {}", + self.line_number, self.num_columns, self.current_field + ))); + } + read += 1; + self.current_field = 0; + self.line_number += 1; + self.num_rows += 1; + + if read == to_read { + // Read sufficient rows + return Ok((read, input_offset)); + } + + if input.len() == input_offset { + // Input exhausted, need to read more + // Without this read_record will interpret the empty input + // byte array as indicating the end of the file + return Ok((read, input_offset)); + } + } + } + } + } + } + + /// Returns the current number of buffered records + pub fn len(&self) -> usize { + self.num_rows + } + + /// Returns true if the decoder is empty + pub fn is_empty(&self) -> bool { + self.num_rows == 0 + } + + /// Clears the current contents of the decoder + pub fn clear(&mut self) { + // This does not reset current_field to allow clearing part way through a record + self.offsets_len = 1; + self.data_len = 0; + self.num_rows = 0; + } + + /// Flushes the current contents of the reader + pub fn flush(&mut self) -> Result, ArrowError> { + if self.current_field != 0 { + return Err(ArrowError::CsvError( + "Cannot flush part way through record".to_string(), + )); + } + + // csv_core::Reader writes end offsets relative to the start of the row + // Therefore scan through and offset these based on the cumulative row offsets + let mut row_offset = 0; + self.offsets[1..self.offsets_len] + .chunks_exact_mut(self.num_columns) + .for_each(|row| { + let offset = row_offset; + row.iter_mut().for_each(|x| { + *x += offset; + row_offset = *x; + }); + }); + + // Need to truncate data t1o the actual amount of data read + let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| { + let valid_up_to = e.valid_up_to(); + + // We can't use binary search because of empty fields + let idx = self.offsets[..self.offsets_len] + .iter() + .rposition(|x| *x <= valid_up_to) + .unwrap(); + + let field = idx % self.num_columns + 1; + let line_offset = self.line_number - self.num_rows; + let line = line_offset + idx / self.num_columns; + + ArrowError::CsvError(format!( + "Encountered invalid UTF-8 data for line {line} and field {field}" + )) + })?; + + let offsets = &self.offsets[..self.offsets_len]; + let num_rows = self.num_rows; + + // Reset state + self.offsets_len = 1; + self.data_len = 0; + self.num_rows = 0; + + Ok(StringRecords { + num_rows, + num_columns: self.num_columns, + offsets, + data, + }) + } +} + +/// A collection of parsed, UTF-8 CSV records +#[derive(Debug)] +pub struct StringRecords<'a> { + num_columns: usize, + num_rows: usize, + offsets: &'a [usize], + data: &'a str, +} + +impl<'a> StringRecords<'a> { + fn get(&self, index: usize) -> StringRecord<'a> { + let field_idx = index * self.num_columns; + StringRecord { + data: self.data, + offsets: &self.offsets[field_idx..field_idx + self.num_columns + 1], + } + } + + pub fn len(&self) -> usize { + self.num_rows + } + + pub fn iter(&self) -> impl Iterator> + '_ { + (0..self.num_rows).map(|x| self.get(x)) + } +} + +/// A single parsed, UTF-8 CSV record +#[derive(Debug, Clone, Copy)] +pub struct StringRecord<'a> { + data: &'a str, + offsets: &'a [usize], +} + +impl<'a> StringRecord<'a> { + pub fn get(&self, index: usize) -> &'a str { + let end = self.offsets[index + 1]; + let start = self.offsets[index]; + + // SAFETY: + // Parsing produces offsets at valid byte boundaries + unsafe { self.data.get_unchecked(start..end) } + } +} + +#[cfg(test)] +mod tests { + use crate::reader::records::RecordDecoder; + use csv_core::Reader; + use std::io::{BufRead, BufReader, Cursor}; + + #[test] + fn test_basic() { + let csv = [ + "foo,bar,baz", + "a,b,c", + "12,3,5", + "\"asda\"\"asas\",\"sdffsnsd\", as", + ] + .join("\n"); + + let mut expected = vec![ + vec!["foo", "bar", "baz"], + vec!["a", "b", "c"], + vec!["12", "3", "5"], + vec!["asda\"asas", "sdffsnsd", " as"], + ] + .into_iter(); + + let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes())); + let mut decoder = RecordDecoder::new(Reader::new(), 3); + + loop { + let to_read = 3; + let mut read = 0; + loop { + let buf = reader.fill_buf().unwrap(); + let (records, bytes) = decoder.decode(buf, to_read - read).unwrap(); + + reader.consume(bytes); + read += records; + + if read == to_read || bytes == 0 { + break; + } + } + if read == 0 { + break; + } + + let b = decoder.flush().unwrap(); + b.iter().zip(&mut expected).for_each(|(record, expected)| { + let actual = (0..3) + .map(|field_idx| record.get(field_idx)) + .collect::>(); + assert_eq!(actual, expected) + }); + } + assert!(expected.next().is_none()); + } + + #[test] + fn test_invalid_fields() { + let csv = "a,b\nb,c\na\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 2); + let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string(); + + let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1"; + + assert_eq!(err, expected); + + // Test with initial skip + let mut decoder = RecordDecoder::new(Reader::new(), 2); + let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap(); + assert_eq!(skipped, 1); + decoder.clear(); + + let remaining = &csv.as_bytes()[bytes..]; + let err = decoder.decode(remaining, 3).unwrap_err().to_string(); + assert_eq!(err, expected); + } + + #[test] + fn test_skip_insufficient_rows() { + let csv = "a\nv\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 1); + let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap(); + assert_eq!(read, 2); + assert_eq!(bytes, csv.len()); + } +} diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs new file mode 100644 index 000000000000..a31a1d5e8c13 --- /dev/null +++ b/arrow-csv/src/writer.rs @@ -0,0 +1,762 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CSV Writer +//! +//! This CSV writer allows Arrow data (in record batches) to be written as CSV files. +//! The writer does not support writing `ListArray` and `StructArray`. +//! +//! Example: +//! +//! ``` +//! # use arrow_array::*; +//! # use arrow_array::types::*; +//! # use arrow_csv::Writer; +//! # use arrow_schema::*; +//! # use std::sync::Arc; +//! +//! let schema = Schema::new(vec![ +//! Field::new("c1", DataType::Utf8, false), +//! Field::new("c2", DataType::Float64, true), +//! Field::new("c3", DataType::UInt32, false), +//! Field::new("c4", DataType::Boolean, true), +//! ]); +//! let c1 = StringArray::from(vec![ +//! "Lorem ipsum dolor sit amet", +//! "consectetur adipiscing elit", +//! "sed do eiusmod tempor", +//! ]); +//! let c2 = PrimitiveArray::::from(vec![ +//! Some(123.564532), +//! None, +//! Some(-556132.25), +//! ]); +//! let c3 = PrimitiveArray::::from(vec![3, 2, 1]); +//! let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); +//! +//! let batch = RecordBatch::try_new( +//! Arc::new(schema), +//! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], +//! ) +//! .unwrap(); +//! +//! let mut output = Vec::with_capacity(1024); +//! +//! let mut writer = Writer::new(&mut output); +//! let batches = vec![&batch, &batch]; +//! for batch in batches { +//! writer.write(batch).unwrap(); +//! } +//! ``` + +use arrow_array::*; +use arrow_cast::display::*; +use arrow_schema::*; +use csv::ByteRecord; +use std::io::Write; + +use crate::map_csv_error; +const DEFAULT_NULL_VALUE: &str = ""; + +/// A CSV writer +#[derive(Debug)] +pub struct Writer { + /// The object to write to + writer: csv::Writer, + /// Whether file should be written with headers, defaults to `true` + has_headers: bool, + /// The date format for date arrays, defaults to RFC3339 + date_format: Option, + /// The datetime format for datetime arrays, defaults to RFC3339 + datetime_format: Option, + /// The timestamp format for timestamp arrays, defaults to RFC3339 + timestamp_format: Option, + /// The timestamp format for timestamp (with timezone) arrays, defaults to RFC3339 + timestamp_tz_format: Option, + /// The time format for time arrays, defaults to RFC3339 + time_format: Option, + /// Is the beginning-of-writer + beginning: bool, + /// The value to represent null entries, defaults to [`DEFAULT_NULL_VALUE`] + null_value: Option, +} + +impl Writer { + /// Create a new CsvWriter from a writable object, with default options + pub fn new(writer: W) -> Self { + let delimiter = b','; + WriterBuilder::new().with_delimiter(delimiter).build(writer) + } + + /// Write a vector of record batches to a writable object + pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + let num_columns = batch.num_columns(); + if self.beginning { + if self.has_headers { + let mut headers: Vec = Vec::with_capacity(num_columns); + batch + .schema() + .fields() + .iter() + .for_each(|field| headers.push(field.name().to_string())); + self.writer + .write_record(&headers[..]) + .map_err(map_csv_error)?; + } + self.beginning = false; + } + + let options = FormatOptions::default() + .with_null(self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE)) + .with_date_format(self.date_format.as_deref()) + .with_datetime_format(self.datetime_format.as_deref()) + .with_timestamp_format(self.timestamp_format.as_deref()) + .with_timestamp_tz_format(self.timestamp_tz_format.as_deref()) + .with_time_format(self.time_format.as_deref()); + + let converters = batch + .columns() + .iter() + .map(|a| match a.data_type() { + d if d.is_nested() => Err(ArrowError::CsvError(format!( + "Nested type {} is not supported in CSV", + a.data_type() + ))), + DataType::Binary | DataType::LargeBinary => Err(ArrowError::CsvError( + "Binary data cannot be written to CSV".to_string(), + )), + _ => ArrayFormatter::try_new(a.as_ref(), &options), + }) + .collect::, ArrowError>>()?; + + let mut buffer = String::with_capacity(1024); + let mut byte_record = ByteRecord::with_capacity(1024, converters.len()); + + for row_idx in 0..batch.num_rows() { + byte_record.clear(); + for (col_idx, converter) in converters.iter().enumerate() { + buffer.clear(); + converter.value(row_idx).write(&mut buffer).map_err(|e| { + ArrowError::CsvError(format!( + "Error processing row {}, col {}: {e}", + row_idx + 1, + col_idx + 1 + )) + })?; + byte_record.push_field(buffer.as_bytes()); + } + + self.writer + .write_byte_record(&byte_record) + .map_err(map_csv_error)?; + } + self.writer.flush()?; + + Ok(()) + } + + /// Unwraps this `Writer`, returning the underlying writer. + pub fn into_inner(self) -> W { + // Safe to call `unwrap` since `write` always flushes the writer. + self.writer.into_inner().unwrap() + } +} + +impl RecordBatchWriter for Writer { + fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + self.write(batch) + } + + fn close(self) -> Result<(), ArrowError> { + Ok(()) + } +} + +/// A CSV writer builder +#[derive(Clone, Debug)] +pub struct WriterBuilder { + /// Optional column delimiter. Defaults to `b','` + delimiter: u8, + /// Whether to write column names as file headers. Defaults to `true` + has_header: bool, + /// Optional quote character. Defaults to `b'"'` + quote: u8, + /// Optional escape character. Defaults to `b'\\'` + escape: u8, + /// Enable double quote escapes. Defaults to `true` + double_quote: bool, + /// Optional date format for date arrays + date_format: Option, + /// Optional datetime format for datetime arrays + datetime_format: Option, + /// Optional timestamp format for timestamp arrays + timestamp_format: Option, + /// Optional timestamp format for timestamp with timezone arrays + timestamp_tz_format: Option, + /// Optional time format for time arrays + time_format: Option, + /// Optional value to represent null + null_value: Option, +} + +impl Default for WriterBuilder { + fn default() -> Self { + WriterBuilder { + delimiter: b',', + has_header: true, + quote: b'"', + escape: b'\\', + double_quote: true, + date_format: None, + datetime_format: None, + timestamp_format: None, + timestamp_tz_format: None, + time_format: None, + null_value: None, + } + } +} + +impl WriterBuilder { + /// Create a new builder for configuring CSV writing options. + /// + /// To convert a builder into a writer, call `WriterBuilder::build` + /// + /// # Example + /// + /// ``` + /// # use arrow_csv::{Writer, WriterBuilder}; + /// # use std::fs::File; + /// + /// fn example() -> Writer { + /// let file = File::create("target/out.csv").unwrap(); + /// + /// // create a builder that doesn't write headers + /// let builder = WriterBuilder::new().with_header(false); + /// let writer = builder.build(file); + /// + /// writer + /// } + /// ``` + pub fn new() -> Self { + Self::default() + } + + /// Set whether to write headers + #[deprecated(note = "Use Self::with_header")] + #[doc(hidden)] + pub fn has_headers(mut self, has_headers: bool) -> Self { + self.has_header = has_headers; + self + } + + /// Set whether to write the CSV file with a header + pub fn with_header(mut self, header: bool) -> Self { + self.has_header = header; + self + } + + /// Returns `true` if this writer is configured to write a header + pub fn header(&self) -> bool { + self.has_header + } + + /// Set the CSV file's column delimiter as a byte character + pub fn with_delimiter(mut self, delimiter: u8) -> Self { + self.delimiter = delimiter; + self + } + + /// Get the CSV file's column delimiter as a byte character + pub fn delimiter(&self) -> u8 { + self.delimiter + } + + /// Set the CSV file's quote character as a byte character + pub fn with_quote(mut self, quote: u8) -> Self { + self.quote = quote; + self + } + + /// Get the CSV file's quote character as a byte character + pub fn quote(&self) -> u8 { + self.quote + } + + /// Set the CSV file's escape character as a byte character + /// + /// In some variants of CSV, quotes are escaped using a special escape + /// character like `\` (instead of escaping quotes by doubling them). + /// + /// By default, writing these idiosyncratic escapes is disabled, and is + /// only used when `double_quote` is disabled. + pub fn with_escape(mut self, escape: u8) -> Self { + self.escape = escape; + self + } + + /// Get the CSV file's escape character as a byte character + pub fn escape(&self) -> u8 { + self.escape + } + + /// Set whether to enable double quote escapes + /// + /// When enabled (which is the default), quotes are escaped by doubling + /// them. e.g., `"` escapes to `""`. + /// + /// When disabled, quotes are escaped with the escape character (which + /// is `\\` by default). + pub fn with_double_quote(mut self, double_quote: bool) -> Self { + self.double_quote = double_quote; + self + } + + /// Get whether double quote escapes are enabled + pub fn double_quote(&self) -> bool { + self.double_quote + } + + /// Set the CSV file's date format + pub fn with_date_format(mut self, format: String) -> Self { + self.date_format = Some(format); + self + } + + /// Get the CSV file's date format if set, defaults to RFC3339 + pub fn date_format(&self) -> Option<&str> { + self.date_format.as_deref() + } + + /// Set the CSV file's datetime format + pub fn with_datetime_format(mut self, format: String) -> Self { + self.datetime_format = Some(format); + self + } + + /// Get the CSV file's datetime format if set, defaults to RFC3339 + pub fn datetime_format(&self) -> Option<&str> { + self.datetime_format.as_deref() + } + + /// Set the CSV file's time format + pub fn with_time_format(mut self, format: String) -> Self { + self.time_format = Some(format); + self + } + + /// Get the CSV file's datetime time if set, defaults to RFC3339 + pub fn time_format(&self) -> Option<&str> { + self.time_format.as_deref() + } + + /// Set the CSV file's timestamp format + pub fn with_timestamp_format(mut self, format: String) -> Self { + self.timestamp_format = Some(format); + self + } + + /// Get the CSV file's timestamp format if set, defaults to RFC3339 + pub fn timestamp_format(&self) -> Option<&str> { + self.timestamp_format.as_deref() + } + + /// Set the value to represent null in output + pub fn with_null(mut self, null_value: String) -> Self { + self.null_value = Some(null_value); + self + } + + /// Get the value to represent null in output + pub fn null(&self) -> &str { + self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE) + } + + /// Use RFC3339 format for date/time/timestamps (default) + #[deprecated(note = "Use WriterBuilder::default()")] + pub fn with_rfc3339(mut self) -> Self { + self.date_format = None; + self.datetime_format = None; + self.time_format = None; + self.timestamp_format = None; + self.timestamp_tz_format = None; + self + } + + /// Create a new `Writer` + pub fn build(self, writer: W) -> Writer { + let mut builder = csv::WriterBuilder::new(); + let writer = builder + .delimiter(self.delimiter) + .quote(self.quote) + .double_quote(self.double_quote) + .escape(self.escape) + .from_writer(writer); + Writer { + writer, + beginning: true, + has_headers: self.has_header, + date_format: self.date_format, + datetime_format: self.datetime_format, + time_format: self.time_format, + timestamp_format: self.timestamp_format, + timestamp_tz_format: self.timestamp_tz_format, + null_value: self.null_value, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::ReaderBuilder; + use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; + use arrow_array::types::*; + use arrow_buffer::i256; + use std::io::{Cursor, Read, Seek}; + use std::sync::Arc; + + #[test] + fn test_write_csv() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + Field::new_dictionary("c7", DataType::Int32, DataType::Utf8, false), + ]); + + let c1 = StringArray::from(vec![ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + "sed do eiusmod tempor", + ]); + let c2 = + PrimitiveArray::::from(vec![Some(123.564532), None, Some(-556132.25)]); + let c3 = PrimitiveArray::::from(vec![3, 2, 1]); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c5 = + TimestampMillisecondArray::from(vec![None, Some(1555584887378), Some(1555555555555)]); + let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); + let c7: DictionaryArray = + vec!["cupcakes", "cupcakes", "foo"].into_iter().collect(); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(c1), + Arc::new(c2), + Arc::new(c3), + Arc::new(c4), + Arc::new(c5), + Arc::new(c6), + Arc::new(c7), + ], + ) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let mut writer = Writer::new(&mut file); + let batches = vec![&batch, &batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + // check that file was written successfully + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + let expected = r#"c1,c2,c3,c4,c5,c6,c7 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo +"#; + assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap()); + } + + #[test] + fn test_write_csv_decimal() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Decimal128(38, 6), true), + Field::new("c2", DataType::Decimal256(76, 6), true), + ]); + + let mut c1_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6)); + c1_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]); + let c1 = c1_builder.finish(); + + let mut c2_builder = Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6)); + c2_builder.extend(vec![ + Some(i256::from_i128(-3335724)), + Some(i256::from_i128(2179404)), + None, + Some(i256::from_i128(290472)), + ]); + let c2 = c2_builder.finish(); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let mut writer = Writer::new(&mut file); + let batches = vec![&batch, &batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + // check that file was written successfully + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + let expected = r#"c1,c2 +-3.335724,-3.335724 +2.179404,2.179404 +, +0.290472,0.290472 +-3.335724,-3.335724 +2.179404,2.179404 +, +0.290472,0.290472 +"#; + assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap()); + } + + #[test] + fn test_write_csv_custom_options() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + ]); + + let c1 = StringArray::from(vec![ + "Lorem ipsum \ndolor sit amet", + "consectetur \"adipiscing\" elit", + "sed do eiusmod tempor", + ]); + let c2 = + PrimitiveArray::::from(vec![Some(123.564532), None, Some(-556132.25)]); + let c3 = PrimitiveArray::::from(vec![3, 2, 1]); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(c1), + Arc::new(c2), + Arc::new(c3), + Arc::new(c4), + Arc::new(c6), + ], + ) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let builder = WriterBuilder::new() + .with_header(false) + .with_delimiter(b'|') + .with_quote(b'\'') + .with_null("NULL".to_string()) + .with_time_format("%r".to_string()); + let mut writer = builder.build(&mut file); + let batches = vec![&batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + // check that file was written successfully + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "'Lorem ipsum \ndolor sit amet'|123.564532|3|true|12:20:34 AM\nconsectetur \"adipiscing\" elit|NULL|2|false|06:51:20 AM\nsed do eiusmod tempor|-556132.25|1|NULL|11:46:03 PM\n" + .to_string(), + String::from_utf8(buffer).unwrap() + ); + + let mut file = tempfile::tempfile().unwrap(); + + let builder = WriterBuilder::new() + .with_header(true) + .with_double_quote(false) + .with_escape(b'$'); + let mut writer = builder.build(&mut file); + let batches = vec![&batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "c1,c2,c3,c4,c6\n\"Lorem ipsum \ndolor sit amet\",123.564532,3,true,00:20:34\n\"consectetur $\"adipiscing$\" elit\",,2,false,06:51:20\nsed do eiusmod tempor,-556132.25,1,,23:46:03\n" + .to_string(), + String::from_utf8(buffer).unwrap() + ); + } + + #[test] + fn test_conversion_consistency() { + // test if we can serialize and deserialize whilst retaining the same type information/ precision + + let schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, false), + Field::new("c2", DataType::Date64, false), + Field::new("c3", DataType::Timestamp(TimeUnit::Nanosecond, None), false), + ]); + + let nanoseconds = vec![ + 1599566300000000000, + 1599566200000000000, + 1599566100000000000, + ]; + let c1 = Date32Array::from(vec![3, 2, 1]); + let c2 = Date64Array::from(vec![3, 2, 1]); + let c3 = TimestampNanosecondArray::from(nanoseconds.clone()); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], + ) + .unwrap(); + + let builder = WriterBuilder::new().with_header(false); + + let mut buf: Cursor> = Default::default(); + // drop the writer early to release the borrow. + { + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + } + buf.set_position(0); + + let mut reader = ReaderBuilder::new(Arc::new(schema)) + .with_batch_size(3) + .build_buffered(buf) + .unwrap(); + + let rb = reader.next().unwrap().unwrap(); + let c1 = rb.column(0).as_any().downcast_ref::().unwrap(); + let c2 = rb.column(1).as_any().downcast_ref::().unwrap(); + let c3 = rb + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + let actual = c1.into_iter().collect::>(); + let expected = vec![Some(3), Some(2), Some(1)]; + assert_eq!(actual, expected); + let actual = c2.into_iter().collect::>(); + let expected = vec![Some(3), Some(2), Some(1)]; + assert_eq!(actual, expected); + let actual = c3.into_iter().collect::>(); + let expected = nanoseconds.into_iter().map(Some).collect::>(); + assert_eq!(actual, expected); + } + + #[test] + fn test_write_csv_invalid_cast() { + let schema = Schema::new(vec![ + Field::new("c0", DataType::UInt32, false), + Field::new("c1", DataType::Date64, false), + ]); + + let c0 = UInt32Array::from(vec![Some(123), Some(234)]); + let c1 = Date64Array::from(vec![Some(1926632005177), Some(1926632005177685347)]); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c0), Arc::new(c1)]).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + let mut writer = Writer::new(&mut file); + let batches = vec![&batch, &batch]; + + for batch in batches { + let err = writer.write(batch).unwrap_err().to_string(); + assert_eq!(err, "Csv error: Error processing row 2, col 2: Cast error: Failed to convert 1926632005177685347 to temporal for Date64") + } + drop(writer); + } + + #[test] + fn test_write_csv_using_rfc3339() { + let schema = Schema::new(vec![ + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c3", DataType::Date32, false), + Field::new("c4", DataType::Time32(TimeUnit::Second), false), + ]); + + let c1 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]) + .with_timezone("+00:00".to_string()); + let c2 = TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]); + let c3 = Date32Array::from(vec![3, 2]); + let c4 = Time32SecondArray::from(vec![1234, 24680]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], + ) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let builder = WriterBuilder::new(); + let mut writer = builder.build(&mut file); + let batches = vec![&batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + assert_eq!( + "c1,c2,c3,c4 +2019-04-18T10:54:47.378Z,2019-04-18T10:54:47.378,1970-01-04,00:20:34 +2021-10-30T06:59:07Z,2021-10-30T06:59:07,1970-01-03,06:51:20\n", + String::from_utf8(buffer).unwrap() + ); + } +} diff --git a/arrow-csv/test/data/custom_null_test.csv b/arrow-csv/test/data/custom_null_test.csv new file mode 100644 index 000000000000..39f9fc4b3eff --- /dev/null +++ b/arrow-csv/test/data/custom_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool +1,1.1,"1.11",True +nil,2.2,"2.22",TRUE +3,nil,"3.33",true +4,4.4,nil,False +5,6.6,"",nil diff --git a/arrow/test/data/decimal_test.csv b/arrow-csv/test/data/decimal_test.csv similarity index 100% rename from arrow/test/data/decimal_test.csv rename to arrow-csv/test/data/decimal_test.csv diff --git a/arrow-csv/test/data/example.csv b/arrow-csv/test/data/example.csv new file mode 100644 index 000000000000..0c03cee84528 --- /dev/null +++ b/arrow-csv/test/data/example.csv @@ -0,0 +1,4 @@ +c1,c2,c3,c4 +1,1.1,"hong kong",true +3,323.12,"XiAn",false +10,131323.12,"cheng du",false \ No newline at end of file diff --git a/arrow-csv/test/data/init_null_test.csv b/arrow-csv/test/data/init_null_test.csv new file mode 100644 index 000000000000..f7d8a299645d --- /dev/null +++ b/arrow-csv/test/data/init_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool,c_null +,,,, +2,2.2,"a",TRUE, +3,,"b",true, +4,4.4,,False, +5,6.6,"",FALSE, \ No newline at end of file diff --git a/arrow/test/data/null_test.csv b/arrow-csv/test/data/null_test.csv similarity index 100% rename from arrow/test/data/null_test.csv rename to arrow-csv/test/data/null_test.csv diff --git a/arrow/test/data/uk_cities.csv b/arrow-csv/test/data/uk_cities.csv similarity index 100% rename from arrow/test/data/uk_cities.csv rename to arrow-csv/test/data/uk_cities.csv diff --git a/arrow/test/data/uk_cities_with_headers.csv b/arrow-csv/test/data/uk_cities_with_headers.csv similarity index 100% rename from arrow/test/data/uk_cities_with_headers.csv rename to arrow-csv/test/data/uk_cities_with_headers.csv diff --git a/arrow/test/data/various_types.csv b/arrow-csv/test/data/various_types.csv similarity index 100% rename from arrow/test/data/various_types.csv rename to arrow-csv/test/data/various_types.csv diff --git a/arrow/test/data/various_types_invalid.csv b/arrow-csv/test/data/various_types_invalid.csv similarity index 100% rename from arrow/test/data/various_types_invalid.csv rename to arrow-csv/test/data/various_types_invalid.csv diff --git a/arrow-data/Cargo.toml b/arrow-data/Cargo.toml new file mode 100644 index 000000000000..c83f867523d5 --- /dev/null +++ b/arrow-data/Cargo.toml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-data" +version = { workspace = true } +description = "Array data abstractions for Apache Arrow" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_data" +path = "src/lib.rs" +bench = false + +[features] +# force_validate runs full data validation for all arrays that are created +# this is not enabled by default as it is too computationally expensive +# but is run as part of our CI checks +force_validate = [] +# Enable ffi support +ffi = ["arrow-schema/ffi"] + +[package.metadata.docs.rs] +features = ["ffi"] + +[dependencies] + +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } + +num = { version = "0.4", default-features = false, features = ["std"] } +half = { version = "2.1", default-features = false } + +[dev-dependencies] + +[build-dependencies] diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs new file mode 100644 index 000000000000..10c53c549e2b --- /dev/null +++ b/arrow-data/src/data.rs @@ -0,0 +1,2122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains [`ArrayData`], a generic representation of Arrow array data which encapsulates +//! common attributes and operations for Arrow array. + +use crate::bit_iterator::BitSliceIterator; +use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; +use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer}; +use arrow_schema::{ArrowError, DataType, UnionMode}; +use std::convert::TryInto; +use std::mem; +use std::ops::Range; +use std::sync::Arc; + +use crate::equal; + +/// A collection of [`Buffer`] +#[doc(hidden)] +#[deprecated(note = "Use [Buffer]")] +pub type Buffers<'a> = &'a [Buffer]; + +#[inline] +pub(crate) fn contains_nulls( + null_bit_buffer: Option<&NullBuffer>, + offset: usize, + len: usize, +) -> bool { + match null_bit_buffer { + Some(buffer) => { + match BitSliceIterator::new(buffer.validity(), buffer.offset() + offset, len).next() { + Some((start, end)) => start != 0 || end != len, + None => len != 0, // No non-null values + } + } + None => false, // No null buffer + } +} + +#[inline] +pub(crate) fn count_nulls( + null_bit_buffer: Option<&NullBuffer>, + offset: usize, + len: usize, +) -> usize { + if let Some(buf) = null_bit_buffer { + let buffer = buf.buffer(); + len - buffer.count_set_bits_offset(offset + buf.offset(), len) + } else { + 0 + } +} + +/// creates 2 [`MutableBuffer`]s with a given `capacity` (in slots). +#[inline] +pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuffer; 2] { + let empty_buffer = MutableBuffer::new(0); + match data_type { + DataType::Null => [empty_buffer, MutableBuffer::new(0)], + DataType::Boolean => { + let bytes = bit_util::ceil(capacity, 8); + let buffer = MutableBuffer::new(bytes); + [buffer, empty_buffer] + } + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Date64 + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::Interval(_) => [ + MutableBuffer::new(capacity * data_type.primitive_width().unwrap()), + empty_buffer, + ], + DataType::Utf8 | DataType::Binary => { + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element + buffer.push(0i32); + [buffer, MutableBuffer::new(capacity * mem::size_of::())] + } + DataType::LargeUtf8 | DataType::LargeBinary => { + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element + buffer.push(0i64); + [buffer, MutableBuffer::new(capacity * mem::size_of::())] + } + DataType::List(_) | DataType::Map(_, _) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + buffer.push(0i32); + [buffer, empty_buffer] + } + DataType::LargeList(_) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + buffer.push(0i64); + [buffer, empty_buffer] + } + DataType::FixedSizeBinary(size) => { + [MutableBuffer::new(capacity * *size as usize), empty_buffer] + } + DataType::Dictionary(k, _) => [ + MutableBuffer::new(capacity * k.primitive_width().unwrap()), + empty_buffer, + ], + DataType::FixedSizeList(_, _) | DataType::Struct(_) | DataType::RunEndEncoded(_, _) => { + [empty_buffer, MutableBuffer::new(0)] + } + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => [ + MutableBuffer::new(capacity * mem::size_of::()), + empty_buffer, + ], + DataType::Union(_, mode) => { + let type_ids = MutableBuffer::new(capacity * mem::size_of::()); + match mode { + UnionMode::Sparse => [type_ids, empty_buffer], + UnionMode::Dense => { + let offsets = MutableBuffer::new(capacity * mem::size_of::()); + [type_ids, offsets] + } + } + } + } +} + +/// Maps 2 [`MutableBuffer`]s into a vector of [Buffer]s whose size depends on `data_type`. +#[inline] +pub(crate) fn into_buffers( + data_type: &DataType, + buffer1: MutableBuffer, + buffer2: MutableBuffer, +) -> Vec { + match data_type { + DataType::Null | DataType::Struct(_) | DataType::FixedSizeList(_, _) => vec![], + DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { + vec![buffer1.into(), buffer2.into()] + } + DataType::Union(_, mode) => { + match mode { + // Based on Union's DataTypeLayout + UnionMode::Sparse => vec![buffer1.into()], + UnionMode::Dense => vec![buffer1.into(), buffer2.into()], + } + } + _ => vec![buffer1.into()], + } +} + +/// A generic representation of Arrow array data which encapsulates common attributes and +/// operations for Arrow array. Specific operations for different arrays types (e.g., +/// primitive, list, struct) are implemented in `Array`. +/// +/// # Memory Layout +/// +/// `ArrayData` has references to one or more underlying data buffers +/// and optional child ArrayData, depending on type as illustrated +/// below. Bitmaps are not shown for simplicity but they are stored +/// similarly to the buffers. +/// +/// ```text +/// offset +/// points to +/// ┌───────────────────┐ start of ┌───────┐ Different +/// │ │ data │ │ ArrayData may +/// │ArrayData { │ │.... │ also refers to +/// │ data_type: ... │ ─ ─ ─ ─▶│1234 │ ┌ ─ the same +/// │ offset: ... ─ ─ ─│─ ┘ │4372 │ underlying +/// │ len: ... ─ ─ ─│─ ┐ │4888 │ │ buffer with different offset/len +/// │ buffers: [ │ │5882 │◀─ +/// │ ... │ │ │4323 │ +/// │ ] │ ─ ─ ─ ─▶│4859 │ +/// │ child_data: [ │ │.... │ +/// │ ... │ │ │ +/// │ ] │ └───────┘ +/// │} │ +/// │ │ Shared Buffer uses +/// │ │ │ bytes::Bytes to hold +/// └───────────────────┘ actual data values +/// ┌ ─ ─ ┘ +/// +/// ▼ +/// ┌───────────────────┐ +/// │ArrayData { │ +/// │ ... │ +/// │} │ +/// │ │ +/// └───────────────────┘ +/// +/// Child ArrayData may also have its own buffers and children +/// ``` + +#[derive(Debug, Clone)] +pub struct ArrayData { + /// The data type for this array data + data_type: DataType, + + /// The number of elements in this array data + len: usize, + + /// The offset into this array data, in number of items + offset: usize, + + /// The buffers for this array data. Note that depending on the array types, this + /// could hold different kinds of buffers (e.g., value buffer, value offset buffer) + /// at different positions. + buffers: Vec, + + /// The child(ren) of this array. Only non-empty for nested types, currently + /// `ListArray` and `StructArray`. + child_data: Vec, + + /// The null bitmap. A `None` value for this indicates all values are non-null in + /// this array. + nulls: Option, +} + +pub type ArrayDataRef = Arc; + +impl ArrayData { + /// Create a new ArrayData instance; + /// + /// If `null_count` is not specified, the number of nulls in + /// null_bit_buffer is calculated. + /// + /// If the number of nulls is 0 then the null_bit_buffer + /// is set to `None`. + /// + /// # Safety + /// + /// The input values *must* form a valid Arrow array for + /// `data_type`, or undefined behavior can result. + /// + /// Note: This is a low level API and most users of the arrow + /// crate should create arrays using the methods in the `array` + /// module. + pub unsafe fn new_unchecked( + data_type: DataType, + len: usize, + null_count: Option, + null_bit_buffer: Option, + offset: usize, + buffers: Vec, + child_data: Vec, + ) -> Self { + ArrayDataBuilder { + data_type, + len, + null_count, + null_bit_buffer, + nulls: None, + offset, + buffers, + child_data, + } + .build_unchecked() + } + + /// Create a new ArrayData, validating that the provided buffers form a valid + /// Arrow array of the specified data type. + /// + /// If the number of nulls in `null_bit_buffer` is 0 then the null_bit_buffer + /// is set to `None`. + /// + /// Internally this calls through to [`Self::validate_data`] + /// + /// Note: This is a low level API and most users of the arrow crate should create + /// arrays using the builders found in [arrow_array](https://docs.rs/arrow-array) + pub fn try_new( + data_type: DataType, + len: usize, + null_bit_buffer: Option, + offset: usize, + buffers: Vec, + child_data: Vec, + ) -> Result { + // we must check the length of `null_bit_buffer` first + // because we use this buffer to calculate `null_count` + // in `Self::new_unchecked`. + if let Some(null_bit_buffer) = null_bit_buffer.as_ref() { + let needed_len = bit_util::ceil(len + offset, 8); + if null_bit_buffer.len() < needed_len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_bit_buffer size too small. got {} needed {}", + null_bit_buffer.len(), + needed_len + ))); + } + } + // Safety justification: `validate_full` is called below + let new_self = unsafe { + Self::new_unchecked( + data_type, + len, + None, + null_bit_buffer, + offset, + buffers, + child_data, + ) + }; + + // As the data is not trusted, do a full validation of its contents + // We don't need to validate children as we can assume that the + // [`ArrayData`] in `child_data` have already been validated through + // a call to `ArrayData::try_new` or created using unsafe + new_self.validate_data()?; + Ok(new_self) + } + + /// Returns a builder to construct a [`ArrayData`] instance of the same [`DataType`] + #[inline] + pub const fn builder(data_type: DataType) -> ArrayDataBuilder { + ArrayDataBuilder::new(data_type) + } + + /// Returns a reference to the [`DataType`] of this [`ArrayData`] + #[inline] + pub const fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the [`Buffer`] storing data for this [`ArrayData`] + pub fn buffers(&self) -> &[Buffer] { + &self.buffers + } + + /// Returns a slice of children [`ArrayData`]. This will be non + /// empty for type such as lists and structs. + pub fn child_data(&self) -> &[ArrayData] { + &self.child_data[..] + } + + /// Returns whether the element at index `i` is null + #[inline] + pub fn is_null(&self, i: usize) -> bool { + match &self.nulls { + Some(v) => v.is_null(i), + None => false, + } + } + + /// Returns a reference to the null buffer of this [`ArrayData`] if any + /// + /// Note: [`ArrayData::offset`] does NOT apply to the returned [`NullBuffer`] + #[inline] + pub fn nulls(&self) -> Option<&NullBuffer> { + self.nulls.as_ref() + } + + /// Returns whether the element at index `i` is not null + #[inline] + pub fn is_valid(&self, i: usize) -> bool { + !self.is_null(i) + } + + /// Returns the length (i.e., number of elements) of this [`ArrayData`]. + #[inline] + pub const fn len(&self) -> usize { + self.len + } + + /// Returns whether this [`ArrayData`] is empty + #[inline] + pub const fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the offset of this [`ArrayData`] + #[inline] + pub const fn offset(&self) -> usize { + self.offset + } + + /// Returns the total number of nulls in this array + #[inline] + pub fn null_count(&self) -> usize { + self.nulls + .as_ref() + .map(|x| x.null_count()) + .unwrap_or_default() + } + + /// Returns the total number of bytes of memory occupied by the + /// buffers owned by this [`ArrayData`] and all of its + /// children. (See also diagram on [`ArrayData`]). + /// + /// Note that this [`ArrayData`] may only refer to a subset of the + /// data in the underlying [`Buffer`]s (due to `offset` and + /// `length`), but the size returned includes the entire size of + /// the buffers. + /// + /// If multiple [`ArrayData`]s refer to the same underlying + /// [`Buffer`]s they will both report the same size. + pub fn get_buffer_memory_size(&self) -> usize { + let mut size = 0; + for buffer in &self.buffers { + size += buffer.capacity(); + } + if let Some(bitmap) = &self.nulls { + size += bitmap.buffer().capacity() + } + for child in &self.child_data { + size += child.get_buffer_memory_size(); + } + size + } + + /// Returns the total number of the bytes of memory occupied by + /// the buffers by this slice of [`ArrayData`] (See also diagram on [`ArrayData`]). + /// + /// This is approximately the number of bytes if a new + /// [`ArrayData`] was formed by creating new [`Buffer`]s with + /// exactly the data needed. + /// + /// For example, a [`DataType::Int64`] with `100` elements, + /// [`Self::get_slice_memory_size`] would return `100 * 8 = 800`. If + /// the [`ArrayData`] was then [`Self::slice`]ed to refer to its + /// first `20` elements, then [`Self::get_slice_memory_size`] on the + /// sliced [`ArrayData`] would return `20 * 8 = 160`. + pub fn get_slice_memory_size(&self) -> Result { + let mut result: usize = 0; + let layout = layout(&self.data_type); + + for spec in layout.buffers.iter() { + match spec { + BufferSpec::FixedWidth { byte_width, .. } => { + let buffer_size = self.len.checked_mul(*byte_width).ok_or_else(|| { + ArrowError::ComputeError( + "Integer overflow computing buffer size".to_string(), + ) + })?; + result += buffer_size; + } + BufferSpec::VariableWidth => { + let buffer_len: usize; + match self.data_type { + DataType::Utf8 | DataType::Binary => { + let offsets = self.typed_offsets::()?; + buffer_len = (offsets[self.len] - offsets[0] ) as usize; + } + DataType::LargeUtf8 | DataType::LargeBinary => { + let offsets = self.typed_offsets::()?; + buffer_len = (offsets[self.len] - offsets[0]) as usize; + } + _ => { + return Err(ArrowError::NotYetImplemented(format!( + "Invalid data type for VariableWidth buffer. Expected Utf8, LargeUtf8, Binary or LargeBinary. Got {}", + self.data_type + ))) + } + }; + result += buffer_len; + } + BufferSpec::BitMap => { + let buffer_size = bit_util::ceil(self.len, 8); + result += buffer_size; + } + BufferSpec::AlwaysNull => { + // Nothing to do + } + } + } + + if self.nulls().is_some() { + result += bit_util::ceil(self.len, 8); + } + + for child in &self.child_data { + result += child.get_slice_memory_size()?; + } + Ok(result) + } + + /// Returns the total number of bytes of memory occupied + /// physically by this [`ArrayData`] and all its [`Buffer`]s and + /// children. (See also diagram on [`ArrayData`]). + /// + /// Equivalent to: + /// `size_of_val(self)` + + /// [`Self::get_buffer_memory_size`] + + /// `size_of_val(child)` for all children + pub fn get_array_memory_size(&self) -> usize { + let mut size = mem::size_of_val(self); + + // Calculate rest of the fields top down which contain actual data + for buffer in &self.buffers { + size += mem::size_of::(); + size += buffer.capacity(); + } + if let Some(nulls) = &self.nulls { + size += nulls.buffer().capacity(); + } + for child in &self.child_data { + size += child.get_array_memory_size(); + } + + size + } + + /// Creates a zero-copy slice of itself. This creates a new + /// [`ArrayData`] pointing at the same underlying [`Buffer`]s with a + /// different offset and len + /// + /// # Panics + /// + /// Panics if `offset + length > self.len()`. + pub fn slice(&self, offset: usize, length: usize) -> ArrayData { + assert!((offset + length) <= self.len()); + + if let DataType::Struct(_) = self.data_type() { + // Slice into children + let new_offset = self.offset + offset; + let new_data = ArrayData { + data_type: self.data_type().clone(), + len: length, + offset: new_offset, + buffers: self.buffers.clone(), + // Slice child data, to propagate offsets down to them + child_data: self + .child_data() + .iter() + .map(|data| data.slice(offset, length)) + .collect(), + nulls: self.nulls.as_ref().map(|x| x.slice(offset, length)), + }; + + new_data + } else { + let mut new_data = self.clone(); + + new_data.len = length; + new_data.offset = offset + self.offset; + new_data.nulls = self.nulls.as_ref().map(|x| x.slice(offset, length)); + + new_data + } + } + + /// Returns the `buffer` as a slice of type `T` starting at self.offset + /// # Panics + /// This function panics if: + /// * the buffer is not byte-aligned with type T, or + /// * the datatype is `Boolean` (it corresponds to a bit-packed buffer where the offset is not applicable) + pub fn buffer(&self, buffer: usize) -> &[T] { + &self.buffers()[buffer].typed_data()[self.offset..] + } + + /// Returns a new [`ArrayData`] valid for `data_type` containing `len` null values + pub fn new_null(data_type: &DataType, len: usize) -> Self { + let bit_len = bit_util::ceil(len, 8); + let zeroed = |len: usize| Buffer::from(MutableBuffer::from_len_zeroed(len)); + + let (buffers, child_data, has_nulls) = match data_type.primitive_width() { + Some(width) => (vec![zeroed(width * len)], vec![], true), + None => match data_type { + DataType::Null => (vec![], vec![], false), + DataType::Boolean => (vec![zeroed(bit_len)], vec![], true), + DataType::Binary | DataType::Utf8 => { + (vec![zeroed((len + 1) * 4), zeroed(0)], vec![], true) + } + DataType::LargeBinary | DataType::LargeUtf8 => { + (vec![zeroed((len + 1) * 8), zeroed(0)], vec![], true) + } + DataType::FixedSizeBinary(i) => (vec![zeroed(*i as usize * len)], vec![], true), + DataType::List(f) | DataType::Map(f, _) => ( + vec![zeroed((len + 1) * 4)], + vec![ArrayData::new_empty(f.data_type())], + true, + ), + DataType::LargeList(f) => ( + vec![zeroed((len + 1) * 8)], + vec![ArrayData::new_empty(f.data_type())], + true, + ), + DataType::FixedSizeList(f, list_len) => ( + vec![], + vec![ArrayData::new_null(f.data_type(), *list_len as usize * len)], + true, + ), + DataType::Struct(fields) => ( + vec![], + fields + .iter() + .map(|f| Self::new_null(f.data_type(), len)) + .collect(), + true, + ), + DataType::Dictionary(k, v) => ( + vec![zeroed(k.primitive_width().unwrap() * len)], + vec![ArrayData::new_empty(v.as_ref())], + true, + ), + DataType::Union(f, mode) => { + let (id, _) = f.iter().next().unwrap(); + let ids = Buffer::from_iter(std::iter::repeat(id).take(len)); + let buffers = match mode { + UnionMode::Sparse => vec![ids], + UnionMode::Dense => { + let end_offset = i32::from_usize(len).unwrap(); + vec![ids, Buffer::from_iter(0_i32..end_offset)] + } + }; + + let children = f + .iter() + .enumerate() + .map(|(idx, (_, f))| { + if idx == 0 || *mode == UnionMode::Sparse { + Self::new_null(f.data_type(), len) + } else { + Self::new_empty(f.data_type()) + } + }) + .collect(); + + (buffers, children, false) + } + DataType::RunEndEncoded(r, v) => { + let runs = match r.data_type() { + DataType::Int16 => { + let i = i16::from_usize(len).expect("run overflow"); + Buffer::from_slice_ref([i]) + } + DataType::Int32 => { + let i = i32::from_usize(len).expect("run overflow"); + Buffer::from_slice_ref([i]) + } + DataType::Int64 => { + let i = i64::from_usize(len).expect("run overflow"); + Buffer::from_slice_ref([i]) + } + dt => unreachable!("Invalid run ends data type {dt}"), + }; + + let builder = ArrayData::builder(r.data_type().clone()) + .len(1) + .buffers(vec![runs]); + + // SAFETY: + // Valid by construction + let runs = unsafe { builder.build_unchecked() }; + ( + vec![], + vec![runs, ArrayData::new_null(v.data_type(), 1)], + false, + ) + } + d => unreachable!("{d}"), + }, + }; + + let mut builder = ArrayDataBuilder::new(data_type.clone()) + .len(len) + .buffers(buffers) + .child_data(child_data); + + if has_nulls { + builder = builder.nulls(Some(NullBuffer::new_null(len))) + } + + // SAFETY: + // Data valid by construction + unsafe { builder.build_unchecked() } + } + + /// Returns a new empty [ArrayData] valid for `data_type`. + pub fn new_empty(data_type: &DataType) -> Self { + Self::new_null(data_type, 0) + } + + /// Verifies that the buffers meet the minimum alignment requirements for the data type + /// + /// Buffers that are not adequately aligned will be copied to a new aligned allocation + /// + /// This can be useful for when interacting with data sent over IPC or FFI, that may + /// not meet the minimum alignment requirements + pub fn align_buffers(&mut self) { + let layout = layout(&self.data_type); + for (buffer, spec) in self.buffers.iter_mut().zip(&layout.buffers) { + if let BufferSpec::FixedWidth { alignment, .. } = spec { + if buffer.as_ptr().align_offset(*alignment) != 0 { + *buffer = Buffer::from_slice_ref(buffer.as_ref()) + } + } + } + } + + /// "cheap" validation of an `ArrayData`. Ensures buffers are + /// sufficiently sized to store `len` + `offset` total elements of + /// `data_type` and performs other inexpensive consistency checks. + /// + /// This check is "cheap" in the sense that it does not validate the + /// contents of the buffers (e.g. that all offsets for UTF8 arrays + /// are within the bounds of the values buffer). + /// + /// See [ArrayData::validate_data] to validate fully the offset content + /// and the validity of utf8 data + pub fn validate(&self) -> Result<(), ArrowError> { + // Need at least this mich space in each buffer + let len_plus_offset = self.len + self.offset; + + // Check that the data layout conforms to the spec + let layout = layout(&self.data_type); + + if !layout.can_contain_null_mask && self.nulls.is_some() { + return Err(ArrowError::InvalidArgumentError(format!( + "Arrays of type {:?} cannot contain a null bitmask", + self.data_type, + ))); + } + + if self.buffers.len() != layout.buffers.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected {} buffers in array of type {:?}, got {}", + layout.buffers.len(), + self.data_type, + self.buffers.len(), + ))); + } + + for (i, (buffer, spec)) in self.buffers.iter().zip(layout.buffers.iter()).enumerate() { + match spec { + BufferSpec::FixedWidth { + byte_width, + alignment, + } => { + let min_buffer_size = len_plus_offset.saturating_mul(*byte_width); + + if buffer.len() < min_buffer_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Need at least {} bytes in buffers[{}] in array of type {:?}, but got {}", + min_buffer_size, i, self.data_type, buffer.len() + ))); + } + + let align_offset = buffer.as_ptr().align_offset(*alignment); + if align_offset != 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Misaligned buffers[{i}] in array of type {:?}, offset from expected alignment of {alignment} by {}", + self.data_type, align_offset.min(alignment - align_offset) + ))); + } + } + BufferSpec::VariableWidth => { + // not cheap to validate (need to look at the + // data). Partially checked in validate_offsets + // called below. Can check with `validate_full` + } + BufferSpec::BitMap => { + let min_buffer_size = bit_util::ceil(len_plus_offset, 8); + if buffer.len() < min_buffer_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Need at least {} bytes for bitmap in buffers[{}] in array of type {:?}, but got {}", + min_buffer_size, i, self.data_type, buffer.len() + ))); + } + } + BufferSpec::AlwaysNull => { + // Nothing to validate + } + } + } + + // check null bit buffer size + if let Some(nulls) = self.nulls() { + if nulls.null_count() > self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count {} for an array exceeds length of {} elements", + nulls.null_count(), + self.len + ))); + } + + let actual_len = nulls.validity().len(); + let needed_len = bit_util::ceil(len_plus_offset, 8); + if actual_len < needed_len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_bit_buffer size too small. got {actual_len} needed {needed_len}", + ))); + } + + if nulls.len() != self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "null buffer incorrect size. got {} expected {}", + nulls.len(), + self.len + ))); + } + } + + self.validate_child_data()?; + + // Additional Type specific checks + match &self.data_type { + DataType::Utf8 | DataType::Binary => { + self.validate_offsets::(self.buffers[1].len())?; + } + DataType::LargeUtf8 | DataType::LargeBinary => { + self.validate_offsets::(self.buffers[1].len())?; + } + DataType::Dictionary(key_type, _value_type) => { + // At the moment, constructing a DictionaryArray will also check this + if !DataType::is_dictionary_key_type(key_type) { + return Err(ArrowError::InvalidArgumentError(format!( + "Dictionary key type must be integer, but was {key_type}" + ))); + } + } + DataType::RunEndEncoded(run_ends_type, _) => { + if run_ends_type.is_nullable() { + return Err(ArrowError::InvalidArgumentError( + "The nullable should be set to false for the field defining run_ends array.".to_string() + )); + } + if !DataType::is_run_ends_type(run_ends_type.data_type()) { + return Err(ArrowError::InvalidArgumentError(format!( + "RunArray run_ends types must be Int16, Int32 or Int64, but was {}", + run_ends_type.data_type() + ))); + } + } + _ => {} + }; + + Ok(()) + } + + /// Returns a reference to the data in `buffer` as a typed slice + /// (typically `&[i32]` or `&[i64]`) after validating. The + /// returned slice is guaranteed to have at least `self.len + 1` + /// entries. + /// + /// For an empty array, the `buffer` can also be empty. + fn typed_offsets(&self) -> Result<&[T], ArrowError> { + // An empty list-like array can have 0 offsets + if self.len == 0 && self.buffers[0].is_empty() { + return Ok(&[]); + } + + self.typed_buffer(0, self.len + 1) + } + + /// Returns a reference to the data in `buffers[idx]` as a typed slice after validating + fn typed_buffer( + &self, + idx: usize, + len: usize, + ) -> Result<&[T], ArrowError> { + let buffer = &self.buffers[idx]; + + let required_len = (len + self.offset) * mem::size_of::(); + + if buffer.len() < required_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Buffer {} of {} isn't large enough. Expected {} bytes got {}", + idx, + self.data_type, + required_len, + buffer.len() + ))); + } + + Ok(&buffer.typed_data::()[self.offset..self.offset + len]) + } + + /// Does a cheap sanity check that the `self.len` values in `buffer` are valid + /// offsets (of type T) into some other buffer of `values_length` bytes long + fn validate_offsets( + &self, + values_length: usize, + ) -> Result<(), ArrowError> { + // Justification: buffer size was validated above + let offsets = self.typed_offsets::()?; + if offsets.is_empty() { + return Ok(()); + } + + let first_offset = offsets[0].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[0] ({}) to usize for {}", + offsets[0], self.data_type + )) + })?; + + let last_offset = offsets[self.len].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[{}] ({}) to usize for {}", + self.len, offsets[self.len], self.data_type + )) + })?; + + if first_offset > values_length { + return Err(ArrowError::InvalidArgumentError(format!( + "First offset {} of {} is larger than values length {}", + first_offset, self.data_type, values_length, + ))); + } + + if last_offset > values_length { + return Err(ArrowError::InvalidArgumentError(format!( + "Last offset {} of {} is larger than values length {}", + last_offset, self.data_type, values_length, + ))); + } + + if first_offset > last_offset { + return Err(ArrowError::InvalidArgumentError(format!( + "First offset {} in {} is smaller than last offset {}", + first_offset, self.data_type, last_offset, + ))); + } + + Ok(()) + } + + /// Validates the layout of `child_data` ArrayData structures + fn validate_child_data(&self) -> Result<(), ArrowError> { + match &self.data_type { + DataType::List(field) | DataType::Map(field, _) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets::(values_data.len)?; + Ok(()) + } + DataType::LargeList(field) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets::(values_data.len)?; + Ok(()) + } + DataType::FixedSizeList(field, list_size) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + + let list_size: usize = (*list_size).try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "{} has a negative list_size {}", + self.data_type, list_size + )) + })?; + + let expected_values_len = self.len + .checked_mul(list_size) + .expect("integer overflow computing expected number of expected values in FixedListSize"); + + if values_data.len < expected_values_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Values length {} is less than the length ({}) multiplied by the value size ({}) for {}", + values_data.len, list_size, list_size, self.data_type + ))); + } + + Ok(()) + } + DataType::Struct(fields) => { + self.validate_num_child_data(fields.len())?; + for (i, field) in fields.iter().enumerate() { + let field_data = self.get_valid_child_data(i, field.data_type())?; + + // Ensure child field has sufficient size + if field_data.len < self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "{} child array #{} for field {} has length smaller than expected for struct array ({} < {})", + self.data_type, i, field.name(), field_data.len, self.len + ))); + } + } + Ok(()) + } + DataType::RunEndEncoded(run_ends_field, values_field) => { + self.validate_num_child_data(2)?; + let run_ends_data = self.get_valid_child_data(0, run_ends_field.data_type())?; + let values_data = self.get_valid_child_data(1, values_field.data_type())?; + if run_ends_data.len != values_data.len { + return Err(ArrowError::InvalidArgumentError(format!( + "The run_ends array length should be the same as values array length. Run_ends array length is {}, values array length is {}", + run_ends_data.len, values_data.len + ))); + } + if run_ends_data.nulls.is_some() { + return Err(ArrowError::InvalidArgumentError( + "Found null values in run_ends array. The run_ends array should not have null values.".to_string(), + )); + } + Ok(()) + } + DataType::Union(fields, mode) => { + self.validate_num_child_data(fields.len())?; + + for (i, (_, field)) in fields.iter().enumerate() { + let field_data = self.get_valid_child_data(i, field.data_type())?; + + if mode == &UnionMode::Sparse && field_data.len < (self.len + self.offset) { + return Err(ArrowError::InvalidArgumentError(format!( + "Sparse union child array #{} has length smaller than expected for union array ({} < {})", + i, field_data.len, self.len + self.offset + ))); + } + } + Ok(()) + } + DataType::Dictionary(_key_type, value_type) => { + self.get_single_valid_child_data(value_type)?; + Ok(()) + } + _ => { + // other types do not have child data + if !self.child_data.is_empty() { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected no child arrays for type {} but got {}", + self.data_type, + self.child_data.len() + ))); + } + Ok(()) + } + } + } + + /// Ensures that this array data has a single child_data with the + /// expected type, and calls `validate()` on it. Returns a + /// reference to that child_data + fn get_single_valid_child_data( + &self, + expected_type: &DataType, + ) -> Result<&ArrayData, ArrowError> { + self.validate_num_child_data(1)?; + self.get_valid_child_data(0, expected_type) + } + + /// Returns `Err` if self.child_data does not have exactly `expected_len` elements + fn validate_num_child_data(&self, expected_len: usize) -> Result<(), ArrowError> { + if self.child_data.len() != expected_len { + Err(ArrowError::InvalidArgumentError(format!( + "Value data for {} should contain {} child data array(s), had {}", + self.data_type, + expected_len, + self.child_data.len() + ))) + } else { + Ok(()) + } + } + + /// Ensures that `child_data[i]` has the expected type, calls + /// `validate()` on it, and returns a reference to that child_data + fn get_valid_child_data( + &self, + i: usize, + expected_type: &DataType, + ) -> Result<&ArrayData, ArrowError> { + let values_data = self.child_data.get(i).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "{} did not have enough child arrays. Expected at least {} but had only {}", + self.data_type, + i + 1, + self.child_data.len() + )) + })?; + + if expected_type != &values_data.data_type { + return Err(ArrowError::InvalidArgumentError(format!( + "Child type mismatch for {}. Expected {} but child data had {}", + self.data_type, expected_type, values_data.data_type + ))); + } + + values_data.validate()?; + Ok(values_data) + } + + /// Validate that the data contained within this [`ArrayData`] is valid + /// + /// 1. Null count is correct + /// 2. All offsets are valid + /// 3. All String data is valid UTF-8 + /// 4. All dictionary offsets are valid + /// + /// Internally this calls: + /// + /// * [`Self::validate`] + /// * [`Self::validate_nulls`] + /// * [`Self::validate_values`] + /// + /// Note: this does not recurse into children, for a recursive variant + /// see [`Self::validate_full`] + pub fn validate_data(&self) -> Result<(), ArrowError> { + self.validate()?; + + self.validate_nulls()?; + self.validate_values()?; + Ok(()) + } + + /// Performs a full recursive validation of this [`ArrayData`] and all its children + /// + /// This is equivalent to calling [`Self::validate_data`] on this [`ArrayData`] + /// and all its children recursively + pub fn validate_full(&self) -> Result<(), ArrowError> { + self.validate_data()?; + // validate all children recursively + self.child_data + .iter() + .enumerate() + .try_for_each(|(i, child_data)| { + child_data.validate_full().map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "{} child #{} invalid: {}", + self.data_type, i, e + )) + }) + })?; + Ok(()) + } + + /// Validates the values stored within this [`ArrayData`] are valid + /// without recursing into child [`ArrayData`] + /// + /// Does not (yet) check + /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) + /// Validates the the null count is correct and that any + /// nullability requirements of its children are correct + pub fn validate_nulls(&self) -> Result<(), ArrowError> { + if let Some(nulls) = &self.nulls { + let actual = nulls.len() - nulls.inner().count_set_bits(); + if actual != nulls.null_count() { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count value ({}) doesn't match actual number of nulls in array ({})", + nulls.null_count(), + actual + ))); + } + } + + // In general non-nullable children should not contain nulls, however, for certain + // types, such as StructArray and FixedSizeList, nulls in the parent take up + // space in the child. As such we permit nulls in the children in the corresponding + // positions for such types + match &self.data_type { + DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => { + if !f.is_nullable() { + self.validate_non_nullable(None, &self.child_data[0])? + } + } + DataType::FixedSizeList(field, len) => { + let child = &self.child_data[0]; + if !field.is_nullable() { + match &self.nulls { + Some(nulls) => { + let element_len = *len as usize; + let expanded = nulls.expand(element_len); + self.validate_non_nullable(Some(&expanded), child)?; + } + None => self.validate_non_nullable(None, child)?, + } + } + } + DataType::Struct(fields) => { + for (field, child) in fields.iter().zip(&self.child_data) { + if !field.is_nullable() { + self.validate_non_nullable(self.nulls(), child)? + } + } + } + _ => {} + } + + Ok(()) + } + + /// Verifies that `child` contains no nulls not present in `mask` + fn validate_non_nullable( + &self, + mask: Option<&NullBuffer>, + child: &ArrayData, + ) -> Result<(), ArrowError> { + let mask = match mask { + Some(mask) => mask, + None => { + return match child.null_count() { + 0 => Ok(()), + _ => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent {}", + child.data_type, self.data_type + ))), + } + } + }; + + match child.nulls() { + Some(nulls) if !mask.contains(nulls) => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent", + child.data_type + ))), + _ => Ok(()), + } + } + + /// Validates the values stored within this [`ArrayData`] are valid + /// without recursing into child [`ArrayData`] + /// + /// Does not (yet) check + /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) + pub fn validate_values(&self) -> Result<(), ArrowError> { + match &self.data_type { + DataType::Utf8 => self.validate_utf8::(), + DataType::LargeUtf8 => self.validate_utf8::(), + DataType::Binary => self.validate_offsets_full::(self.buffers[1].len()), + DataType::LargeBinary => self.validate_offsets_full::(self.buffers[1].len()), + DataType::List(_) | DataType::Map(_, _) => { + let child = &self.child_data[0]; + self.validate_offsets_full::(child.len) + } + DataType::LargeList(_) => { + let child = &self.child_data[0]; + self.validate_offsets_full::(child.len) + } + DataType::Union(_, _) => { + // Validate Union Array as part of implementing new Union semantics + // See comments in `ArrayData::validate()` + // https://github.com/apache/arrow-rs/issues/85 + // + // TODO file follow on ticket for full union validation + Ok(()) + } + DataType::Dictionary(key_type, _value_type) => { + let dictionary_length: i64 = self.child_data[0].len.try_into().unwrap(); + let max_value = dictionary_length - 1; + match key_type.as_ref() { + DataType::UInt8 => self.check_bounds::(max_value), + DataType::UInt16 => self.check_bounds::(max_value), + DataType::UInt32 => self.check_bounds::(max_value), + DataType::UInt64 => self.check_bounds::(max_value), + DataType::Int8 => self.check_bounds::(max_value), + DataType::Int16 => self.check_bounds::(max_value), + DataType::Int32 => self.check_bounds::(max_value), + DataType::Int64 => self.check_bounds::(max_value), + _ => unreachable!(), + } + } + DataType::RunEndEncoded(run_ends, _values) => { + let run_ends_data = self.child_data()[0].clone(); + match run_ends.data_type() { + DataType::Int16 => run_ends_data.check_run_ends::(), + DataType::Int32 => run_ends_data.check_run_ends::(), + DataType::Int64 => run_ends_data.check_run_ends::(), + _ => unreachable!(), + } + } + _ => { + // No extra validation check required for other types + Ok(()) + } + } + } + + /// Calls the `validate(item_index, range)` function for each of + /// the ranges specified in the arrow offsets buffer of type + /// `T`. Also validates that each offset is smaller than + /// `offset_limit` + /// + /// For an empty array, the offsets buffer can either be empty + /// or contain a single `0`. + /// + /// For example, the offsets buffer contained `[1, 2, 4]`, this + /// function would call `validate([1,2])`, and `validate([2,4])` + fn validate_each_offset(&self, offset_limit: usize, validate: V) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + V: Fn(usize, Range) -> Result<(), ArrowError>, + { + self.typed_offsets::()? + .iter() + .enumerate() + .map(|(i, x)| { + // check if the offset can be converted to usize + let r = x.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: Could not convert offset {x} to usize at position {i}"))} + ); + // check if the offset exceeds the limit + match r { + Ok(n) if n <= offset_limit => Ok((i, n)), + Ok(_) => Err(ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: offset at position {i} out of bounds: {x} > {offset_limit}")) + ), + Err(e) => Err(e), + } + }) + .scan(0_usize, |start, end| { + // check offsets are monotonically increasing + match end { + Ok((i, end)) if *start <= end => { + let range = Some(Ok((i, *start..end))); + *start = end; + range + } + Ok((i, end)) => Some(Err(ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: non-monotonic offset at slot {}: {} > {}", + i - 1, start, end)) + )), + Err(err) => Some(Err(err)), + } + }) + .skip(1) // the first element is meaningless + .try_for_each(|res: Result<(usize, Range), ArrowError>| { + let (item_index, range) = res?; + validate(item_index-1, range) + }) + } + + /// Ensures that all strings formed by the offsets in `buffers[0]` + /// into `buffers[1]` are valid utf8 sequences + fn validate_utf8(&self) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let values_buffer = &self.buffers[1].as_slice(); + if let Ok(values_str) = std::str::from_utf8(values_buffer) { + // Validate Offsets are correct + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + if !values_str.is_char_boundary(range.start) + || !values_str.is_char_boundary(range.end) + { + return Err(ArrowError::InvalidArgumentError(format!( + "incomplete utf-8 byte sequence from index {string_index}" + ))); + } + Ok(()) + }) + } else { + // find specific offset that failed utf8 validation + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Invalid UTF8 sequence at string index {string_index} ({range:?}): {e}" + )) + })?; + Ok(()) + }) + } + } + + /// Ensures that all offsets in `buffers[0]` into `buffers[1]` are + /// between `0` and `offset_limit` + fn validate_offsets_full(&self, offset_limit: usize) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + self.validate_each_offset::(offset_limit, |_string_index, _range| { + // No validation applied to each value, but the iteration + // itself applies bounds checking to each range + Ok(()) + }) + } + + /// Validates that each value in self.buffers (typed as T) + /// is within the range [0, max_value], inclusive + fn check_bounds(&self, max_value: i64) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let required_len = self.len + self.offset; + let buffer = &self.buffers[0]; + + // This should have been checked as part of `validate()` prior + // to calling `validate_full()` but double check to be sure + assert!(buffer.len() / mem::size_of::() >= required_len); + + // Justification: buffer size was validated above + let indexes: &[T] = &buffer.typed_data::()[self.offset..self.offset + self.len]; + + indexes.iter().enumerate().try_for_each(|(i, &dict_index)| { + // Do not check the value is null (value can be arbitrary) + if self.is_null(i) { + return Ok(()); + } + let dict_index: i64 = dict_index.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Value at position {i} out of bounds: {dict_index} (can not convert to i64)" + )) + })?; + + if dict_index < 0 || dict_index > max_value { + return Err(ArrowError::InvalidArgumentError(format!( + "Value at position {i} out of bounds: {dict_index} (should be in [0, {max_value}])" + ))); + } + Ok(()) + }) + } + + /// Validates that each value in run_ends array is positive and strictly increasing. + fn check_run_ends(&self) -> Result<(), ArrowError> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let values = self.typed_buffer::(0, self.len)?; + let mut prev_value: i64 = 0_i64; + values.iter().enumerate().try_for_each(|(ix, &inp_value)| { + let value: i64 = inp_value.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Value at position {ix} out of bounds: {inp_value} (can not convert to i64)" + )) + })?; + if value <= 0_i64 { + return Err(ArrowError::InvalidArgumentError(format!( + "The values in run_ends array should be strictly positive. Found value {value} at index {ix} that does not match the criteria." + ))); + } + if ix > 0 && value <= prev_value { + return Err(ArrowError::InvalidArgumentError(format!( + "The values in run_ends array should be strictly increasing. Found value {value} at index {ix} with previous value {prev_value} that does not match the criteria." + ))); + } + + prev_value = value; + Ok(()) + })?; + + if prev_value.as_usize() < (self.offset + self.len) { + return Err(ArrowError::InvalidArgumentError(format!( + "The offset + length of array should be less or equal to last value in the run_ends array. The last value of run_ends array is {prev_value} and offset + length of array is {}.", + self.offset + self.len + ))); + } + Ok(()) + } + + /// Returns true if this `ArrayData` is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + pub fn ptr_eq(&self, other: &Self) -> bool { + if self.offset != other.offset + || self.len != other.len + || self.data_type != other.data_type + || self.buffers.len() != other.buffers.len() + || self.child_data.len() != other.child_data.len() + { + return false; + } + + match (&self.nulls, &other.nulls) { + (Some(a), Some(b)) if !a.inner().ptr_eq(b.inner()) => return false, + (Some(_), None) | (None, Some(_)) => return false, + _ => {} + }; + + if !self + .buffers + .iter() + .zip(other.buffers.iter()) + .all(|(a, b)| a.as_ptr() == b.as_ptr()) + { + return false; + } + + self.child_data + .iter() + .zip(other.child_data.iter()) + .all(|(a, b)| a.ptr_eq(b)) + } + + /// Converts this [`ArrayData`] into an [`ArrayDataBuilder`] + pub fn into_builder(self) -> ArrayDataBuilder { + self.into() + } +} + +/// Return the expected [`DataTypeLayout`] Arrays of this data +/// type are expected to have +pub fn layout(data_type: &DataType) -> DataTypeLayout { + // based on C/C++ implementation in + // https://github.com/apache/arrow/blob/661c7d749150905a63dd3b52e0a04dac39030d95/cpp/src/arrow/type.h (and .cc) + use arrow_schema::IntervalUnit::*; + + match data_type { + DataType::Null => DataTypeLayout { + buffers: vec![], + can_contain_null_mask: false, + }, + DataType::Boolean => DataTypeLayout { + buffers: vec![BufferSpec::BitMap], + can_contain_null_mask: true, + }, + DataType::Int8 => DataTypeLayout::new_fixed_width::(), + DataType::Int16 => DataTypeLayout::new_fixed_width::(), + DataType::Int32 => DataTypeLayout::new_fixed_width::(), + DataType::Int64 => DataTypeLayout::new_fixed_width::(), + DataType::UInt8 => DataTypeLayout::new_fixed_width::(), + DataType::UInt16 => DataTypeLayout::new_fixed_width::(), + DataType::UInt32 => DataTypeLayout::new_fixed_width::(), + DataType::UInt64 => DataTypeLayout::new_fixed_width::(), + DataType::Float16 => DataTypeLayout::new_fixed_width::(), + DataType::Float32 => DataTypeLayout::new_fixed_width::(), + DataType::Float64 => DataTypeLayout::new_fixed_width::(), + DataType::Timestamp(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Date32 => DataTypeLayout::new_fixed_width::(), + DataType::Date64 => DataTypeLayout::new_fixed_width::(), + DataType::Time32(_) => DataTypeLayout::new_fixed_width::(), + DataType::Time64(_) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(YearMonth) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(DayTime) => DataTypeLayout::new_fixed_width::(), + DataType::Interval(MonthDayNano) => DataTypeLayout::new_fixed_width::(), + DataType::Duration(_) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal128(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal256(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::FixedSizeBinary(size) => { + let spec = BufferSpec::FixedWidth { + byte_width: (*size).try_into().unwrap(), + alignment: mem::align_of::(), + }; + DataTypeLayout { + buffers: vec![spec], + can_contain_null_mask: true, + } + } + DataType::Binary => DataTypeLayout::new_binary::(), + DataType::LargeBinary => DataTypeLayout::new_binary::(), + DataType::Utf8 => DataTypeLayout::new_binary::(), + DataType::LargeUtf8 => DataTypeLayout::new_binary::(), + DataType::FixedSizeList(_, _) => DataTypeLayout::new_empty(), // all in child data + DataType::List(_) => DataTypeLayout::new_fixed_width::(), + DataType::LargeList(_) => DataTypeLayout::new_fixed_width::(), + DataType::Map(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child data, + DataType::RunEndEncoded(_, _) => DataTypeLayout::new_empty(), // all in child data, + DataType::Union(_, mode) => { + let type_ids = BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }; + + DataTypeLayout { + buffers: match mode { + UnionMode::Sparse => { + vec![type_ids] + } + UnionMode::Dense => { + vec![ + type_ids, + BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }, + ] + } + }, + can_contain_null_mask: false, + } + } + DataType::Dictionary(key_type, _value_type) => layout(key_type), + } +} + +/// Layout specification for a data type +#[derive(Debug, PartialEq, Eq)] +// Note: Follows structure from C++: https://github.com/apache/arrow/blob/master/cpp/src/arrow/type.h#L91 +pub struct DataTypeLayout { + /// A vector of buffer layout specifications, one for each expected buffer + pub buffers: Vec, + + /// Can contain a null bitmask + pub can_contain_null_mask: bool, +} + +impl DataTypeLayout { + /// Describes a basic numeric array where each element has type `T` + pub fn new_fixed_width() -> Self { + Self { + buffers: vec![BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }], + can_contain_null_mask: true, + } + } + + /// Describes arrays which have no data of their own + /// (e.g. FixedSizeList). Note such arrays may still have a Null + /// Bitmap + pub fn new_empty() -> Self { + Self { + buffers: vec![], + can_contain_null_mask: true, + } + } + + /// Describes a basic numeric array where each element has a fixed + /// with offset buffer of type `T`, followed by a + /// variable width data buffer + pub fn new_binary() -> Self { + Self { + buffers: vec![ + // offsets + BufferSpec::FixedWidth { + byte_width: mem::size_of::(), + alignment: mem::align_of::(), + }, + // values + BufferSpec::VariableWidth, + ], + can_contain_null_mask: true, + } + } +} + +/// Layout specification for a single data type buffer +#[derive(Debug, PartialEq, Eq)] +pub enum BufferSpec { + /// Each element is a fixed width primitive, with the given `byte_width` and `alignment` + /// + /// `alignment` is the alignment required by Rust for an array of the corresponding primitive, + /// see [`Layout::array`](std::alloc::Layout::array) and [`std::mem::align_of`]. + /// + /// Arrow-rs requires that all buffers have at least this alignment, to allow for + /// [slice](std::slice) based APIs. Alignment in excess of this is not required to allow + /// for array slicing and interoperability with `Vec`, which cannot be over-aligned. + /// + /// Note that these alignment requirements will vary between architectures + FixedWidth { byte_width: usize, alignment: usize }, + /// Variable width, such as string data for utf8 data + VariableWidth, + /// Buffer holds a bitmap. + /// + /// Note: Unlike the C++ implementation, the null/validity buffer + /// is handled specially rather than as another of the buffers in + /// the spec, so this variant is only used for the Boolean type. + BitMap, + /// Buffer is always null. Unused currently in Rust implementation, + /// (used in C++ for Union type) + #[allow(dead_code)] + AlwaysNull, +} + +impl PartialEq for ArrayData { + fn eq(&self, other: &Self) -> bool { + equal::equal(self, other) + } +} + +/// Builder for `ArrayData` type +#[derive(Debug)] +pub struct ArrayDataBuilder { + data_type: DataType, + len: usize, + null_count: Option, + null_bit_buffer: Option, + nulls: Option, + offset: usize, + buffers: Vec, + child_data: Vec, +} + +impl ArrayDataBuilder { + #[inline] + pub const fn new(data_type: DataType) -> Self { + Self { + data_type, + len: 0, + null_count: None, + null_bit_buffer: None, + nulls: None, + offset: 0, + buffers: vec![], + child_data: vec![], + } + } + + pub fn data_type(self, data_type: DataType) -> Self { + Self { data_type, ..self } + } + + #[inline] + #[allow(clippy::len_without_is_empty)] + pub const fn len(mut self, n: usize) -> Self { + self.len = n; + self + } + + pub fn nulls(mut self, nulls: Option) -> Self { + self.nulls = nulls; + self.null_count = None; + self.null_bit_buffer = None; + self + } + + pub fn null_count(mut self, null_count: usize) -> Self { + self.null_count = Some(null_count); + self + } + + pub fn null_bit_buffer(mut self, buf: Option) -> Self { + self.nulls = None; + self.null_bit_buffer = buf; + self + } + + #[inline] + pub const fn offset(mut self, n: usize) -> Self { + self.offset = n; + self + } + + pub fn buffers(mut self, v: Vec) -> Self { + self.buffers = v; + self + } + + pub fn add_buffer(mut self, b: Buffer) -> Self { + self.buffers.push(b); + self + } + + pub fn child_data(mut self, v: Vec) -> Self { + self.child_data = v; + self + } + + pub fn add_child_data(mut self, r: ArrayData) -> Self { + self.child_data.push(r); + self + } + + /// Creates an array data, without any validation + /// + /// # Safety + /// + /// The same caveats as [`ArrayData::new_unchecked`] + /// apply. + #[allow(clippy::let_and_return)] + pub unsafe fn build_unchecked(self) -> ArrayData { + let data = self.build_impl(); + // Provide a force_validate mode + #[cfg(feature = "force_validate")] + data.validate_data().unwrap(); + data + } + + /// Same as [`Self::build_unchecked`] but ignoring `force_validate` feature flag + unsafe fn build_impl(self) -> ArrayData { + let nulls = self.nulls.or_else(|| { + let buffer = self.null_bit_buffer?; + let buffer = BooleanBuffer::new(buffer, self.offset, self.len); + Some(match self.null_count { + Some(n) => NullBuffer::new_unchecked(buffer, n), + None => NullBuffer::new(buffer), + }) + }); + + ArrayData { + data_type: self.data_type, + len: self.len, + offset: self.offset, + buffers: self.buffers, + child_data: self.child_data, + nulls: nulls.filter(|b| b.null_count() != 0), + } + } + + /// Creates an array data, validating all inputs + pub fn build(self) -> Result { + let data = unsafe { self.build_impl() }; + data.validate_data()?; + Ok(data) + } + + /// Creates an array data, validating all inputs, and aligning any buffers + /// + /// Rust requires that arrays are aligned to their corresponding primitive, + /// see [`Layout::array`](std::alloc::Layout::array) and [`std::mem::align_of`]. + /// + /// [`ArrayData`] therefore requires that all buffers have at least this alignment, + /// to allow for [slice](std::slice) based APIs. See [`BufferSpec::FixedWidth`]. + /// + /// As this alignment is architecture specific, and not guaranteed by all arrow implementations, + /// this method is provided to automatically copy buffers to a new correctly aligned allocation + /// when necessary, making it useful when interacting with buffers produced by other systems, + /// e.g. IPC or FFI. + /// + /// This is unlike `[Self::build`] which will instead return an error on encountering + /// insufficiently aligned buffers. + pub fn build_aligned(self) -> Result { + let mut data = unsafe { self.build_impl() }; + data.align_buffers(); + data.validate_data()?; + Ok(data) + } +} + +impl From for ArrayDataBuilder { + fn from(d: ArrayData) -> Self { + Self { + data_type: d.data_type, + len: d.len, + offset: d.offset, + buffers: d.buffers, + child_data: d.child_data, + nulls: d.nulls, + null_bit_buffer: None, + null_count: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{Field, UnionFields}; + + // See arrow/tests/array_data_validation.rs for test of array validation + + /// returns a buffer initialized with some constant value for tests + fn make_i32_buffer(n: usize) -> Buffer { + Buffer::from_slice_ref(vec![42i32; n]) + } + + /// returns a buffer initialized with some constant value for tests + fn make_f32_buffer(n: usize) -> Buffer { + Buffer::from_slice_ref(vec![42f32; n]) + } + + #[test] + fn test_builder() { + // Buffer needs to be at least 25 long + let v = (0..25).collect::>(); + let b1 = Buffer::from_slice_ref(&v); + let arr_data = ArrayData::builder(DataType::Int32) + .len(20) + .offset(5) + .add_buffer(b1) + .null_bit_buffer(Some(Buffer::from(vec![ + 0b01011111, 0b10110101, 0b01100011, 0b00011110, + ]))) + .build() + .unwrap(); + + assert_eq!(20, arr_data.len()); + assert_eq!(10, arr_data.null_count()); + assert_eq!(5, arr_data.offset()); + assert_eq!(1, arr_data.buffers().len()); + assert_eq!( + Buffer::from_slice_ref(&v).as_slice(), + arr_data.buffers()[0].as_slice() + ); + } + + #[test] + fn test_builder_with_child_data() { + let child_arr_data = ArrayData::try_new( + DataType::Int32, + 5, + None, + 0, + vec![Buffer::from_slice_ref([1i32, 2, 3, 4, 5])], + vec![], + ) + .unwrap(); + + let field = Arc::new(Field::new("x", DataType::Int32, true)); + let data_type = DataType::Struct(vec![field].into()); + + let arr_data = ArrayData::builder(data_type) + .len(5) + .offset(0) + .add_child_data(child_arr_data.clone()) + .build() + .unwrap(); + + assert_eq!(5, arr_data.len()); + assert_eq!(1, arr_data.child_data().len()); + assert_eq!(child_arr_data, arr_data.child_data()[0]); + } + + #[test] + fn test_null_count() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let arr_data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + assert_eq!(13, arr_data.null_count()); + + // Test with offset + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let arr_data = ArrayData::builder(DataType::Int32) + .len(12) + .offset(2) + .add_buffer(make_i32_buffer(14)) // requires at least 14 bytes of space, + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + assert_eq!(10, arr_data.null_count()); + } + + #[test] + fn test_null_buffer_ref() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let arr_data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + assert!(arr_data.nulls().is_some()); + assert_eq!(&bit_v, arr_data.nulls().unwrap().validity()); + } + + #[test] + fn test_slice() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + let new_data = data.slice(1, 15); + assert_eq!(data.len() - 1, new_data.len()); + assert_eq!(1, new_data.offset()); + assert_eq!(data.null_count(), new_data.null_count()); + + // slice of a slice (removes one null) + let new_data = new_data.slice(1, 14); + assert_eq!(data.len() - 2, new_data.len()); + assert_eq!(2, new_data.offset()); + assert_eq!(data.null_count() - 1, new_data.null_count()); + } + + #[test] + fn test_equality() { + let int_data = ArrayData::builder(DataType::Int32) + .len(1) + .add_buffer(make_i32_buffer(1)) + .build() + .unwrap(); + + let float_data = ArrayData::builder(DataType::Float32) + .len(1) + .add_buffer(make_f32_buffer(1)) + .build() + .unwrap(); + assert_ne!(int_data, float_data); + assert!(!int_data.ptr_eq(&float_data)); + assert!(int_data.ptr_eq(&int_data)); + + #[allow(clippy::redundant_clone)] + let int_data_clone = int_data.clone(); + assert_eq!(int_data, int_data_clone); + assert!(int_data.ptr_eq(&int_data_clone)); + assert!(int_data_clone.ptr_eq(&int_data)); + + let int_data_slice = int_data_clone.slice(1, 0); + assert!(int_data_slice.ptr_eq(&int_data_slice)); + assert!(!int_data.ptr_eq(&int_data_slice)); + assert!(!int_data_slice.ptr_eq(&int_data)); + + let data_buffer = Buffer::from_slice_ref("abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref([0_i32, 2_i32, 2_i32, 5_i32]); + let string_data = ArrayData::try_new( + DataType::Utf8, + 3, + Some(Buffer::from_iter(vec![true, false, true])), + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + + assert_ne!(float_data, string_data); + assert!(!float_data.ptr_eq(&string_data)); + + assert!(string_data.ptr_eq(&string_data)); + + #[allow(clippy::redundant_clone)] + let string_data_cloned = string_data.clone(); + assert!(string_data_cloned.ptr_eq(&string_data)); + assert!(string_data.ptr_eq(&string_data_cloned)); + + let string_data_slice = string_data.slice(1, 2); + assert!(string_data_slice.ptr_eq(&string_data_slice)); + assert!(!string_data_slice.ptr_eq(&string_data)) + } + + #[test] + fn test_slice_memory_size() { + let mut bit_v: [u8; 2] = [0; 2]; + bit_util::set_bit(&mut bit_v, 0); + bit_util::set_bit(&mut bit_v, 3); + bit_util::set_bit(&mut bit_v, 10); + let data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) + .build() + .unwrap(); + let new_data = data.slice(1, 14); + assert_eq!( + data.get_slice_memory_size().unwrap() - 8, + new_data.get_slice_memory_size().unwrap() + ); + let data_buffer = Buffer::from_slice_ref("abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref([0_i32, 2_i32, 2_i32, 5_i32]); + let string_data = ArrayData::try_new( + DataType::Utf8, + 3, + Some(Buffer::from_iter(vec![true, false, true])), + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + let string_data_slice = string_data.slice(1, 2); + //4 bytes of offset and 2 bytes of data reduced by slicing. + assert_eq!( + string_data.get_slice_memory_size().unwrap() - 6, + string_data_slice.get_slice_memory_size().unwrap() + ); + } + + #[test] + fn test_count_nulls() { + let buffer = Buffer::from(vec![0b00010110, 0b10011111]); + let buffer = NullBuffer::new(BooleanBuffer::new(buffer, 0, 16)); + let count = count_nulls(Some(&buffer), 0, 16); + assert_eq!(count, 7); + + let count = count_nulls(Some(&buffer), 4, 8); + assert_eq!(count, 3); + } + + #[test] + fn test_contains_nulls() { + let buffer: Buffer = + MutableBuffer::from_iter([false, false, false, true, true, false]).into(); + let buffer = NullBuffer::new(BooleanBuffer::new(buffer, 0, 6)); + assert!(contains_nulls(Some(&buffer), 0, 6)); + assert!(contains_nulls(Some(&buffer), 0, 3)); + assert!(!contains_nulls(Some(&buffer), 3, 2)); + assert!(!contains_nulls(Some(&buffer), 0, 0)); + } + + #[test] + fn test_into_buffers() { + let data_types = vec![ + DataType::Union(UnionFields::empty(), UnionMode::Dense), + DataType::Union(UnionFields::empty(), UnionMode::Sparse), + ]; + + for data_type in data_types { + let buffers = new_buffers(&data_type, 0); + let [buffer1, buffer2] = buffers; + let buffers = into_buffers(&data_type, buffer1, buffer2); + + let layout = layout(&data_type); + assert_eq!(buffers.len(), layout.buffers.len()); + } + } + + #[test] + fn test_alignment() { + let buffer = Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let sliced = buffer.slice(1); + + let mut data = ArrayData { + data_type: DataType::Int32, + len: 0, + offset: 0, + buffers: vec![buffer], + child_data: vec![], + nulls: None, + }; + data.validate_full().unwrap(); + + data.buffers[0] = sliced; + let err = data.validate().unwrap_err(); + + assert_eq!( + err.to_string(), + "Invalid argument error: Misaligned buffers[0] in array of type Int32, offset from expected alignment of 4 by 1" + ); + + data.align_buffers(); + data.validate_full().unwrap(); + } +} diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs new file mode 100644 index 000000000000..74279bfb9af1 --- /dev/null +++ b/arrow-data/src/decimal.rs @@ -0,0 +1,781 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::i256; +use arrow_schema::ArrowError; + +pub use arrow_schema::{ + DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DECIMAL_DEFAULT_SCALE, +}; + +// MAX decimal256 value of little-endian format for each precision. +// Each element is the max value of signed 256-bit integer for the specified precision which +// is encoded to the 32-byte width format of little-endian. +pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ + i256::from_le_bytes([ + 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 231, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 15, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]), + i256::from_le_bytes([ + 159, 134, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, + ]), + i256::from_le_bytes([ + 63, 66, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, + ]), + i256::from_le_bytes([ + 127, 150, 152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 224, 245, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 201, 154, 59, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 227, 11, 84, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 231, 118, 72, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 15, 165, 212, 232, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 159, 114, 78, 24, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 63, 122, 16, 243, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 127, 198, 164, 126, 141, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 192, 111, 242, 134, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 137, 93, 120, 69, 99, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 99, 167, 179, 182, 224, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 231, 137, 4, 35, 199, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 15, 99, 45, 94, 199, 107, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 159, 222, 197, 173, 201, 53, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 63, 178, 186, 201, 224, 25, 30, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 127, 246, 74, 225, 199, 2, 45, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 160, 237, 204, 206, 27, 194, 211, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 73, 72, 1, 20, 22, 149, 69, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 227, 210, 12, 200, 220, 210, 183, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 231, 60, 128, 208, 159, 60, 46, 59, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 15, 97, 2, 37, 62, 94, 206, 79, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 159, 202, 23, 114, 109, 174, 15, 30, 67, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 63, 234, 237, 116, 70, 208, 156, 44, 159, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 127, 38, 75, 145, 192, 34, 32, 190, 55, 126, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 128, 239, 172, 133, 91, 65, 109, 45, 238, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 9, 91, 193, 56, 147, 141, 68, 198, 77, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 99, 142, 141, 55, 192, 135, 173, 190, 9, 237, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 231, 143, 135, 43, 130, 77, 199, 114, 97, 66, 19, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 15, 159, 75, 179, 21, 7, 201, 123, 206, 151, 192, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 159, 54, 244, 0, 217, 70, 218, 213, 16, 238, 133, 7, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 127, 86, 101, 95, 196, 172, 67, 137, 147, 254, 80, 240, 2, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 96, 245, 185, 171, 191, 164, 92, 195, 241, 41, 99, 29, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 201, 149, 67, 181, 124, 111, 158, 161, 113, 163, 223, 37, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 227, 217, 163, 20, 223, 90, 48, 80, 112, 98, 188, 122, 11, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 231, 130, 102, 206, 182, 140, 227, 33, 99, 216, 91, 203, 114, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 15, 29, 1, 16, 36, 127, 227, 82, 223, 115, 150, 241, 123, 4, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 159, 34, 11, 160, 104, 247, 226, 60, 185, 134, 224, 111, 215, 44, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 63, 90, 111, 64, 22, 170, 221, 96, 60, 67, 197, 94, 106, 192, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 127, 134, 89, 132, 222, 164, 168, 200, 91, 160, 180, 179, 39, 132, + 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 64, 127, 43, 177, 112, 150, 214, 149, 67, 14, 5, 141, 41, + 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 137, 248, 178, 235, 102, 224, 97, 218, 163, 142, 50, 130, + 159, 215, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 99, 181, 253, 52, 5, 196, 210, 135, 102, 146, 249, 21, 59, + 108, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 231, 21, 233, 17, 52, 168, 59, 78, 1, 184, 191, 219, 78, 58, + 172, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 15, 219, 26, 179, 8, 146, 84, 14, 13, 48, 125, 149, 20, 71, + 186, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 159, 142, 12, 255, 86, 180, 77, 143, 130, 224, 227, 214, 205, + 198, 70, 11, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 63, 146, 125, 246, 101, 11, 9, 153, 25, 197, 230, 100, 10, + 196, 195, 112, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 127, 182, 231, 160, 251, 113, 90, 250, 255, 178, 3, 241, 103, + 168, 165, 103, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 32, 13, 73, 212, 115, 136, 199, 255, 253, 36, 106, 15, + 148, 120, 12, 20, 4, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 73, 131, 218, 74, 134, 84, 203, 253, 235, 113, 37, 154, + 200, 181, 124, 200, 40, 0, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 227, 32, 137, 236, 62, 77, 241, 233, 55, 115, 118, 5, + 214, 25, 223, 212, 151, 1, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 231, 72, 91, 61, 117, 4, 109, 35, 47, 128, 160, 54, 92, + 2, 183, 80, 238, 15, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 15, 217, 144, 101, 148, 44, 66, 98, 215, 1, 69, 34, 154, + 23, 38, 39, 79, 159, 0, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 159, 122, 168, 247, 203, 189, 149, 214, 105, 18, 178, + 86, 5, 236, 124, 135, 23, 57, 6, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 63, 202, 148, 172, 247, 105, 217, 97, 34, 184, 244, 98, + 53, 56, 225, 74, 235, 58, 62, 0, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 127, 230, 207, 189, 172, 35, 126, 210, 87, 49, 143, 221, + 21, 50, 204, 236, 48, 77, 110, 2, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 0, 31, 106, 191, 100, 237, 56, 110, 237, 151, 167, + 218, 244, 249, 63, 233, 3, 79, 24, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 9, 54, 37, 122, 239, 69, 57, 78, 70, 239, 139, 138, + 144, 195, 127, 28, 39, 22, 243, 0, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 99, 28, 116, 197, 90, 187, 60, 14, 191, 88, 119, + 105, 165, 163, 253, 28, 135, 221, 126, 9, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 231, 27, 137, 182, 139, 81, 95, 142, 118, 119, 169, + 30, 118, 100, 232, 33, 71, 167, 244, 94, 0, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 15, 23, 91, 33, 117, 47, 185, 143, 161, 170, 158, + 50, 157, 236, 19, 83, 199, 136, 142, 181, 3, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 159, 230, 142, 77, 147, 218, 59, 157, 79, 170, 50, + 250, 35, 62, 199, 62, 201, 87, 145, 23, 37, 0, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 63, 2, 149, 7, 193, 137, 86, 36, 28, 167, 250, 197, + 103, 109, 200, 115, 220, 109, 173, 235, 114, 1, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 127, 22, 210, 75, 138, 97, 97, 107, 25, 135, 202, + 187, 13, 70, 212, 133, 156, 74, 198, 52, 125, 14, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 224, 52, 246, 102, 207, 205, 49, 254, 70, 233, + 85, 137, 188, 74, 58, 29, 234, 190, 15, 228, 144, 0, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 201, 16, 158, 5, 26, 10, 242, 237, 197, 28, + 91, 93, 93, 235, 70, 36, 37, 117, 157, 232, 168, 5, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 227, 167, 44, 56, 4, 101, 116, 75, 187, 31, + 143, 165, 165, 49, 197, 106, 115, 147, 38, 22, 153, 56, 0, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 231, 142, 190, 49, 42, 242, 139, 242, 80, 61, + 151, 119, 120, 240, 179, 43, 130, 194, 129, 221, 250, 53, 2, + ]), + i256::from_le_bytes([ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 15, 149, 113, 241, 165, 117, 119, 121, 41, + 101, 232, 171, 180, 100, 7, 181, 21, 153, 17, 167, 204, 27, 22, + ]), +]; + +// MIN decimal256 value of little-endian format for each precision. +// Each element is the min value of signed 256-bit integer for the specified precision which +// is encoded to the 76-byte width format of little-endian. +pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [ + i256::from_le_bytes([ + 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 157, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 25, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 241, 216, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 97, 121, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 193, 189, 240, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 129, 105, 103, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 31, 10, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 54, 101, 196, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 28, 244, 171, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 24, 137, 183, 232, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 240, 90, 43, 23, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 96, 141, 177, 231, 246, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 192, 133, 239, 12, 165, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 128, 57, 91, 129, 114, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 63, 144, 13, 121, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 118, 162, 135, 186, 156, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 156, 88, 76, 73, 31, 242, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 24, 118, 251, 220, 56, 117, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 240, 156, 210, 161, 56, 148, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 96, 33, 58, 82, 54, 202, 201, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 192, 77, 69, 54, 31, 230, 225, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 128, 9, 181, 30, 56, 253, 210, 234, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 95, 18, 51, 49, 228, 61, 44, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 182, 183, 254, 235, 233, 106, 186, 247, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 28, 45, 243, 55, 35, 45, 72, 173, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 24, 195, 127, 47, 96, 195, 209, 196, 252, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 240, 158, 253, 218, 193, 161, 49, 176, 223, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 96, 53, 232, 141, 146, 81, 240, 225, 188, 254, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 192, 21, 18, 139, 185, 47, 99, 211, 96, 243, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 128, 217, 180, 110, 63, 221, 223, 65, 200, 129, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 127, 16, 83, 122, 164, 190, 146, 210, 17, 251, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 246, 164, 62, 199, 108, 114, 187, 57, 178, 206, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 156, 113, 114, 200, 63, 120, 82, 65, 246, 18, 254, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 24, 112, 120, 212, 125, 178, 56, 141, 158, 189, 236, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 240, 96, 180, 76, 234, 248, 54, 132, 49, 104, 63, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 96, 201, 11, 255, 38, 185, 37, 42, 239, 17, 122, 248, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 192, 221, 117, 246, 133, 59, 121, 165, 87, 179, 196, 180, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 128, 169, 154, 160, 59, 83, 188, 118, 108, 1, 175, 15, 253, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 159, 10, 70, 84, 64, 91, 163, 60, 14, 214, 156, 226, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 54, 106, 188, 74, 131, 144, 97, 94, 142, 92, 32, 218, 254, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 28, 38, 92, 235, 32, 165, 207, 175, 143, 157, 67, 133, 244, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 24, 125, 153, 49, 73, 115, 28, 222, 156, 39, 164, 52, 141, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 240, 226, 254, 239, 219, 128, 28, 173, 32, 140, 105, 14, 132, 251, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 96, 221, 244, 95, 151, 8, 29, 195, 70, 121, 31, 144, 40, 211, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 192, 165, 144, 191, 233, 85, 34, 159, 195, 188, 58, 161, 149, 63, 254, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 128, 121, 166, 123, 33, 91, 87, 55, 164, 95, 75, 76, 216, 123, 238, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 191, 128, 212, 78, 143, 105, 41, 106, 188, 241, 250, 114, 214, 80, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 118, 7, 77, 20, 153, 31, 158, 37, 92, 113, 205, 125, 96, 40, 249, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 156, 74, 2, 203, 250, 59, 45, 120, 153, 109, 6, 234, 196, 147, 187, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 24, 234, 22, 238, 203, 87, 196, 177, 254, 71, 64, 36, 177, 197, 83, 253, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 240, 36, 229, 76, 247, 109, 171, 241, 242, 207, 130, 106, 235, 184, 69, + 229, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 96, 113, 243, 0, 169, 75, 178, 112, 125, 31, 28, 41, 50, 57, 185, 244, + 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 192, 109, 130, 9, 154, 244, 246, 102, 230, 58, 25, 155, 245, 59, 60, 143, + 245, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 128, 73, 24, 95, 4, 142, 165, 5, 0, 77, 252, 14, 152, 87, 90, 152, 151, + 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 223, 242, 182, 43, 140, 119, 56, 0, 2, 219, 149, 240, 107, 135, 243, + 235, 251, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 182, 124, 37, 181, 121, 171, 52, 2, 20, 142, 218, 101, 55, 74, 131, + 55, 215, 255, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 28, 223, 118, 19, 193, 178, 14, 22, 200, 140, 137, 250, 41, 230, 32, + 43, 104, 254, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 24, 183, 164, 194, 138, 251, 146, 220, 208, 127, 95, 201, 163, 253, + 72, 175, 17, 240, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 240, 38, 111, 154, 107, 211, 189, 157, 40, 254, 186, 221, 101, 232, + 217, 216, 176, 96, 255, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 96, 133, 87, 8, 52, 66, 106, 41, 150, 237, 77, 169, 250, 19, 131, 120, + 232, 198, 249, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 192, 53, 107, 83, 8, 150, 38, 158, 221, 71, 11, 157, 202, 199, 30, + 181, 20, 197, 193, 255, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 128, 25, 48, 66, 83, 220, 129, 45, 168, 206, 112, 34, 234, 205, 51, + 19, 207, 178, 145, 253, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 255, 224, 149, 64, 155, 18, 199, 145, 18, 104, 88, 37, 11, 6, 192, + 22, 252, 176, 231, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 246, 201, 218, 133, 16, 186, 198, 177, 185, 16, 116, 117, 111, 60, + 128, 227, 216, 233, 12, 255, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 156, 227, 139, 58, 165, 68, 195, 241, 64, 167, 136, 150, 90, 92, 2, + 227, 120, 34, 129, 246, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 24, 228, 118, 73, 116, 174, 160, 113, 137, 136, 86, 225, 137, 155, + 23, 222, 184, 88, 11, 161, 255, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 240, 232, 164, 222, 138, 208, 70, 112, 94, 85, 97, 205, 98, 19, + 236, 172, 56, 119, 113, 74, 252, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 96, 25, 113, 178, 108, 37, 196, 98, 176, 85, 205, 5, 220, 193, 56, + 193, 54, 168, 110, 232, 218, 255, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 192, 253, 106, 248, 62, 118, 169, 219, 227, 88, 5, 58, 152, 146, + 55, 140, 35, 146, 82, 20, 141, 254, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 128, 233, 45, 180, 117, 158, 158, 148, 230, 120, 53, 68, 242, 185, + 43, 122, 99, 181, 57, 203, 130, 241, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 31, 203, 9, 153, 48, 50, 206, 1, 185, 22, 170, 118, 67, 181, + 197, 226, 21, 65, 240, 27, 111, 255, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 54, 239, 97, 250, 229, 245, 13, 18, 58, 227, 164, 162, 162, 20, + 185, 219, 218, 138, 98, 23, 87, 250, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 28, 88, 211, 199, 251, 154, 139, 180, 68, 224, 112, 90, 90, 206, + 58, 149, 140, 108, 217, 233, 102, 199, 255, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 24, 113, 65, 206, 213, 13, 116, 13, 175, 194, 104, 136, 135, 15, + 76, 212, 125, 61, 126, 34, 5, 202, 253, + ]), + i256::from_le_bytes([ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 240, 106, 142, 14, 90, 138, 136, 134, 214, 154, 23, 84, 75, 155, + 248, 74, 234, 102, 238, 88, 51, 228, 233, + ]), +]; + +/// `MAX_DECIMAL_FOR_EACH_PRECISION[p]` holds the maximum `i128` value that can +/// be stored in [arrow_schema::DataType::Decimal128] value of precision `p` +pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ + 9, + 99, + 999, + 9999, + 99999, + 999999, + 9999999, + 99999999, + 999999999, + 9999999999, + 99999999999, + 999999999999, + 9999999999999, + 99999999999999, + 999999999999999, + 9999999999999999, + 99999999999999999, + 999999999999999999, + 9999999999999999999, + 99999999999999999999, + 999999999999999999999, + 9999999999999999999999, + 99999999999999999999999, + 999999999999999999999999, + 9999999999999999999999999, + 99999999999999999999999999, + 999999999999999999999999999, + 9999999999999999999999999999, + 99999999999999999999999999999, + 999999999999999999999999999999, + 9999999999999999999999999999999, + 99999999999999999999999999999999, + 999999999999999999999999999999999, + 9999999999999999999999999999999999, + 99999999999999999999999999999999999, + 999999999999999999999999999999999999, + 9999999999999999999999999999999999999, + 99999999999999999999999999999999999999, +]; + +/// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value that can +/// be stored in a [arrow_schema::DataType::Decimal128] value of precision `p` +pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ + -9, + -99, + -999, + -9999, + -99999, + -999999, + -9999999, + -99999999, + -999999999, + -9999999999, + -99999999999, + -999999999999, + -9999999999999, + -99999999999999, + -999999999999999, + -9999999999999999, + -99999999999999999, + -999999999999999999, + -9999999999999999999, + -99999999999999999999, + -999999999999999999999, + -9999999999999999999999, + -99999999999999999999999, + -999999999999999999999999, + -9999999999999999999999999, + -99999999999999999999999999, + -999999999999999999999999999, + -9999999999999999999999999999, + -99999999999999999999999999999, + -999999999999999999999999999999, + -9999999999999999999999999999999, + -99999999999999999999999999999999, + -999999999999999999999999999999999, + -9999999999999999999999999999999999, + -99999999999999999999999999999999999, + -999999999999999999999999999999999999, + -9999999999999999999999999999999999999, + -99999999999999999999999999999999999999, +]; + +/// Validates that the specified `i128` value can be properly +/// interpreted as a Decimal number with precision `precision` +#[inline] +pub fn validate_decimal_precision(value: i128, precision: u8) -> Result<(), ArrowError> { + if precision > DECIMAL128_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal128 is {DECIMAL128_MAX_PRECISION}, but got {precision}", + ))); + } + + let max = MAX_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1]; + let min = MIN_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1]; + + if value > max { + Err(ArrowError::InvalidArgumentError(format!( + "{value} is too large to store in a Decimal128 of precision {precision}. Max is {max}" + ))) + } else if value < min { + Err(ArrowError::InvalidArgumentError(format!( + "{value} is too small to store in a Decimal128 of precision {precision}. Min is {min}" + ))) + } else { + Ok(()) + } +} + +/// Validates that the specified `i256` of value can be properly +/// interpreted as a Decimal256 number with precision `precision` +#[inline] +pub fn validate_decimal256_precision(value: i256, precision: u8) -> Result<(), ArrowError> { + if precision > DECIMAL256_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal256 is {DECIMAL256_MAX_PRECISION}, but got {precision}", + ))); + } + let max = MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1]; + let min = MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1]; + + if value > max { + Err(ArrowError::InvalidArgumentError(format!( + "{value:?} is too large to store in a Decimal256 of precision {precision}. Max is {max:?}" + ))) + } else if value < min { + Err(ArrowError::InvalidArgumentError(format!( + "{value:?} is too small to store in a Decimal256 of precision {precision}. Min is {min:?}" + ))) + } else { + Ok(()) + } +} diff --git a/arrow/src/array/equal/boolean.rs b/arrow-data/src/equal/boolean.rs similarity index 65% rename from arrow/src/array/equal/boolean.rs rename to arrow-data/src/equal/boolean.rs index fddf21b963ad..addae936f118 100644 --- a/arrow/src/array/equal/boolean.rs +++ b/arrow-data/src/equal/boolean.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::contains_nulls, ArrayData}; -use crate::util::bit_iterator::BitIndexIterator; -use crate::util::bit_util::get_bit; +use crate::bit_iterator::BitIndexIterator; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::bit_util::get_bit; use super::utils::{equal_bits, equal_len}; @@ -33,7 +33,7 @@ pub(super) fn boolean_equal( let lhs_values = lhs.buffers()[0].as_slice(); let rhs_values = rhs.buffers()[0].as_slice(); - let contains_nulls = contains_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let contains_nulls = contains_nulls(lhs.nulls(), lhs_start, len); if !contains_nulls { // Optimize performance for starting offset at u8 boundary. @@ -76,42 +76,12 @@ pub(super) fn boolean_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); - let lhs_start = lhs.offset() + lhs_start; - let rhs_start = rhs.offset() + rhs_start; - - BitIndexIterator::new(lhs_null_bytes, lhs_start, len).all(|i| { - let lhs_pos = lhs_start + i; - let rhs_pos = rhs_start + i; + BitIndexIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len).all(|i| { + let lhs_pos = lhs_start + lhs.offset() + i; + let rhs_pos = rhs_start + rhs.offset() + i; get_bit(lhs_values, lhs_pos) == get_bit(rhs_values, rhs_pos) }) } } - -#[cfg(test)] -mod tests { - use crate::array::{Array, BooleanArray}; - - #[test] - fn test_boolean_slice() { - let array = BooleanArray::from(vec![true; 32]); - let slice = array.slice(4, 12); - assert_eq!(slice.data(), slice.data()); - - let slice = array.slice(8, 12); - assert_eq!(slice.data(), slice.data()); - - let slice = array.slice(8, 24); - assert_eq!(slice.data(), slice.data()); - } - - #[test] - fn test_sliced_nullable_boolean_array() { - let a = BooleanArray::from(vec![None; 32]); - let b = BooleanArray::from(vec![true; 32]); - let slice_a = a.slice(1, 12); - let slice_b = b.slice(1, 12); - assert_ne!(slice_a.data(), slice_b.data()); - } -} diff --git a/arrow/src/array/equal/dictionary.rs b/arrow-data/src/equal/dictionary.rs similarity index 75% rename from arrow/src/array/equal/dictionary.rs rename to arrow-data/src/equal/dictionary.rs index 4c9bcf798760..1d9c4b8d964f 100644 --- a/arrow/src/array/equal/dictionary.rs +++ b/arrow-data/src/equal/dictionary.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; -use crate::datatypes::ArrowNativeType; -use crate::util::bit_util::get_bit; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::ArrowNativeType; use super::equal_range; @@ -34,10 +33,9 @@ pub(super) fn dictionary_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 && rhs_null_count == 0 { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; @@ -52,14 +50,14 @@ pub(super) fn dictionary_equal( }) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); lhs_is_null || (lhs_is_null == rhs_is_null) diff --git a/arrow-data/src/equal/fixed_binary.rs b/arrow-data/src/equal/fixed_binary.rs new file mode 100644 index 000000000000..0778d77e2fdd --- /dev/null +++ b/arrow-data/src/equal/fixed_binary.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_iterator::BitSliceIterator; +use crate::contains_nulls; +use crate::data::ArrayData; +use crate::equal::primitive::NULL_SLICES_SELECTIVITY_THRESHOLD; +use arrow_schema::DataType; + +use super::utils::equal_len; + +pub(super) fn fixed_binary_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let size = match lhs.data_type() { + DataType::FixedSizeBinary(i) => *i as usize, + _ => unreachable!(), + }; + + let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..]; + let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..]; + + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { + equal_len( + lhs_values, + rhs_values, + size * lhs_start, + size * rhs_start, + size * len, + ) + } else { + let selectivity_frac = lhs.null_count() as f64 / lhs.len() as f64; + + if selectivity_frac >= NULL_SLICES_SELECTIVITY_THRESHOLD { + // get a ref of the null buffer bytes, to use in testing for nullness + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_len( + lhs_values, + rhs_values, + lhs_pos * size, + rhs_pos * size, + size, // 1 * size since we are comparing a single entry + ) + }) + } else { + let lhs_nulls = lhs.nulls().unwrap(); + let lhs_slices_iter = + BitSliceIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len); + let rhs_nulls = rhs.nulls().unwrap(); + let rhs_slices_iter = + BitSliceIterator::new(rhs_nulls.validity(), rhs_start + rhs_nulls.offset(), len); + + lhs_slices_iter + .zip(rhs_slices_iter) + .all(|((l_start, l_end), (r_start, r_end))| { + l_start == r_start + && l_end == r_end + && equal_len( + lhs_values, + rhs_values, + (lhs_start + l_start) * size, + (rhs_start + r_start) * size, + (l_end - l_start) * size, + ) + }) + } + } +} diff --git a/arrow/src/array/equal/fixed_list.rs b/arrow-data/src/equal/fixed_list.rs similarity index 75% rename from arrow/src/array/equal/fixed_list.rs rename to arrow-data/src/equal/fixed_list.rs index 82a347c86574..4b79e5c33fab 100644 --- a/arrow/src/array/equal/fixed_list.rs +++ b/arrow-data/src/equal/fixed_list.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; -use crate::datatypes::DataType; -use crate::util::bit_util::get_bit; +use crate::data::{contains_nulls, ArrayData}; +use arrow_schema::DataType; use super::equal_range; @@ -36,10 +35,9 @@ pub(super) fn fixed_list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 && rhs_null_count == 0 { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { equal_range( lhs_values, rhs_values, @@ -49,15 +47,15 @@ pub(super) fn fixed_list_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); lhs_is_null || (lhs_is_null == rhs_is_null) diff --git a/arrow/src/array/equal/list.rs b/arrow-data/src/equal/list.rs similarity index 69% rename from arrow/src/array/equal/list.rs rename to arrow-data/src/equal/list.rs index b3bca9a69228..cc4ba3cacf9f 100644 --- a/arrow/src/array/equal/list.rs +++ b/arrow-data/src/equal/list.rs @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::{ - array::ArrayData, - array::{data::count_nulls, OffsetSizeTrait}, - util::bit_util::get_bit, -}; +use crate::data::{count_nulls, ArrayData}; +use arrow_buffer::ArrowNativeType; +use num::Integer; use super::equal_range; -fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { +fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { // invariant from `base_equal` debug_assert_eq!(lhs.len(), rhs.len()); @@ -45,7 +43,7 @@ fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { }) } -pub(super) fn list_equal( +pub(super) fn list_equal( lhs: &ArrayData, rhs: &ArrayData, lhs_start: usize, @@ -91,8 +89,8 @@ pub(super) fn list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + let lhs_null_count = count_nulls(lhs.nulls(), lhs_start, len); + let rhs_null_count = count_nulls(rhs.nulls(), rhs_start, len); if lhs_null_count != rhs_null_count { return false; @@ -113,8 +111,8 @@ pub(super) fn list_equal( ) } else { // get a ref of the parent null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); // with nulls, we need to compare item by item whenever it is not null // TODO: Could potentially compare runs of not NULL values @@ -122,8 +120,8 @@ pub(super) fn list_equal( let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); if lhs_is_null != rhs_is_null { return false; @@ -149,52 +147,3 @@ pub(super) fn list_equal( }) } } - -#[cfg(test)] -mod tests { - use crate::{ - array::{Array, Int64Builder, ListArray, ListBuilder}, - datatypes::Int32Type, - }; - - #[test] - fn list_array_non_zero_nulls() { - // Tests handling of list arrays with non-empty null ranges - let mut builder = ListBuilder::new(Int64Builder::with_capacity(10)); - builder.values().append_value(1); - builder.values().append_value(2); - builder.values().append_value(3); - builder.append(true); - builder.append(false); - let array1 = builder.finish(); - - let mut builder = ListBuilder::new(Int64Builder::with_capacity(10)); - builder.values().append_value(1); - builder.values().append_value(2); - builder.values().append_value(3); - builder.append(true); - builder.values().append_null(); - builder.values().append_null(); - builder.append(false); - let array2 = builder.finish(); - - assert_eq!(array1, array2); - } - - #[test] - fn test_list_different_offsets() { - let a = ListArray::from_iter_primitive::([ - Some([Some(0), Some(0)]), - Some([Some(1), Some(2)]), - Some([None, None]), - ]); - let b = ListArray::from_iter_primitive::([ - Some([Some(1), Some(2)]), - Some([None, None]), - Some([None, None]), - ]); - let a_slice = a.slice(1, 2); - let b_slice = b.slice(0, 2); - assert_eq!(&a_slice, &b_slice); - } -} diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs new file mode 100644 index 000000000000..b279546474a0 --- /dev/null +++ b/arrow-data/src/equal/mod.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing functionality to compute array equality. +//! This module uses [ArrayData] and does not +//! depend on dynamic casting of `Array`. + +use crate::data::ArrayData; +use arrow_buffer::i256; +use arrow_schema::{DataType, IntervalUnit}; +use half::f16; + +mod boolean; +mod dictionary; +mod fixed_binary; +mod fixed_list; +mod list; +mod null; +mod primitive; +mod run; +mod structure; +mod union; +mod utils; +mod variable_size; + +// these methods assume the same type, len and null count. +// For this reason, they are not exposed and are instead used +// to build the generic functions below (`equal_range` and `equal`). +use boolean::boolean_equal; +use dictionary::dictionary_equal; +use fixed_binary::fixed_binary_equal; +use fixed_list::fixed_list_equal; +use list::list_equal; +use null::null_equal; +use primitive::primitive_equal; +use structure::struct_equal; +use union::union_equal; +use variable_size::variable_sized_equal; + +use self::run::run_equal; + +/// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively +/// for `len` slots. +#[inline] +fn equal_values( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + match lhs.data_type() { + DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal128(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal256(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Date64 + | DataType::Interval(IntervalUnit::DayTime) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Interval(IntervalUnit::MonthDayNano) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Utf8 | DataType::Binary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::FixedSizeBinary(_) => fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Dictionary(data_type, _) => match data_type.as_ref() { + DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + _ => unreachable!(), + }, + DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Map(_, _) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len), + } +} + +fn equal_range( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len) + && equal_values(lhs, rhs, lhs_start, rhs_start, len) +} + +/// Logically compares two [ArrayData]. +/// Two arrays are logically equal if and only if: +/// * their data types are equal +/// * their lengths are equal +/// * their null counts are equal +/// * their null bitmaps are equal +/// * each of their items are equal +/// two items are equal when their in-memory representation is physically equal (i.e. same bit content). +/// The physical comparison depend on the data type. +/// # Panics +/// This function may panic whenever any of the [ArrayData] does not follow the Arrow specification. +/// (e.g. wrong number of buffers, buffer `len` does not correspond to the declared `len`) +pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { + utils::base_equal(lhs, rhs) + && lhs.null_count() == rhs.null_count() + && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len()) + && equal_values(lhs, rhs, 0, 0, lhs.len()) +} + +// See arrow/tests/array_equal.rs for tests diff --git a/arrow/src/array/equal/null.rs b/arrow-data/src/equal/null.rs similarity index 97% rename from arrow/src/array/equal/null.rs rename to arrow-data/src/equal/null.rs index f287a382507a..1478e448cec2 100644 --- a/arrow/src/array/equal/null.rs +++ b/arrow-data/src/equal/null.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; +use crate::data::ArrayData; #[inline] pub(super) fn null_equal( diff --git a/arrow-data/src/equal/primitive.rs b/arrow-data/src/equal/primitive.rs new file mode 100644 index 000000000000..e92fdd2ba23b --- /dev/null +++ b/arrow-data/src/equal/primitive.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::bit_iterator::BitSliceIterator; +use crate::contains_nulls; +use std::mem::size_of; + +use crate::data::ArrayData; + +use super::utils::equal_len; + +pub(crate) const NULL_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.4; + +pub(super) fn primitive_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let byte_width = size_of::(); + let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * byte_width..]; + let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * byte_width..]; + + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { + // without nulls, we just need to compare slices + equal_len( + lhs_values, + rhs_values, + lhs_start * byte_width, + rhs_start * byte_width, + len * byte_width, + ) + } else { + let selectivity_frac = lhs.null_count() as f64 / lhs.len() as f64; + + if selectivity_frac >= NULL_SLICES_SELECTIVITY_THRESHOLD { + // get a ref of the null buffer bytes, to use in testing for nullness + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); + // with nulls, we need to compare item by item whenever it is not null + (0..len).all(|i| { + let lhs_pos = lhs_start + i; + let rhs_pos = rhs_start + i; + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); + + lhs_is_null + || (lhs_is_null == rhs_is_null) + && equal_len( + lhs_values, + rhs_values, + lhs_pos * byte_width, + rhs_pos * byte_width, + byte_width, // 1 * byte_width since we are comparing a single entry + ) + }) + } else { + let lhs_nulls = lhs.nulls().unwrap(); + let lhs_slices_iter = + BitSliceIterator::new(lhs_nulls.validity(), lhs_start + lhs_nulls.offset(), len); + let rhs_nulls = rhs.nulls().unwrap(); + let rhs_slices_iter = + BitSliceIterator::new(rhs_nulls.validity(), rhs_start + rhs_nulls.offset(), len); + + lhs_slices_iter + .zip(rhs_slices_iter) + .all(|((l_start, l_end), (r_start, r_end))| { + l_start == r_start + && l_end == r_end + && equal_len( + lhs_values, + rhs_values, + (lhs_start + l_start) * byte_width, + (rhs_start + r_start) * byte_width, + (l_end - l_start) * byte_width, + ) + }) + } + } +} diff --git a/arrow-data/src/equal/run.rs b/arrow-data/src/equal/run.rs new file mode 100644 index 000000000000..ede172c999fd --- /dev/null +++ b/arrow-data/src/equal/run.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::data::ArrayData; + +use super::equal_range; + +/// The current implementation of comparison of run array support physical comparison. +/// Comparing run encoded array based on logical indices (`lhs_start`, `rhs_start`) will +/// be time consuming as converting from logical index to physical index cannot be done +/// in constant time. The current comparison compares the underlying physical arrays. +pub(super) fn run_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + if lhs_start != 0 + || rhs_start != 0 + || (lhs.len() != len && rhs.len() != len) + || lhs.offset() > 0 + || rhs.offset() > 0 + { + unimplemented!("Logical comparison for run array not supported.") + } + + if lhs.len() != rhs.len() { + return false; + } + + let lhs_run_ends_array = lhs.child_data().get(0).unwrap(); + let lhs_values_array = lhs.child_data().get(1).unwrap(); + + let rhs_run_ends_array = rhs.child_data().get(0).unwrap(); + let rhs_values_array = rhs.child_data().get(1).unwrap(); + + if lhs_run_ends_array.len() != rhs_run_ends_array.len() { + return false; + } + + if lhs_values_array.len() != rhs_values_array.len() { + return false; + } + + // check run ends array are equal. The length of the physical array + // is used to validate the child arrays. + let run_ends_equal = equal_range( + lhs_run_ends_array, + rhs_run_ends_array, + lhs_start, + rhs_start, + lhs_run_ends_array.len(), + ); + + // if run ends array are not the same return early without validating + // values array. + if !run_ends_equal { + return false; + } + + // check values array are equal + equal_range( + lhs_values_array, + rhs_values_array, + lhs_start, + rhs_start, + rhs_values_array.len(), + ) +} diff --git a/arrow/src/array/equal/structure.rs b/arrow-data/src/equal/structure.rs similarity index 74% rename from arrow/src/array/equal/structure.rs rename to arrow-data/src/equal/structure.rs index 0f943e40cac6..e4751c26f489 100644 --- a/arrow/src/array/equal/structure.rs +++ b/arrow-data/src/equal/structure.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::data::count_nulls, array::ArrayData, util::bit_util::get_bit}; +use crate::data::{contains_nulls, ArrayData}; use super::equal_range; @@ -43,23 +43,21 @@ pub(super) fn struct_equal( rhs_start: usize, len: usize, ) -> bool { - // we have to recalculate null counts from the null buffers - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 && rhs_null_count == 0 { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { equal_child_values(lhs, rhs, lhs_start, rhs_start, len) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); + let lhs_nulls = lhs.nulls().unwrap(); + let rhs_nulls = rhs.nulls().unwrap(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; // if both struct and child had no null buffers, - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + let lhs_is_null = lhs_nulls.is_null(lhs_pos); + let rhs_is_null = rhs_nulls.is_null(rhs_pos); if lhs_is_null != rhs_is_null { return false; diff --git a/arrow/src/array/equal/union.rs b/arrow-data/src/equal/union.rs similarity index 79% rename from arrow/src/array/equal/union.rs rename to arrow-data/src/equal/union.rs index e8b9d27b6f0f..62de276e507f 100644 --- a/arrow/src/array/equal/union.rs +++ b/arrow-data/src/equal/union.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::ArrayData, datatypes::DataType, datatypes::UnionMode}; +use crate::data::ArrayData; +use arrow_schema::{DataType, UnionFields, UnionMode}; use super::equal_range; @@ -27,8 +28,8 @@ fn equal_dense( rhs_type_ids: &[i8], lhs_offsets: &[i32], rhs_offsets: &[i32], - lhs_field_type_ids: &[i8], - rhs_field_type_ids: &[i8], + lhs_fields: &UnionFields, + rhs_fields: &UnionFields, ) -> bool { let offsets = lhs_offsets.iter().zip(rhs_offsets.iter()); @@ -37,13 +38,13 @@ fn equal_dense( .zip(rhs_type_ids.iter()) .zip(offsets) .all(|((l_type_id, r_type_id), (l_offset, r_offset))| { - let lhs_child_index = lhs_field_type_ids + let lhs_child_index = lhs_fields .iter() - .position(|r| r == l_type_id) + .position(|(r, _)| r == *l_type_id) .unwrap(); - let rhs_child_index = rhs_field_type_ids + let rhs_child_index = rhs_fields .iter() - .position(|r| r == r_type_id) + .position(|(r, _)| r == *r_type_id) .unwrap(); let lhs_values = &lhs.child_data()[lhs_child_index]; let rhs_values = &rhs.child_data()[rhs_child_index]; @@ -69,7 +70,13 @@ fn equal_sparse( .iter() .zip(rhs.child_data()) .all(|(lhs_values, rhs_values)| { - equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) + equal_range( + lhs_values, + rhs_values, + lhs_start + lhs.offset(), + rhs_start + rhs.offset(), + len, + ) }) } @@ -88,8 +95,8 @@ pub(super) fn union_equal( match (lhs.data_type(), rhs.data_type()) { ( - DataType::Union(_, lhs_type_ids, UnionMode::Dense), - DataType::Union(_, rhs_type_ids, UnionMode::Dense), + DataType::Union(lhs_fields, UnionMode::Dense), + DataType::Union(rhs_fields, UnionMode::Dense), ) => { let lhs_offsets = lhs.buffer::(1); let rhs_offsets = rhs.buffer::(1); @@ -105,14 +112,11 @@ pub(super) fn union_equal( rhs_type_id_range, lhs_offsets_range, rhs_offsets_range, - lhs_type_ids, - rhs_type_ids, + lhs_fields, + rhs_fields, ) } - ( - DataType::Union(_, _, UnionMode::Sparse), - DataType::Union(_, _, UnionMode::Sparse), - ) => { + (DataType::Union(_, UnionMode::Sparse), DataType::Union(_, UnionMode::Sparse)) => { lhs_type_id_range == rhs_type_id_range && equal_sparse(lhs, rhs, lhs_start, rhs_start, len) } diff --git a/arrow/src/array/equal/utils.rs b/arrow-data/src/equal/utils.rs similarity index 75% rename from arrow/src/array/equal/utils.rs rename to arrow-data/src/equal/utils.rs index 449055d366ec..cc81943756d2 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow-data/src/equal/utils.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::array::data::contains_nulls; -use crate::array::ArrayData; -use crate::datatypes::DataType; -use crate::util::bit_chunk_iterator::BitChunks; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::bit_chunk_iterator::BitChunks; +use arrow_schema::DataType; // whether bits along the positions are equal // `lhs_start`, `rhs_start` and `len` are _measured in bits_. @@ -30,16 +29,9 @@ pub(super) fn equal_bits( rhs_start: usize, len: usize, ) -> bool { - let lhs = BitChunks::new(lhs_values, lhs_start, len); - let rhs = BitChunks::new(rhs_values, rhs_start, len); - - for (a, b) in lhs.iter().zip(rhs.iter()) { - if a != b { - return false; - } - } - - lhs.remainder_bits() == rhs.remainder_bits() + let lhs = BitChunks::new(lhs_values, lhs_start, len).iter_padded(); + let rhs = BitChunks::new(rhs_values, rhs_start, len).iter_padded(); + lhs.zip(rhs).all(|(a, b)| a == b) } #[inline] @@ -50,15 +42,16 @@ pub(super) fn equal_nulls( rhs_start: usize, len: usize, ) -> bool { - let lhs_offset = lhs_start + lhs.offset(); - let rhs_offset = rhs_start + rhs.offset(); - - match (lhs.null_buffer(), rhs.null_buffer()) { - (Some(lhs), Some(rhs)) => { - equal_bits(lhs.as_slice(), rhs.as_slice(), lhs_offset, rhs_offset, len) - } - (Some(lhs), None) => !contains_nulls(Some(lhs), lhs_offset, len), - (None, Some(rhs)) => !contains_nulls(Some(rhs), rhs_offset, len), + match (lhs.nulls(), rhs.nulls()) { + (Some(lhs), Some(rhs)) => equal_bits( + lhs.validity(), + rhs.validity(), + lhs.offset() + lhs_start, + rhs.offset() + rhs_start, + len, + ), + (Some(lhs), None) => !contains_nulls(Some(lhs), lhs_start, len), + (None, Some(rhs)) => !contains_nulls(Some(rhs), rhs_start, len), (None, None) => true, } } @@ -66,7 +59,7 @@ pub(super) fn equal_nulls( #[inline] pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { let equal_type = match (lhs.data_type(), rhs.data_type()) { - (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _, r_mode)) => { + (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) => { l_fields == r_fields && l_mode == r_mode } (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => { @@ -80,11 +73,9 @@ pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { let r_value_field = r_fields.get(1).unwrap(); // We don't enforce the equality of field names - let data_type_equal = l_key_field.data_type() - == r_key_field.data_type() + let data_type_equal = l_key_field.data_type() == r_key_field.data_type() && l_value_field.data_type() == r_value_field.data_type(); - let nullability_equal = l_key_field.is_nullable() - == r_key_field.is_nullable() + let nullability_equal = l_key_field.is_nullable() == r_key_field.is_nullable() && l_value_field.is_nullable() == r_value_field.is_nullable(); let metadata_equal = l_key_field.metadata() == r_key_field.metadata() && l_value_field.metadata() == r_value_field.metadata(); diff --git a/arrow/src/array/equal/variable_size.rs b/arrow-data/src/equal/variable_size.rs similarity index 71% rename from arrow/src/array/equal/variable_size.rs rename to arrow-data/src/equal/variable_size.rs index f40f79e404ac..92f00818b4a0 100644 --- a/arrow/src/array/equal/variable_size.rs +++ b/arrow-data/src/equal/variable_size.rs @@ -15,15 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::util::bit_util::get_bit; -use crate::{ - array::data::count_nulls, - array::{ArrayData, OffsetSizeTrait}, -}; +use crate::data::{contains_nulls, ArrayData}; +use arrow_buffer::ArrowNativeType; +use num::Integer; use super::utils::equal_len; -fn offset_value_equal( +fn offset_value_equal( lhs_values: &[u8], rhs_values: &[u8], lhs_offsets: &[T], @@ -32,8 +30,8 @@ fn offset_value_equal( rhs_pos: usize, len: usize, ) -> bool { - let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); - let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); + let lhs_start = lhs_offsets[lhs_pos].as_usize(); + let rhs_start = rhs_offsets[rhs_pos].as_usize(); let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; @@ -47,7 +45,7 @@ fn offset_value_equal( ) } -pub(super) fn variable_sized_equal( +pub(super) fn variable_sized_equal( lhs: &ArrayData, rhs: &ArrayData, lhs_start: usize, @@ -61,14 +59,9 @@ pub(super) fn variable_sized_equal( let lhs_values = lhs.buffers()[1].as_slice(); let rhs_values = rhs.buffers()[1].as_slice(); - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - - if lhs_null_count == 0 - && rhs_null_count == 0 - && !lhs_values.is_empty() - && !rhs_values.is_empty() - { + // Only checking one null mask here because by the time the control flow reaches + // this point, the equality of the two masks would have already been verified. + if !contains_nulls(lhs.nulls(), lhs_start, len) { offset_value_equal( lhs_values, rhs_values, @@ -84,15 +77,8 @@ pub(super) fn variable_sized_equal( let rhs_pos = rhs_start + i; // the null bits can still be `None`, indicating that the value is valid. - let lhs_is_null = !lhs - .null_buffer() - .map(|v| get_bit(v.as_slice(), lhs.offset() + lhs_pos)) - .unwrap_or(true); - - let rhs_is_null = !rhs - .null_buffer() - .map(|v| get_bit(v.as_slice(), rhs.offset() + rhs_pos)) - .unwrap_or(true); + let lhs_is_null = lhs.nulls().map(|v| v.is_null(lhs_pos)).unwrap_or_default(); + let rhs_is_null = rhs.nulls().map(|v| v.is_null(rhs_pos)).unwrap_or_default(); lhs_is_null || (lhs_is_null == rhs_is_null) diff --git a/arrow-data/src/ffi.rs b/arrow-data/src/ffi.rs new file mode 100644 index 000000000000..589f7dac6d19 --- /dev/null +++ b/arrow-data/src/ffi.rs @@ -0,0 +1,329 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). + +use crate::bit_mask::set_bits; +use crate::{layout, ArrayData}; +use arrow_buffer::buffer::NullBuffer; +use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_schema::DataType; +use std::ffi::c_void; + +/// ABI-compatible struct for ArrowArray from C Data Interface +/// See +/// +/// ``` +/// # use arrow_data::ArrayData; +/// # use arrow_data::ffi::FFI_ArrowArray; +/// fn export_array(array: &ArrayData) -> FFI_ArrowArray { +/// FFI_ArrowArray::new(array) +/// } +/// ``` +#[repr(C)] +#[derive(Debug)] +pub struct FFI_ArrowArray { + length: i64, + null_count: i64, + offset: i64, + n_buffers: i64, + n_children: i64, + buffers: *mut *const c_void, + children: *mut *mut FFI_ArrowArray, + dictionary: *mut FFI_ArrowArray, + release: Option, + // When exported, this MUST contain everything that is owned by this array. + // for example, any buffer pointed to in `buffers` must be here, as well + // as the `buffers` pointer itself. + // In other words, everything in [FFI_ArrowArray] must be owned by + // `private_data` and can assume that they do not outlive `private_data`. + private_data: *mut c_void, +} + +impl Drop for FFI_ArrowArray { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +unsafe impl Send for FFI_ArrowArray {} +unsafe impl Sync for FFI_ArrowArray {} + +// callback used to drop [FFI_ArrowArray] when it is exported +unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it` + let private = Box::from_raw(array.private_data as *mut ArrayPrivateData); + for child in private.children.iter() { + let _ = Box::from_raw(*child); + } + if !private.dictionary.is_null() { + let _ = Box::from_raw(private.dictionary); + } + + array.release = None; +} + +/// Aligns the provided `nulls` to the provided `data_offset` +/// +/// This is a temporary measure until offset is removed from ArrayData (#1799) +fn align_nulls(data_offset: usize, nulls: Option<&NullBuffer>) -> Option { + let nulls = nulls?; + if data_offset == nulls.offset() { + // Underlying buffer is already aligned + return Some(nulls.buffer().clone()); + } + if data_offset == 0 { + return Some(nulls.inner().sliced()); + } + let mut builder = MutableBuffer::new_null(data_offset + nulls.len()); + set_bits( + builder.as_slice_mut(), + nulls.validity(), + data_offset, + nulls.offset(), + nulls.len(), + ); + Some(builder.into()) +} + +struct ArrayPrivateData { + #[allow(dead_code)] + buffers: Vec>, + buffers_ptr: Box<[*const c_void]>, + children: Box<[*mut FFI_ArrowArray]>, + dictionary: *mut FFI_ArrowArray, +} + +impl FFI_ArrowArray { + /// creates a new `FFI_ArrowArray` from existing data. + pub fn new(data: &ArrayData) -> Self { + let data_layout = layout(data.data_type()); + + let buffers = if data_layout.can_contain_null_mask { + // * insert the null buffer at the start + // * make all others `Option`. + std::iter::once(align_nulls(data.offset(), data.nulls())) + .chain(data.buffers().iter().map(|b| Some(b.clone()))) + .collect::>() + } else { + data.buffers().iter().map(|b| Some(b.clone())).collect() + }; + + // `n_buffers` is the number of buffers by the spec. + let n_buffers = { + data_layout.buffers.len() + { + // If the layout has a null buffer by Arrow spec. + // Note that even the array doesn't have a null buffer because it has + // no null value, we still need to count 1 here to follow the spec. + usize::from(data_layout.can_contain_null_mask) + } + } as i64; + + let buffers_ptr = buffers + .iter() + .flat_map(|maybe_buffer| match maybe_buffer { + // note that `raw_data` takes into account the buffer's offset + Some(b) => Some(b.as_ptr() as *const c_void), + // This is for null buffer. We only put a null pointer for + // null buffer if by spec it can contain null mask. + None if data_layout.can_contain_null_mask => Some(std::ptr::null()), + None => None, + }) + .collect::>(); + + let empty = vec![]; + let (child_data, dictionary) = match data.data_type() { + DataType::Dictionary(_, _) => ( + empty.as_slice(), + Box::into_raw(Box::new(FFI_ArrowArray::new(&data.child_data()[0]))), + ), + _ => (data.child_data(), std::ptr::null_mut()), + }; + + let children = child_data + .iter() + .map(|child| Box::into_raw(Box::new(FFI_ArrowArray::new(child)))) + .collect::>(); + let n_children = children.len() as i64; + + // As in the IPC format, emit null_count = length for Null type + let null_count = match data.data_type() { + DataType::Null => data.len(), + _ => data.null_count(), + }; + + // create the private data owning everything. + // any other data must be added here, e.g. via a struct, to track lifetime. + let mut private_data = Box::new(ArrayPrivateData { + buffers, + buffers_ptr, + children, + dictionary, + }); + + Self { + length: data.len() as i64, + null_count: null_count as i64, + offset: data.offset() as i64, + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children.as_mut_ptr(), + dictionary, + release: Some(release_array), + private_data: Box::into_raw(private_data) as *mut c_void, + } + } + + /// Takes ownership of the pointed to [`FFI_ArrowArray`] + /// + /// This acts to [move] the data out of `array`, setting the release callback to NULL + /// + /// # Safety + /// + /// * `array` must be [valid] for reads and writes + /// * `array` must be properly aligned + /// * `array` must point to a properly initialized value of [`FFI_ArrowArray`] + /// + /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array + /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety + pub unsafe fn from_raw(array: *mut FFI_ArrowArray) -> Self { + std::ptr::replace(array, Self::empty()) + } + + /// create an empty `FFI_ArrowArray`, which can be used to import data into + pub fn empty() -> Self { + Self { + length: 0, + null_count: 0, + offset: 0, + n_buffers: 0, + n_children: 0, + buffers: std::ptr::null_mut(), + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// the length of the array + #[inline] + pub fn len(&self) -> usize { + self.length as usize + } + + /// whether the array is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Whether the array has been released + #[inline] + pub fn is_released(&self) -> bool { + self.release.is_none() + } + + /// the offset of the array + #[inline] + pub fn offset(&self) -> usize { + self.offset as usize + } + + /// the null count of the array + #[inline] + pub fn null_count(&self) -> usize { + self.null_count as usize + } + + /// Returns the buffer at the provided index + /// + /// # Panic + /// Panics if index exceeds the number of buffers or the buffer is not correctly aligned + #[inline] + pub fn buffer(&self, index: usize) -> *const u8 { + assert!(!self.buffers.is_null()); + assert!(index < self.num_buffers()); + // SAFETY: + // If buffers is not null must be valid for reads up to num_buffers + unsafe { std::ptr::read_unaligned((self.buffers as *mut *const u8).add(index)) } + } + + /// Returns the number of buffers + #[inline] + pub fn num_buffers(&self) -> usize { + self.n_buffers as _ + } + + /// Returns the child at the provided index + #[inline] + pub fn child(&self, index: usize) -> &FFI_ArrowArray { + assert!(!self.children.is_null()); + assert!(index < self.num_children()); + // Safety: + // If children is not null must be valid for reads up to num_children + unsafe { + let child = std::ptr::read_unaligned(self.children.add(index)); + child.as_ref().unwrap() + } + } + + /// Returns the number of children + #[inline] + pub fn num_children(&self) -> usize { + self.n_children as _ + } + + /// Returns the dictionary if any + #[inline] + pub fn dictionary(&self) -> Option<&Self> { + // Safety: + // If dictionary is not null should be valid for reads of `Self` + unsafe { self.dictionary.as_ref() } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // More tests located in top-level arrow crate + + #[test] + fn null_array_n_buffers() { + let data = ArrayData::new_null(&DataType::Null, 10); + + let ffi_array = FFI_ArrowArray::new(&data); + assert_eq!(0, ffi_array.n_buffers); + + let private_data = + unsafe { Box::from_raw(ffi_array.private_data as *mut ArrayPrivateData) }; + + assert_eq!(0, private_data.buffers_ptr.len()); + + Box::into_raw(private_data); + } +} diff --git a/arrow-data/src/lib.rs b/arrow-data/src/lib.rs new file mode 100644 index 000000000000..cfa0dba66c35 --- /dev/null +++ b/arrow-data/src/lib.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Low-level array data abstractions for [Apache Arrow Rust](https://docs.rs/arrow) +//! +//! For a higher-level, strongly-typed interface see [arrow_array](https://docs.rs/arrow_array) + +mod data; +pub use data::*; + +mod equal; +pub mod transform; + +pub use arrow_buffer::{bit_iterator, bit_mask}; +pub mod decimal; + +#[cfg(feature = "ffi")] +pub mod ffi; diff --git a/arrow/src/array/transform/boolean.rs b/arrow-data/src/transform/boolean.rs similarity index 95% rename from arrow/src/array/transform/boolean.rs rename to arrow-data/src/transform/boolean.rs index e0b6231a226e..d93fa15a4e0f 100644 --- a/arrow/src/array/transform/boolean.rs +++ b/arrow-data/src/transform/boolean.rs @@ -16,8 +16,8 @@ // under the License. use super::{Extend, _MutableArrayData, utils::resize_for_bits}; -use crate::array::ArrayData; -use crate::util::bit_mask::set_bits; +use crate::bit_mask::set_bits; +use crate::ArrayData; pub(super) fn build_extend(array: &ArrayData) -> Extend { let values = array.buffers()[0].as_slice(); diff --git a/arrow/src/array/transform/fixed_binary.rs b/arrow-data/src/transform/fixed_binary.rs similarity index 54% rename from arrow/src/array/transform/fixed_binary.rs rename to arrow-data/src/transform/fixed_binary.rs index 6d6262ca3c4e..44c6f46ebf7e 100644 --- a/arrow/src/array/transform/fixed_binary.rs +++ b/arrow-data/src/transform/fixed_binary.rs @@ -15,50 +15,28 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::ArrayData, datatypes::DataType}; - use super::{Extend, _MutableArrayData}; +use crate::ArrayData; +use arrow_schema::DataType; pub(super) fn build_extend(array: &ArrayData) -> Extend { let size = match array.data_type() { DataType::FixedSizeBinary(i) => *i as usize, - DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; let values = &array.buffers()[0].as_slice()[array.offset() * size..]; - if array.null_count() == 0 { - // fast case where we can copy regions without null issues - Box::new( - move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - let buffer = &mut mutable.buffer1; - buffer.extend_from_slice(&values[start * size..(start + len) * size]); - }, - ) - } else { - Box::new( - move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - // nulls present: append item by item, ignoring null entries - let values_buffer = &mut mutable.buffer1; - - (start..start + len).for_each(|i| { - if array.is_valid(i) { - // append value - let bytes = &values[i * size..(i + 1) * size]; - values_buffer.extend_from_slice(bytes); - } else { - values_buffer.extend_zeros(size); - } - }) - }, - ) - } + Box::new( + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { + let buffer = &mut mutable.buffer1; + buffer.extend_from_slice(&values[start * size..(start + len) * size]); + }, + ) } pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { let size = match mutable.data_type { DataType::FixedSizeBinary(i) => i as usize, - DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; diff --git a/arrow/src/array/transform/fixed_size_list.rs b/arrow-data/src/transform/fixed_size_list.rs similarity index 53% rename from arrow/src/array/transform/fixed_size_list.rs rename to arrow-data/src/transform/fixed_size_list.rs index 77912a7026fd..8eef7bce9bb3 100644 --- a/arrow/src/array/transform/fixed_size_list.rs +++ b/arrow-data/src/transform/fixed_size_list.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; -use crate::datatypes::DataType; +use crate::ArrayData; +use arrow_schema::DataType; use super::{Extend, _MutableArrayData}; @@ -26,38 +26,14 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { _ => unreachable!(), }; - if array.null_count() == 0 { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - mutable.child_data.iter_mut().for_each(|child| { - child.extend(index, start * size, (start + len) * size) - }) - }, - ) - } else { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - (start..start + len).for_each(|i| { - if array.is_valid(i) { - mutable.child_data.iter_mut().for_each(|child| { - child.extend(index, i * size, (i + 1) * size) - }) - } else { - mutable - .child_data - .iter_mut() - .for_each(|child| child.extend_nulls(size)) - } - }) - }, - ) - } + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, start * size, (start + len) * size)) + }, + ) } pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { diff --git a/arrow-data/src/transform/list.rs b/arrow-data/src/transform/list.rs new file mode 100644 index 000000000000..d9a1c62a8e8e --- /dev/null +++ b/arrow-data/src/transform/list.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + Extend, _MutableArrayData, + utils::{extend_offsets, get_last_offset}, +}; +use crate::ArrayData; +use arrow_buffer::ArrowNativeType; +use num::{CheckedAdd, Integer}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let offsets = array.buffer::(0); + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; + + // offsets + extend_offsets::(offset_buffer, last_offset, &offsets[start..start + len + 1]); + + mutable.child_data[0].extend( + index, + offsets[start].as_usize(), + offsets[start + len].as_usize(), + ) + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; + + (0..len).for_each(|_| offset_buffer.push(last_offset)) +} diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs new file mode 100644 index 000000000000..268cf10f2326 --- /dev/null +++ b/arrow-data/src/transform/mod.rs @@ -0,0 +1,694 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + data::{into_buffers, new_buffers}, + ArrayData, ArrayDataBuilder, +}; +use crate::bit_mask::set_bits; +use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; +use arrow_buffer::{bit_util, i256, ArrowNativeType, MutableBuffer}; +use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode}; +use half::f16; +use num::Integer; +use std::mem; + +mod boolean; +mod fixed_binary; +mod fixed_size_list; +mod list; +mod null; +mod primitive; +mod structure; +mod union; +mod utils; +mod variable_size; + +type ExtendNullBits<'a> = Box; +// function that extends `[start..start+len]` to the mutable array. +// this is dynamic because different data_types influence how buffers and children are extended. +type Extend<'a> = Box; + +type ExtendNulls = Box; + +/// A mutable [ArrayData] that knows how to freeze itself into an [ArrayData]. +/// This is just a data container. +#[derive(Debug)] +struct _MutableArrayData<'a> { + pub data_type: DataType, + pub null_count: usize, + + pub len: usize, + pub null_buffer: Option, + + // arrow specification only allows up to 3 buffers (2 ignoring the nulls above). + // Thus, we place them in the stack to avoid bound checks and greater data locality. + pub buffer1: MutableBuffer, + pub buffer2: MutableBuffer, + pub child_data: Vec>, +} + +impl<'a> _MutableArrayData<'a> { + fn null_buffer(&mut self) -> &mut MutableBuffer { + self.null_buffer + .as_mut() + .expect("MutableArrayData not nullable") + } + + fn freeze(self, dictionary: Option) -> ArrayDataBuilder { + let buffers = into_buffers(&self.data_type, self.buffer1, self.buffer2); + + let child_data = match self.data_type { + DataType::Dictionary(_, _) => vec![dictionary.unwrap()], + _ => { + let mut child_data = Vec::with_capacity(self.child_data.len()); + for child in self.child_data { + child_data.push(child.freeze()); + } + child_data + } + }; + + let nulls = self + .null_buffer + .map(|nulls| { + let bools = BooleanBuffer::new(nulls.into(), 0, self.len); + unsafe { NullBuffer::new_unchecked(bools, self.null_count) } + }) + .filter(|n| n.null_count() > 0); + + ArrayDataBuilder::new(self.data_type) + .offset(0) + .len(self.len) + .nulls(nulls) + .buffers(buffers) + .child_data(child_data) + } +} + +fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits { + if let Some(nulls) = array.nulls() { + let bytes = nulls.validity(); + Box::new(move |mutable, start, len| { + let mutable_len = mutable.len; + let out = mutable.null_buffer(); + utils::resize_for_bits(out, mutable_len + len); + mutable.null_count += set_bits( + out.as_slice_mut(), + bytes, + mutable_len, + nulls.offset() + start, + len, + ); + }) + } else if use_nulls { + Box::new(|mutable, _, len| { + let mutable_len = mutable.len; + let out = mutable.null_buffer(); + utils::resize_for_bits(out, mutable_len + len); + let write_data = out.as_slice_mut(); + (0..len).for_each(|i| { + bit_util::set_bit(write_data, mutable_len + i); + }); + }) + } else { + Box::new(|_, _, _| {}) + } +} + +/// Struct to efficiently and interactively create an [ArrayData] from an existing [ArrayData] by +/// copying chunks. +/// +/// The main use case of this struct is to perform unary operations to arrays of arbitrary types, +/// such as `filter` and `take`. +pub struct MutableArrayData<'a> { + #[allow(dead_code)] + arrays: Vec<&'a ArrayData>, + // The attributes in [_MutableArrayData] cannot be in [MutableArrayData] due to + // mutability invariants (interior mutability): + // [MutableArrayData] contains a function that can only mutate [_MutableArrayData], not + // [MutableArrayData] itself + data: _MutableArrayData<'a>, + + // the child data of the `Array` in Dictionary arrays. + // This is not stored in `MutableArrayData` because these values constant and only needed + // at the end, when freezing [_MutableArrayData]. + dictionary: Option, + + // function used to extend values from arrays. This function's lifetime is bound to the array + // because it reads values from it. + extend_values: Vec>, + // function used to extend nulls from arrays. This function's lifetime is bound to the array + // because it reads nulls from it. + extend_null_bits: Vec>, + + // function used to extend nulls. + // this is independent of the arrays and therefore has no lifetime. + extend_nulls: ExtendNulls, +} + +impl<'a> std::fmt::Debug for MutableArrayData<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // ignores the closures. + f.debug_struct("MutableArrayData") + .field("data", &self.data) + .finish() + } +} + +/// Builds an extend that adds `offset` to the source primitive +/// Additionally validates that `max` fits into the +/// the underlying primitive returning None if not +fn build_extend_dictionary(array: &ArrayData, offset: usize, max: usize) -> Option { + macro_rules! validate_and_build { + ($dt: ty) => {{ + let _: $dt = max.try_into().ok()?; + let offset: $dt = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + }}; + } + match array.data_type() { + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => validate_and_build!(u8), + DataType::UInt16 => validate_and_build!(u16), + DataType::UInt32 => validate_and_build!(u32), + DataType::UInt64 => validate_and_build!(u64), + DataType::Int8 => validate_and_build!(i8), + DataType::Int16 => validate_and_build!(i16), + DataType::Int32 => validate_and_build!(i32), + DataType::Int64 => validate_and_build!(i64), + _ => unreachable!(), + }, + _ => None, + } +} + +fn build_extend(array: &ArrayData) -> Extend { + match array.data_type() { + DataType::Null => null::build_extend(array), + DataType::Boolean => boolean::build_extend(array), + DataType::UInt8 => primitive::build_extend::(array), + DataType::UInt16 => primitive::build_extend::(array), + DataType::UInt32 => primitive::build_extend::(array), + DataType::UInt64 => primitive::build_extend::(array), + DataType::Int8 => primitive::build_extend::(array), + DataType::Int16 => primitive::build_extend::(array), + DataType::Int32 => primitive::build_extend::(array), + DataType::Int64 => primitive::build_extend::(array), + DataType::Float32 => primitive::build_extend::(array), + DataType::Float64 => primitive::build_extend::(array), + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::build_extend::(array) + } + DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => primitive::build_extend::(array), + DataType::Interval(IntervalUnit::MonthDayNano) => primitive::build_extend::(array), + DataType::Decimal128(_, _) => primitive::build_extend::(array), + DataType::Decimal256(_, _) => primitive::build_extend::(array), + DataType::Utf8 | DataType::Binary => variable_size::build_extend::(array), + DataType::LargeUtf8 | DataType::LargeBinary => variable_size::build_extend::(array), + DataType::Map(_, _) | DataType::List(_) => list::build_extend::(array), + DataType::LargeList(_) => list::build_extend::(array), + DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), + DataType::Struct(_) => structure::build_extend(array), + DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), + DataType::Float16 => primitive::build_extend::(array), + DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array), + DataType::Union(_, mode) => match mode { + UnionMode::Sparse => union::build_extend_sparse(array), + UnionMode::Dense => union::build_extend_dense(array), + }, + DataType::RunEndEncoded(_, _) => todo!(), + } +} + +fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { + Box::new(match data_type { + DataType::Null => null::extend_nulls, + DataType::Boolean => boolean::extend_nulls, + DataType::UInt8 => primitive::extend_nulls::, + DataType::UInt16 => primitive::extend_nulls::, + DataType::UInt32 => primitive::extend_nulls::, + DataType::UInt64 => primitive::extend_nulls::, + DataType::Int8 => primitive::extend_nulls::, + DataType::Int16 => primitive::extend_nulls::, + DataType::Int32 => primitive::extend_nulls::, + DataType::Int64 => primitive::extend_nulls::, + DataType::Float32 => primitive::extend_nulls::, + DataType::Float64 => primitive::extend_nulls::, + DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::extend_nulls:: + } + DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => primitive::extend_nulls::, + DataType::Interval(IntervalUnit::MonthDayNano) => primitive::extend_nulls::, + DataType::Decimal128(_, _) => primitive::extend_nulls::, + DataType::Decimal256(_, _) => primitive::extend_nulls::, + DataType::Utf8 | DataType::Binary => variable_size::extend_nulls::, + DataType::LargeUtf8 | DataType::LargeBinary => variable_size::extend_nulls::, + DataType::Map(_, _) | DataType::List(_) => list::extend_nulls::, + DataType::LargeList(_) => list::extend_nulls::, + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => primitive::extend_nulls::, + DataType::UInt16 => primitive::extend_nulls::, + DataType::UInt32 => primitive::extend_nulls::, + DataType::UInt64 => primitive::extend_nulls::, + DataType::Int8 => primitive::extend_nulls::, + DataType::Int16 => primitive::extend_nulls::, + DataType::Int32 => primitive::extend_nulls::, + DataType::Int64 => primitive::extend_nulls::, + _ => unreachable!(), + }, + DataType::Struct(_) => structure::extend_nulls, + DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, + DataType::Float16 => primitive::extend_nulls::, + DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls, + DataType::Union(_, mode) => match mode { + UnionMode::Sparse => union::extend_nulls_sparse, + UnionMode::Dense => union::extend_nulls_dense, + }, + DataType::RunEndEncoded(_, _) => todo!(), + }) +} + +fn preallocate_offset_and_binary_buffer( + capacity: usize, + binary_size: usize, +) -> [MutableBuffer; 2] { + // offsets + let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element + buffer.push(Offset::zero()); + + [ + buffer, + MutableBuffer::new(binary_size * mem::size_of::()), + ] +} + +/// Define capacities of child data or data buffers. +#[derive(Debug, Clone)] +pub enum Capacities { + /// Binary, Utf8 and LargeUtf8 data types + /// Define + /// * the capacity of the array offsets + /// * the capacity of the binary/ str buffer + Binary(usize, Option), + /// List and LargeList data types + /// Define + /// * the capacity of the array offsets + /// * the capacity of the child data + List(usize, Option>), + /// Struct type + /// * the capacity of the array + /// * the capacities of the fields + Struct(usize, Option>), + /// Dictionary type + /// * the capacity of the array/keys + /// * the capacity of the values + Dictionary(usize, Option>), + /// Don't preallocate inner buffers and rely on array growth strategy + Array(usize), +} +impl<'a> MutableArrayData<'a> { + /// returns a new [MutableArrayData] with capacity to `capacity` slots and specialized to create an + /// [ArrayData] from multiple `arrays`. + /// + /// `use_nulls` is a flag used to optimize insertions. It should be `false` if the only source of nulls + /// are the arrays themselves and `true` if the user plans to call [MutableArrayData::extend_nulls]. + /// In other words, if `use_nulls` is `false`, calling [MutableArrayData::extend_nulls] should not be used. + pub fn new(arrays: Vec<&'a ArrayData>, use_nulls: bool, capacity: usize) -> Self { + Self::with_capacities(arrays, use_nulls, Capacities::Array(capacity)) + } + + /// Similar to [MutableArrayData::new], but lets users define the preallocated capacities of the array. + /// See also [MutableArrayData::new] for more information on the arguments. + /// + /// # Panic + /// This function panics if the given `capacities` don't match the data type of `arrays`. Or when + /// a [Capacities] variant is not yet supported. + pub fn with_capacities( + arrays: Vec<&'a ArrayData>, + use_nulls: bool, + capacities: Capacities, + ) -> Self { + let data_type = arrays[0].data_type(); + + for a in arrays.iter().skip(1) { + assert_eq!( + data_type, + a.data_type(), + "Arrays with inconsistent types passed to MutableArrayData" + ) + } + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + let use_nulls = use_nulls | arrays.iter().any(|array| array.null_count() > 0); + + let mut array_capacity; + + let [buffer1, buffer2] = match (data_type, &capacities) { + ( + DataType::LargeUtf8 | DataType::LargeBinary, + Capacities::Binary(capacity, Some(value_cap)), + ) => { + array_capacity = *capacity; + preallocate_offset_and_binary_buffer::(*capacity, *value_cap) + } + (DataType::Utf8 | DataType::Binary, Capacities::Binary(capacity, Some(value_cap))) => { + array_capacity = *capacity; + preallocate_offset_and_binary_buffer::(*capacity, *value_cap) + } + (_, Capacities::Array(capacity)) => { + array_capacity = *capacity; + new_buffers(data_type, *capacity) + } + (DataType::List(_) | DataType::LargeList(_), Capacities::List(capacity, _)) => { + array_capacity = *capacity; + new_buffers(data_type, *capacity) + } + _ => panic!("Capacities: {capacities:?} not yet supported"), + }; + + let child_data = match &data_type { + DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Null + | DataType::Boolean + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary + | DataType::Interval(_) + | DataType::FixedSizeBinary(_) => vec![], + DataType::Map(_, _) | DataType::List(_) | DataType::LargeList(_) => { + let children = arrays + .iter() + .map(|array| &array.child_data()[0]) + .collect::>(); + + let capacities = + if let Capacities::List(capacity, ref child_capacities) = capacities { + child_capacities + .clone() + .map(|c| *c) + .unwrap_or(Capacities::Array(capacity)) + } else { + Capacities::Array(array_capacity) + }; + + vec![MutableArrayData::with_capacities( + children, use_nulls, capacities, + )] + } + // the dictionary type just appends keys and clones the values. + DataType::Dictionary(_, _) => vec![], + DataType::Struct(fields) => match capacities { + Capacities::Struct(capacity, Some(ref child_capacities)) => { + array_capacity = capacity; + (0..fields.len()) + .zip(child_capacities) + .map(|(i, child_cap)| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::with_capacities( + child_arrays, + use_nulls, + child_cap.clone(), + ) + }) + .collect::>() + } + Capacities::Struct(capacity, None) => { + array_capacity = capacity; + (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, capacity) + }) + .collect::>() + } + _ => (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, array_capacity) + }) + .collect::>(), + }, + DataType::RunEndEncoded(_, _) => { + let run_ends_child = arrays + .iter() + .map(|array| &array.child_data()[0]) + .collect::>(); + let value_child = arrays + .iter() + .map(|array| &array.child_data()[1]) + .collect::>(); + vec![ + MutableArrayData::new(run_ends_child, false, array_capacity), + MutableArrayData::new(value_child, use_nulls, array_capacity), + ] + } + DataType::FixedSizeList(_, _) => { + let children = arrays + .iter() + .map(|array| &array.child_data()[0]) + .collect::>(); + vec![MutableArrayData::new(children, use_nulls, array_capacity)] + } + DataType::Union(fields, _) => (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, array_capacity) + }) + .collect::>(), + }; + + // Get the dictionary if any, and if it is a concatenation of multiple + let (dictionary, dict_concat) = match &data_type { + DataType::Dictionary(_, _) => { + // If more than one dictionary, concatenate dictionaries together + let dict_concat = !arrays + .windows(2) + .all(|a| a[0].child_data()[0].ptr_eq(&a[1].child_data()[0])); + + match dict_concat { + false => (Some(arrays[0].child_data()[0].clone()), false), + true => { + if let Capacities::Dictionary(_, _) = capacities { + panic!("dictionary capacity not yet supported") + } + let dictionaries: Vec<_> = + arrays.iter().map(|array| &array.child_data()[0]).collect(); + let lengths: Vec<_> = dictionaries + .iter() + .map(|dictionary| dictionary.len()) + .collect(); + let capacity = lengths.iter().sum(); + + let mut mutable = MutableArrayData::new(dictionaries, false, capacity); + + for (i, len) in lengths.iter().enumerate() { + mutable.extend(i, 0, *len) + } + + (Some(mutable.freeze()), true) + } + } + } + _ => (None, false), + }; + + let extend_nulls = build_extend_nulls(data_type); + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(array, use_nulls)) + .collect(); + + let null_buffer = use_nulls.then(|| { + let null_bytes = bit_util::ceil(array_capacity, 8); + MutableBuffer::from_len_zeroed(null_bytes) + }); + + let extend_values = match &data_type { + DataType::Dictionary(_, _) => { + let mut next_offset = 0; + let extend_values: Result, _> = arrays + .iter() + .map(|array| { + let offset = next_offset; + let dict_len = array.child_data()[0].len(); + + if dict_concat { + next_offset += dict_len; + } + + build_extend_dictionary(array, offset, offset + dict_len) + .ok_or(ArrowError::DictionaryKeyOverflowError) + }) + .collect(); + + extend_values.expect("MutableArrayData::new is infallible") + } + _ => arrays.iter().map(|array| build_extend(array)).collect(), + }; + + let data = _MutableArrayData { + data_type: data_type.clone(), + len: 0, + null_count: 0, + null_buffer, + buffer1, + buffer2, + child_data, + }; + Self { + arrays, + data, + dictionary, + extend_values, + extend_null_bits, + extend_nulls, + } + } + + /// Extends this array with a chunk of its source arrays + /// + /// # Arguments + /// * `index` - the index of array that you what to copy values from + /// * `start` - the start index of the chunk (inclusive) + /// * `end` - the end index of the chunk (exclusive) + /// + /// # Panic + /// This function panics if there is an invalid index, + /// i.e. `index` >= the number of source arrays + /// or `end` > the length of the `index`th array + pub fn extend(&mut self, index: usize, start: usize, end: usize) { + let len = end - start; + (self.extend_null_bits[index])(&mut self.data, start, len); + (self.extend_values[index])(&mut self.data, index, start, len); + self.data.len += len; + } + + /// Extends this [MutableArrayData] with null elements, disregarding the bound arrays + /// + /// # Panics + /// + /// Panics if [`MutableArrayData`] not created with `use_nulls` or nullable source arrays + /// + pub fn extend_nulls(&mut self, len: usize) { + self.data.len += len; + let bit_len = bit_util::ceil(self.data.len, 8); + let nulls = self.data.null_buffer(); + nulls.resize(bit_len, 0); + self.data.null_count += len; + (self.extend_nulls)(&mut self.data, len); + } + + /// Returns the current length + #[inline] + pub fn len(&self) -> usize { + self.data.len + } + + /// Returns true if len is 0 + #[inline] + pub fn is_empty(&self) -> bool { + self.data.len == 0 + } + + /// Returns the current null count + #[inline] + pub fn null_count(&self) -> usize { + self.data.null_count + } + + /// Creates a [ArrayData] from the pushed regions up to this point, consuming `self`. + pub fn freeze(self) -> ArrayData { + unsafe { self.data.freeze(self.dictionary).build_unchecked() } + } + + /// Creates a [ArrayDataBuilder] from the pushed regions up to this point, consuming `self`. + /// This is useful for extending the default behavior of MutableArrayData. + pub fn into_builder(self) -> ArrayDataBuilder { + self.data.freeze(self.dictionary) + } +} + +// See arrow/tests/array_transform.rs for tests of transform functionality + +#[cfg(test)] +mod test { + use super::*; + use arrow_schema::Field; + use std::sync::Arc; + + #[test] + fn test_list_append_with_capacities() { + let array = ArrayData::new_empty(&DataType::List(Arc::new(Field::new( + "element", + DataType::Int64, + false, + )))); + + let mutable = MutableArrayData::with_capacities( + vec![&array], + false, + Capacities::List(6, Some(Box::new(Capacities::Array(17)))), + ); + + // capacities are rounded up to multiples of 64 by MutableBuffer + assert_eq!(mutable.data.buffer1.capacity(), 64); + assert_eq!(mutable.data.child_data[0].data.buffer1.capacity(), 192); + } +} diff --git a/arrow/src/array/transform/null.rs b/arrow-data/src/transform/null.rs similarity index 97% rename from arrow/src/array/transform/null.rs rename to arrow-data/src/transform/null.rs index e1335e179713..5d1535564d9e 100644 --- a/arrow/src/array/transform/null.rs +++ b/arrow-data/src/transform/null.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; - use super::{Extend, _MutableArrayData}; +use crate::ArrayData; pub(super) fn build_extend(_: &ArrayData) -> Extend { Box::new(move |_, _, _, _| {}) diff --git a/arrow/src/array/transform/primitive.rs b/arrow-data/src/transform/primitive.rs similarity index 91% rename from arrow/src/array/transform/primitive.rs rename to arrow-data/src/transform/primitive.rs index 4c765c0c0d95..627dc00de1df 100644 --- a/arrow/src/array/transform/primitive.rs +++ b/arrow-data/src/transform/primitive.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. +use crate::ArrayData; +use arrow_buffer::ArrowNativeType; use std::mem::size_of; use std::ops::Add; -use crate::{array::ArrayData, datatypes::ArrowNativeType}; - use super::{Extend, _MutableArrayData}; pub(super) fn build_extend(array: &ArrayData) -> Extend { @@ -47,9 +47,6 @@ where ) } -pub(super) fn extend_nulls( - mutable: &mut _MutableArrayData, - len: usize, -) { +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { mutable.buffer1.extend_zeros(len * size_of::()); } diff --git a/arrow-data/src/transform/structure.rs b/arrow-data/src/transform/structure.rs new file mode 100644 index 000000000000..7330dcaa3705 --- /dev/null +++ b/arrow-data/src/transform/structure.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{Extend, _MutableArrayData}; +use crate::ArrayData; + +pub(super) fn build_extend(_: &ArrayData) -> Extend { + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, start, start + len)) + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend_nulls(len)) +} diff --git a/arrow/src/array/transform/union.rs b/arrow-data/src/transform/union.rs similarity index 82% rename from arrow/src/array/transform/union.rs rename to arrow-data/src/transform/union.rs index bbea508219d0..d7083588d782 100644 --- a/arrow/src/array/transform/union.rs +++ b/arrow-data/src/transform/union.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::ArrayData; - use super::{Extend, _MutableArrayData}; +use crate::ArrayData; pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { let type_ids = array.buffer::(0); @@ -40,6 +39,9 @@ pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { let type_ids = array.buffer::(0); let offsets = array.buffer::(1); + let arrow_schema::DataType::Union(src_fields, _) = array.data_type() else { + unreachable!(); + }; Box::new( move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { @@ -49,14 +51,18 @@ pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { .extend_from_slice(&type_ids[start..start + len]); (start..start + len).for_each(|i| { - let type_id = type_ids[i] as usize; + let type_id = type_ids[i]; + let child_index = src_fields + .iter() + .position(|(r, _)| r == type_id) + .expect("invalid union type ID"); let src_offset = offsets[i] as usize; - let child_data = &mut mutable.child_data[type_id]; + let child_data = &mut mutable.child_data[child_index]; let dst_offset = child_data.len(); // Extend offsets mutable.buffer2.push(dst_offset as i32); - mutable.child_data[type_id].extend(index, src_offset, src_offset + 1) + mutable.child_data[child_index].extend(index, src_offset, src_offset + 1) }) }, ) diff --git a/arrow/src/array/transform/utils.rs b/arrow-data/src/transform/utils.rs similarity index 66% rename from arrow/src/array/transform/utils.rs rename to arrow-data/src/transform/utils.rs index 68aee79c41bb..5407f68e0d0c 100644 --- a/arrow/src/array/transform/utils.rs +++ b/arrow-data/src/transform/utils.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{array::OffsetSizeTrait, buffer::MutableBuffer, util::bit_util}; +use arrow_buffer::{bit_util, ArrowNativeType, MutableBuffer}; +use num::{CheckedAdd, Integer}; /// extends the `buffer` to be able to hold `len` bits, setting all bits of the new size to zero. #[inline] @@ -26,24 +27,25 @@ pub(super) fn resize_for_bits(buffer: &mut MutableBuffer, len: usize) { } } -pub(super) fn extend_offsets( +pub(super) fn extend_offsets( buffer: &mut MutableBuffer, mut last_offset: T, offsets: &[T], ) { - buffer.reserve(offsets.len() * std::mem::size_of::()); + buffer.reserve(std::mem::size_of_val(offsets)); offsets.windows(2).for_each(|offsets| { // compute the new offset let length = offsets[1] - offsets[0]; - last_offset += length; + // if you hit this appending to a StringArray / BinaryArray it is because you + // are trying to add more data than can fit into that type. Try breaking your data into + // smaller batches or using LargeStringArray / LargeBinaryArray + last_offset = last_offset.checked_add(&length).expect("offset overflow"); buffer.push(last_offset); }); } #[inline] -pub(super) unsafe fn get_last_offset( - offset_buffer: &MutableBuffer, -) -> T { +pub(super) unsafe fn get_last_offset(offset_buffer: &MutableBuffer) -> T { // JUSTIFICATION // Benefit // 20% performance improvement extend of variable sized arrays (see bench `mutable_array`) @@ -54,3 +56,16 @@ pub(super) unsafe fn get_last_offset( debug_assert!(prefix.is_empty() && suffix.is_empty()); *offsets.get_unchecked(offsets.len() - 1) } + +#[cfg(test)] +mod tests { + use crate::transform::utils::extend_offsets; + use arrow_buffer::MutableBuffer; + + #[test] + #[should_panic(expected = "offset overflow")] + fn test_overflow() { + let mut buffer = MutableBuffer::new(10); + extend_offsets(&mut buffer, i32::MAX - 4, &[0, 5]); + } +} diff --git a/arrow-data/src/transform/variable_size.rs b/arrow-data/src/transform/variable_size.rs new file mode 100644 index 000000000000..fa1592d973ed --- /dev/null +++ b/arrow-data/src/transform/variable_size.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use arrow_buffer::{ArrowNativeType, MutableBuffer}; +use num::traits::AsPrimitive; +use num::{CheckedAdd, Integer}; + +use super::{ + Extend, _MutableArrayData, + utils::{extend_offsets, get_last_offset}, +}; + +#[inline] +fn extend_offset_values>( + buffer: &mut MutableBuffer, + offsets: &[T], + values: &[u8], + start: usize, + len: usize, +) { + let start_values = offsets[start].as_(); + let end_values = offsets[start + len].as_(); + let new_values = &values[start_values..end_values]; + buffer.extend_from_slice(new_values); +} + +pub(super) fn build_extend>( + array: &ArrayData, +) -> Extend { + let offsets = array.buffer::(0); + let values = array.buffers()[1].as_slice(); + Box::new( + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { + let offset_buffer = &mut mutable.buffer1; + let values_buffer = &mut mutable.buffer2; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset = unsafe { get_last_offset(offset_buffer) }; + + extend_offsets::(offset_buffer, last_offset, &offsets[start..start + len + 1]); + // values + extend_offset_values::(values_buffer, offsets, values, start, len); + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; + + (0..len).for_each(|_| offset_buffer.push(last_offset)) +} diff --git a/arrow-flight/CONTRIBUTING.md b/arrow-flight/CONTRIBUTING.md new file mode 100644 index 000000000000..156a0b9caaed --- /dev/null +++ b/arrow-flight/CONTRIBUTING.md @@ -0,0 +1,41 @@ + + +# Flight + +## Generated Code + +The prost/tonic code can be generated by running, which in turn invokes the Rust binary located in [gen](./gen) + +This is necessary after modifying the protobuf definitions or altering the dependencies of [gen](./gen), and requires a +valid installation of [protoc](https://github.com/protocolbuffers/protobuf#protocol-compiler-installation). + +```bash +./regen.sh +``` + +### Why Vendor + +The standard approach to integrating `prost-build` / `tonic-build` is to use a `build.rs` script that automatically generates the code as part of the standard build process. + +Unfortunately this caused a lot of friction for users: + +- Requires all users to have a protoc install in order to compile the crate - [#2616](https://github.com/apache/arrow-rs/issues/2616) +- Some distributions have very old versions of protoc that don't support required functionality - [#1574](https://github.com/apache/arrow-rs/issues/1574) +- Inconsistent support within IDEs for code completion of automatically generated code diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index ecf02625c9d3..1bea347c3037 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -18,37 +18,72 @@ [package] name = "arrow-flight" description = "Apache Arrow Flight" -version = "22.0.0" -edition = "2021" -rust-version = "1.62" -authors = ["Apache Arrow "] -homepage = "https://github.com/apache/arrow-rs" -repository = "https://github.com/apache/arrow-rs" -license = "Apache-2.0" +version = { workspace = true } +edition = { workspace = true } +rust-version = "1.70.0" +authors = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } [dependencies] -arrow = { path = "../arrow", version = "22.0.0", default-features = false, features = ["ipc"] } -base64 = { version = "0.13", default-features = false } -tonic = { version = "0.8", default-features = false, features = ["transport", "codegen", "prost"] } +arrow-arith = { workspace = true, optional = true } +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +# Cast is needed to work around https://github.com/apache/arrow-rs/issues/3389 +arrow-cast = { workspace = true } +arrow-data = { workspace = true, optional = true } +arrow-ipc = { workspace = true } +arrow-ord = { workspace = true, optional = true } +arrow-row = { workspace = true, optional = true } +arrow-select = { workspace = true, optional = true } +arrow-schema = { workspace = true } +arrow-string = { workspace = true, optional = true } +base64 = { version = "0.21", default-features = false, features = ["std"] } bytes = { version = "1", default-features = false } -prost = { version = "0.11", default-features = false } -prost-types = { version = "0.11.0", default-features = false, optional = true } -prost-derive = { version = "0.11", default-features = false } +futures = { version = "0.3", default-features = false, features = ["alloc"] } +once_cell = { version = "1", optional = true } +paste = { version = "1.0" } +prost = { version = "0.12.1", default-features = false, features = ["prost-derive"] } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } -futures = { version = "0.3", default-features = false, features = ["alloc"]} +tonic = { version = "0.10.0", default-features = false, features = ["transport", "codegen", "prost"] } + +# CLI-related dependencies +anyhow = { version = "1.0", optional = true } +clap = { version = "4.4.6", default-features = false, features = ["std", "derive", "env", "help", "error-context", "usage", "wrap_help", "color", "suggestions"], optional = true } +tracing-log = { version = "0.2", optional = true } +tracing-subscriber = { version = "0.3.1", default-features = false, features = ["ansi", "env-filter", "fmt"], optional = true } + +[package.metadata.docs.rs] +all-features = true [features] default = [] -flight-sql-experimental = ["prost-types"] +flight-sql-experimental = ["arrow-arith", "arrow-data", "arrow-ord", "arrow-row", "arrow-select", "arrow-string", "once_cell"] +tls = ["tonic/tls"] -[dev-dependencies] +# Enable CLI tools +cli = ["anyhow", "arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", "tonic/tls-webpki-roots"] -[build-dependencies] -tonic-build = { version = "0.8", default-features = false, features = ["transport", "prost"] } -# Pin specific version of the tonic-build dependencies to avoid auto-generated -# (and checked in) arrow.flight.protocol.rs from changing -proc-macro2 = { version = ">1.0.30", default-features = false } +[dev-dependencies] +arrow-cast = { workspace = true, features = ["prettyprint"] } +assert_cmd = "2.0.8" +http = "0.2.9" +http-body = "0.4.5" +pin-project-lite = "0.2" +tempfile = "3.3" +tokio-stream = { version = "0.1", features = ["net"] } +tower = "0.4.13" [[example]] name = "flight_sql_server" -required-features = ["flight-sql-experimental"] +required-features = ["flight-sql-experimental", "tls"] + +[[bin]] +name = "flight_sql_client" +required-features = ["cli", "flight-sql-experimental", "tls"] + +[[test]] +name = "flight_sql_client_cli" +path = "tests/flight_sql_client_cli.rs" +required-features = ["cli", "flight-sql-experimental", "tls"] diff --git a/arrow-flight/README.md b/arrow-flight/README.md index 9e9a18ad4789..b80772ac927e 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -21,15 +21,56 @@ [![Crates.io](https://img.shields.io/crates/v/arrow-flight.svg)](https://crates.io/crates/arrow-flight) +See the [API documentation](https://docs.rs/arrow_flight/latest) for examples and the full API. + +The API documentation for most recent, unreleased code is available [here](https://arrow.apache.org/rust/arrow_flight/index.html). + ## Usage Add this to your Cargo.toml: ```toml [dependencies] -arrow-flight = "22.0.0" +arrow-flight = "39.0.0" ``` Apache Arrow Flight is a gRPC based protocol for exchanging Arrow data between processes. See the blog post [Introducing Apache Arrow Flight: A Framework for Fast Data Transport](https://arrow.apache.org/blog/2019/10/13/introducing-arrow-flight/) for more information. -This crate provides a Rust implementation of the [Flight.proto](../../format/Flight.proto) gRPC protocol and provides an example that demonstrates how to build a Flight server implemented with Tonic. +This crate provides a Rust implementation of the +[Flight.proto](../format/Flight.proto) gRPC protocol and +[examples](https://github.com/apache/arrow-rs/tree/master/arrow-flight/examples) +that demonstrate how to build a Flight server implemented with [tonic](https://docs.rs/crate/tonic/latest). + +## Feature Flags + +- `flight-sql-experimental`: Enables experimental support for + [Apache Arrow FlightSQL], a protocol for interacting with SQL databases. + +## CLI + +This crates offers a basic [Apache Arrow FlightSQL] command line interface. + +The client can be installed from the repository: + +```console +$ cargo install --features=cli,flight-sql-experimental,tls --bin=flight_sql_client --path=. --locked +``` + +The client comes with extensive help text: + +```console +$ flight_sql_client help +``` + +A query can be executed using: + +```console +$ flight_sql_client --host example.com statement-query "SELECT 1;" ++----------+ +| Int64(1) | ++----------+ +| 1 | ++----------+ +``` + +[apache arrow flightsql]: https://arrow.apache.org/docs/format/FlightSql.html diff --git a/arrow-flight/build.rs b/arrow-flight/build.rs deleted file mode 100644 index 25f034ac191b..000000000000 --- a/arrow-flight/build.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{ - env, - fs::OpenOptions, - io::{Read, Write}, - path::Path, -}; - -fn main() -> Result<(), Box> { - // override the build location, in order to check in the changes to proto files - env::set_var("OUT_DIR", "src"); - - // The current working directory can vary depending on how the project is being - // built or released so we build an absolute path to the proto file - let path = Path::new("../format/Flight.proto"); - if path.exists() { - // avoid rerunning build if the file has not changed - println!("cargo:rerun-if-changed=../format/Flight.proto"); - - let proto_dir = Path::new("../format"); - let proto_path = Path::new("../format/Flight.proto"); - - tonic_build::configure() - // protoc in unbuntu builder needs this option - .protoc_arg("--experimental_allow_proto3_optional") - .compile(&[proto_path], &[proto_dir])?; - - // read file contents to string - let mut file = OpenOptions::new() - .read(true) - .open("src/arrow.flight.protocol.rs")?; - let mut buffer = String::new(); - file.read_to_string(&mut buffer)?; - // append warning that file was auto-generate - let mut file = OpenOptions::new() - .write(true) - .truncate(true) - .open("src/arrow.flight.protocol.rs")?; - file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; - file.write_all(buffer.as_bytes())?; - } - - // override the build location, in order to check in the changes to proto files - env::set_var("OUT_DIR", "src/sql"); - // The current working directory can vary depending on how the project is being - // built or released so we build an absolute path to the proto file - let path = Path::new("../format/FlightSql.proto"); - if path.exists() { - // avoid rerunning build if the file has not changed - println!("cargo:rerun-if-changed=../format/FlightSql.proto"); - - let proto_dir = Path::new("../format"); - let proto_path = Path::new("../format/FlightSql.proto"); - - tonic_build::configure() - // protoc in unbuntu builder needs this option - .protoc_arg("--experimental_allow_proto3_optional") - .compile(&[proto_path], &[proto_dir])?; - - // read file contents to string - let mut file = OpenOptions::new() - .read(true) - .open("src/sql/arrow.flight.protocol.sql.rs")?; - let mut buffer = String::new(); - file.read_to_string(&mut buffer)?; - // append warning that file was auto-generate - let mut file = OpenOptions::new() - .write(true) - .truncate(true) - .open("src/sql/arrow.flight.protocol.sql.rs")?; - file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; - file.write_all(buffer.as_bytes())?; - } - - // Prost currently generates an empty file, this was fixed but then reverted - // https://github.com/tokio-rs/prost/pull/639 - let google_protobuf_rs = Path::new("src/sql/google.protobuf.rs"); - if google_protobuf_rs.exists() && google_protobuf_rs.metadata().unwrap().len() == 0 { - std::fs::remove_file(google_protobuf_rs).unwrap(); - } - - // As the proto file is checked in, the build should not fail if the file is not found - Ok(()) -} diff --git a/arrow-flight/examples/data/ca.pem b/arrow-flight/examples/data/ca.pem new file mode 100644 index 000000000000..d81956096677 --- /dev/null +++ b/arrow-flight/examples/data/ca.pem @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIE3DCCA0SgAwIBAgIRAObeYbJFiVQSGR8yk44dsOYwDQYJKoZIhvcNAQELBQAw +gYUxHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTEtMCsGA1UECwwkbHVj +aW9ATHVjaW9zLVdvcmstTUJQIChMdWNpbyBGcmFuY28pMTQwMgYDVQQDDCtta2Nl +cnQgbHVjaW9ATHVjaW9zLVdvcmstTUJQIChMdWNpbyBGcmFuY28pMB4XDTE5MDky +OTIzMzUzM1oXDTI5MDkyOTIzMzUzM1owgYUxHjAcBgNVBAoTFW1rY2VydCBkZXZl +bG9wbWVudCBDQTEtMCsGA1UECwwkbHVjaW9ATHVjaW9zLVdvcmstTUJQIChMdWNp +byBGcmFuY28pMTQwMgYDVQQDDCtta2NlcnQgbHVjaW9ATHVjaW9zLVdvcmstTUJQ +IChMdWNpbyBGcmFuY28pMIIBojANBgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEA +y/vE61ItbN/1qMYt13LMf+le1svwfkCCOPsygk7nWeRXmomgUpymqn1LnWiuB0+e +4IdVH2f5E9DknWEpPhKIDMRTCbz4jTwQfHrxCb8EGj3I8oO73pJO5S/xCedM9OrZ +qWcYWwN0GQ8cO/ogazaoZf1uTrRNHyzRyQsKyb412kDBTNEeldJZ2ljKgXXvh4HO +2ZIk9K/ZAaAf6VN8K/89rlJ9/KPgRVNsyAapE+Pb8XXKtpzeFiEcUfuXVYWtkoW+ +xyn/Zu8A1L2CXMQ1sARh7P/42BTMKr5pfraYgcBGxKXLrxoySpxCO9KqeVveKy1q +fPm5FCwFsXDr0koFLrCiR58mcIO/04Q9DKKTV4Z2a+LoqDJRY37KfBSc8sDMPhw5 +k7g3WPoa6QwXRjZTCA5fHWVgLOtcwLsnju5tBE4LDxwF6s+1wPF8NI5yUfufcEjJ +Z6JBwgoWYosVj27Lx7KBNLU/57PX9ryee691zmtswt0tP0WVBAgalhYWg99RXoa3 +AgMBAAGjRTBDMA4GA1UdDwEB/wQEAwICBDASBgNVHRMBAf8ECDAGAQH/AgEAMB0G +A1UdDgQWBBQdvlE4Bdcsjc9oaxjDCRu5FiuZkzANBgkqhkiG9w0BAQsFAAOCAYEA +BP/6o1kPINksMJZSSXgNCPZskDLyGw7auUZBnQ0ocDT3W6gXQvT/27LM1Hxoj9Eh +qU1TYdEt7ppecLQSGvzQ02MExG7H75art75oLiB+A5agDira937YbK4MCjqW481d +bDhw6ixJnY1jIvwjEZxyH6g94YyL927aSPch51fys0kSnjkFzC2RmuzDADScc4XH +5P1+/3dnIm3M5yfpeUzoaOrTXNmhn8p0RDIGrZ5kA5eISIGGD3Mm8FDssUNKndtO +g4ojHUsxb14icnAYGeye1NOhGiqN6TEFcgr6MPd0XdFNZ5c0HUaBCfN6bc+JxDV5 +MKZVJdNeJsYYwilgJNHAyZgCi30JC20xeYVtTF7CEEsMrFDGJ70Kz7o/FnRiFsA1 +ZSwVVWhhkHG2VkT4vlo0O3fYeZpenYicvy+wZNTbGK83gzHWqxxNC1z3Etg5+HRJ +F9qeMWPyfA3IHYXygiMcviyLcyNGG/SJ0EhUpYBN/Gg7wI5yFkcsxUDPPzd23O0M +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/data/client1.key b/arrow-flight/examples/data/client1.key new file mode 100644 index 000000000000..f4d8da2758ac --- /dev/null +++ b/arrow-flight/examples/data/client1.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCiiWrmzpENsI+c +Cz4aBpG+Pl8WOsrByfZx/ZnJdCZHO3MTYE6sCLhYssf0ygAEEGxvmkd4cxmfCfgf +xuT8u+D7Y5zQSoymkbWdU6/9jbNY6Ovtc+a96I1LGXOKROQw6KR3PuqLpUqEOJiB +l03qK+HMU0g56G1n31Od7HkJsDRvtePqy3I3LgpdcRps23sk46tCzZzhyfqIQ7Qf +J5qZx93tA+pfy+Xtb9XIUTIWKIp1/uyfh8Fp8HA0c9zJCSZzJOX2j3GH1TYqkVgP +egI2lhmdXhP5Q8vdhwy0UJaL28RJXA6UAg0tPZeWJe6pux9JiA81sI6My+Krrw8D +yibkGTTbAgMBAAECggEANCQhRym9HsclSsnQgkjZOE6J8nep08EWbjsMurOoE/He +WLjshAPIH6w6uSyUFLmwD51OkDVcYsiv8IG9s9YRtpOeGrPPqx/TQ0U1kAGFJ2CR +Tvt/aizQJudjSVgQXCBFontsgp/j58bAJdKEDDtHlGSjJvCJKGlcSa0ypwj/yVXt +frjROJNYzw9gMM7fN/IKF/cysdXSeLl/Q9RnHVIfC3jOFJutsILCK8+PC51dM8Fl +IOjmPmiZ080yV8RBcMRECwl53vLOE3OOpR3ZijfNCY1KU8zWi1oELJ1o6f4+cBye +7WPgFEoBew5XHXZ+ke8rh8cc0wth7ZTcC+xC/456AQKBgQDQr2EzBwXxYLF8qsN1 +R4zlzXILLdZN8a4bKfrS507/Gi1gDBHzfvbE7HfljeqrAkbKMdKNkbz3iS85SguH +jsM047xUGJg0PAcwBLHUedlSn1xDDcDHW6X8ginpA2Zz1+WAlhNz6XurA1wnjZmS +VcPxopH7QsuFCclqtt14MbBQ6QKBgQDHY3jcAVfQF+yhQ0YyM6GPLN342aTplgyJ +yz4uWVMeXacU4QzqGbf2L2hc9M2L28Xb37RWC3Q/by0vUefiC6qxRt+GJdRsOuQj +2F1uUibeWtAWp249fcfvxjLib276J+Eit18LI0s0mNR3ekK4GcjSe4NwSq5IrU8e +pBreet3dIwKBgQCxVuil4WkGd+I8jC0v5A7zVsR8hYZhlGkdgm45fgHevdMjlP5I +S3PPYxh8hj6O9o9L0k0Yq2nHfdgYujjUCNkQgBuR55iogv6kqsioRKgPE4fnH6/c +eqCy1bZh4tbUyPqqbF65mQfUCzXsEuQXvDSYiku+F0Q2mVuGCUJpmug3yQKBgEd3 +LeCdUp4xlQ0QEd74hpXM3RrO178pmwDgqj7uoU4m/zYKnBhkc3137I406F+SvE5c +1kRpApeh/64QS27IA7xazM9GS+cnDJKUgJiENY5JOoCELo03wiv8/EwQ6NQc6yMI +WrahRdlqVe0lEzjtdP+MacYb3nAKPmubIk5P96nFAoGAFAyrKpFTyXbNYBTw9Rab +TG6q7qkn+YTHN3+k4mo9NGGwZ3pXvmrKMYCIRhLMbqzsmTbFqCPPIxKsrmf8QYLh +xHYQjrCkbZ0wZdcdeV6yFSDsF218nF/12ZPE7CBOQMfZTCKFNWGL97uIVcmR6K5G +ojTkOvaUnwQtSFhNuzyr23I= +-----END PRIVATE KEY----- diff --git a/arrow-flight/examples/data/client1.pem b/arrow-flight/examples/data/client1.pem new file mode 100644 index 000000000000..bb3b82c40c5a --- /dev/null +++ b/arrow-flight/examples/data/client1.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIQYbE9d1Rft5h4ku7FSAvWdzANBgkqhkiG9w0BAQsFADAn +MSUwIwYDVQQDExxUb25pYyBFeGFtcGxlIENsaWVudCBSb290IENBMB4XDTE5MTAx +NDEyMzkzNloXDTI0MTAxMjEyMzkzNlowEjEQMA4GA1UEAxMHY2xpZW50MTCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKKJaubOkQ2wj5wLPhoGkb4+XxY6 +ysHJ9nH9mcl0Jkc7cxNgTqwIuFiyx/TKAAQQbG+aR3hzGZ8J+B/G5Py74PtjnNBK +jKaRtZ1Tr/2Ns1jo6+1z5r3ojUsZc4pE5DDopHc+6oulSoQ4mIGXTeor4cxTSDno +bWffU53seQmwNG+14+rLcjcuCl1xGmzbeyTjq0LNnOHJ+ohDtB8nmpnH3e0D6l/L +5e1v1chRMhYoinX+7J+HwWnwcDRz3MkJJnMk5faPcYfVNiqRWA96AjaWGZ1eE/lD +y92HDLRQlovbxElcDpQCDS09l5Yl7qm7H0mIDzWwjozL4quvDwPKJuQZNNsCAwEA +AaNGMEQwEwYDVR0lBAwwCgYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAfBgNVHSME +GDAWgBQV1YOR+Jpl1fbujvWLSBEoRvsDhTANBgkqhkiG9w0BAQsFAAOCAQEAfTPu +KeHXmyVTSCUrYQ1X5Mu7VzfZlRbhoytHOw7bYGgwaFwQj+ZhlPt8nFC22/bEk4IV +AoCOli0WyPIB7Lx52dZ+v9JmYOK6ca2Aa/Dkw8Q+M3XA024FQWq3nZ6qANKC32/9 +Nk+xOcb1Qd/11stpTkRf2Oj7F7K4GnlFbY6iMyNW+RFXGKEbL5QAJDTDPIT8vw1x +oYeNPwmC042uEboCZPNXmuctiK9Wt1TAxjZT/cwdIBGGJ+xrW72abfJGs7bUcJfc +O4r9V0xVv+X0iKWTW0fwd9qjNfiEP1tFCcZb2XsNQPe/DlQZ+h98P073tZEsWI/G +KJrFspGX8vOuSdIeqw== +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/data/client_ca.pem b/arrow-flight/examples/data/client_ca.pem new file mode 100644 index 000000000000..aa483b931056 --- /dev/null +++ b/arrow-flight/examples/data/client_ca.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDGzCCAgOgAwIBAgIRAMNWpWRu6Q1txEYUyrkyXKEwDQYJKoZIhvcNAQELBQAw +JzElMCMGA1UEAxMcVG9uaWMgRXhhbXBsZSBDbGllbnQgUm9vdCBDQTAeFw0xOTEw +MTQxMjM5MzZaFw0yOTEwMTExMjM5MzZaMCcxJTAjBgNVBAMTHFRvbmljIEV4YW1w +bGUgQ2xpZW50IFJvb3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQCv8Nj4XJbMI0wWUvLbmCf7IEvJFnomodGnDurh8Y5AGMPJ8cGdZC1yo2Lgah+D +IhXdsd72Wp7MhdntJAyPrMCDBfDrFiuj6YHDgt3OhPQSYl7EWG7QjFK3B2sp1K5D +h16G5zfwUKDj9Jp3xuPGuqNFQHL02nwbhtDilqHvaTfOJKVjsFCoU8Z77mfwXSwn +sPXpPB7oOO4mWfAtcwU11rTMiHFSGFlFhgbHULU/y90DcpfRQEpEiBoiK13gkyoP +zHT9WAg3Pelwb6K7c7kJ7mp4axhbf7MkwFhDQIjbBWqus2Eu3b0mf86ALfDbAaNC +wBi8xbNH2vWaDjiwLDY5uMZDAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwICBDAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBQV1YOR+Jpl1fbujvWLSBEoRvsDhTANBgkq +hkiG9w0BAQsFAAOCAQEAaXmM29TYkFUzZUsV7TSonAK560BjxDmbg0GJSUgLEFUJ +wpKqa9UKOSapG45LEeR2wwAmVWDJomJplkuvTD/KOabAbZKyPEfp+VMCaBUnILQF +Cxv5m7kQ3wmPS/rEL8FD809UGowW9cYqnZzUy5i/r263rx0k3OPjkkZN66Mh6+3H +ibNdaxf7ITO0JVb/Ohq9vLC9qf7ujiB1atMdJwkOWsZrLJXLygpx/D0/UhBT4fFH +OlyVOmuR27qaMbPgOs2l8DznkJY/QUfnET8iOQhFgb0Dt/Os4PYFhSDRIrgl5dJ7 +L/zZVQfZYpdxlBHJlDC1/NzVQl/1MgDnSgPGStZKPQ== +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/data/server.key b/arrow-flight/examples/data/server.key new file mode 100644 index 000000000000..80984ef9000d --- /dev/null +++ b/arrow-flight/examples/data/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDyptbMyYWztgta +t1MXLMzIkaQdeeVbs1Y/qCpAdwZe/Y5ZpbzjGIjCxbB6vNRSnEbYKpytKHPzYfM7 +8d8K8bPvpnqXIiTXFT0JQlw1OHLC1fr4e598GJumAmpMYFrtqv0fbmUFTuQGbHxe +OH2vji0bvr3NKZubMfkEZP3X4sNXXoXIuW2LaS8OMGKoJaeCBvdbszEiSGj/v9Bj +pM0yLTH89NNMX1T+FtTKnuXag5g7pr6lzJj83+MzAGy4nOjseSuUimuiyG90/C5t +A5wC0Qh5RbDnkFYhC44Kxof/i6+jnfateIPNiIIwQV+2f6G/aK1hgjekT10m/eoR +YDTf+e5ZAgMBAAECggEACODt7yRYjhDVLYaTtb9f5t7dYG67Y7WWLFIc6arxQryI +XuNfm/ej2WyeXn9WTYeGWBaHERbv1zH4UnMxNBdP/C7dQXZwXqZaS2JwOUpNeK+X +tUvgtAu6dkKUXSMRcKzXAjVp4N3YHhwOGOx8PNY49FDwZPdmyDD16aFAYIvdle6/ +PSMrj38rB1sbQQdmRob2FjJBSDZ44nsr+/nilrcOFNfNnWv7tQIWYVXNcLfdK/WJ +ZCDFhA8lr/Yon6MEq6ApTj2ZYRRGXPd6UeASJkmTZEUIUbeDcje/MO8cHkREpuRH +wm3pCjR7OdO4vc+/d/QmEvu5ns6wbTauelYnL616YQKBgQD414gJtpCHauNEUlFB +v/R3DzPI5NGp9PAqovOD8nCbI49Mw61gP/ExTIPKiR5uUX/5EL04uspaNkuohXk+ +ys0G5At0NfV7W39lzhvALEaSfleybvYxppbBrc20/q8Gvi/i30NY+1LM3RdtMiEw +hKHjU0SnFhJq0InFg3AO/iCeTQKBgQD5obkbzpOidSsa55aNsUlO2qjiUY9leq9b +irAohIZ8YnuuixYvkOeSeSz1eIrA4tECeAFSgTZxYe1Iz+USru2Xg/0xNte11dJD +rBoH/yMn2gDvBK7xQ6uFMPTeYtKG0vfvpXZYSWZzGntyrHTwFk6UV+xdrt9MBdd1 +XdSn7bwOPQKBgC9VQAko8uDvUf+C8PXiv2uONrl13PPJJY3WpR9qFEVOREnDxszS +HNzVwxPZdTJiykbkCjoqPadfQJDzopZxGQLAifU29lTamKcSx3CMe3gOFDxaovXa +zD5XAxP0hfJwZsdu1G6uj5dsTrJ0oJ+L+wc0pZBqwGIU/L/XOo9/g1DZAoGAUebL +kuH98ik7EUK2VJq8EJERI9/ailLsQb6I+WIxtZGiPqwHhWencpkrNQZtj8dbB9JT +rLwUHrMgZOlAoRafgTyez4zMzS3wJJ/Mkp8U67hM4h7JPwMSvUpIrMYDiJSjIA9L +er/qSw1/Pypx22uWMHmAZWRAgvLPtAQrB0Wqk4kCgYEAr2H1PvfbwZwkSvlMt5o8 +WLnBbxcM3AKglLRbkShxxgiZYdEP71/uOtRMiL26du5XX8evItITN0DsvmXL/kcd +h29LK7LM5uLw7efz0Qxs03G6kEyIHVkacowHi5I5Ul1qI61SoV3yMB1TjIU+bXZt +0ZjC07totO0fqPOLQxonjQg= +-----END PRIVATE KEY----- diff --git a/arrow-flight/examples/data/server.pem b/arrow-flight/examples/data/server.pem new file mode 100644 index 000000000000..4cc97bcf4b6d --- /dev/null +++ b/arrow-flight/examples/data/server.pem @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEmDCCAwCgAwIBAgIQVEJFCgU/CZk9JEwTucWPpzANBgkqhkiG9w0BAQsFADCB +hTEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMS0wKwYDVQQLDCRsdWNp +b0BMdWNpb3MtV29yay1NQlAgKEx1Y2lvIEZyYW5jbykxNDAyBgNVBAMMK21rY2Vy +dCBsdWNpb0BMdWNpb3MtV29yay1NQlAgKEx1Y2lvIEZyYW5jbykwHhcNMTkwNjAx +MDAwMDAwWhcNMjkwOTI5MjMzNTM0WjBYMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxv +cG1lbnQgY2VydGlmaWNhdGUxLTArBgNVBAsMJGx1Y2lvQEx1Y2lvcy1Xb3JrLU1C +UCAoTHVjaW8gRnJhbmNvKTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +APKm1szJhbO2C1q3UxcszMiRpB155VuzVj+oKkB3Bl79jlmlvOMYiMLFsHq81FKc +RtgqnK0oc/Nh8zvx3wrxs++mepciJNcVPQlCXDU4csLV+vh7n3wYm6YCakxgWu2q +/R9uZQVO5AZsfF44fa+OLRu+vc0pm5sx+QRk/dfiw1dehci5bYtpLw4wYqglp4IG +91uzMSJIaP+/0GOkzTItMfz000xfVP4W1Mqe5dqDmDumvqXMmPzf4zMAbLic6Ox5 +K5SKa6LIb3T8Lm0DnALRCHlFsOeQViELjgrGh/+Lr6Od9q14g82IgjBBX7Z/ob9o +rWGCN6RPXSb96hFgNN/57lkCAwEAAaOBrzCBrDAOBgNVHQ8BAf8EBAMCBaAwEwYD +VR0lBAwwCgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAfBgNVHSMEGDAWgBQdvlE4 +Bdcsjc9oaxjDCRu5FiuZkzBWBgNVHREETzBNggtleGFtcGxlLmNvbYINKi5leGFt +cGxlLmNvbYIMZXhhbXBsZS50ZXN0gglsb2NhbGhvc3SHBH8AAAGHEAAAAAAAAAAA +AAAAAAAAAAEwDQYJKoZIhvcNAQELBQADggGBAKb2TJ8l+e1eraNwZWizLw5fccAf +y59J1JAWdLxZyAI/bkiTlVO3DQoPZpw7XwLhefCvILkwKAL4TtIGGVC9yTb5Q5eg +rqGO3FC0yg1fn65Kf1VpVxxUVyoiM5PQ4pFJb4AicAv88rCOLD9FFuE0PKOKU/dm +Tw0WgPStoh9wsJ1RXUuTJYZs1nd1kMBlfv9NbLilnL+cR2sLktS54X5XagsBYVlf +oapRb0JtABOoQhX3U8QMq8UF8yzceRHNTN9yfLOUrW26s9nKtlWVniNhw1uPxZw9 +RHM7w9/4+a9LXtEDYg4IP/1mm0ywBoUqy1O6hA73uId+Yi/kFBks/GyYaGjKgYcO +23B75tkPGYEdGuGZYLzZNHbXg4V0UxFQG3KA1pUiSnD3bN2Rxs+CMpzORnOeK3xi +EooKgAPYsehItoQOMPpccI2xHdSAMWtwUgOKrefUQujkx2Op+KFlspF0+WJ6AZEe +2D4hyWaEZsvvILXapwqHDCuN3/jSUlTIqUoE1w== +-----END CERTIFICATE----- diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index aa0d407113d7..bd94d3c499ca 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -15,31 +15,129 @@ // specific language governing permissions and limitations // under the License. -use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; -use arrow_flight::{Action, FlightData, HandshakeRequest, HandshakeResponse, Ticket}; -use futures::Stream; +use arrow_flight::sql::server::PeekableFlightDataStream; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use futures::{stream, Stream, TryStreamExt}; +use once_cell::sync::Lazy; +use prost::Message; +use std::collections::HashSet; use std::pin::Pin; +use std::sync::Arc; use tonic::transport::Server; +use tonic::transport::{Certificate, Identity, ServerTlsConfig}; use tonic::{Request, Response, Status, Streaming}; +use arrow_array::builder::StringBuilder; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::sql::metadata::{ + SqlInfoData, SqlInfoDataBuilder, XdbcTypeInfo, XdbcTypeInfoData, XdbcTypeInfoDataBuilder, +}; +use arrow_flight::sql::{ + server::FlightSqlService, ActionBeginSavepointRequest, ActionBeginSavepointResult, + ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionCancelQueryRequest, + ActionCancelQueryResult, ActionClosePreparedStatementRequest, + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, + ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest, + ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference, + CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, + CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, + CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, Searchable, + SqlInfo, TicketStatementQuery, XdbcDataType, +}; +use arrow_flight::utils::batches_to_flight_data; use arrow_flight::{ - flight_service_server::FlightService, - flight_service_server::FlightServiceServer, - sql::{ - server::FlightSqlService, ActionClosePreparedStatementRequest, - ActionCreatePreparedStatementRequest, CommandGetCatalogs, - CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, - CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery, - CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementUpdate, - TicketStatementQuery, - }, - FlightDescriptor, FlightInfo, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, + IpcMessage, Location, SchemaAsIpc, Ticket, }; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_schema::{ArrowError, DataType, Field, Schema}; + +macro_rules! status { + ($desc:expr, $err:expr) => { + Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!())) + }; +} + +const FAKE_TOKEN: &str = "uuid_token"; +const FAKE_HANDLE: &str = "uuid_handle"; +const FAKE_UPDATE_RESULT: i64 = 1; + +static INSTANCE_SQL_DATA: Lazy = Lazy::new(|| { + let mut builder = SqlInfoDataBuilder::new(); + // Server information + builder.append(SqlInfo::FlightSqlServerName, "Example Flight SQL Server"); + builder.append(SqlInfo::FlightSqlServerVersion, "1"); + // 1.3 comes from https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/Schema.fbs#L24 + builder.append(SqlInfo::FlightSqlServerArrowVersion, "1.3"); + builder.build().unwrap() +}); + +static INSTANCE_XBDC_DATA: Lazy = Lazy::new(|| { + let mut builder = XdbcTypeInfoDataBuilder::new(); + builder.append(XdbcTypeInfo { + type_name: "INTEGER".into(), + data_type: XdbcDataType::XdbcInteger, + column_size: Some(32), + literal_prefix: None, + literal_suffix: None, + create_params: None, + nullable: Nullable::NullabilityNullable, + case_sensitive: false, + searchable: Searchable::Full, + unsigned_attribute: Some(false), + fixed_prec_scale: false, + auto_increment: Some(false), + local_type_name: Some("INTEGER".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcInteger, + datetime_subcode: None, + num_prec_radix: Some(2), + interval_precision: None, + }); + builder.build().unwrap() +}); + +static TABLES: Lazy> = Lazy::new(|| vec!["flight_sql.example.table"]); #[derive(Clone)] pub struct FlightSqlServiceImpl {} +impl FlightSqlServiceImpl { + fn check_token(&self, req: &Request) -> Result<(), Status> { + let metadata = req.metadata(); + let auth = metadata.get("authorization").ok_or_else(|| { + Status::internal(format!("No authorization header! metadata = {metadata:?}")) + })?; + let str = auth + .to_str() + .map_err(|e| Status::internal(format!("Error parsing header: {e}")))?; + let authorization = str.to_string(); + let bearer = "Bearer "; + if !authorization.starts_with(bearer) { + Err(Status::internal("Invalid auth header!"))?; + } + let token = authorization[bearer.len()..].to_string(); + if token == FAKE_TOKEN { + Ok(()) + } else { + Err(Status::unauthenticated("invalid token ")) + } + } + + fn fake_result() -> Result { + let schema = Schema::new(vec![Field::new("salutation", DataType::Utf8, false)]); + let mut builder = StringBuilder::new(); + builder.append_value("Hello, FlightSQL!"); + let cols = vec![Arc::new(builder.finish()) as ArrayRef]; + RecordBatch::try_new(Arc::new(schema), cols) + } +} + #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; @@ -55,41 +153,59 @@ impl FlightSqlService for FlightSqlServiceImpl { let authorization = request .metadata() .get("authorization") - .ok_or(Status::invalid_argument("authorization field not present"))? + .ok_or_else(|| Status::invalid_argument("authorization field not present"))? .to_str() - .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + .map_err(|e| status!("authorization not parsable", e))?; if !authorization.starts_with(basic) { Err(Status::invalid_argument(format!( - "Auth type not implemented: {}", - authorization + "Auth type not implemented: {authorization}" )))?; } let base64 = &authorization[basic.len()..]; - let bytes = base64::decode(base64) - .map_err(|_| Status::invalid_argument("authorization not parsable"))?; - let str = String::from_utf8(bytes) - .map_err(|_| Status::invalid_argument("authorization not parsable"))?; - let parts: Vec<_> = str.split(":").collect(); - if parts.len() != 2 { - Err(Status::invalid_argument(format!( - "Invalid authorization header" - )))?; - } - let user = parts[0]; - let pass = parts[1]; - if user != "admin" || pass != "password" { + let bytes = BASE64_STANDARD + .decode(base64) + .map_err(|e| status!("authorization not decodable", e))?; + let str = String::from_utf8(bytes).map_err(|e| status!("authorization not parsable", e))?; + let parts: Vec<_> = str.split(':').collect(); + let (user, pass) = match parts.as_slice() { + [user, pass] => (user, pass), + _ => Err(Status::invalid_argument( + "Invalid authorization header".to_string(), + ))?, + }; + if user != &"admin" || pass != &"password" { Err(Status::unauthenticated("Invalid credentials!"))? } + let result = HandshakeResponse { protocol_version: 0, - payload: "random_uuid_token".as_bytes().to_vec(), + payload: FAKE_TOKEN.into(), }; let result = Ok(result); let output = futures::stream::iter(vec![result]); return Ok(Response::new(Box::pin(output))); } - // get_flight_info + async fn do_get_fallback( + &self, + request: Request, + _message: Any, + ) -> Result::DoGetStream>, Status> { + self.check_token(&request)?; + let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; + let schema = batch.schema(); + let batches = vec![batch]; + let flight_data = batches_to_flight_data(schema.as_ref(), batches) + .map_err(|e| status!("Could not convert batches", e))? + .into_iter() + .map(Ok); + + let stream: Pin> + Send>> = + Box::pin(stream::iter(flight_data)); + let resp = Response::new(stream); + Ok(resp) + } + async fn get_flight_info_statement( &self, _query: CommandStatementQuery, @@ -100,44 +216,111 @@ impl FlightSqlService for FlightSqlServiceImpl { )) } - async fn get_flight_info_prepared_statement( + async fn get_flight_info_substrait_plan( &self, - _query: CommandPreparedStatementQuery, + _query: CommandStatementSubstraitPlan, _request: Request, ) -> Result, Status> { Err(Status::unimplemented( - "get_flight_info_prepared_statement not implemented", + "get_flight_info_substrait_plan not implemented", )) } + async fn get_flight_info_prepared_statement( + &self, + cmd: CommandPreparedStatementQuery, + request: Request, + ) -> Result, Status> { + self.check_token(&request)?; + let handle = std::str::from_utf8(&cmd.prepared_statement_handle) + .map_err(|e| status!("Unable to parse handle", e))?; + let batch = Self::fake_result().map_err(|e| status!("Could not fake a result", e))?; + let schema = (*batch.schema()).clone(); + let num_rows = batch.num_rows(); + let num_bytes = batch.get_array_memory_size(); + let loc = Location { + uri: "grpc+tcp://127.0.0.1".to_string(), + }; + let fetch = FetchResults { + handle: handle.to_string(), + }; + let buf = fetch.as_any().encode_to_vec().into(); + let ticket = Ticket { ticket: buf }; + let endpoint = FlightEndpoint { + ticket: Some(ticket), + location: vec![loc], + }; + let info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| status!("Unable to serialize schema", e))? + .with_descriptor(FlightDescriptor::new_cmd(vec![])) + .with_endpoint(endpoint) + .with_total_records(num_rows as i64) + .with_total_bytes(num_bytes as i64) + .with_ordered(false); + + let resp = Response::new(info); + Ok(resp) + } + async fn get_flight_info_catalogs( &self, - _query: CommandGetCatalogs, - _request: Request, + query: CommandGetCatalogs, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_catalogs not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_schemas( &self, - _query: CommandGetDbSchemas, - _request: Request, + query: CommandGetDbSchemas, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_schemas not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_tables( &self, - _query: CommandGetTables, - _request: Request, + query: CommandGetTables, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_tables not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(&query.into_builder().schema()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_table_types( @@ -152,12 +335,20 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn get_flight_info_sql_info( &self, - _query: CommandGetSqlInfo, - _request: Request, + query: CommandGetSqlInfo, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_sql_info not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket::new(query.encode_to_vec()); + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(query.into_builder(&INSTANCE_SQL_DATA).schema().as_ref()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_primary_keys( @@ -200,6 +391,24 @@ impl FlightSqlService for FlightSqlServiceImpl { )) } + async fn get_flight_info_xdbc_type_info( + &self, + query: CommandGetXdbcTypeInfo, + request: Request, + ) -> Result, Status> { + let flight_descriptor = request.into_inner(); + let ticket = Ticket::new(query.encode_to_vec()); + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(query.into_builder(&INSTANCE_XBDC_DATA).schema().as_ref()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) + } + // do_get async fn do_get_statement( &self, @@ -221,26 +430,91 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_get_catalogs( &self, - _query: CommandGetCatalogs, + query: CommandGetCatalogs, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_catalogs not implemented")) + let catalog_names = TABLES + .iter() + .map(|full_name| full_name.split('.').collect::>()[0].to_string()) + .collect::>(); + let mut builder = query.into_builder(); + for catalog_name in catalog_names { + builder.append(catalog_name); + } + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_schemas( &self, - _query: CommandGetDbSchemas, + query: CommandGetDbSchemas, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_schemas not implemented")) + let schemas = TABLES + .iter() + .map(|full_name| { + let parts = full_name.split('.').collect::>(); + (parts[0].to_string(), parts[1].to_string()) + }) + .collect::>(); + + let mut builder = query.into_builder(); + for (catalog_name, schema_name) in schemas { + builder.append(catalog_name, schema_name); + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_tables( &self, - _query: CommandGetTables, + query: CommandGetTables, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_tables not implemented")) + let tables = TABLES + .iter() + .map(|full_name| { + let parts = full_name.split('.').collect::>(); + ( + parts[0].to_string(), + parts[1].to_string(), + parts[2].to_string(), + ) + }) + .collect::>(); + + let dummy_schema = Schema::empty(); + let mut builder = query.into_builder(); + for (catalog_name, schema_name, table_name) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + "TABLE", + &dummy_schema, + ) + .map_err(Status::from)?; + } + + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_table_types( @@ -253,10 +527,17 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_get_sql_info( &self, - _query: CommandGetSqlInfo, + query: CommandGetSqlInfo, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_sql_info not implemented")) + let builder = query.into_builder(&INSTANCE_SQL_DATA); + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_primary_keys( @@ -297,21 +578,45 @@ impl FlightSqlService for FlightSqlServiceImpl { )) } + async fn do_get_xdbc_type_info( + &self, + query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + // create a builder with pre-defined Xdbc data: + let builder = query.into_builder(&INSTANCE_XBDC_DATA); + let schema = builder.schema(); + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) + } + // do_put async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request, + ) -> Result { + Ok(FAKE_UPDATE_RESULT) + } + + async fn do_put_substrait_plan( + &self, + _ticket: CommandStatementSubstraitPlan, + _request: Request, ) -> Result { Err(Status::unimplemented( - "do_put_statement_update not implemented", + "do_put_substrait_plan not implemented", )) } async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, + _request: Request, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented( "do_put_prepared_statement_query not implemented", @@ -321,27 +626,94 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { Err(Status::unimplemented( "do_put_prepared_statement_update not implemented", )) } - // do_action async fn do_action_create_prepared_statement( &self, _query: ActionCreatePreparedStatementRequest, - _request: Request, + request: Request, ) -> Result { - Err(Status::unimplemented("Not yet implemented")) + self.check_token(&request)?; + let schema = Self::fake_result() + .map_err(|e| status!("Error getting result schema", e))? + .schema(); + let message = SchemaAsIpc::new(&schema, &IpcWriteOptions::default()) + .try_into() + .map_err(|e| status!("Unable to serialize schema", e))?; + let IpcMessage(schema_bytes) = message; + let res = ActionCreatePreparedStatementResult { + prepared_statement_handle: FAKE_HANDLE.into(), + dataset_schema: schema_bytes, + parameter_schema: Default::default(), // TODO: parameters + }; + Ok(res) } + async fn do_action_close_prepared_statement( &self, _query: ActionClosePreparedStatementRequest, _request: Request, - ) { - unimplemented!("Not yet implemented") + ) -> Result<(), Status> { + Err(Status::unimplemented( + "Implement do_action_close_prepared_statement", + )) + } + + async fn do_action_create_prepared_substrait_plan( + &self, + _query: ActionCreatePreparedSubstraitPlanRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "Implement do_action_create_prepared_substrait_plan", + )) + } + + async fn do_action_begin_transaction( + &self, + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "Implement do_action_begin_transaction", + )) + } + + async fn do_action_end_transaction( + &self, + _query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented("Implement do_action_end_transaction")) + } + + async fn do_action_begin_savepoint( + &self, + _query: ActionBeginSavepointRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented("Implement do_action_begin_savepoint")) + } + + async fn do_action_end_savepoint( + &self, + _query: ActionEndSavepointRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented("Implement do_action_end_savepoint")) + } + + async fn do_action_cancel_query( + &self, + _query: ActionCancelQueryRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented("Implement do_action_cancel_query")) } async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} @@ -356,7 +728,293 @@ async fn main() -> Result<(), Box> { println!("Listening on {:?}", addr); - Server::builder().add_service(svc).serve(addr).await?; + if std::env::var("USE_TLS").ok().is_some() { + let cert = std::fs::read_to_string("arrow-flight/examples/data/server.pem")?; + let key = std::fs::read_to_string("arrow-flight/examples/data/server.key")?; + let client_ca = std::fs::read_to_string("arrow-flight/examples/data/client_ca.pem")?; + + let tls_config = ServerTlsConfig::new() + .identity(Identity::from_pem(&cert, &key)) + .client_ca_root(Certificate::from_pem(&client_ca)); + + Server::builder() + .tls_config(tls_config)? + .add_service(svc) + .serve(addr) + .await?; + } else { + Server::builder().add_service(svc).serve(addr).await?; + } Ok(()) } + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchResults { + #[prost(string, tag = "1")] + pub handle: ::prost::alloc::string::String, +} + +impl ProstMessageExt for FetchResults { + fn type_url() -> &'static str { + "type.googleapis.com/arrow.flight.protocol.sql.FetchResults" + } + + fn as_any(&self) -> Any { + Any { + type_url: FetchResults::type_url().to_string(), + value: ::prost::Message::encode_to_vec(self).into(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::TryStreamExt; + use std::fs; + use std::future::Future; + use std::net::SocketAddr; + use std::time::Duration; + use tempfile::NamedTempFile; + use tokio::net::{TcpListener, UnixListener, UnixStream}; + use tokio_stream::wrappers::UnixListenerStream; + use tonic::transport::{Channel, ClientTlsConfig}; + + use arrow_cast::pretty::pretty_format_batches; + use arrow_flight::sql::client::FlightSqlServiceClient; + use tonic::transport::server::TcpIncoming; + use tonic::transport::{Certificate, Endpoint}; + use tower::service_fn; + + async fn bind_tcp() -> (TcpIncoming, SocketAddr) { + let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + (incoming, addr) + } + + fn endpoint(uri: String) -> Result { + let endpoint = Endpoint::new(uri) + .map_err(|_| ArrowError::IpcError("Cannot create endpoint".to_string()))? + .connect_timeout(Duration::from_secs(20)) + .timeout(Duration::from_secs(20)) + .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait + .tcp_keepalive(Option::Some(Duration::from_secs(3600))) + .http2_keep_alive_interval(Duration::from_secs(300)) + .keep_alive_timeout(Duration::from_secs(20)) + .keep_alive_while_idle(true); + + Ok(endpoint) + } + + async fn auth_client(client: &mut FlightSqlServiceClient) { + let token = client.handshake("admin", "password").await.unwrap(); + client.set_token(String::from_utf8(token.to_vec()).unwrap()); + } + + async fn test_uds_client(f: F) + where + F: FnOnce(FlightSqlServiceClient) -> C, + C: Future, + { + let file = NamedTempFile::new().unwrap(); + let path = file.into_temp_path().to_str().unwrap().to_string(); + let _ = fs::remove_file(path.clone()); + + let uds = UnixListener::bind(path.clone()).unwrap(); + let stream = UnixListenerStream::new(uds); + + let service = FlightSqlServiceImpl {}; + let serve_future = Server::builder() + .add_service(FlightServiceServer::new(service)) + .serve_with_incoming(stream); + + let request_future = async { + let connector = service_fn(move |_| UnixStream::connect(path.clone())); + let channel = Endpoint::try_from("http://example.com") + .unwrap() + .connect_with_connector(connector) + .await + .unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_http_client(f: F) + where + F: FnOnce(FlightSqlServiceClient) -> C, + C: Future, + { + let (incoming, addr) = bind_tcp().await; + let uri = format!("http://{}:{}", addr.ip(), addr.port()); + + let service = FlightSqlServiceImpl {}; + let serve_future = Server::builder() + .add_service(FlightServiceServer::new(service)) + .serve_with_incoming(incoming); + + let request_future = async { + let endpoint = endpoint(uri).unwrap(); + let channel = endpoint.connect().await.unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_https_client(f: F) + where + F: FnOnce(FlightSqlServiceClient) -> C, + C: Future, + { + let cert = std::fs::read_to_string("examples/data/server.pem").unwrap(); + let key = std::fs::read_to_string("examples/data/server.key").unwrap(); + let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap(); + + let tls_config = ServerTlsConfig::new() + .identity(Identity::from_pem(&cert, &key)) + .client_ca_root(Certificate::from_pem(&client_ca)); + + let (incoming, addr) = bind_tcp().await; + let uri = format!("https://{}:{}", addr.ip(), addr.port()); + + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + + let serve_future = Server::builder() + .tls_config(tls_config) + .unwrap() + .add_service(svc) + .serve_with_incoming(incoming); + + let request_future = async { + let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap(); + let key = std::fs::read_to_string("examples/data/client1.key").unwrap(); + let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap(); + + let tls_config = ClientTlsConfig::new() + .domain_name("localhost") + .ca_certificate(Certificate::from_pem(&server_ca)) + .identity(Identity::from_pem(cert, key)); + + let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap(); + let channel = endpoint.connect().await.unwrap(); + let client = FlightSqlServiceClient::new(channel); + f(client).await + }; + + tokio::select! { + _ = serve_future => panic!("server returned first"), + _ = request_future => println!("Client finished!"), + } + } + + async fn test_all_clients(task: F) + where + F: FnOnce(FlightSqlServiceClient) -> C + Copy, + C: Future, + { + println!("testing uds client"); + test_uds_client(task).await; + println!("======="); + + println!("testing http client"); + test_http_client(task).await; + println!("======="); + + println!("testing https client"); + test_https_client(task).await; + println!("======="); + } + + #[tokio::test] + async fn test_select() { + test_all_clients(|mut client| async move { + auth_client(&mut client).await; + + let mut stmt = client.prepare("select 1;".to_string(), None).await.unwrap(); + + let flight_info = stmt.execute().await.unwrap(); + + let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); + let flight_data = client.do_get(ticket).await.unwrap(); + let batches: Vec<_> = flight_data.try_collect().await.unwrap(); + + let res = pretty_format_batches(batches.as_slice()).unwrap(); + let expected = r#" ++-------------------+ +| salutation | ++-------------------+ +| Hello, FlightSQL! | ++-------------------+"# + .trim() + .to_string(); + assert_eq!(res.to_string(), expected); + }) + .await + } + + #[tokio::test] + async fn test_execute_update() { + test_all_clients(|mut client| async move { + auth_client(&mut client).await; + let res = client + .execute_update("creat table test(a int);".to_string(), None) + .await + .unwrap(); + assert_eq!(res, FAKE_UPDATE_RESULT); + }) + .await + } + + #[tokio::test] + async fn test_auth() { + test_all_clients(|mut client| async move { + // no handshake + assert!(client + .prepare("select 1;".to_string(), None) + .await + .unwrap_err() + .to_string() + .contains("No authorization header")); + + // Invalid credentials + assert!(client + .handshake("admin", "password2") + .await + .unwrap_err() + .to_string() + .contains("Invalid credentials")); + + // forget to set_token + client.handshake("admin", "password").await.unwrap(); + assert!(client + .prepare("select 1;".to_string(), None) + .await + .unwrap_err() + .to_string() + .contains("No authorization header")); + + // Invalid Tokens + client.handshake("admin", "password").await.unwrap(); + client.set_token("wrong token".to_string()); + assert!(client + .prepare("select 1;".to_string(), None) + .await + .unwrap_err() + .to_string() + .contains("invalid token")); + }) + .await + } +} diff --git a/arrow-flight/examples/server.rs b/arrow-flight/examples/server.rs index 75d05378710f..85ac4ca1384c 100644 --- a/arrow-flight/examples/server.rs +++ b/arrow-flight/examples/server.rs @@ -15,16 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::pin::Pin; - -use futures::Stream; +use futures::stream::BoxStream; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, + HandshakeResponse, PutResult, SchemaResult, Ticket, }; #[derive(Clone)] @@ -32,89 +30,75 @@ pub struct FlightServiceImpl {} #[tonic::async_trait] impl FlightService for FlightServiceImpl { - type HandshakeStream = Pin< - Box> + Send + Sync + 'static>, - >; - type ListFlightsStream = - Pin> + Send + Sync + 'static>>; - type DoGetStream = - Pin> + Send + Sync + 'static>>; - type DoPutStream = - Pin> + Send + Sync + 'static>>; - type DoActionStream = Pin< - Box< - dyn Stream> - + Send - + Sync - + 'static, - >, - >; - type ListActionsStream = - Pin> + Send + Sync + 'static>>; - type DoExchangeStream = - Pin> + Send + Sync + 'static>>; + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; async fn handshake( &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement handshake")) } async fn list_flights( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement list_flights")) } async fn get_flight_info( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement get_flight_info")) } async fn get_schema( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement get_schema")) } async fn do_get( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_get")) } async fn do_put( &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_put")) } async fn do_action( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_action")) } async fn list_actions( &self, _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement list_actions")) } async fn do_exchange( &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("Implement do_exchange")) } } diff --git a/arrow-flight/gen/Cargo.toml b/arrow-flight/gen/Cargo.toml new file mode 100644 index 000000000000..4f7a032f51e5 --- /dev/null +++ b/arrow-flight/gen/Cargo.toml @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "gen" +description = "Code generation for arrow-flight" +version = "0.1.0" +edition = { workspace = true } +rust-version = { workspace = true } +authors = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +# Pin specific version of the tonic-build dependencies to avoid auto-generated +# (and checked in) arrow.flight.protocol.rs from changing +proc-macro2 = { version = "=1.0.70", default-features = false } +prost-build = { version = "=0.12.3", default-features = false } +tonic-build = { version = "=0.10.2", default-features = false, features = ["transport", "prost"] } diff --git a/arrow-flight/gen/src/main.rs b/arrow-flight/gen/src/main.rs new file mode 100644 index 000000000000..a3541c63b173 --- /dev/null +++ b/arrow-flight/gen/src/main.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + fs::OpenOptions, + io::{Read, Write}, + path::Path, +}; + +fn main() -> Result<(), Box> { + let proto_dir = Path::new("../format"); + let proto_path = Path::new("../format/Flight.proto"); + + tonic_build::configure() + // protoc in unbuntu builder needs this option + .protoc_arg("--experimental_allow_proto3_optional") + .out_dir("src") + .compile_with_config(prost_config(), &[proto_path], &[proto_dir])?; + + // read file contents to string + let mut file = OpenOptions::new() + .read(true) + .open("src/arrow.flight.protocol.rs")?; + let mut buffer = String::new(); + file.read_to_string(&mut buffer)?; + // append warning that file was auto-generate + let mut file = OpenOptions::new() + .write(true) + .truncate(true) + .open("src/arrow.flight.protocol.rs")?; + file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; + file.write_all(buffer.as_bytes())?; + + let proto_dir = Path::new("../format"); + let proto_path = Path::new("../format/FlightSql.proto"); + + tonic_build::configure() + // protoc in ubuntu builder needs this option + .protoc_arg("--experimental_allow_proto3_optional") + .out_dir("src/sql") + .compile_with_config(prost_config(), &[proto_path], &[proto_dir])?; + + // read file contents to string + let mut file = OpenOptions::new() + .read(true) + .open("src/sql/arrow.flight.protocol.sql.rs")?; + let mut buffer = String::new(); + file.read_to_string(&mut buffer)?; + // append warning that file was auto-generate + let mut file = OpenOptions::new() + .write(true) + .truncate(true) + .open("src/sql/arrow.flight.protocol.sql.rs")?; + file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; + file.write_all(buffer.as_bytes())?; + + // Prost currently generates an empty file, this was fixed but then reverted + // https://github.com/tokio-rs/prost/pull/639 + let google_protobuf_rs = Path::new("src/sql/google.protobuf.rs"); + if google_protobuf_rs.exists() && google_protobuf_rs.metadata().unwrap().len() == 0 { + std::fs::remove_file(google_protobuf_rs).unwrap(); + } + + // As the proto file is checked in, the build should not fail if the file is not found + Ok(()) +} + +fn prost_config() -> prost_build::Config { + let mut config = prost_build::Config::new(); + config.bytes([".arrow"]); + config +} diff --git a/arrow-flight/regen.sh b/arrow-flight/regen.sh new file mode 100755 index 000000000000..d83f9d580e8d --- /dev/null +++ b/arrow-flight/regen.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR && cargo run --manifest-path gen/Cargo.toml diff --git a/arrow-flight/src/arrow.flight.protocol.rs b/arrow-flight/src/arrow.flight.protocol.rs index 2b085d6d1f6b..e76013bd7c5f 100644 --- a/arrow-flight/src/arrow.flight.protocol.rs +++ b/arrow-flight/src/arrow.flight.protocol.rs @@ -1,117 +1,139 @@ // This file was automatically generated through the build.rs script, and should not be edited. /// -/// The request that a client provides to a server on handshake. +/// The request that a client provides to a server on handshake. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct HandshakeRequest { /// - /// A defined protocol version - #[prost(uint64, tag="1")] + /// A defined protocol version + #[prost(uint64, tag = "1")] pub protocol_version: u64, /// - /// Arbitrary auth/handshake info. - #[prost(bytes="vec", tag="2")] - pub payload: ::prost::alloc::vec::Vec, + /// Arbitrary auth/handshake info. + #[prost(bytes = "bytes", tag = "2")] + pub payload: ::prost::bytes::Bytes, } +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct HandshakeResponse { /// - /// A defined protocol version - #[prost(uint64, tag="1")] + /// A defined protocol version + #[prost(uint64, tag = "1")] pub protocol_version: u64, /// - /// Arbitrary auth/handshake info. - #[prost(bytes="vec", tag="2")] - pub payload: ::prost::alloc::vec::Vec, + /// Arbitrary auth/handshake info. + #[prost(bytes = "bytes", tag = "2")] + pub payload: ::prost::bytes::Bytes, } /// -/// A message for doing simple auth. +/// A message for doing simple auth. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct BasicAuth { - #[prost(string, tag="2")] + #[prost(string, tag = "2")] pub username: ::prost::alloc::string::String, - #[prost(string, tag="3")] + #[prost(string, tag = "3")] pub password: ::prost::alloc::string::String, } +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct Empty { -} +pub struct Empty {} /// -/// Describes an available action, including both the name used for execution -/// along with a short description of the purpose of the action. +/// Describes an available action, including both the name used for execution +/// along with a short description of the purpose of the action. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionType { - #[prost(string, tag="1")] + #[prost(string, tag = "1")] pub r#type: ::prost::alloc::string::String, - #[prost(string, tag="2")] + #[prost(string, tag = "2")] pub description: ::prost::alloc::string::String, } /// -/// A service specific expression that can be used to return a limited set -/// of available Arrow Flight streams. +/// A service specific expression that can be used to return a limited set +/// of available Arrow Flight streams. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Criteria { - #[prost(bytes="vec", tag="1")] - pub expression: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub expression: ::prost::bytes::Bytes, } /// -/// An opaque action specific for the service. +/// An opaque action specific for the service. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Action { - #[prost(string, tag="1")] + #[prost(string, tag = "1")] pub r#type: ::prost::alloc::string::String, - #[prost(bytes="vec", tag="2")] - pub body: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "2")] + pub body: ::prost::bytes::Bytes, } /// -/// An opaque result returned after executing an action. +/// An opaque result returned after executing an action. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Result { - #[prost(bytes="vec", tag="1")] - pub body: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub body: ::prost::bytes::Bytes, } /// -/// Wrap the result of a getSchema call +/// Wrap the result of a getSchema call +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SchemaResult { - /// schema of the dataset as described in Schema.fbs::Schema. - #[prost(bytes="vec", tag="1")] - pub schema: ::prost::alloc::vec::Vec, + /// The schema of the dataset in its IPC form: + /// 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + /// 4 bytes - the byte length of the payload + /// a flatbuffer Message whose header is the Schema + #[prost(bytes = "bytes", tag = "1")] + pub schema: ::prost::bytes::Bytes, } /// -/// The name or tag for a Flight. May be used as a way to retrieve or generate -/// a flight or be used to expose a set of previously defined flights. +/// The name or tag for a Flight. May be used as a way to retrieve or generate +/// a flight or be used to expose a set of previously defined flights. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightDescriptor { - #[prost(enumeration="flight_descriptor::DescriptorType", tag="1")] + #[prost(enumeration = "flight_descriptor::DescriptorType", tag = "1")] pub r#type: i32, /// - /// Opaque value used to express a command. Should only be defined when - /// type = CMD. - #[prost(bytes="vec", tag="2")] - pub cmd: ::prost::alloc::vec::Vec, + /// Opaque value used to express a command. Should only be defined when + /// type = CMD. + #[prost(bytes = "bytes", tag = "2")] + pub cmd: ::prost::bytes::Bytes, /// - /// List of strings identifying a particular dataset. Should only be defined - /// when type = PATH. - #[prost(string, repeated, tag="3")] + /// List of strings identifying a particular dataset. Should only be defined + /// when type = PATH. + #[prost(string, repeated, tag = "3")] pub path: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } /// Nested message and enum types in `FlightDescriptor`. pub mod flight_descriptor { /// - /// Describes what type of descriptor is defined. - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + /// Describes what type of descriptor is defined. + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] #[repr(i32)] pub enum DescriptorType { - /// Protobuf pattern, not used. + /// Protobuf pattern, not used. Unknown = 0, /// - /// A named path that identifies a dataset. A path is composed of a string - /// or list of strings describing a particular dataset. This is conceptually + /// A named path that identifies a dataset. A path is composed of a string + /// or list of strings describing a particular dataset. This is conceptually /// similar to a path inside a filesystem. Path = 1, /// - /// An opaque command to generate a dataset. + /// An opaque command to generate a dataset. Cmd = 2, } impl DescriptorType { @@ -126,93 +148,149 @@ pub mod flight_descriptor { DescriptorType::Cmd => "CMD", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNKNOWN" => Some(Self::Unknown), + "PATH" => Some(Self::Path), + "CMD" => Some(Self::Cmd), + _ => None, + } + } } } /// -/// The access coordinates for retrieval of a dataset. With a FlightInfo, a -/// consumer is able to determine how to retrieve a dataset. +/// The access coordinates for retrieval of a dataset. With a FlightInfo, a +/// consumer is able to determine how to retrieve a dataset. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightInfo { - /// schema of the dataset as described in Schema.fbs::Schema. - #[prost(bytes="vec", tag="1")] - pub schema: ::prost::alloc::vec::Vec, + /// The schema of the dataset in its IPC form: + /// 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + /// 4 bytes - the byte length of the payload + /// a flatbuffer Message whose header is the Schema + #[prost(bytes = "bytes", tag = "1")] + pub schema: ::prost::bytes::Bytes, /// - /// The descriptor associated with this info. - #[prost(message, optional, tag="2")] + /// The descriptor associated with this info. + #[prost(message, optional, tag = "2")] pub flight_descriptor: ::core::option::Option, /// - /// A list of endpoints associated with the flight. To consume the whole - /// flight, all endpoints must be consumed. - #[prost(message, repeated, tag="3")] + /// A list of endpoints associated with the flight. To consume the + /// whole flight, all endpoints (and hence all Tickets) must be + /// consumed. Endpoints can be consumed in any order. + /// + /// In other words, an application can use multiple endpoints to + /// represent partitioned data. + /// + /// If the returned data has an ordering, an application can use + /// "FlightInfo.ordered = true" or should return the all data in a + /// single endpoint. Otherwise, there is no ordering defined on + /// endpoints or the data within. + /// + /// A client can read ordered data by reading data from returned + /// endpoints, in order, from front to back. + /// + /// Note that a client may ignore "FlightInfo.ordered = true". If an + /// ordering is important for an application, an application must + /// choose one of them: + /// + /// * An application requires that all clients must read data in + /// returned endpoints order. + /// * An application must return the all data in a single endpoint. + #[prost(message, repeated, tag = "3")] pub endpoint: ::prost::alloc::vec::Vec, - /// Set these to -1 if unknown. - #[prost(int64, tag="4")] + /// Set these to -1 if unknown. + #[prost(int64, tag = "4")] pub total_records: i64, - #[prost(int64, tag="5")] + #[prost(int64, tag = "5")] pub total_bytes: i64, + /// + /// FlightEndpoints are in the same order as the data. + #[prost(bool, tag = "6")] + pub ordered: bool, } /// -/// A particular stream or split associated with a flight. +/// A particular stream or split associated with a flight. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightEndpoint { /// - /// Token used to retrieve this stream. - #[prost(message, optional, tag="1")] + /// Token used to retrieve this stream. + #[prost(message, optional, tag = "1")] pub ticket: ::core::option::Option, /// - /// A list of URIs where this ticket can be redeemed. If the list is - /// empty, the expectation is that the ticket can only be redeemed on the - /// current service where the ticket was generated. - #[prost(message, repeated, tag="2")] + /// A list of URIs where this ticket can be redeemed via DoGet(). + /// + /// If the list is empty, the expectation is that the ticket can only + /// be redeemed on the current service where the ticket was + /// generated. + /// + /// If the list is not empty, the expectation is that the ticket can + /// be redeemed at any of the locations, and that the data returned + /// will be equivalent. In this case, the ticket may only be redeemed + /// at one of the given locations, and not (necessarily) on the + /// current service. + /// + /// In other words, an application can use multiple locations to + /// represent redundant and/or load balanced services. + #[prost(message, repeated, tag = "2")] pub location: ::prost::alloc::vec::Vec, } /// -/// A location where a Flight service will accept retrieval of a particular -/// stream given a ticket. +/// A location where a Flight service will accept retrieval of a particular +/// stream given a ticket. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Location { - #[prost(string, tag="1")] + #[prost(string, tag = "1")] pub uri: ::prost::alloc::string::String, } /// -/// An opaque identifier that the service can use to retrieve a particular -/// portion of a stream. +/// An opaque identifier that the service can use to retrieve a particular +/// portion of a stream. +/// +/// Tickets are meant to be single use. It is an error/application-defined +/// behavior to reuse a ticket. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Ticket { - #[prost(bytes="vec", tag="1")] - pub ticket: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub ticket: ::prost::bytes::Bytes, } /// -/// A batch of Arrow data as part of a stream of batches. +/// A batch of Arrow data as part of a stream of batches. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightData { /// - /// The descriptor of the data. This is only relevant when a client is - /// starting a new DoPut stream. - #[prost(message, optional, tag="1")] + /// The descriptor of the data. This is only relevant when a client is + /// starting a new DoPut stream. + #[prost(message, optional, tag = "1")] pub flight_descriptor: ::core::option::Option, /// - /// Header for message data as described in Message.fbs::Message. - #[prost(bytes="vec", tag="2")] - pub data_header: ::prost::alloc::vec::Vec, + /// Header for message data as described in Message.fbs::Message. + #[prost(bytes = "bytes", tag = "2")] + pub data_header: ::prost::bytes::Bytes, /// - /// Application-defined metadata. - #[prost(bytes="vec", tag="3")] - pub app_metadata: ::prost::alloc::vec::Vec, + /// Application-defined metadata. + #[prost(bytes = "bytes", tag = "3")] + pub app_metadata: ::prost::bytes::Bytes, /// - /// The actual batch of Arrow data. Preferably handled with minimal-copies - /// coming last in the definition to help with sidecar patterns (it is - /// expected that some implementations will fetch this field off the wire - /// with specialized code to avoid extra memory copies). - #[prost(bytes="vec", tag="1000")] - pub data_body: ::prost::alloc::vec::Vec, + /// The actual batch of Arrow data. Preferably handled with minimal-copies + /// coming last in the definition to help with sidecar patterns (it is + /// expected that some implementations will fetch this field off the wire + /// with specialized code to avoid extra memory copies). + #[prost(bytes = "bytes", tag = "1000")] + pub data_body: ::prost::bytes::Bytes, } /// * -/// The response message associated with the submission of a DoPut. +/// The response message associated with the submission of a DoPut. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PutResult { - #[prost(bytes="vec", tag="1")] - pub app_metadata: ::prost::alloc::vec::Vec, + #[prost(bytes = "bytes", tag = "1")] + pub app_metadata: ::prost::bytes::Bytes, } /// Generated client implementations. pub mod flight_service_client { @@ -232,7 +310,7 @@ pub mod flight_service_client { /// Attempt to create a new client by connecting to a given endpoint. pub async fn connect(dst: D) -> Result where - D: std::convert::TryInto, + D: TryInto, D::Error: Into, { let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; @@ -288,6 +366,22 @@ pub mod flight_service_client { self.inner = self.inner.accept_compressed(encoding); self } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } /// /// Handshake between client and server. Depending on the server, the /// handshake may be required to determine the token that should be used for @@ -296,7 +390,7 @@ pub mod flight_service_client { pub async fn handshake( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -313,7 +407,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/Handshake", ); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "Handshake"), + ); + self.inner.streaming(req, path, codec).await } /// /// Get a list of available streams given a particular criteria. Most flight @@ -325,7 +424,7 @@ pub mod flight_service_client { pub async fn list_flights( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -342,7 +441,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListFlights", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "ListFlights"), + ); + self.inner.server_streaming(req, path, codec).await } /// /// For a given FlightDescriptor, get information about how the flight can be @@ -358,7 +462,7 @@ pub mod flight_service_client { pub async fn get_flight_info( &mut self, request: impl tonic::IntoRequest, - ) -> Result, tonic::Status> { + ) -> std::result::Result, tonic::Status> { self.inner .ready() .await @@ -372,7 +476,15 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetFlightInfo", ); - self.inner.unary(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "arrow.flight.protocol.FlightService", + "GetFlightInfo", + ), + ); + self.inner.unary(req, path, codec).await } /// /// For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema @@ -382,7 +494,7 @@ pub mod flight_service_client { pub async fn get_schema( &mut self, request: impl tonic::IntoRequest, - ) -> Result, tonic::Status> { + ) -> std::result::Result, tonic::Status> { self.inner .ready() .await @@ -396,7 +508,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetSchema", ); - self.inner.unary(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "GetSchema"), + ); + self.inner.unary(req, path, codec).await } /// /// Retrieve a single stream associated with a particular descriptor @@ -406,7 +523,7 @@ pub mod flight_service_client { pub async fn do_get( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -423,7 +540,10 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoGet", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("arrow.flight.protocol.FlightService", "DoGet")); + self.inner.server_streaming(req, path, codec).await } /// /// Push a stream to the flight service associated with a particular @@ -435,7 +555,7 @@ pub mod flight_service_client { pub async fn do_put( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -452,7 +572,10 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoPut", ); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert(GrpcMethod::new("arrow.flight.protocol.FlightService", "DoPut")); + self.inner.streaming(req, path, codec).await } /// /// Open a bidirectional data channel for a given descriptor. This @@ -463,7 +586,7 @@ pub mod flight_service_client { pub async fn do_exchange( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -480,7 +603,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoExchange", ); - self.inner.streaming(request.into_streaming_request(), path, codec).await + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "DoExchange"), + ); + self.inner.streaming(req, path, codec).await } /// /// Flight services can support an arbitrary number of simple actions in @@ -492,7 +620,7 @@ pub mod flight_service_client { pub async fn do_action( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -509,7 +637,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoAction", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "DoAction"), + ); + self.inner.server_streaming(req, path, codec).await } /// /// A flight service exposes all of the available action types that it has @@ -518,7 +651,7 @@ pub mod flight_service_client { pub async fn list_actions( &mut self, request: impl tonic::IntoRequest, - ) -> Result< + ) -> std::result::Result< tonic::Response>, tonic::Status, > { @@ -535,7 +668,12 @@ pub mod flight_service_client { let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListActions", ); - self.inner.server_streaming(request.into_request(), path, codec).await + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("arrow.flight.protocol.FlightService", "ListActions"), + ); + self.inner.server_streaming(req, path, codec).await } } } @@ -543,12 +681,12 @@ pub mod flight_service_client { pub mod flight_service_server { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] use tonic::codegen::*; - ///Generated trait containing gRPC methods that should be implemented for use with FlightServiceServer. + /// Generated trait containing gRPC methods that should be implemented for use with FlightServiceServer. #[async_trait] pub trait FlightService: Send + Sync + 'static { - ///Server streaming response type for the Handshake method. - type HandshakeStream: futures_core::Stream< - Item = Result, + /// Server streaming response type for the Handshake method. + type HandshakeStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > + Send + 'static; @@ -560,10 +698,10 @@ pub mod flight_service_server { async fn handshake( &self, request: tonic::Request>, - ) -> Result, tonic::Status>; - ///Server streaming response type for the ListFlights method. - type ListFlightsStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the ListFlights method. + type ListFlightsStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > + Send + 'static; @@ -577,7 +715,10 @@ pub mod flight_service_server { async fn list_flights( &self, request: tonic::Request, - ) -> Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; /// /// For a given FlightDescriptor, get information about how the flight can be /// consumed. This is a useful interface if the consumer of the interface @@ -592,7 +733,7 @@ pub mod flight_service_server { async fn get_flight_info( &self, request: tonic::Request, - ) -> Result, tonic::Status>; + ) -> std::result::Result, tonic::Status>; /// /// For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema /// This is used when a consumer needs the Schema of flight stream. Similar to @@ -601,10 +742,10 @@ pub mod flight_service_server { async fn get_schema( &self, request: tonic::Request, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoGet method. - type DoGetStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoGet method. + type DoGetStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > + Send + 'static; @@ -616,10 +757,10 @@ pub mod flight_service_server { async fn do_get( &self, request: tonic::Request, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoPut method. - type DoPutStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoPut method. + type DoPutStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > + Send + 'static; @@ -633,10 +774,10 @@ pub mod flight_service_server { async fn do_put( &self, request: tonic::Request>, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoExchange method. - type DoExchangeStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoExchange method. + type DoExchangeStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > + Send + 'static; @@ -649,10 +790,10 @@ pub mod flight_service_server { async fn do_exchange( &self, request: tonic::Request>, - ) -> Result, tonic::Status>; - ///Server streaming response type for the DoAction method. - type DoActionStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the DoAction method. + type DoActionStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > + Send + 'static; @@ -666,10 +807,10 @@ pub mod flight_service_server { async fn do_action( &self, request: tonic::Request, - ) -> Result, tonic::Status>; - ///Server streaming response type for the ListActions method. - type ListActionsStream: futures_core::Stream< - Item = Result, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the ListActions method. + type ListActionsStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, > + Send + 'static; @@ -680,7 +821,10 @@ pub mod flight_service_server { async fn list_actions( &self, request: tonic::Request, - ) -> Result, tonic::Status>; + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; } /// /// A flight service is an endpoint for retrieving or storing Arrow data. A @@ -692,6 +836,8 @@ pub mod flight_service_server { inner: _Inner, accept_compression_encodings: EnabledCompressionEncodings, send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, } struct _Inner(Arc); impl FlightServiceServer { @@ -704,6 +850,8 @@ pub mod flight_service_server { inner, accept_compression_encodings: Default::default(), send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, } } pub fn with_interceptor( @@ -727,6 +875,22 @@ pub mod flight_service_server { self.send_compression_encodings.enable(encoding); self } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } } impl tonic::codegen::Service> for FlightServiceServer where @@ -740,7 +904,7 @@ pub mod flight_service_server { fn poll_ready( &mut self, _cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: http::Request) -> Self::Future { @@ -765,13 +929,17 @@ pub mod flight_service_server { tonic::Streaming, >, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).handshake(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::handshake(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -781,6 +949,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.streaming(method, req).await; Ok(res) @@ -804,15 +976,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); + let inner = Arc::clone(&self.0); let fut = async move { - (*inner).list_flights(request).await + ::list_flights(&inner, request).await }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -822,6 +996,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -844,15 +1022,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); + let inner = Arc::clone(&self.0); let fut = async move { - (*inner).get_flight_info(request).await + ::get_flight_info(&inner, request).await }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -862,6 +1042,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.unary(method, req).await; Ok(res) @@ -884,13 +1068,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).get_schema(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_schema(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -900,6 +1088,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.unary(method, req).await; Ok(res) @@ -923,13 +1115,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_get(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_get(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -939,6 +1135,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -962,13 +1162,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request>, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_put(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_put(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -978,6 +1182,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.streaming(method, req).await; Ok(res) @@ -1001,13 +1209,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request>, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_exchange(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_exchange(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -1017,6 +1229,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.streaming(method, req).await; Ok(res) @@ -1040,13 +1256,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); - let fut = async move { (*inner).do_action(request).await }; + let inner = Arc::clone(&self.0); + let fut = async move { + ::do_action(&inner, request).await + }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -1056,6 +1276,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -1079,15 +1303,17 @@ pub mod flight_service_server { &mut self, request: tonic::Request, ) -> Self::Future { - let inner = self.0.clone(); + let inner = Arc::clone(&self.0); let fut = async move { - (*inner).list_actions(request).await + ::list_actions(&inner, request).await }; Box::pin(fut) } } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -1097,6 +1323,10 @@ pub mod flight_service_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -1125,12 +1355,14 @@ pub mod flight_service_server { inner, accept_compression_encodings: self.accept_compression_encodings, send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, } } } impl Clone for _Inner { fn clone(&self) -> Self { - Self(self.0.clone()) + Self(Arc::clone(&self.0)) } } impl std::fmt::Debug for _Inner { diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs new file mode 100644 index 000000000000..296efc1c308e --- /dev/null +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -0,0 +1,351 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{sync::Arc, time::Duration}; + +use anyhow::{bail, Context, Result}; +use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; +use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; +use arrow_flight::{sql::client::FlightSqlServiceClient, FlightInfo}; +use arrow_schema::Schema; +use clap::{Parser, Subcommand}; +use futures::TryStreamExt; +use tonic::{ + metadata::MetadataMap, + transport::{Channel, ClientTlsConfig, Endpoint}, +}; +use tracing_log::log::info; + +/// Logging CLI config. +#[derive(Debug, Parser)] +pub struct LoggingArgs { + /// Log verbosity. + /// + /// Defaults to "warn". + /// + /// Use `-v` for "info", `-vv` for "debug", `-vvv` for "trace". + /// + /// Note you can also set logging level using `RUST_LOG` environment variable: + /// `RUST_LOG=debug`. + #[clap( + short = 'v', + long = "verbose", + action = clap::ArgAction::Count, + )] + log_verbose_count: u8, +} + +#[derive(Debug, Parser)] +struct ClientArgs { + /// Additional headers. + /// + /// Can be given multiple times. Headers and values are separated by '='. + /// + /// Example: `-H foo=bar -H baz=42` + #[clap(long = "header", short = 'H', value_parser = parse_key_val)] + headers: Vec<(String, String)>, + + /// Username. + /// + /// Optional. If given, `password` must also be set. + #[clap(long, requires = "password")] + username: Option, + + /// Password. + /// + /// Optional. If given, `username` must also be set. + #[clap(long, requires = "username")] + password: Option, + + /// Auth token. + #[clap(long)] + token: Option, + + /// Use TLS. + /// + /// If not provided, use cleartext connection. + #[clap(long)] + tls: bool, + + /// Server host. + /// + /// Required. + #[clap(long)] + host: String, + + /// Server port. + /// + /// Defaults to `443` if `tls` is set, otherwise defaults to `80`. + #[clap(long)] + port: Option, +} + +#[derive(Debug, Parser)] +struct Args { + /// Logging args. + #[clap(flatten)] + logging_args: LoggingArgs, + + /// Client args. + #[clap(flatten)] + client_args: ClientArgs, + + #[clap(subcommand)] + cmd: Command, +} + +/// Different available commands. +#[derive(Debug, Subcommand)] +enum Command { + /// Execute given statement. + StatementQuery { + /// SQL query. + /// + /// Required. + query: String, + }, + + /// Prepare given statement and then execute it. + PreparedStatementQuery { + /// SQL query. + /// + /// Required. + /// + /// Can contains placeholders like `$1`. + /// + /// Example: `SELECT * FROM t WHERE x = $1` + query: String, + + /// Additional parameters. + /// + /// Can be given multiple times. Names and values are separated by '='. Values will be + /// converted to the type that the server reported for the prepared statement. + /// + /// Example: `-p $1=42` + #[clap(short, value_parser = parse_key_val)] + params: Vec<(String, String)>, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + setup_logging(args.logging_args)?; + let mut client = setup_client(args.client_args) + .await + .context("setup client")?; + + let flight_info = match args.cmd { + Command::StatementQuery { query } => client + .execute(query, None) + .await + .context("execute statement")?, + Command::PreparedStatementQuery { query, params } => { + let mut prepared_stmt = client + .prepare(query, None) + .await + .context("prepare statement")?; + + if !params.is_empty() { + prepared_stmt + .set_parameters( + construct_record_batch_from_params( + ¶ms, + prepared_stmt + .parameter_schema() + .context("get parameter schema")?, + ) + .context("construct parameters")?, + ) + .context("bind parameters")?; + } + + prepared_stmt + .execute() + .await + .context("execute prepared statement")? + } + }; + + let batches = execute_flight(&mut client, flight_info) + .await + .context("read flight data")?; + + let res = pretty_format_batches(batches.as_slice()).context("format results")?; + println!("{res}"); + + Ok(()) +} + +async fn execute_flight( + client: &mut FlightSqlServiceClient, + info: FlightInfo, +) -> Result> { + let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?); + let mut batches = Vec::with_capacity(info.endpoint.len() + 1); + batches.push(RecordBatch::new_empty(schema)); + info!("decoded schema"); + + for endpoint in info.endpoint { + let Some(ticket) = &endpoint.ticket else { + bail!("did not get ticket"); + }; + + let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?; + log_metadata(flight_data.headers(), "header"); + + let mut endpoint_batches: Vec<_> = (&mut flight_data) + .try_collect() + .await + .context("collect data stream")?; + batches.append(&mut endpoint_batches); + + if let Some(trailers) = flight_data.trailers() { + log_metadata(&trailers, "trailer"); + } + } + info!("received data"); + + Ok(batches) +} + +fn construct_record_batch_from_params( + params: &[(String, String)], + parameter_schema: &Schema, +) -> Result { + let mut items = Vec::<(&String, ArrayRef)>::new(); + + for (name, value) in params { + let field = parameter_schema.field_with_name(name)?; + let value_as_array = StringArray::new_scalar(value); + let casted = cast_with_options( + value_as_array.get().0, + field.data_type(), + &CastOptions::default(), + )?; + items.push((name, casted)) + } + + Ok(RecordBatch::try_from_iter(items)?) +} + +fn setup_logging(args: LoggingArgs) -> Result<()> { + use tracing_subscriber::{util::SubscriberInitExt, EnvFilter, FmtSubscriber}; + + tracing_log::LogTracer::init().context("tracing log init")?; + + let filter = match args.log_verbose_count { + 0 => "warn", + 1 => "info", + 2 => "debug", + _ => "trace", + }; + let filter = EnvFilter::try_new(filter).context("set up log env filter")?; + + let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish(); + subscriber.try_init().context("init logging subscriber")?; + + Ok(()) +} + +async fn setup_client(args: ClientArgs) -> Result> { + let port = args.port.unwrap_or(if args.tls { 443 } else { 80 }); + + let protocol = if args.tls { "https" } else { "http" }; + + let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port)) + .context("create endpoint")? + .connect_timeout(Duration::from_secs(20)) + .timeout(Duration::from_secs(20)) + .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait + .tcp_keepalive(Option::Some(Duration::from_secs(3600))) + .http2_keep_alive_interval(Duration::from_secs(300)) + .keep_alive_timeout(Duration::from_secs(20)) + .keep_alive_while_idle(true); + + if args.tls { + let tls_config = ClientTlsConfig::new(); + endpoint = endpoint + .tls_config(tls_config) + .context("create TLS endpoint")?; + } + + let channel = endpoint.connect().await.context("connect to endpoint")?; + + let mut client = FlightSqlServiceClient::new(channel); + info!("connected"); + + for (k, v) in args.headers { + client.set_header(k, v); + } + + if let Some(token) = args.token { + client.set_token(token); + info!("token set"); + } + + match (args.username, args.password) { + (None, None) => {} + (Some(username), Some(password)) => { + client + .handshake(&username, &password) + .await + .context("handshake")?; + info!("performed handshake"); + } + (Some(_), None) => { + bail!("when username is set, you also need to set a password") + } + (None, Some(_)) => { + bail!("when password is set, you also need to set a username") + } + } + + Ok(client) +} + +/// Parse a single key-value pair +fn parse_key_val(s: &str) -> Result<(String, String), String> { + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].to_owned(), s[pos + 1..].to_owned())) +} + +/// Log headers/trailers. +fn log_metadata(map: &MetadataMap, what: &'static str) { + for k_v in map.iter() { + match k_v { + tonic::metadata::KeyAndValueRef::Ascii(k, v) => { + info!( + "{}: {}={}", + what, + k.as_str(), + v.to_str().unwrap_or(""), + ); + } + tonic::metadata::KeyAndValueRef::Binary(k, v) => { + info!( + "{}: {}={}", + what, + k.as_str(), + String::from_utf8_lossy(v.as_ref()), + ); + } + } + } +} diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs new file mode 100644 index 000000000000..a264012c82ec --- /dev/null +++ b/arrow-flight/src/client.rs @@ -0,0 +1,550 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::task::Poll; + +use crate::{ + decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, + trailers::extract_lazy_trailers, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, PutResult, Ticket, +}; +use arrow_schema::Schema; +use bytes::Bytes; +use futures::{ + future::ready, + ready, + stream::{self, BoxStream}, + FutureExt, Stream, StreamExt, TryStreamExt, +}; +use tonic::{metadata::MetadataMap, transport::Channel}; + +use crate::error::{FlightError, Result}; + +/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client. +/// +/// [`FlightClient`] is intended as a convenience for interactions +/// with Arrow Flight servers. For more direct control, such as access +/// to the response headers, use [`FlightServiceClient`] directly +/// via methods such as [`Self::inner`] or [`Self::into_inner`]. +/// +/// # Example: +/// ```no_run +/// # async fn run() { +/// # use arrow_flight::FlightClient; +/// # use bytes::Bytes; +/// use tonic::transport::Channel; +/// let channel = Channel::from_static("http://localhost:1234") +/// .connect() +/// .await +/// .expect("error connecting"); +/// +/// let mut client = FlightClient::new(channel); +/// +/// // Send 'Hi' bytes as the handshake request to the server +/// let response = client +/// .handshake(Bytes::from("Hi")) +/// .await +/// .expect("error handshaking"); +/// +/// // Expect the server responded with 'Ho' +/// assert_eq!(response, Bytes::from("Ho")); +/// # } +/// ``` +#[derive(Debug)] +pub struct FlightClient { + /// Optional grpc header metadata to include with each request + metadata: MetadataMap, + + /// The inner client + inner: FlightServiceClient, +} + +impl FlightClient { + /// Creates a client client with the provided [`Channel`] + pub fn new(channel: Channel) -> Self { + Self::new_from_inner(FlightServiceClient::new(channel)) + } + + /// Creates a new higher level client with the provided lower level client + pub fn new_from_inner(inner: FlightServiceClient) -> Self { + Self { + metadata: MetadataMap::new(), + inner, + } + } + + /// Return a reference to gRPC metadata included with each request + pub fn metadata(&self) -> &MetadataMap { + &self.metadata + } + + /// Return a reference to gRPC metadata included with each request + /// + /// These headers can be used, for example, to include + /// authorization or other application specific headers. + pub fn metadata_mut(&mut self) -> &mut MetadataMap { + &mut self.metadata + } + + /// Add the specified header with value to all subsequent + /// requests. See [`Self::metadata_mut`] for fine grained control. + pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> { + let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes()) + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + + let value = value + .parse() + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + + // ignore previous value + self.metadata.insert(key, value); + + Ok(()) + } + + /// Return a reference to the underlying tonic + /// [`FlightServiceClient`] + pub fn inner(&self) -> &FlightServiceClient { + &self.inner + } + + /// Return a mutable reference to the underlying tonic + /// [`FlightServiceClient`] + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + &mut self.inner + } + + /// Consume this client and return the underlying tonic + /// [`FlightServiceClient`] + pub fn into_inner(self) -> FlightServiceClient { + self.inner + } + + /// Perform an Arrow Flight handshake with the server, sending + /// `payload` as the [`HandshakeRequest`] payload and returning + /// the [`HandshakeResponse`](crate::HandshakeResponse) + /// bytes returned from the server + /// + /// See [`FlightClient`] docs for an example. + pub async fn handshake(&mut self, payload: impl Into) -> Result { + let request = HandshakeRequest { + protocol_version: 0, + payload: payload.into(), + }; + + // apply headers, etc + let request = self.make_request(stream::once(ready(request))); + + let mut response_stream = self.inner.handshake(request).await?.into_inner(); + + if let Some(response) = response_stream.next().await.transpose()? { + // check if there is another response + if response_stream.next().await.is_some() { + return Err(FlightError::protocol( + "Got unexpected second response from handshake", + )); + } + + Ok(response.payload) + } else { + Err(FlightError::protocol("No response from handshake")) + } + } + + /// Make a `DoGet` call to the server with the provided ticket, + /// returning a [`FlightRecordBatchStream`] for reading + /// [`RecordBatch`](arrow_array::RecordBatch)es. + /// + /// # Note + /// + /// To access the returned [`FlightData`] use + /// [`FlightRecordBatchStream::into_inner()`] + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use bytes::Bytes; + /// # use arrow_flight::FlightClient; + /// # use arrow_flight::Ticket; + /// # use arrow_array::RecordBatch; + /// # use futures::stream::TryStreamExt; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// # let ticket = Ticket { ticket: Bytes::from("foo") }; + /// let mut client = FlightClient::new(channel); + /// + /// // Invoke a do_get request on the server with a previously + /// // received Ticket + /// + /// let response = client + /// .do_get(ticket) + /// .await + /// .expect("error invoking do_get"); + /// + /// // Use try_collect to get the RecordBatches from the server + /// let batches: Vec = response + /// .try_collect() + /// .await + /// .expect("no stream errors"); + /// # } + /// ``` + pub async fn do_get(&mut self, ticket: Ticket) -> Result { + let request = self.make_request(ticket); + + let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts(); + let (response_stream, trailers) = extract_lazy_trailers(response_stream); + + Ok(FlightRecordBatchStream::new_from_flight_data( + response_stream.map_err(FlightError::Tonic), + ) + .with_headers(md) + .with_trailers(trailers)) + } + + /// Make a `GetFlightInfo` call to the server with the provided + /// [`FlightDescriptor`] and return the [`FlightInfo`] from the + /// server. The [`FlightInfo`] can be used with [`Self::do_get`] + /// to retrieve the requested batches. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use arrow_flight::FlightClient; + /// # use arrow_flight::FlightDescriptor; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Send a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// let flight_info = client + /// .get_flight_info(request) + /// .await + /// .expect("error handshaking"); + /// + /// // retrieve the first endpoint from the returned flight info + /// let ticket = flight_info + /// .endpoint[0] + /// // Extract the ticket + /// .ticket + /// .clone() + /// .expect("expected ticket"); + /// + /// // Retrieve the corresponding RecordBatch stream with do_get + /// let data = client + /// .do_get(ticket) + /// .await + /// .expect("error fetching data"); + /// # } + /// ``` + pub async fn get_flight_info(&mut self, descriptor: FlightDescriptor) -> Result { + let request = self.make_request(descriptor); + + let response = self.inner.get_flight_info(request).await?.into_inner(); + Ok(response) + } + + /// Make a `DoPut` call to the server with the provided + /// [`Stream`] of [`FlightData`] and returning a + /// stream of [`PutResult`]. + /// + /// # Note + /// + /// The input stream is [`Result`] so that this can be connected + /// to a streaming data source, such as [`FlightDataEncoder`](crate::encode::FlightDataEncoder), + /// without having to buffer. If the input stream returns an error + /// that error will not be sent to the server, instead it will be + /// placed into the result stream and the server connection + /// terminated. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::{TryStreamExt, StreamExt}; + /// # use std::sync::Arc; + /// # use arrow_array::UInt64Array; + /// # use arrow_array::RecordBatch; + /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult}; + /// # use arrow_flight::encode::FlightDataEncoderBuilder; + /// # let batch = RecordBatch::try_from_iter(vec![ + /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _) + /// # ]).unwrap(); + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // encode the batch as a stream of `FlightData` + /// let flight_data_stream = FlightDataEncoderBuilder::new() + /// .build(futures::stream::iter(vec![Ok(batch)])); + /// + /// // send the stream and get the results as `PutResult` + /// let response: Vec= client + /// .do_put(flight_data_stream) + /// .await + /// .unwrap() + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error calling do_put"); + /// # } + /// ``` + pub async fn do_put> + Send + 'static>( + &mut self, + request: S, + ) -> Result>> { + let (sender, mut receiver) = futures::channel::oneshot::channel(); + + // Intercepts client errors and sends them to the oneshot channel above + let mut request = Box::pin(request); // Pin to heap + let mut sender = Some(sender); // Wrap into Option so can be taken + let request_stream = futures::stream::poll_fn(move |cx| { + Poll::Ready(match ready!(request.poll_next_unpin(cx)) { + Some(Ok(data)) => Some(data), + Some(Err(e)) => { + let _ = sender.take().unwrap().send(e); + None + } + None => None, + }) + }); + + let request = self.make_request(request_stream); + let mut response_stream = self.inner.do_put(request).await?.into_inner(); + + // Forwards errors from the error oneshot with priority over responses from server + let error_stream = futures::stream::poll_fn(move |cx| { + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + } + let next = ready!(response_stream.poll_next_unpin(cx)); + Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) + }); + + // combine the response from the server and any error from the client + Ok(error_stream.boxed()) + } + + /// Make a `DoExchange` call to the server with the provided + /// [`Stream`] of [`FlightData`] and returning a + /// stream of [`FlightData`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::{TryStreamExt, StreamExt}; + /// # use std::sync::Arc; + /// # use arrow_array::UInt64Array; + /// # use arrow_array::RecordBatch; + /// # use arrow_flight::{FlightClient, FlightDescriptor, PutResult}; + /// # use arrow_flight::encode::FlightDataEncoderBuilder; + /// # let batch = RecordBatch::try_from_iter(vec![ + /// # ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _) + /// # ]).unwrap(); + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // encode the batch as a stream of `FlightData` + /// let flight_data_stream = FlightDataEncoderBuilder::new() + /// .build(futures::stream::iter(vec![Ok(batch)])) + /// // data encoder return Results, but do_exchange requires FlightData + /// .map(|batch|batch.unwrap()); + /// + /// // send the stream and get the results as `RecordBatches` + /// let response: Vec = client + /// .do_exchange(flight_data_stream) + /// .await + /// .unwrap() + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error calling do_exchange"); + /// # } + /// ``` + pub async fn do_exchange + Send + 'static>( + &mut self, + request: S, + ) -> Result { + let request = self.make_request(request); + + let response = self + .inner + .do_exchange(request) + .await? + .into_inner() + .map_err(FlightError::Tonic); + + Ok(FlightRecordBatchStream::new_from_flight_data(response)) + } + + /// Make a `ListFlights` call to the server with the provided + /// criteria and returning a [`Stream`] of [`FlightInfo`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::TryStreamExt; + /// # use bytes::Bytes; + /// # use arrow_flight::{FlightInfo, FlightClient}; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Send 'Name=Foo' bytes as the "expression" to the server + /// // and gather the returned FlightInfo + /// let responses: Vec = client + /// .list_flights(Bytes::from("Name=Foo")) + /// .await + /// .expect("error listing flights") + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error gathering flights"); + /// # } + /// ``` + pub async fn list_flights( + &mut self, + expression: impl Into, + ) -> Result>> { + let request = Criteria { + expression: expression.into(), + }; + + let request = self.make_request(request); + + let response = self + .inner + .list_flights(request) + .await? + .into_inner() + .map_err(FlightError::Tonic); + + Ok(response.boxed()) + } + + /// Make a `GetSchema` call to the server with the provided + /// [`FlightDescriptor`] and returning the associated [`Schema`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use bytes::Bytes; + /// # use arrow_flight::{FlightDescriptor, FlightClient}; + /// # use arrow_schema::Schema; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // Request the schema result of a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// + /// let schema: Schema = client + /// .get_schema(request) + /// .await + /// .expect("error making request"); + /// # } + /// ``` + pub async fn get_schema(&mut self, flight_descriptor: FlightDescriptor) -> Result { + let request = self.make_request(flight_descriptor); + + let schema_result = self.inner.get_schema(request).await?.into_inner(); + + // attempt decode from IPC + let schema: Schema = schema_result.try_into()?; + + Ok(schema) + } + + /// Make a `ListActions` call to the server and returning a + /// [`Stream`] of [`ActionType`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use futures::TryStreamExt; + /// # use arrow_flight::{ActionType, FlightClient}; + /// # use arrow_schema::Schema; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// // List available actions on the server: + /// let actions: Vec = client + /// .list_actions() + /// .await + /// .expect("error listing actions") + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error gathering actions"); + /// # } + /// ``` + pub async fn list_actions(&mut self) -> Result>> { + let request = self.make_request(Empty {}); + + let action_stream = self + .inner + .list_actions(request) + .await? + .into_inner() + .map_err(FlightError::Tonic); + + Ok(action_stream.boxed()) + } + + /// Make a `DoAction` call to the server and returning a + /// [`Stream`] of opaque [`Bytes`]. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use bytes::Bytes; + /// # use futures::TryStreamExt; + /// # use arrow_flight::{Action, FlightClient}; + /// # use arrow_schema::Schema; + /// # let channel: tonic::transport::Channel = unimplemented!(); + /// let mut client = FlightClient::new(channel); + /// + /// let request = Action::new("my_action", "the body"); + /// + /// // Make a request to run the action on the server + /// let results: Vec = client + /// .do_action(request) + /// .await + /// .expect("error executing acton") + /// .try_collect() // use TryStreamExt to collect stream + /// .await + /// .expect("error gathering action results"); + /// # } + /// ``` + pub async fn do_action(&mut self, action: Action) -> Result>> { + let request = self.make_request(action); + + let result_stream = self + .inner + .do_action(request) + .await? + .into_inner() + .map_err(FlightError::Tonic) + .map(|r| { + r.map(|r| { + // unwrap inner bytes + let crate::Result { body } = r; + body + }) + }); + + Ok(result_stream.boxed()) + } + + /// return a Request, adding any configured metadata + fn make_request(&self, t: T) -> tonic::Request { + // Pass along metadata + let mut request = tonic::Request::new(t); + *request.metadata_mut() = self.metadata.clone(); + request + } +} diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs new file mode 100644 index 000000000000..95bbe2b46bb2 --- /dev/null +++ b/arrow-flight/src/decode.rs @@ -0,0 +1,434 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::Buffer; +use arrow_schema::{Schema, SchemaRef}; +use bytes::Bytes; +use futures::{ready, stream::BoxStream, Stream, StreamExt}; +use std::{collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; +use tonic::metadata::MetadataMap; + +use crate::error::{FlightError, Result}; + +/// Decodes a [Stream] of [`FlightData`] back into +/// [`RecordBatch`]es. This can be used to decode the response from an +/// Arrow Flight server +/// +/// # Note +/// To access the lower level Flight messages (e.g. to access +/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`] +/// and use the [`FlightDataDecoder`] directly. +/// +/// # Example: +/// ```no_run +/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{ +/// # use bytes::Bytes; +/// // make a do_get request +/// use arrow_flight::{ +/// error::Result, +/// decode::FlightRecordBatchStream, +/// Ticket, +/// flight_service_client::FlightServiceClient +/// }; +/// use tonic::transport::Channel; +/// use futures::stream::{StreamExt, TryStreamExt}; +/// +/// let client: FlightServiceClient = // make client.. +/// # unimplemented!(); +/// +/// let request = tonic::Request::new( +/// Ticket { ticket: Bytes::new() } +/// ); +/// +/// // Get a stream of FlightData; +/// let flight_data_stream = client +/// .do_get(request) +/// .await? +/// .into_inner(); +/// +/// // Decode stream of FlightData to RecordBatches +/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( +/// // convert tonic::Status to FlightError +/// flight_data_stream.map_err(|e| e.into()) +/// ); +/// +/// // Read back RecordBatches +/// while let Some(batch) = record_batch_stream.next().await { +/// match batch { +/// Ok(batch) => { /* process batch */ }, +/// Err(e) => { /* handle error */ }, +/// }; +/// } +/// +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct FlightRecordBatchStream { + /// Optional grpc header metadata. + headers: MetadataMap, + + /// Optional grpc trailer metadata. + trailers: Option, + + inner: FlightDataDecoder, +} + +impl FlightRecordBatchStream { + /// Create a new [`FlightRecordBatchStream`] from a decoded stream + pub fn new(inner: FlightDataDecoder) -> Self { + Self { + inner, + headers: MetadataMap::default(), + trailers: None, + } + } + + /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`] + pub fn new_from_flight_data(inner: S) -> Self + where + S: Stream> + Send + 'static, + { + Self { + inner: FlightDataDecoder::new(inner), + headers: MetadataMap::default(), + trailers: None, + } + } + + /// Record response headers. + pub fn with_headers(self, headers: MetadataMap) -> Self { + Self { headers, ..self } + } + + /// Record response trailers. + pub fn with_trailers(self, trailers: LazyTrailers) -> Self { + Self { + trailers: Some(trailers), + ..self + } + } + + /// Headers attached to this stream. + pub fn headers(&self) -> &MetadataMap { + &self.headers + } + + /// Trailers attached to this stream. + /// + /// Note that this will return `None` until the entire stream is consumed. + /// Only after calling `next()` returns `None`, might any available trailers be returned. + pub fn trailers(&self) -> Option { + self.trailers.as_ref().and_then(|trailers| trailers.get()) + } + + /// Has a message defining the schema been received yet? + #[deprecated = "use schema().is_some() instead"] + pub fn got_schema(&self) -> bool { + self.schema().is_some() + } + + /// Return schema for the stream, if it has been received + pub fn schema(&self) -> Option<&SchemaRef> { + self.inner.schema() + } + + /// Consume self and return the wrapped [`FlightDataDecoder`] + pub fn into_inner(self) -> FlightDataDecoder { + self.inner + } +} + +impl futures::Stream for FlightRecordBatchStream { + type Item = Result; + + /// Returns the next [`RecordBatch`] available in this stream, or `None` if + /// there are no further results available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + let had_schema = self.schema().is_some(); + let res = ready!(self.inner.poll_next_unpin(cx)); + match res { + // Inner exhausted + None => { + return Poll::Ready(None); + } + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + // translate data + Some(Ok(data)) => match data.payload { + DecodedPayload::Schema(_) if had_schema => { + return Poll::Ready(Some(Err(FlightError::protocol( + "Unexpectedly saw multiple Schema messages in FlightData stream", + )))); + } + DecodedPayload::Schema(_) => { + // Need next message, poll inner again + } + DecodedPayload::RecordBatch(batch) => { + return Poll::Ready(Some(Ok(batch))); + } + DecodedPayload::None => { + // Need next message + } + }, + } + } + } +} + +/// Wrapper around a stream of [`FlightData`] that handles the details +/// of decoding low level Flight messages into [`Schema`] and +/// [`RecordBatch`]es, including details such as dictionaries. +/// +/// # Protocol Details +/// +/// The client handles flight messages as followes: +/// +/// - **None:** This message has no effect. This is useful to +/// transmit metadata without any actual payload. +/// +/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and +/// the decoded schema is returned. +/// +/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing +/// dictionary for the same column will be overwritten. This +/// message is NOT visible. +/// +/// - **Record Batch:** Record batch is created based on the current +/// schema and dictionaries. This fails if no schema was transmitted +/// yet. +/// +/// All other message types (at the time of writing: e.g. tensor and +/// sparse tensor) lead to an error. +/// +/// Example usecases +/// +/// 1. Using this low level stream it is possible to receive a steam +/// of RecordBatches in FlightData that have different schemas by +/// handling multiple schema messages separately. +pub struct FlightDataDecoder { + /// Underlying data stream + response: BoxStream<'static, Result>, + /// Decoding state + state: Option, + /// Seen the end of the inner stream? + done: bool, +} + +impl Debug for FlightDataDecoder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FlightDataDecoder") + .field("response", &"") + .field("state", &self.state) + .field("done", &self.done) + .finish() + } +} + +impl FlightDataDecoder { + /// Create a new wrapper around the stream of [`FlightData`] + pub fn new(response: S) -> Self + where + S: Stream> + Send + 'static, + { + Self { + state: None, + response: response.boxed(), + done: false, + } + } + + /// Returns the current schema for this stream + pub fn schema(&self) -> Option<&SchemaRef> { + self.state.as_ref().map(|state| &state.schema) + } + + /// Extracts flight data from the next message, updating decoding + /// state as necessary. + fn extract_message(&mut self, data: FlightData) -> Result> { + use arrow_ipc::MessageHeader; + let message = arrow_ipc::root_as_message(&data.data_header[..]) + .map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?; + + match message.header_type() { + MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))), + MessageHeader::Schema => { + let schema = Schema::try_from(&data) + .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; + + let schema = Arc::new(schema); + let dictionaries_by_field = HashMap::new(); + + self.state = Some(FlightStreamState { + schema: Arc::clone(&schema), + dictionaries_by_field, + }); + Ok(Some(DecodedFlightData::new_schema(data, schema))) + } + MessageHeader::DictionaryBatch => { + let state = if let Some(state) = self.state.as_mut() { + state + } else { + return Err(FlightError::protocol( + "Received DictionaryBatch prior to Schema", + )); + }; + + let buffer = Buffer::from_bytes(data.data_body.into()); + let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| { + FlightError::protocol( + "Could not get dictionary batch from DictionaryBatch message", + ) + })?; + + arrow_ipc::reader::read_dictionary( + &buffer, + dictionary_batch, + &state.schema, + &mut state.dictionaries_by_field, + &message.version(), + ) + .map_err(|e| { + FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}")) + })?; + + // Updated internal state, but no decoded message + Ok(None) + } + MessageHeader::RecordBatch => { + let state = if let Some(state) = self.state.as_ref() { + state + } else { + return Err(FlightError::protocol( + "Received RecordBatch prior to Schema", + )); + }; + + let batch = flight_data_to_arrow_batch( + &data, + Arc::clone(&state.schema), + &state.dictionaries_by_field, + ) + .map_err(|e| { + FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}")) + })?; + + Ok(Some(DecodedFlightData::new_record_batch(data, batch))) + } + other => { + let name = other.variant_name().unwrap_or("UNKNOWN"); + Err(FlightError::protocol(format!("Unexpected message: {name}"))) + } + } + } +} + +impl futures::Stream for FlightDataDecoder { + type Item = Result; + /// Returns the result of decoding the next [`FlightData`] message + /// from the server, or `None` if there are no further results + /// available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if self.done { + return Poll::Ready(None); + } + loop { + let res = ready!(self.response.poll_next_unpin(cx)); + + return Poll::Ready(match res { + None => { + self.done = true; + None // inner is exhausted + } + Some(data) => Some(match data { + Err(e) => Err(e), + Ok(data) => match self.extract_message(data) { + Ok(Some(extracted)) => Ok(extracted), + Ok(None) => continue, // Need next input message + Err(e) => Err(e), + }, + }), + }); + } + } +} + +/// tracks the state needed to reconstruct [`RecordBatch`]es from a +/// streaming flight response. +#[derive(Debug)] +struct FlightStreamState { + schema: SchemaRef, + dictionaries_by_field: HashMap, +} + +/// FlightData and the decoded payload (Schema, RecordBatch), if any +#[derive(Debug)] +pub struct DecodedFlightData { + pub inner: FlightData, + pub payload: DecodedPayload, +} + +impl DecodedFlightData { + pub fn new_none(inner: FlightData) -> Self { + Self { + inner, + payload: DecodedPayload::None, + } + } + + pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self { + Self { + inner, + payload: DecodedPayload::Schema(schema), + } + } + + pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self { + Self { + inner, + payload: DecodedPayload::RecordBatch(batch), + } + } + + /// return the metadata field of the inner flight data + pub fn app_metadata(&self) -> Bytes { + self.inner.app_metadata.clone() + } +} + +/// The result of decoding [`FlightData`] +#[derive(Debug)] +pub enum DecodedPayload { + /// None (no data was sent in the corresponding FlightData) + None, + + /// A decoded Schema message + Schema(SchemaRef), + + /// A decoded Record batch. + RecordBatch(RecordBatch), +} diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs new file mode 100644 index 000000000000..e6ef9994d487 --- /dev/null +++ b/arrow-flight/src/encode.rs @@ -0,0 +1,987 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; + +use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc}; +use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use bytes::Bytes; +use futures::{ready, stream::BoxStream, Stream, StreamExt}; + +/// Creates a [`Stream`] of [`FlightData`]s from a +/// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>. +/// +/// This can be used to implement [`FlightService::do_get`] in an +/// Arrow Flight implementation; +/// +/// This structure encodes a stream of `Result`s rather than `RecordBatch`es to +/// propagate errors from streaming execution, where the generation of the +/// `RecordBatch`es is incremental, and an error may occur even after +/// several have already been successfully produced. +/// +/// # Caveats +/// 1. When [`DictionaryHandling`] is [`DictionaryHandling::Hydrate`], [`DictionaryArray`](arrow_array::array::DictionaryArray)s +/// are converted to their underlying types prior to transport. +/// When [`DictionaryHandling`] is [`DictionaryHandling::Resend`], Dictionary [`FlightData`] is sent with every +/// [`RecordBatch`] that contains a [`DictionaryArray`](arrow_array::array::DictionaryArray). +/// See . +/// +/// # Example +/// ```no_run +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get an input stream of Result +/// let input_stream = futures::stream::iter(vec![Ok(batch)]); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// .build(input_stream); +/// +/// // Create a tonic `Response` that can be returned from a Flight server +/// let response = tonic::Response::new(flight_data_stream); +/// # } +/// ``` +/// +/// # Example: Sending `Vec` +/// +/// You can create a [`Stream`] to pass to [`Self::build`] from an existing +/// `Vec` of `RecordBatch`es like this: +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # fn make_batches() -> Vec { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// # vec![batch.clone(), batch.clone()] +/// # } +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get batches that you want to send via Flight +/// let batches: Vec = make_batches(); +/// +/// // Create an input stream of Result +/// let input_stream = futures::stream::iter( +/// batches.into_iter().map(Ok) +/// ); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// .build(input_stream); +/// # } +/// ``` +/// +/// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get +/// [`FlightError`]: crate::error::FlightError +#[derive(Debug)] +pub struct FlightDataEncoderBuilder { + /// The maximum approximate target message size in bytes + /// (see details on [`Self::with_max_flight_data_size`]). + max_flight_data_size: usize, + /// Ipc writer options + options: IpcWriteOptions, + /// Metadata to add to the schema message + app_metadata: Bytes, + /// Optional schema, if known before data. + schema: Option, + /// Optional flight descriptor, if known before data. + descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, +} + +/// Default target size for encoded [`FlightData`]. +/// +/// Note this value would normally be 4MB, but the size calculation is +/// somewhat inexact, so we set it to 2MB. +pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152; + +impl Default for FlightDataEncoderBuilder { + fn default() -> Self { + Self { + max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES, + options: IpcWriteOptions::default(), + app_metadata: Bytes::new(), + schema: None, + descriptor: None, + dictionary_handling: DictionaryHandling::Hydrate, + } + } +} + +impl FlightDataEncoderBuilder { + pub fn new() -> Self { + Self::default() + } + + /// Set the (approximate) maximum size, in bytes, of the + /// [`FlightData`] produced by this encoder. Defaults to 2MB. + /// + /// Since there is often a maximum message size for gRPC messages + /// (typically around 4MB), this encoder splits up [`RecordBatch`]s + /// (preserving order) into multiple [`FlightData`] objects to + /// limit the size individual messages sent via gRPC. + /// + /// The size is approximate because of the additional encoding + /// overhead on top of the underlying data buffers themselves. + pub fn with_max_flight_data_size(mut self, max_flight_data_size: usize) -> Self { + self.max_flight_data_size = max_flight_data_size; + self + } + + /// Set [`DictionaryHandling`] for encoder + pub fn with_dictionary_handling(mut self, dictionary_handling: DictionaryHandling) -> Self { + self.dictionary_handling = dictionary_handling; + self + } + + /// Specify application specific metadata included in the + /// [`FlightData::app_metadata`] field of the the first Schema + /// message + pub fn with_metadata(mut self, app_metadata: Bytes) -> Self { + self.app_metadata = app_metadata; + self + } + + /// Set the [`IpcWriteOptions`] used to encode the [`RecordBatch`]es for transport. + pub fn with_options(mut self, options: IpcWriteOptions) -> Self { + self.options = options; + self + } + + /// Specify a schema for the RecordBatches being sent. If a schema + /// is not specified, an encoded Schema message will be sent when + /// the first [`RecordBatch`], if any, is encoded. Some clients + /// expect a Schema message even if there is no data sent. + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + /// Specify a flight descriptor in the first FlightData message. + pub fn with_flight_descriptor(mut self, descriptor: Option) -> Self { + self.descriptor = descriptor; + self + } + + /// Takes a [`Stream`] of [`Result`] and returns a [`Stream`] + /// of [`FlightData`], consuming self. + /// + /// See example on [`Self`] and [`FlightDataEncoder`] for more details + pub fn build(self, input: S) -> FlightDataEncoder + where + S: Stream> + Send + 'static, + { + let Self { + max_flight_data_size, + options, + app_metadata, + schema, + descriptor, + dictionary_handling, + } = self; + + FlightDataEncoder::new( + input.boxed(), + schema, + max_flight_data_size, + options, + app_metadata, + descriptor, + dictionary_handling, + ) + } +} + +/// Stream that encodes a stream of record batches to flight data. +/// +/// See [`FlightDataEncoderBuilder`] for details and example. +pub struct FlightDataEncoder { + /// Input stream + inner: BoxStream<'static, Result>, + /// schema, set after the first batch + schema: Option, + /// Target maximum size of flight data + /// (see details on [`FlightDataEncoderBuilder::with_max_flight_data_size`]). + max_flight_data_size: usize, + /// do the encoding / tracking of dictionaries + encoder: FlightIpcEncoder, + /// optional metadata to add to schema FlightData + app_metadata: Option, + /// data queued up to send but not yet sent + queue: VecDeque, + /// Is this stream done (inner is empty or errored) + done: bool, + /// cleared after the first FlightData message is sent + descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, +} + +impl FlightDataEncoder { + fn new( + inner: BoxStream<'static, Result>, + schema: Option, + max_flight_data_size: usize, + options: IpcWriteOptions, + app_metadata: Bytes, + descriptor: Option, + dictionary_handling: DictionaryHandling, + ) -> Self { + let mut encoder = Self { + inner, + schema: None, + max_flight_data_size, + encoder: FlightIpcEncoder::new( + options, + dictionary_handling != DictionaryHandling::Resend, + ), + app_metadata: Some(app_metadata), + queue: VecDeque::new(), + done: false, + descriptor, + dictionary_handling, + }; + + // If schema is known up front, enqueue it immediately + if let Some(schema) = schema { + encoder.encode_schema(&schema); + } + + encoder + } + + /// Place the `FlightData` in the queue to send + fn queue_message(&mut self, mut data: FlightData) { + if let Some(descriptor) = self.descriptor.take() { + data.flight_descriptor = Some(descriptor); + } + self.queue.push_back(data); + } + + /// Place the `FlightData` in the queue to send + fn queue_messages(&mut self, datas: impl IntoIterator) { + for data in datas { + self.queue_message(data) + } + } + + /// Encodes schema as a [`FlightData`] in self.queue. + /// Updates `self.schema` and returns the new schema + fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef { + // The first message is the schema message, and all + // batches have the same schema + let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; + let schema = Arc::new(prepare_schema_for_flight(schema, send_dictionaries)); + let mut schema_flight_data = self.encoder.encode_schema(&schema); + + // attach any metadata requested + if let Some(app_metadata) = self.app_metadata.take() { + schema_flight_data.app_metadata = app_metadata; + } + self.queue_message(schema_flight_data); + // remember schema + self.schema = Some(schema.clone()); + schema + } + + /// Encodes batch into one or more `FlightData` messages in self.queue + fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> { + let schema = match &self.schema { + Some(schema) => schema.clone(), + // encode the schema if this is the first time we have seen it + None => self.encode_schema(&batch.schema()), + }; + + // encode the batch + let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; + let batch = prepare_batch_for_flight(&batch, schema, send_dictionaries)?; + + for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { + let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; + + self.queue_messages(flight_dictionaries); + self.queue_message(flight_batch); + } + + Ok(()) + } +} + +impl Stream for FlightDataEncoder { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + loop { + if self.done && self.queue.is_empty() { + return Poll::Ready(None); + } + + // Any messages queued to send? + if let Some(data) = self.queue.pop_front() { + return Poll::Ready(Some(Ok(data))); + } + + // Get next batch + let batch = ready!(self.inner.poll_next_unpin(cx)); + + match batch { + None => { + // inner is done + self.done = true; + // queue must also be empty so we are done + assert!(self.queue.is_empty()); + return Poll::Ready(None); + } + Some(Err(e)) => { + // error from inner + self.done = true; + self.queue.clear(); + return Poll::Ready(Some(Err(e))); + } + Some(Ok(batch)) => { + // had data, encode into the queue + if let Err(e) = self.encode_batch(batch) { + self.done = true; + self.queue.clear(); + return Poll::Ready(Some(Err(e))); + } + } + } + } + } +} + +/// Defines how a [`FlightDataEncoder`] encodes [`DictionaryArray`]s +/// +/// [`DictionaryArray`]: arrow_array::DictionaryArray +#[derive(Debug, PartialEq)] +pub enum DictionaryHandling { + /// Expands to the underlying type (default). This likely sends more data + /// over the network but requires less memory (dictionaries are not tracked) + /// and is more compatible with other arrow flight client implementations + /// that may not support `DictionaryEncoding` + /// + /// An IPC response, streaming or otherwise, defines its schema up front + /// which defines the mapping from dictionary IDs. It then sends these + /// dictionaries over the wire. + /// + /// This requires identifying the different dictionaries in use, assigning + /// them IDs, and sending new dictionaries, delta or otherwise, when needed + /// + /// See also: + /// * + Hydrate, + /// Send dictionary FlightData with every RecordBatch that contains a + /// [`DictionaryArray`]. See [`Self::Hydrate`] for more tradeoffs. No + /// attempt is made to skip sending the same (logical) dictionary values + /// twice. + /// + /// [`DictionaryArray`]: arrow_array::DictionaryArray + Resend, +} + +/// Prepare an arrow Schema for transport over the Arrow Flight protocol +/// +/// Convert dictionary types to underlying types +/// +/// See hydrate_dictionary for more information +fn prepare_schema_for_flight(schema: &Schema, send_dictionaries: bool) -> Schema { + let fields: Fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + _ => field.as_ref().clone(), + }) + .collect(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + +/// Split [`RecordBatch`] so it hopefully fits into a gRPC response. +/// +/// Data is zero-copy sliced into batches. +/// +/// Note: this method does not take into account already sliced +/// arrays: +fn split_batch_for_grpc_response( + batch: RecordBatch, + max_flight_data_size: usize, +) -> Vec { + let size = batch + .columns() + .iter() + .map(|col| col.get_buffer_memory_size()) + .sum::(); + + let n_batches = + (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1); + let rows_per_batch = (batch.num_rows() / n_batches).max(1); + let mut out = Vec::with_capacity(n_batches + 1); + + let mut offset = 0; + while offset < batch.num_rows() { + let length = (rows_per_batch).min(batch.num_rows() - offset); + out.push(batch.slice(offset, length)); + + offset += length; + } + + out +} + +/// The data needed to encode a stream of flight data, holding on to +/// shared Dictionaries. +/// +/// TODO: at allow dictionaries to be flushed / avoid building them +/// +/// TODO limit on the number of dictionaries??? +struct FlightIpcEncoder { + options: IpcWriteOptions, + data_gen: IpcDataGenerator, + dictionary_tracker: DictionaryTracker, +} + +impl FlightIpcEncoder { + fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { + Self { + options, + data_gen: IpcDataGenerator::default(), + dictionary_tracker: DictionaryTracker::new(error_on_replacement), + } + } + + /// Encode a schema as a FlightData + fn encode_schema(&self, schema: &Schema) -> FlightData { + SchemaAsIpc::new(schema, &self.options).into() + } + + /// Convert a `RecordBatch` to a Vec of `FlightData` representing + /// dictionaries and a `FlightData` representing the batch + fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec, FlightData)> { + let (encoded_dictionaries, encoded_batch) = + self.data_gen + .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?; + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + Ok((flight_dictionaries, flight_batch)) + } +} + +/// Prepares a RecordBatch for transport over the Arrow Flight protocol +/// +/// This means: +/// +/// 1. Hydrates any dictionaries to its underlying type. See +/// hydrate_dictionary for more information. +/// +fn prepare_batch_for_flight( + batch: &RecordBatch, + schema: SchemaRef, + send_dictionaries: bool, +) -> Result { + let columns = batch + .columns() + .iter() + .map(|c| hydrate_dictionary(c, send_dictionaries)) + .collect::>>()?; + + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + + Ok(RecordBatch::try_new_with_options( + schema, columns, &options, + )?) +} + +/// Hydrates a dictionary to its underlying type if send_dictionaries is false. If send_dictionaries +/// is true, dictionaries are sent with every batch which is not as optimal as described in [DictionaryHandling::Hydrate] above, +/// but does enable sending DictionaryArray's via Flight. +fn hydrate_dictionary(array: &ArrayRef, send_dictionaries: bool) -> Result { + let arr = match array.data_type() { + DataType::Dictionary(_, value) if !send_dictionaries => arrow_cast::cast(array, value)?, + _ => Arc::clone(array), + }; + Ok(arr) +} + +#[cfg(test)] +mod tests { + use arrow_array::*; + use arrow_array::{cast::downcast_array, types::*}; + use arrow_cast::pretty::pretty_format_batches; + use std::collections::HashMap; + + use crate::decode::{DecodedPayload, FlightDataDecoder}; + + use super::*; + + #[test] + /// ensure only the batch's used data (not the allocated data) is sent + /// + fn test_encode_flight_data() { + let options = IpcWriteOptions::default(); + let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); + + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)]) + .expect("cannot create record batch"); + let schema = batch.schema(); + + let (_, baseline_flight_batch) = make_flight_data(&batch, &options); + + let big_batch = batch.slice(0, batch.num_rows() - 1); + let optimized_big_batch = prepare_batch_for_flight(&big_batch, Arc::clone(&schema), false) + .expect("failed to optimize"); + let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options); + + assert_eq!( + baseline_flight_batch.data_body.len(), + optimized_big_flight_batch.data_body.len() + ); + + let small_batch = batch.slice(0, 1); + let optimized_small_batch = + prepare_batch_for_flight(&small_batch, Arc::clone(&schema), false) + .expect("failed to optimize"); + let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options); + + assert!( + baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len() + ); + } + + #[tokio::test] + async fn test_dictionary_hydration() { + let arr: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap(); + let encoder = + FlightDataEncoderBuilder::default().build(futures::stream::once(async { Ok(batch) })); + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]); + let expected_schema = Arc::new(expected_schema); + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = StringArray::from(vec!["a", "a", "b"]); + let actual_array = b.column_by_name("dict").unwrap(); + let actual_array = downcast_array::(actual_array); + + assert_eq!(actual_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_send_dictionaries() { + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + + let arr_one: Arc> = + Arc::new(vec!["a", "a", "b"].into_iter().collect()); + let arr_two: Arc> = + Arc::new(vec!["b", "a", "c"].into_iter().collect()); + let batch_one = RecordBatch::try_new(schema.clone(), vec![arr_one.clone()]).unwrap(); + let batch_two = RecordBatch::try_new(schema.clone(), vec![arr_two.clone()]).unwrap(); + + let encoder = FlightDataEncoderBuilder::default() + .with_dictionary_handling(DictionaryHandling::Resend) + .build(futures::stream::iter(vec![Ok(batch_one), Ok(batch_two)])); + + let mut decoder = FlightDataDecoder::new(encoder); + let mut expected_array = arr_one; + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), schema); + + let actual_array = Arc::new(downcast_array::>( + b.column_by_name("dict").unwrap(), + )); + + assert_eq!(actual_array, expected_array); + + expected_array = arr_two.clone(); + } + } + } + } + + #[test] + fn test_schema_metadata_encoded() { + let schema = Schema::new(vec![Field::new("data", DataType::Int32, false)]).with_metadata( + HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), + ); + + let got = prepare_schema_for_flight(&schema, false); + assert!(got.metadata().contains_key("some_key")); + } + + #[test] + fn test_encode_no_column_batch() { + let batch = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(10)), + ) + .expect("cannot create record batch"); + + prepare_batch_for_flight(&batch, batch.schema(), false).expect("failed to optimize"); + } + + pub fn make_flight_data( + batch: &RecordBatch, + options: &IpcWriteOptions, + ) -> (Vec, FlightData) { + let data_gen = IpcDataGenerator::default(); + let mut dictionary_tracker = DictionaryTracker::new(false); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + (flight_dictionaries, flight_batch) + } + + #[test] + fn test_split_batch_for_grpc_response() { + let max_flight_data_size = 1024; + + // no split + let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) + .expect("cannot create record batch"); + let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); + assert_eq!(split.len(), 1); + assert_eq!(batch, split[0]); + + // split once + let n_rows = max_flight_data_size + 1; + assert!(n_rows % 2 == 1, "should be an odd number"); + let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) + .expect("cannot create record batch"); + let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); + assert_eq!(split.len(), 3); + assert_eq!( + split.iter().map(|batch| batch.num_rows()).sum::(), + n_rows + ); + let a = pretty_format_batches(&split).unwrap().to_string(); + let b = pretty_format_batches(&[batch]).unwrap().to_string(); + assert_eq!(a, b); + } + + #[test] + fn test_split_batch_for_grpc_response_sizes() { + // 2000 8 byte entries into 2k pieces: 8 chunks of 250 rows + verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]); + + // 2000 8 byte entries into 4k pieces: 4 chunks of 500 rows + verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]); + + // 2023 8 byte entries into 3k pieces does not divide evenly + verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]); + + // 10 8 byte entries into 1 byte pieces means each rows gets its own + verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]); + + // 10 8 byte entries into 1k byte pieces means one piece + verify_split(10, 1024, vec![10]); + } + + /// Creates a UInt64Array of 8 byte integers with input_rows rows + /// `max_flight_data_size_bytes` pieces and verifies the row counts in + /// those pieces + fn verify_split( + num_input_rows: u64, + max_flight_data_size_bytes: usize, + expected_sizes: Vec, + ) { + let array: UInt64Array = (0..num_input_rows).collect(); + + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)]) + .expect("cannot create record batch"); + + let input_rows = batch.num_rows(); + + let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes); + let sizes: Vec<_> = split.iter().map(|batch| batch.num_rows()).collect(); + let output_rows: usize = sizes.iter().sum(); + + assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}"); + assert_eq!(input_rows, output_rows, "mismatch for {batch:?}"); + } + + // test sending record batches + // test sending record batches with multiple different dictionaries + + #[tokio::test] + async fn flight_data_size_even() { + let s1 = StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024)); + let i1 = Int16Array::from_iter_values(0..1024); + let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024)); + let i2 = Int64Array::from_iter_values(0..1024); + + let batch = RecordBatch::try_from_iter(vec![ + ("s1", Arc::new(s1) as _), + ("i1", Arc::new(i1) as _), + ("s2", Arc::new(s2) as _), + ("i2", Arc::new(i2) as _), + ]) + .unwrap(); + + verify_encoded_split(batch, 112).await; + } + + #[tokio::test] + async fn flight_data_size_uneven_variable_lengths() { + // each row has a longer string than the last with increasing lengths 0 --> 1024 + let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i))); + let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 4304).await; + } + + #[tokio::test] + async fn flight_data_size_large_row() { + // batch with individual that can each exceed the batch size + let array1 = StringArray::from_iter_values(vec![ + "*".repeat(500), + "*".repeat(500), + "*".repeat(500), + "*".repeat(500), + ]); + let array2 = StringArray::from_iter_values(vec![ + "*".to_string(), + "*".repeat(1000), + "*".repeat(2000), + "*".repeat(4000), + ]); + + let array3 = StringArray::from_iter_values(vec![ + "*".to_string(), + "*".to_string(), + "*".repeat(1000), + "*".repeat(2000), + ]); + + let batch = RecordBatch::try_from_iter(vec![ + ("a1", Arc::new(array1) as _), + ("a2", Arc::new(array2) as _), + ("a3", Arc::new(array3) as _), + ]) + .unwrap(); + + // 5k over limit (which is 2x larger than limit of 5k) + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 5800).await; + } + + #[tokio::test] + async fn flight_data_size_string_dictionary() { + // Small dictionary (only 2 distinct values ==> 2 entries in dictionary) + let array: DictionaryArray = (1..1024) + .map(|i| match i % 3 { + 0 => Some("value0"), + 1 => Some("value1"), + _ => None, + }) + .collect(); + + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + + verify_encoded_split(batch, 160).await; + } + + #[tokio::test] + async fn flight_data_size_large_dictionary() { + // large dictionary (all distinct values ==> 1024 entries in dictionary) + let values: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect(); + + let array: DictionaryArray = values.iter().map(|s| Some(s.as_str())).collect(); + + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 3328).await; + } + + #[tokio::test] + async fn flight_data_size_large_dictionary_repeated_non_uniform() { + // large dictionary (1024 distinct values) that are used throughout the array + let values = StringArray::from_iter_values((0..1024).map(|i| "******".repeat(i))); + let keys = Int32Array::from_iter_values((0..3000).map(|i| (3000 - i) % 1024)); + let array = DictionaryArray::new(keys, Arc::new(values)); + + let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 5280).await; + } + + #[tokio::test] + async fn flight_data_size_multiple_dictionaries() { + // high cardinality + let values1: Vec<_> = (1..1024).map(|i| "**".repeat(i)).collect(); + // highish cardinality + let values2: Vec<_> = (1..1024).map(|i| "**".repeat(i % 10)).collect(); + // medium cardinality + let values3: Vec<_> = (1..1024).map(|i| "**".repeat(i % 100)).collect(); + + let array1: DictionaryArray = values1.iter().map(|s| Some(s.as_str())).collect(); + let array2: DictionaryArray = values2.iter().map(|s| Some(s.as_str())).collect(); + let array3: DictionaryArray = values3.iter().map(|s| Some(s.as_str())).collect(); + + let batch = RecordBatch::try_from_iter(vec![ + ("a1", Arc::new(array1) as _), + ("a2", Arc::new(array2) as _), + ("a3", Arc::new(array3) as _), + ]) + .unwrap(); + + // overage is much higher than ideal + // https://github.com/apache/arrow-rs/issues/3478 + verify_encoded_split(batch, 4128).await; + } + + /// Return size, in memory of flight data + fn flight_data_size(d: &FlightData) -> usize { + let flight_descriptor_size = d + .flight_descriptor + .as_ref() + .map(|descriptor| { + let path_len: usize = descriptor.path.iter().map(|p| p.as_bytes().len()).sum(); + + std::mem::size_of_val(descriptor) + descriptor.cmd.len() + path_len + }) + .unwrap_or(0); + + flight_descriptor_size + d.app_metadata.len() + d.data_body.len() + d.data_header.len() + } + + /// Coverage for + /// + /// Encodes the specified batch using several values of + /// `max_flight_data_size` between 1K to 5K and ensures that the + /// resulting size of the flight data stays within the limit + /// + `allowed_overage` + /// + /// `allowed_overage` is how far off the actual data encoding is + /// from the target limit that was set. It is an improvement when + /// the allowed_overage decreses. + /// + /// Note this overhead will likely always be greater than zero to + /// account for encoding overhead such as IPC headers and padding. + /// + /// + async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) { + let num_rows = batch.num_rows(); + + // Track the overall required maximum overage + let mut max_overage_seen = 0; + + for max_flight_data_size in [1024, 2021, 5000] { + println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}"); + + let mut stream = FlightDataEncoderBuilder::new() + .with_max_flight_data_size(max_flight_data_size) + .build(futures::stream::iter([Ok(batch.clone())])); + + let mut i = 0; + while let Some(data) = stream.next().await.transpose().unwrap() { + let actual_data_size = flight_data_size(&data); + + let actual_overage = if actual_data_size > max_flight_data_size { + actual_data_size - max_flight_data_size + } else { + 0 + }; + + assert!( + actual_overage <= allowed_overage, + "encoded data[{i}]: actual size {actual_data_size}, \ + actual_overage: {actual_overage} \ + allowed_overage: {allowed_overage}" + ); + + i += 1; + + max_overage_seen = max_overage_seen.max(actual_overage) + } + } + + // ensure that the specified overage is exactly the maxmium so + // that when the splitting logic improves, the tests must be + // updated to reflect the better logic + assert_eq!( + allowed_overage, max_overage_seen, + "Specified overage was too high" + ); + } +} diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs new file mode 100644 index 000000000000..e054883e965d --- /dev/null +++ b/arrow-flight/src/error.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::error::Error; + +use arrow_schema::ArrowError; + +/// Errors for the Apache Arrow Flight crate +#[derive(Debug)] +pub enum FlightError { + /// Underlying arrow error + Arrow(ArrowError), + /// Returned when functionality is not yet available. + NotYetImplemented(String), + /// Error from the underlying tonic library + Tonic(tonic::Status), + /// Some unexpected message was received + ProtocolError(String), + /// An error occurred during decoding + DecodeError(String), + /// External error that can provide source of error by calling `Error::source`. + ExternalError(Box), +} + +impl FlightError { + pub fn protocol(message: impl Into) -> Self { + Self::ProtocolError(message.into()) + } + + /// Wraps an external error in an `ArrowError`. + pub fn from_external_error(error: Box) -> Self { + Self::ExternalError(error) + } +} + +impl std::fmt::Display for FlightError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // TODO better format / error + write!(f, "{self:?}") + } +} + +impl Error for FlightError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + if let Self::ExternalError(e) = self { + Some(e.as_ref()) + } else { + None + } + } +} + +impl From for FlightError { + fn from(status: tonic::Status) -> Self { + Self::Tonic(status) + } +} + +impl From for FlightError { + fn from(value: ArrowError) -> Self { + Self::Arrow(value) + } +} + +// default conversion from FlightError to tonic treats everything +// other than `Status` as an internal error +impl From for tonic::Status { + fn from(value: FlightError) -> Self { + match value { + FlightError::Arrow(e) => tonic::Status::internal(e.to_string()), + FlightError::NotYetImplemented(e) => tonic::Status::internal(e), + FlightError::Tonic(status) => status, + FlightError::ProtocolError(e) => tonic::Status::internal(e), + FlightError::DecodeError(e) => tonic::Status::internal(e), + FlightError::ExternalError(e) => tonic::Status::internal(e.to_string()), + } + } +} + +pub type Result = std::result::Result; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn error_source() { + let e1 = FlightError::DecodeError("foo".into()); + assert!(e1.source().is_none()); + + // one level of wrapping + let e2 = FlightError::ExternalError(Box::new(e1)); + let source = e2.source().unwrap().downcast_ref::().unwrap(); + assert!(matches!(source, FlightError::DecodeError(_))); + + let e3 = FlightError::ExternalError(Box::new(e2)); + let source = e3 + .source() + .unwrap() + .downcast_ref::() + .unwrap() + .source() + .unwrap() + .downcast_ref::() + .unwrap(); + + assert!(matches!(source, FlightError::DecodeError(_))); + } + + #[test] + fn error_through_arrow() { + // flight error that wraps an arrow error that wraps a flight error + let e1 = FlightError::DecodeError("foo".into()); + let e2 = ArrowError::ExternalError(Box::new(e1)); + let e3 = FlightError::ExternalError(Box::new(e2)); + + // ensure we can find the lowest level error by following source() + let mut root_error: &dyn Error = &e3; + while let Some(source) = root_error.source() { + // walk the next level + root_error = source; + } + + let source = root_error.downcast_ref::().unwrap(); + assert!(matches!(source, FlightError::DecodeError(_))); + } +} diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 3f4f09855353..8d05f658703a 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -15,40 +15,85 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::Schema; -use arrow::error::{ArrowError, Result as ArrowResult}; -use arrow::ipc::{ - convert, size_prefixed_root_as_message, writer, writer::EncodedData, - writer::IpcWriteOptions, -}; - +//! A native Rust implementation of [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) +//! for exchanging [Arrow](https://arrow.apache.org) data between processes. +//! +//! Please see the [arrow-flight crates.io](https://crates.io/crates/arrow-flight) +//! page for feature flags and more information. +//! +//! # Overview +//! +//! This crate contains: +//! +//! 1. Low level [prost] generated structs +//! for Flight gRPC protobuf messages, such as [`FlightData`], [`FlightInfo`], +//! [`Location`] and [`Ticket`]. +//! +//! 2. Low level [tonic] generated [`flight_service_client`] and +//! [`flight_service_server`]. +//! +//! 3. Experimental support for [Flight SQL] in [`sql`]. Requires the +//! `flight-sql-experimental` feature of this crate to be activated. +//! +//! [Flight SQL]: https://arrow.apache.org/docs/format/FlightSql.html +#![allow(rustdoc::invalid_html_tags)] + +use arrow_ipc::{convert, writer, writer::EncodedData, writer::IpcWriteOptions}; +use arrow_schema::{ArrowError, Schema}; + +use arrow_ipc::convert::try_schema_from_ipc_buffer; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::Bytes; use std::{ convert::{TryFrom, TryInto}, fmt, ops::Deref, }; +type ArrowResult = std::result::Result; + #[allow(clippy::derive_partial_eq_without_eq)] + mod gen { include!("arrow.flight.protocol.rs"); } +/// Defines a `Flight` for generation or retrieval. pub mod flight_descriptor { use super::gen; pub use gen::flight_descriptor::DescriptorType; } +/// Low Level [tonic] [`FlightServiceClient`](gen::flight_service_client::FlightServiceClient). pub mod flight_service_client { use super::gen; pub use gen::flight_service_client::FlightServiceClient; } +/// Low Level [tonic] [`FlightServiceServer`](gen::flight_service_server::FlightServiceServer) +/// and [`FlightService`](gen::flight_service_server::FlightService). pub mod flight_service_server { use super::gen; pub use gen::flight_service_server::FlightService; pub use gen::flight_service_server::FlightServiceServer; } +/// Mid Level [`FlightClient`] +pub mod client; +pub use client::FlightClient; + +/// Decoder to create [`RecordBatch`](arrow_array::RecordBatch) streams from [`FlightData`] streams. +/// See [`FlightRecordBatchStream`](decode::FlightRecordBatchStream). +pub mod decode; + +/// Encoder to create [`FlightData`] streams from [`RecordBatch`](arrow_array::RecordBatch) streams. +/// See [`FlightDataEncoderBuilder`](encode::FlightDataEncoderBuilder). +pub mod encode; + +/// Common error types +pub mod error; + pub use gen::Action; pub use gen::ActionType; pub use gen::BasicAuth; @@ -66,6 +111,9 @@ pub use gen::Result; pub use gen::SchemaResult; pub use gen::Ticket; +/// Helper to extract HTTP/gRPC trailers from a tonic stream. +mod trailers; + pub mod utils; #[cfg(feature = "flight-sql-experimental")] @@ -81,21 +129,18 @@ pub struct SchemaAsIpc<'a> { /// IpcMessage represents a `Schema` in the format expected in /// `FlightInfo.schema` #[derive(Debug)] -pub struct IpcMessage(pub Vec); +pub struct IpcMessage(pub Bytes); // Useful conversion functions -fn flight_schema_as_encoded_data( - arrow_schema: &Schema, - options: &IpcWriteOptions, -) -> EncodedData { +fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData { let data_gen = writer::IpcDataGenerator::default(); data_gen.schema_to_bytes(arrow_schema, options) } fn flight_schema_as_flatbuffer(schema: &Schema, options: &IpcWriteOptions) -> IpcMessage { let encoded_data = flight_schema_as_encoded_data(schema, options); - IpcMessage(encoded_data.ipc_message) + IpcMessage(encoded_data.ipc_message.into()) } // Implement a bunch of useful traits for various conversions, displays, @@ -104,7 +149,7 @@ fn flight_schema_as_flatbuffer(schema: &Schema, options: &IpcWriteOptions) -> Ip // Deref impl Deref for IpcMessage { - type Target = Vec; + type Target = [u8]; fn deref(&self) -> &Self::Target { &self.0 @@ -135,7 +180,7 @@ impl fmt::Display for FlightData { write!(f, "FlightData {{")?; write!(f, " descriptor: ")?; match &self.flight_descriptor { - Some(d) => write!(f, "{}", d)?, + Some(d) => write!(f, "{d}")?, None => write!(f, "None")?, }; write!(f, ", header: ")?; @@ -161,7 +206,7 @@ impl fmt::Display for FlightDescriptor { write!(f, "path: [")?; let mut sep = ""; for element in &self.path { - write!(f, "{}{}", sep, element)?; + write!(f, "{sep}{element}")?; sep = ", "; } write!(f, "]")?; @@ -179,13 +224,13 @@ impl fmt::Display for FlightEndpoint { write!(f, "FlightEndpoint {{")?; write!(f, " ticket: ")?; match &self.ticket { - Some(value) => write!(f, "{}", value), + Some(value) => write!(f, "{value}"), None => write!(f, " none"), }?; write!(f, ", location: [")?; let mut sep = ""; for location in &self.location { - write!(f, "{}{}", sep, location)?; + write!(f, "{sep}{location}")?; sep = ", "; } write!(f, "]")?; @@ -198,16 +243,16 @@ impl fmt::Display for FlightInfo { let ipc_message = IpcMessage(self.schema.clone()); let schema: Schema = ipc_message.try_into().map_err(|_err| fmt::Error)?; write!(f, "FlightInfo {{")?; - write!(f, " schema: {}", schema)?; + write!(f, " schema: {schema}")?; write!(f, ", descriptor:")?; match &self.flight_descriptor { - Some(d) => write!(f, " {}", d), + Some(d) => write!(f, " {d}"), None => write!(f, " None"), }?; write!(f, ", endpoint: [")?; let mut sep = ""; for endpoint in &self.endpoint { - write!(f, "{}{}", sep, endpoint)?; + write!(f, "{sep}{endpoint}")?; sep = ", "; } write!(f, "], total_records: {}", self.total_records)?; @@ -228,7 +273,7 @@ impl fmt::Display for Ticket { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Ticket {{")?; write!(f, " ticket: ")?; - write!(f, "{}", base64::encode(&self.ticket)) + write!(f, "{}", BASE64_STANDARD.encode(&self.ticket)) } } @@ -237,8 +282,8 @@ impl fmt::Display for Ticket { impl From for FlightData { fn from(data: EncodedData) -> Self { FlightData { - data_header: data.ipc_message, - data_body: data.arrow_data, + data_header: data.ipc_message.into(), + data_body: data.arrow_data.into(), ..Default::default() } } @@ -254,20 +299,17 @@ impl From> for FlightData { } } -impl From> for SchemaResult { - fn from(schema_ipc: SchemaAsIpc) -> Self { - let IpcMessage(vals) = flight_schema_as_flatbuffer(schema_ipc.0, schema_ipc.1); - SchemaResult { schema: vals } - } -} - -// TryFrom... - -impl TryFrom for DescriptorType { +impl TryFrom> for SchemaResult { type Error = ArrowError; - fn try_from(value: i32) -> ArrowResult { - value.try_into() + fn try_from(schema_ipc: SchemaAsIpc) -> ArrowResult { + // According to the definition from `Flight.proto` + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + let IpcMessage(vals) = schema_to_ipc_format(schema_ipc)?; + Ok(SchemaResult { schema: vals }) } } @@ -275,22 +317,25 @@ impl TryFrom> for IpcMessage { type Error = ArrowError; fn try_from(schema_ipc: SchemaAsIpc) -> ArrowResult { - let pair = *schema_ipc; - let encoded_data = flight_schema_as_encoded_data(pair.0, pair.1); - - let mut schema = vec![]; - arrow::ipc::writer::write_message(&mut schema, encoded_data, pair.1)?; - Ok(IpcMessage(schema)) + schema_to_ipc_format(schema_ipc) } } +fn schema_to_ipc_format(schema_ipc: SchemaAsIpc) -> ArrowResult { + let pair = *schema_ipc; + let encoded_data = flight_schema_as_encoded_data(pair.0, pair.1); + + let mut schema = vec![]; + writer::write_message(&mut schema, encoded_data, pair.1)?; + Ok(IpcMessage(schema.into())) +} + impl TryFrom<&FlightData> for Schema { type Error = ArrowError; fn try_from(data: &FlightData) -> ArrowResult { - convert::schema_from_bytes(&data.data_header[..]).map_err(|err| { + convert::try_schema_from_flatbuffer_bytes(&data.data_header[..]).map_err(|err| { ArrowError::ParseError(format!( - "Unable to convert flight data to Arrow schema: {}", - err + "Unable to convert flight data to Arrow schema: {err}" )) }) } @@ -300,8 +345,7 @@ impl TryFrom for Schema { type Error = ArrowError; fn try_from(value: FlightInfo) -> ArrowResult { - let msg = IpcMessage(value.schema); - msg.try_into() + value.try_decode_schema() } } @@ -309,63 +353,97 @@ impl TryFrom for Schema { type Error = ArrowError; fn try_from(value: IpcMessage) -> ArrowResult { - // CONTINUATION TAKES 4 BYTES - // SIZE TAKES 4 BYTES (so read msg as size prefixed) - let msg = size_prefixed_root_as_message(&value.0[4..]).map_err(|err| { - ArrowError::ParseError(format!( - "Unable to convert flight info to a message: {}", - err - )) - })?; - let ipc_schema = msg.header_as_schema().ok_or_else(|| { - ArrowError::ParseError( - "Unable to convert flight info to a schema".to_string(), - ) - })?; - Ok(convert::fb_to_schema(ipc_schema)) + try_schema_from_ipc_buffer(&value) } } impl TryFrom<&SchemaResult> for Schema { type Error = ArrowError; fn try_from(data: &SchemaResult) -> ArrowResult { - convert::schema_from_bytes(&data.schema[..]).map_err(|err| { - ArrowError::ParseError(format!( - "Unable to convert schema result to Arrow schema: {}", - err - )) - }) + try_schema_from_ipc_buffer(&data.schema) + } +} + +impl TryFrom for Schema { + type Error = ArrowError; + fn try_from(data: SchemaResult) -> ArrowResult { + (&data).try_into() } } // FlightData, FlightDescriptor, etc.. impl FlightData { - pub fn new( - flight_descriptor: Option, - message: IpcMessage, - app_metadata: Vec, - data_body: Vec, - ) -> Self { - let IpcMessage(vals) = message; - FlightData { - flight_descriptor, - data_header: vals, - app_metadata, - data_body, - } + /// Create a new [`FlightData`]. + /// + /// # See Also + /// + /// See [`FlightDataEncoderBuilder`] for a higher level API to + /// convert a stream of [`RecordBatch`]es to [`FlightData`]s + /// + /// # Example: + /// + /// ``` + /// # use bytes::Bytes; + /// # use arrow_flight::{FlightData, FlightDescriptor}; + /// # fn encode_data() -> Bytes { Bytes::new() } // dummy data + /// // Get encoded Arrow IPC data: + /// let data_body: Bytes = encode_data(); + /// // Create the FlightData message + /// let flight_data = FlightData::new() + /// .with_descriptor(FlightDescriptor::new_cmd("the command")) + /// .with_app_metadata("My apps metadata") + /// .with_data_body(data_body); + /// ``` + /// + /// [`FlightDataEncoderBuilder`]: crate::encode::FlightDataEncoderBuilder + /// [`RecordBatch`]: arrow_array::RecordBatch + pub fn new() -> Self { + Default::default() + } + + /// Add a [`FlightDescriptor`] describing the data + pub fn with_descriptor(mut self, flight_descriptor: FlightDescriptor) -> Self { + self.flight_descriptor = Some(flight_descriptor); + self + } + + /// Add a data header + pub fn with_data_header(mut self, data_header: impl Into) -> Self { + self.data_header = data_header.into(); + self + } + + /// Add a data body. See [`IpcDataGenerator`] to create this data. + /// + /// [`IpcDataGenerator`]: arrow_ipc::writer::IpcDataGenerator + pub fn with_data_body(mut self, data_body: impl Into) -> Self { + self.data_body = data_body.into(); + self + } + + /// Add optional application specific metadata to the message + pub fn with_app_metadata(mut self, app_metadata: impl Into) -> Self { + self.app_metadata = app_metadata.into(); + self } } impl FlightDescriptor { - pub fn new_cmd(cmd: Vec) -> Self { + /// Create a new opaque command [`CMD`] `FlightDescriptor` to generate a dataset. + /// + /// [`CMD`]: https://github.com/apache/arrow/blob/6bd31f37ae66bd35594b077cb2f830be57e08acd/format/Flight.proto#L224-L227 + pub fn new_cmd(cmd: impl Into) -> Self { FlightDescriptor { r#type: DescriptorType::Cmd.into(), - cmd, + cmd: cmd.into(), ..Default::default() } } + /// Create a new named path [`PATH`] `FlightDescriptor` that identifies a dataset + /// + /// [`PATH`]: https://github.com/apache/arrow/blob/6bd31f37ae66bd35594b077cb2f830be57e08acd/format/Flight.proto#L217-L222 pub fn new_path(path: Vec) -> Self { FlightDescriptor { r#type: DescriptorType::Path.into(), @@ -376,22 +454,98 @@ impl FlightDescriptor { } impl FlightInfo { - pub fn new( - message: IpcMessage, - flight_descriptor: Option, - endpoint: Vec, - total_records: i64, - total_bytes: i64, - ) -> Self { - let IpcMessage(vals) = message; + /// Create a new, empty `FlightInfo`, describing where to fetch flight data + /// + /// + /// # Example: + /// ``` + /// # use arrow_flight::{FlightInfo, Ticket, FlightDescriptor, FlightEndpoint}; + /// # use arrow_schema::{Schema, Field, DataType}; + /// # fn get_schema() -> Schema { + /// # Schema::new(vec![ + /// # Field::new("a", DataType::Utf8, false), + /// # ]) + /// # } + /// # + /// // Create a new FlightInfo + /// let flight_info = FlightInfo::new() + /// // Encode the Arrow schema + /// .try_with_schema(&get_schema()) + /// .expect("encoding failed") + /// .with_descriptor( + /// FlightDescriptor::new_cmd("a command") + /// ) + /// .with_endpoint( + /// FlightEndpoint::new() + /// .with_ticket(Ticket::new("ticket contents") + /// ) + /// ) + /// .with_descriptor(FlightDescriptor::new_cmd("RUN QUERY")); + /// ``` + pub fn new() -> FlightInfo { FlightInfo { - schema: vals, - flight_descriptor, - endpoint, - total_records, - total_bytes, + schema: Bytes::new(), + flight_descriptor: None, + endpoint: vec![], + ordered: false, + // Flight says "Set these to -1 if unknown." + // + // https://github.com/apache/arrow-rs/blob/17ca4d51d0490f9c65f5adde144f677dbc8300e7/format/Flight.proto#L287-L289 + total_records: -1, + total_bytes: -1, } } + + /// Try and convert the data in this `FlightInfo` into a [`Schema`] + pub fn try_decode_schema(self) -> ArrowResult { + let msg = IpcMessage(self.schema); + msg.try_into() + } + + /// Specify the schema for the response. + /// + /// Note this takes the arrow [`Schema`] (not the IPC schema) and + /// encodes it using the default IPC options. + /// + /// Returns an error if `schema` can not be encoded into IPC form. + pub fn try_with_schema(mut self, schema: &Schema) -> ArrowResult { + let options = IpcWriteOptions::default(); + let IpcMessage(schema) = SchemaAsIpc::new(schema, &options).try_into()?; + self.schema = schema; + Ok(self) + } + + /// Add specific a endpoint for fetching the data + pub fn with_endpoint(mut self, endpoint: FlightEndpoint) -> Self { + self.endpoint.push(endpoint); + self + } + + /// Add a [`FlightDescriptor`] describing what this data is + pub fn with_descriptor(mut self, flight_descriptor: FlightDescriptor) -> Self { + self.flight_descriptor = Some(flight_descriptor); + self + } + + /// Set the number of records in the result, if known + pub fn with_total_records(mut self, total_records: i64) -> Self { + self.total_records = total_records; + self + } + + /// Set the number of bytes in the result, if known + pub fn with_total_bytes(mut self, total_bytes: i64) -> Self { + self.total_bytes = total_bytes; + self + } + + /// Specify if the response is [ordered] across endpoints + /// + /// [ordered]: https://github.com/apache/arrow-rs/blob/17ca4d51d0490f9c65f5adde144f677dbc8300e7/format/Flight.proto#L269-L275 + pub fn with_ordered(mut self, ordered: bool) -> Self { + self.ordered = ordered; + self + } } impl<'a> SchemaAsIpc<'a> { @@ -402,9 +556,90 @@ impl<'a> SchemaAsIpc<'a> { } } +impl Action { + /// Create a new Action with type and body + pub fn new(action_type: impl Into, body: impl Into) -> Self { + Self { + r#type: action_type.into(), + body: body.into(), + } + } +} + +impl Result { + /// Create a new Result with the specified body + pub fn new(body: impl Into) -> Self { + Self { body: body.into() } + } +} + +impl Ticket { + /// Create a new `Ticket` + /// + /// # Example + /// + /// ``` + /// # use arrow_flight::Ticket; + /// let ticket = Ticket::new("SELECT * from FOO"); + /// ``` + pub fn new(ticket: impl Into) -> Self { + Self { + ticket: ticket.into(), + } + } +} + +impl FlightEndpoint { + /// Create a new, empty `FlightEndpoint` that represents a location + /// to retrieve Flight results. + /// + /// # Example + /// ``` + /// # use arrow_flight::{FlightEndpoint, Ticket}; + /// # + /// // Specify the client should fetch results from this server + /// let endpoint = FlightEndpoint::new() + /// .with_ticket(Ticket::new("the ticket")); + /// + /// // Specify the client should fetch results from either + /// // `http://example.com` or `https://example.com` + /// let endpoint = FlightEndpoint::new() + /// .with_ticket(Ticket::new("the ticket")) + /// .with_location("http://example.com") + /// .with_location("https://example.com"); + /// ``` + pub fn new() -> FlightEndpoint { + Default::default() + } + + /// Set the [`Ticket`] used to retrieve data from the endpoint + pub fn with_ticket(mut self, ticket: Ticket) -> Self { + self.ticket = Some(ticket); + self + } + + /// Add a location `uri` to this endpoint. Note each endpoint can + /// have multiple locations. + /// + /// If no `uri` is specified, the [Flight Spec] says: + /// + /// ```text + /// * If the list is empty, the expectation is that the ticket can only + /// * be redeemed on the current service where the ticket was + /// * generated. + /// ``` + /// [Flight Spec]: https://github.com/apache/arrow-rs/blob/17ca4d51d0490f9c65f5adde144f677dbc8300e7/format/Flight.proto#L307C2-L312 + pub fn with_location(mut self, uri: impl Into) -> Self { + self.location.push(Location { uri: uri.into() }); + self + } +} + #[cfg(test)] mod tests { use super::*; + use arrow_ipc::MetadataVersion; + use arrow_schema::{DataType, Field, TimeUnit}; struct TestVector(Vec, usize); @@ -426,7 +661,7 @@ mod tests { fn it_accepts_equal_output() { let input = TestVector(vec![91; 10], 10); - let actual = format!("{}", input); + let actual = format!("{input}"); let expected = format!("{:?}", vec![91; 10]); assert_eq!(actual, expected); } @@ -435,7 +670,7 @@ mod tests { fn it_accepts_short_output() { let input = TestVector(vec![91; 6], 10); - let actual = format!("{}", input); + let actual = format!("{input}"); let expected = format!("{:?}", vec![91; 6]); assert_eq!(actual, expected); } @@ -444,8 +679,35 @@ mod tests { fn it_accepts_long_output() { let input = TestVector(vec![91; 10], 9); - let actual = format!("{}", input); + let actual = format!("{input}"); let expected = format!("{:?}", vec![91; 9]); assert_eq!(actual, expected); } + + #[test] + fn ser_deser_schema_result() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("c6", DataType::Time32(TimeUnit::Second), false), + ]); + // V5 with write_legacy_ipc_format = false + // this will write the continuation marker + let option = IpcWriteOptions::default(); + let schema_ipc = SchemaAsIpc::new(&schema, &option); + let result: SchemaResult = schema_ipc.try_into().unwrap(); + let des_schema: Schema = (&result).try_into().unwrap(); + assert_eq!(schema, des_schema); + + // V4 with write_legacy_ipc_format = true + // this will not write the continuation marker + let option = IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap(); + let schema_ipc = SchemaAsIpc::new(&schema, &option); + let result: SchemaResult = schema_ipc.try_into().unwrap(); + let des_schema: Schema = (&result).try_into().unwrap(); + assert_eq!(schema, des_schema); + } } diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs index 77221dd1a489..c7c23311e61e 100644 --- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -1,13 +1,13 @@ // This file was automatically generated through the build.rs script, and should not be edited. /// -/// Represents a metadata request. Used in the command member of FlightDescriptor -/// for the following RPC calls: +/// Represents a metadata request. Used in the command member of FlightDescriptor +/// for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// info_name: uint32 not null, /// value: dense_union< /// string_value: utf8, @@ -16,185 +16,267 @@ /// int32_bitmask: int32, /// string_list: list /// int32_to_int32_list_map: map> -/// > -/// where there is one row per requested piece of metadata information. +/// > +/// where there is one row per requested piece of metadata information. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetSqlInfo { /// - /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide - /// Flight SQL clients with basic, SQL syntax and SQL functions related information. - /// More information types can be added in future releases. - /// E.g. more SQL syntax support types, scalar functions support, type conversion support etc. + /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide + /// Flight SQL clients with basic, SQL syntax and SQL functions related information. + /// More information types can be added in future releases. + /// E.g. more SQL syntax support types, scalar functions support, type conversion support etc. /// - /// Note that the set of metadata may expand. + /// Note that the set of metadata may expand. /// - /// Initially, Flight SQL will support the following information types: - /// - Server Information - Range [0-500) - /// - Syntax Information - Range [500-1000) - /// Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). - /// Custom options should start at 10,000. + /// Initially, Flight SQL will support the following information types: + /// - Server Information - Range [0-500) + /// - Syntax Information - Range [500-1000) + /// Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). + /// Custom options should start at 10,000. /// - /// If omitted, then all metadata will be retrieved. - /// Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must - /// at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. - /// If additional metadata is included, the metadata IDs should start from 10,000. - #[prost(uint32, repeated, tag="1")] + /// If omitted, then all metadata will be retrieved. + /// Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must + /// at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. + /// If additional metadata is included, the metadata IDs should start from 10,000. + #[prost(uint32, repeated, tag = "1")] pub info: ::prost::alloc::vec::Vec, } /// -/// Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. -/// The definition of a catalog depends on vendor/implementation. It is usually the database itself -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve information about data type supported on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned schema will be: +/// < +/// type_name: utf8 not null (The name of the data type, for example: VARCHAR, INTEGER, etc), +/// data_type: int32 not null (The SQL data type), +/// column_size: int32 (The maximum size supported by that column. +/// In case of exact numeric types, this represents the maximum precision. +/// In case of string types, this represents the character length. +/// In case of datetime data types, this represents the length in characters of the string representation. +/// NULL is returned for data types where column size is not applicable.), +/// literal_prefix: utf8 (Character or characters used to prefix a literal, NULL is returned for +/// data types where a literal prefix is not applicable.), +/// literal_suffix: utf8 (Character or characters used to terminate a literal, +/// NULL is returned for data types where a literal suffix is not applicable.), +/// create_params: list +/// (A list of keywords corresponding to which parameters can be used when creating +/// a column for that specific type. +/// NULL is returned if there are no parameters for the data type definition.), +/// nullable: int32 not null (Shows if the data type accepts a NULL value. The possible values can be seen in the +/// Nullable enum.), +/// case_sensitive: bool not null (Shows if a character data type is case-sensitive in collations and comparisons), +/// searchable: int32 not null (Shows how the data type is used in a WHERE clause. The possible values can be seen in the +/// Searchable enum.), +/// unsigned_attribute: bool (Shows if the data type is unsigned. NULL is returned if the attribute is +/// not applicable to the data type or the data type is not numeric.), +/// fixed_prec_scale: bool not null (Shows if the data type has predefined fixed precision and scale.), +/// auto_increment: bool (Shows if the data type is auto incremental. NULL is returned if the attribute +/// is not applicable to the data type or the data type is not numeric.), +/// local_type_name: utf8 (Localized version of the data source-dependent name of the data type. NULL +/// is returned if a localized name is not supported by the data source), +/// minimum_scale: int32 (The minimum scale of the data type on the data source. +/// If a data type has a fixed scale, the MINIMUM_SCALE and MAXIMUM_SCALE +/// columns both contain this value. NULL is returned if scale is not applicable.), +/// maximum_scale: int32 (The maximum scale of the data type on the data source. +/// NULL is returned if scale is not applicable.), +/// sql_data_type: int32 not null (The value of the SQL DATA TYPE which has the same values +/// as data_type value. Except for interval and datetime, which +/// uses generic values. More info about those types can be +/// obtained through datetime_subcode. The possible values can be seen +/// in the XdbcDataType enum.), +/// datetime_subcode: int32 (Only used when the SQL DATA TYPE is interval or datetime. It contains +/// its sub types. For type different from interval and datetime, this value +/// is NULL. The possible values can be seen in the XdbcDatetimeSubcode enum.), +/// num_prec_radix: int32 (If the data type is an approximate numeric type, this column contains +/// the value 2 to indicate that COLUMN_SIZE specifies a number of bits. For +/// exact numeric types, this column contains the value 10 to indicate that +/// column size specifies a number of decimal digits. Otherwise, this column is NULL.), +/// interval_precision: int32 (If the data type is an interval data type, then this column contains the value +/// of the interval leading precision. Otherwise, this column is NULL. This fields +/// is only relevant to be used by ODBC). +/// > +/// The returned data should be ordered by data_type and then by type_name. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetXdbcTypeInfo { + /// + /// Specifies the data type to search for the info. + #[prost(int32, optional, tag = "1")] + pub data_type: ::core::option::Option, +} +/// +/// Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. +/// The definition of a catalog depends on vendor/implementation. It is usually the database itself +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8 not null -/// > -/// The returned data should be ordered by catalog_name. +/// > +/// The returned data should be ordered by catalog_name. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct CommandGetCatalogs { -} +pub struct CommandGetCatalogs {} /// -/// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. -/// The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. +/// The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8, /// db_schema_name: utf8 not null -/// > -/// The returned data should be ordered by catalog_name, then db_schema_name. +/// > +/// The returned data should be ordered by catalog_name, then db_schema_name. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetDbSchemas { /// - /// Specifies the Catalog to search for the tables. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for schemas to search for. - /// When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. - /// In the pattern string, two special characters can be used to denote matching rules: + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: /// - "%" means to match any substring with 0 or more characters. /// - "_" means to match any one character. - #[prost(string, optional, tag="2")] + #[prost(string, optional, tag = "2")] pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, } /// -/// Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8, /// db_schema_name: utf8, /// table_name: utf8 not null, /// table_type: utf8 not null, /// \[optional\] table_schema: bytes not null (schema of the table as described in Schema.fbs::Schema, /// it is serialized as an IPC message.) -/// > -/// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. +/// > +/// Fields on table_schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetTables { /// - /// Specifies the Catalog to search for the tables. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for schemas to search for. - /// When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. - /// In the pattern string, two special characters can be used to denote matching rules: + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: /// - "%" means to match any substring with 0 or more characters. /// - "_" means to match any one character. - #[prost(string, optional, tag="2")] + #[prost(string, optional, tag = "2")] pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for tables to search for. - /// When no table_name_filter_pattern is provided, all tables matching other filters are searched. - /// In the pattern string, two special characters can be used to denote matching rules: + /// Specifies a filter pattern for tables to search for. + /// When no table_name_filter_pattern is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: /// - "%" means to match any substring with 0 or more characters. /// - "_" means to match any one character. - #[prost(string, optional, tag="3")] - pub table_name_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, - /// - /// Specifies a filter of table types which must match. - /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. - /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. - #[prost(string, repeated, tag="4")] + #[prost(string, optional, tag = "3")] + pub table_name_filter_pattern: ::core::option::Option< + ::prost::alloc::string::String, + >, + /// + /// Specifies a filter of table types which must match. + /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. + #[prost(string, repeated, tag = "4")] pub table_types: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// Specifies if the Arrow schema should be returned for found tables. - #[prost(bool, tag="5")] + /// Specifies if the Arrow schema should be returned for found tables. + #[prost(bool, tag = "5")] pub include_schema: bool, } /// -/// Represents a request to retrieve the list of table types on a Flight SQL enabled backend. -/// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. -/// TABLE, VIEW, and SYSTEM TABLE are commonly supported. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the list of table types on a Flight SQL enabled backend. +/// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. +/// TABLE, VIEW, and SYSTEM TABLE are commonly supported. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// table_type: utf8 not null -/// > -/// The returned data should be ordered by table_type. +/// > +/// The returned data should be ordered by table_type. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct CommandGetTableTypes { -} +pub struct CommandGetTableTypes {} /// -/// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// catalog_name: utf8, /// db_schema_name: utf8, /// table_name: utf8 not null, /// column_name: utf8 not null, /// key_name: utf8, -/// key_sequence: int not null -/// > -/// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. +/// key_sequence: int32 not null +/// > +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetPrimaryKeys { /// - /// Specifies the catalog to search for the table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the catalog to search for the table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// Specifies the schema to search for the table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the table to get the primary keys for. - #[prost(string, tag="3")] + /// Specifies the table to get the primary keys for. + #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve a description of the foreign key columns that reference the given table's -/// primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve a description of the foreign key columns that reference the given table's +/// primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// pk_catalog_name: utf8, /// pk_db_schema_name: utf8, /// pk_table_name: utf8 not null, @@ -203,40 +285,41 @@ pub struct CommandGetPrimaryKeys { /// fk_db_schema_name: utf8, /// fk_table_name: utf8 not null, /// fk_column_name: utf8 not null, -/// key_sequence: int not null, +/// key_sequence: int32 not null, /// fk_key_name: utf8, /// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. +/// update_rule: uint8 not null, +/// delete_rule: uint8 not null +/// > +/// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetExportedKeys { /// - /// Specifies the catalog to search for the foreign key table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the catalog to search for the foreign key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the foreign key table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// Specifies the schema to search for the foreign key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the foreign key table to get the foreign keys for. - #[prost(string, tag="3")] + /// Specifies the foreign key table to get the foreign keys for. + #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// pk_catalog_name: utf8, /// pk_db_schema_name: utf8, /// pk_table_name: utf8 not null, @@ -245,47 +328,48 @@ pub struct CommandGetExportedKeys { /// fk_db_schema_name: utf8, /// fk_table_name: utf8 not null, /// fk_column_name: utf8 not null, -/// key_sequence: int not null, +/// key_sequence: int32 not null, /// fk_key_name: utf8, /// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// update_rule: uint8 not null, +/// delete_rule: uint8 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: /// - 0 = CASCADE /// - 1 = RESTRICT /// - 2 = SET NULL /// - 3 = NO ACTION /// - 4 = SET DEFAULT +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetImportedKeys { /// - /// Specifies the catalog to search for the primary key table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// Specifies the catalog to search for the primary key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the primary key table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// Specifies the schema to search for the primary key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the primary key table to get the foreign keys for. - #[prost(string, tag="3")] + /// Specifies the primary key table to get the foreign keys for. + #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve a description of the foreign key columns in the given foreign key table that -/// reference the primary key or the columns representing a unique constraint of the parent table (could be the same -/// or a different table) on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: +/// Represents a request to retrieve a description of the foreign key columns in the given foreign key table that +/// reference the primary key or the columns representing a unique constraint of the parent table (could be the same +/// or a different table) on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. /// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < +/// The returned Arrow schema will be: +/// < /// pk_catalog_name: utf8, /// pk_db_schema_name: utf8, /// pk_table_name: utf8 not null, @@ -294,713 +378,1115 @@ pub struct CommandGetImportedKeys { /// fk_db_schema_name: utf8, /// fk_table_name: utf8 not null, /// fk_column_name: utf8 not null, -/// key_sequence: int not null, +/// key_sequence: int32 not null, /// fk_key_name: utf8, /// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// update_rule: uint8 not null, +/// delete_rule: uint8 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: /// - 0 = CASCADE /// - 1 = RESTRICT /// - 2 = SET NULL /// - 3 = NO ACTION /// - 4 = SET DEFAULT +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetCrossReference { /// * - /// The catalog name where the parent table is. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="1")] + /// The catalog name where the parent table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "1")] pub pk_catalog: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The Schema name where the parent table is. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="2")] + /// The Schema name where the parent table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "2")] pub pk_db_schema: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The parent table name. It cannot be null. - #[prost(string, tag="3")] + /// The parent table name. It cannot be null. + #[prost(string, tag = "3")] pub pk_table: ::prost::alloc::string::String, /// * - /// The catalog name where the foreign table is. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. - #[prost(string, optional, tag="4")] + /// The catalog name where the foreign table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag = "4")] pub fk_catalog: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The schema name where the foreign table is. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. - #[prost(string, optional, tag="5")] + /// The schema name where the foreign table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag = "5")] pub fk_db_schema: ::core::option::Option<::prost::alloc::string::String>, /// * - /// The foreign table name. It cannot be null. - #[prost(string, tag="6")] + /// The foreign table name. It cannot be null. + #[prost(string, tag = "6")] pub fk_table: ::prost::alloc::string::String, } -// SQL Execution Action Messages - /// -/// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. +/// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionCreatePreparedStatementRequest { - /// The valid SQL string to create a prepared statement for. - #[prost(string, tag="1")] + /// The valid SQL string to create a prepared statement for. + #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String, + /// Create/execute the prepared statement as part of this transaction (if + /// unset, executions of the prepared statement will be auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, +} +/// +/// An embedded message describing a Substrait plan to execute. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SubstraitPlan { + /// The serialized substrait.Plan to create a prepared statement for. + /// XXX(ARROW-16902): this is bytes instead of an embedded message + /// because Protobuf does not really support one DLL using Protobuf + /// definitions from another DLL. + #[prost(bytes = "bytes", tag = "1")] + pub plan: ::prost::bytes::Bytes, + /// The Substrait release, e.g. "0.12.0". This information is not + /// tracked in the plan itself, so this is the only way for consumers + /// to potentially know if they can handle the plan. + #[prost(string, tag = "2")] + pub version: ::prost::alloc::string::String, } /// -/// Wrap the result of a "GetPreparedStatement" action. +/// Request message for the "CreatePreparedSubstraitPlan" action on a Flight SQL enabled backend. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionCreatePreparedSubstraitPlanRequest { + /// The serialized substrait.Plan to create a prepared statement for. + #[prost(message, optional, tag = "1")] + pub plan: ::core::option::Option, + /// Create/execute the prepared statement as part of this transaction (if + /// unset, executions of the prepared statement will be auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, +} +/// +/// Wrap the result of a "CreatePreparedStatement" or "CreatePreparedSubstraitPlan" action. +/// +/// The resultant PreparedStatement can be closed either: +/// - Manually, through the "ClosePreparedStatement" action; +/// - Automatically, by a server timeout. /// -/// The resultant PreparedStatement can be closed either: -/// - Manually, through the "ClosePreparedStatement" action; -/// - Automatically, by a server timeout. +/// The result should be wrapped in a google.protobuf.Any message. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionCreatePreparedStatementResult { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, - /// If a result set generating query was provided, dataset_schema contains the - /// schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. - #[prost(bytes="vec", tag="2")] - pub dataset_schema: ::prost::alloc::vec::Vec, - /// If the query provided contained parameters, parameter_schema contains the - /// schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. - #[prost(bytes="vec", tag="3")] - pub parameter_schema: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, + /// If a result set generating query was provided, dataset_schema contains the + /// schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. + #[prost(bytes = "bytes", tag = "2")] + pub dataset_schema: ::prost::bytes::Bytes, + /// If the query provided contained parameters, parameter_schema contains the + /// schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. + #[prost(bytes = "bytes", tag = "3")] + pub parameter_schema: ::prost::bytes::Bytes, } /// -/// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. -/// Closes server resources associated with the prepared statement handle. +/// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. +/// Closes server resources associated with the prepared statement handle. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionClosePreparedStatementRequest { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, } -// SQL Execution Messages. - /// -/// Represents a SQL query. Used in the command member of FlightDescriptor -/// for the following RPC calls: +/// Request message for the "BeginTransaction" action. +/// Begins a transaction. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionBeginTransactionRequest {} +/// +/// Request message for the "BeginSavepoint" action. +/// Creates a savepoint within a transaction. +/// +/// Only supported if FLIGHT_SQL_TRANSACTION is +/// FLIGHT_SQL_TRANSACTION_SUPPORT_SAVEPOINT. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionBeginSavepointRequest { + /// The transaction to which a savepoint belongs. + #[prost(bytes = "bytes", tag = "1")] + pub transaction_id: ::prost::bytes::Bytes, + /// Name for the savepoint. + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, +} +/// +/// The result of a "BeginTransaction" action. +/// +/// The transaction can be manipulated with the "EndTransaction" action, or +/// automatically via server timeout. If the transaction times out, then it is +/// automatically rolled back. +/// +/// The result should be wrapped in a google.protobuf.Any message. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionBeginTransactionResult { + /// Opaque handle for the transaction on the server. + #[prost(bytes = "bytes", tag = "1")] + pub transaction_id: ::prost::bytes::Bytes, +} +/// +/// The result of a "BeginSavepoint" action. +/// +/// The transaction can be manipulated with the "EndSavepoint" action. +/// If the associated transaction is committed, rolled back, or times +/// out, then the savepoint is also invalidated. +/// +/// The result should be wrapped in a google.protobuf.Any message. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionBeginSavepointResult { + /// Opaque handle for the savepoint on the server. + #[prost(bytes = "bytes", tag = "1")] + pub savepoint_id: ::prost::bytes::Bytes, +} +/// +/// Request message for the "EndTransaction" action. +/// +/// Commit (COMMIT) or rollback (ROLLBACK) the transaction. +/// +/// If the action completes successfully, the transaction handle is +/// invalidated, as are all associated savepoints. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionEndTransactionRequest { + /// Opaque handle for the transaction on the server. + #[prost(bytes = "bytes", tag = "1")] + pub transaction_id: ::prost::bytes::Bytes, + /// Whether to commit/rollback the given transaction. + #[prost(enumeration = "action_end_transaction_request::EndTransaction", tag = "2")] + pub action: i32, +} +/// Nested message and enum types in `ActionEndTransactionRequest`. +pub mod action_end_transaction_request { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum EndTransaction { + Unspecified = 0, + /// Commit the transaction. + Commit = 1, + /// Roll back the transaction. + Rollback = 2, + } + impl EndTransaction { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + EndTransaction::Unspecified => "END_TRANSACTION_UNSPECIFIED", + EndTransaction::Commit => "END_TRANSACTION_COMMIT", + EndTransaction::Rollback => "END_TRANSACTION_ROLLBACK", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "END_TRANSACTION_UNSPECIFIED" => Some(Self::Unspecified), + "END_TRANSACTION_COMMIT" => Some(Self::Commit), + "END_TRANSACTION_ROLLBACK" => Some(Self::Rollback), + _ => None, + } + } + } +} +/// +/// Request message for the "EndSavepoint" action. +/// +/// Release (RELEASE) the savepoint or rollback (ROLLBACK) to the +/// savepoint. +/// +/// Releasing a savepoint invalidates that savepoint. Rolling back to +/// a savepoint does not invalidate the savepoint, but invalidates all +/// savepoints created after the current savepoint. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionEndSavepointRequest { + /// Opaque handle for the savepoint on the server. + #[prost(bytes = "bytes", tag = "1")] + pub savepoint_id: ::prost::bytes::Bytes, + /// Whether to rollback/release the given savepoint. + #[prost(enumeration = "action_end_savepoint_request::EndSavepoint", tag = "2")] + pub action: i32, +} +/// Nested message and enum types in `ActionEndSavepointRequest`. +pub mod action_end_savepoint_request { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum EndSavepoint { + Unspecified = 0, + /// Release the savepoint. + Release = 1, + /// Roll back to a savepoint. + Rollback = 2, + } + impl EndSavepoint { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + EndSavepoint::Unspecified => "END_SAVEPOINT_UNSPECIFIED", + EndSavepoint::Release => "END_SAVEPOINT_RELEASE", + EndSavepoint::Rollback => "END_SAVEPOINT_ROLLBACK", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "END_SAVEPOINT_UNSPECIFIED" => Some(Self::Unspecified), + "END_SAVEPOINT_RELEASE" => Some(Self::Release), + "END_SAVEPOINT_ROLLBACK" => Some(Self::Rollback), + _ => None, + } + } + } +} +/// +/// Represents a SQL query. Used in the command member of FlightDescriptor +/// for the following RPC calls: /// - GetSchema: return the Arrow schema of the query. +/// Fields on this schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. /// - GetFlightInfo: execute the query. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandStatementQuery { - /// The SQL syntax. - #[prost(string, tag="1")] + /// The SQL syntax. + #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String, + /// Include the query as part of this transaction (if unset, the query is auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, +} +/// +/// Represents a Substrait plan. Used in the command member of FlightDescriptor +/// for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// Fields on this schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. +/// - GetFlightInfo: execute the query. +/// - DoPut: execute the query. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandStatementSubstraitPlan { + /// A serialized substrait.Plan + #[prost(message, optional, tag = "1")] + pub plan: ::core::option::Option, + /// Include the query as part of this transaction (if unset, the query is auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, } /// * -/// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. -/// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. +/// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. +/// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct TicketStatementQuery { - /// Unique identifier for the instance of the statement to execute. - #[prost(bytes="vec", tag="1")] - pub statement_handle: ::prost::alloc::vec::Vec, + /// Unique identifier for the instance of the statement to execute. + #[prost(bytes = "bytes", tag = "1")] + pub statement_handle: ::prost::bytes::Bytes, } /// -/// Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for -/// the following RPC calls: +/// Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for +/// the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// Fields on this schema may contain the following metadata: +/// - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name +/// - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name +/// - ARROW:FLIGHT:SQL:TABLE_NAME - Table name +/// - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. +/// - ARROW:FLIGHT:SQL:PRECISION - Column precision/size +/// - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable +/// - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case sensitive, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. +/// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. /// - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. /// - GetFlightInfo: execute the prepared statement instance. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandPreparedStatementQuery { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, } /// -/// Represents a SQL update query. Used in the command member of FlightDescriptor -/// for the the RPC call DoPut to cause the server to execute the included SQL update. +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the the RPC call DoPut to cause the server to execute the included SQL update. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandStatementUpdate { - /// The SQL syntax. - #[prost(string, tag="1")] + /// The SQL syntax. + #[prost(string, tag = "1")] pub query: ::prost::alloc::string::String, + /// Include the query as part of this transaction (if unset, the query is auto-committed). + #[prost(bytes = "bytes", optional, tag = "2")] + pub transaction_id: ::core::option::Option<::prost::bytes::Bytes>, } /// -/// Represents a SQL update query. Used in the command member of FlightDescriptor -/// for the the RPC call DoPut to cause the server to execute the included -/// prepared statement handle as an update. +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the the RPC call DoPut to cause the server to execute the included +/// prepared statement handle as an update. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandPreparedStatementUpdate { - /// Opaque handle for the prepared statement on the server. - #[prost(bytes="vec", tag="1")] - pub prepared_statement_handle: ::prost::alloc::vec::Vec, + /// Opaque handle for the prepared statement on the server. + #[prost(bytes = "bytes", tag = "1")] + pub prepared_statement_handle: ::prost::bytes::Bytes, } /// -/// Returned from the RPC call DoPut when a CommandStatementUpdate -/// CommandPreparedStatementUpdate was in the request, containing -/// results from the update. +/// Returned from the RPC call DoPut when a CommandStatementUpdate +/// CommandPreparedStatementUpdate was in the request, containing +/// results from the update. +#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoPutUpdateResult { - /// The number of records updated. A return value of -1 represents - /// an unknown updated record count. - #[prost(int64, tag="1")] + /// The number of records updated. A return value of -1 represents + /// an unknown updated record count. + #[prost(int64, tag = "1")] pub record_count: i64, } -/// Options for CommandGetSqlInfo. +/// +/// Request message for the "CancelQuery" action. +/// +/// Explicitly cancel a running query. +/// +/// This lets a single client explicitly cancel work, no matter how many clients +/// are involved/whether the query is distributed or not, given server support. +/// The transaction/statement is not rolled back; it is the application's job to +/// commit or rollback as appropriate. This only indicates the client no longer +/// wishes to read the remainder of the query results or continue submitting +/// data. +/// +/// This command is idempotent. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionCancelQueryRequest { + /// The result of the GetFlightInfo RPC that initiated the query. + /// XXX(ARROW-16902): this must be a serialized FlightInfo, but is + /// rendered as bytes because Protobuf does not really support one + /// DLL using Protobuf definitions from another DLL. + #[prost(bytes = "bytes", tag = "1")] + pub info: ::prost::bytes::Bytes, +} +/// +/// The result of cancelling a query. +/// +/// The result should be wrapped in a google.protobuf.Any message. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionCancelQueryResult { + #[prost(enumeration = "action_cancel_query_result::CancelResult", tag = "1")] + pub result: i32, +} +/// Nested message and enum types in `ActionCancelQueryResult`. +pub mod action_cancel_query_result { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum CancelResult { + /// The cancellation status is unknown. Servers should avoid using + /// this value (send a NOT_FOUND error if the requested query is + /// not known). Clients can retry the request. + Unspecified = 0, + /// The cancellation request is complete. Subsequent requests with + /// the same payload may return CANCELLED or a NOT_FOUND error. + Cancelled = 1, + /// The cancellation request is in progress. The client may retry + /// the cancellation request. + Cancelling = 2, + /// The query is not cancellable. The client should not retry the + /// cancellation request. + NotCancellable = 3, + } + impl CancelResult { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + CancelResult::Unspecified => "CANCEL_RESULT_UNSPECIFIED", + CancelResult::Cancelled => "CANCEL_RESULT_CANCELLED", + CancelResult::Cancelling => "CANCEL_RESULT_CANCELLING", + CancelResult::NotCancellable => "CANCEL_RESULT_NOT_CANCELLABLE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "CANCEL_RESULT_UNSPECIFIED" => Some(Self::Unspecified), + "CANCEL_RESULT_CANCELLED" => Some(Self::Cancelled), + "CANCEL_RESULT_CANCELLING" => Some(Self::Cancelling), + "CANCEL_RESULT_NOT_CANCELLABLE" => Some(Self::NotCancellable), + _ => None, + } + } + } +} +/// Options for CommandGetSqlInfo. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlInfo { - // Server Information [0-500): Provides basic information about the Flight SQL Server. - - /// Retrieves a UTF-8 string with the name of the Flight SQL Server. + /// Retrieves a UTF-8 string with the name of the Flight SQL Server. FlightSqlServerName = 0, - /// Retrieves a UTF-8 string with the native version of the Flight SQL Server. + /// Retrieves a UTF-8 string with the native version of the Flight SQL Server. FlightSqlServerVersion = 1, - /// Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. + /// Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. FlightSqlServerArrowVersion = 2, - /// - /// Retrieves a boolean value indicating whether the Flight SQL Server is read only. /// - /// Returns: - /// - false: if read-write - /// - true: if read only + /// Retrieves a boolean value indicating whether the Flight SQL Server is read only. + /// + /// Returns: + /// - false: if read-write + /// - true: if read only FlightSqlServerReadOnly = 3, - // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. - /// - /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + /// Retrieves a boolean value indicating whether the Flight SQL Server supports executing + /// SQL queries. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of catalogs. - /// - true: if it supports CREATE and DROP of catalogs. + /// Note that the absence of this info (as opposed to a false value) does not necessarily + /// mean that SQL is not supported, as this property was not originally defined. + FlightSqlServerSql = 4, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports executing + /// Substrait plans. + FlightSqlServerSubstrait = 5, + /// + /// Retrieves a string value indicating the minimum supported Substrait version, or null + /// if Substrait is not supported. + FlightSqlServerSubstraitMinVersion = 6, + /// + /// Retrieves a string value indicating the maximum supported Substrait version, or null + /// if Substrait is not supported. + FlightSqlServerSubstraitMaxVersion = 7, + /// + /// Retrieves an int32 indicating whether the Flight SQL Server supports the + /// BeginTransaction/EndTransaction/BeginSavepoint/EndSavepoint actions. + /// + /// Even if this is not supported, the database may still support explicit "BEGIN + /// TRANSACTION"/"COMMIT" SQL statements (see SQL_TRANSACTIONS_SUPPORTED); this property + /// is only about whether the server implements the Flight SQL API endpoints. + /// + /// The possible values are listed in `SqlSupportedTransaction`. + FlightSqlServerTransaction = 8, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports explicit + /// query cancellation (the CancelQuery action). + FlightSqlServerCancel = 9, + /// + /// Retrieves an int32 indicating the timeout (in milliseconds) for prepared statement handles. + /// + /// If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + FlightSqlServerStatementTimeout = 100, + /// + /// Retrieves an int32 indicating the timeout (in milliseconds) for transactions, since transactions are not tied to a connection. + /// + /// If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + FlightSqlServerTransactionTimeout = 101, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + /// + /// Returns: + /// - false: if it doesn't support CREATE and DROP of catalogs. + /// - true: if it supports CREATE and DROP of catalogs. SqlDdlCatalog = 500, /// - /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of schemas. - /// - true: if it supports CREATE and DROP of schemas. + /// Returns: + /// - false: if it doesn't support CREATE and DROP of schemas. + /// - true: if it supports CREATE and DROP of schemas. SqlDdlSchema = 501, /// - /// Indicates whether the Flight SQL Server supports CREATE and DROP of tables. + /// Indicates whether the Flight SQL Server supports CREATE and DROP of tables. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of tables. - /// - true: if it supports CREATE and DROP of tables. + /// Returns: + /// - false: if it doesn't support CREATE and DROP of tables. + /// - true: if it supports CREATE and DROP of tables. SqlDdlTable = 502, /// - /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of catalog, table, schema and table names. + /// Retrieves a int32 ordinal representing the case sensitivity of catalog, table, schema and table names. /// - /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. SqlIdentifierCase = 503, - /// Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. + /// Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. SqlIdentifierQuoteChar = 504, /// - /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of quoted identifiers. + /// Retrieves a int32 describing the case sensitivity of quoted identifiers. /// - /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. SqlQuotedIdentifierCase = 505, /// - /// Retrieves a boolean value indicating whether all tables are selectable. + /// Retrieves a boolean value indicating whether all tables are selectable. /// - /// Returns: - /// - false: if not all tables are selectable or if none are; - /// - true: if all tables are selectable. + /// Returns: + /// - false: if not all tables are selectable or if none are; + /// - true: if all tables are selectable. SqlAllTablesAreSelectable = 506, /// - /// Retrieves the null ordering. + /// Retrieves the null ordering. /// - /// Returns a uint32 ordinal for the null ordering being used, as described in - /// `arrow.flight.protocol.sql.SqlNullOrdering`. + /// Returns a int32 ordinal for the null ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlNullOrdering`. SqlNullOrdering = 507, - /// Retrieves a UTF-8 string list with values of the supported keywords. + /// Retrieves a UTF-8 string list with values of the supported keywords. SqlKeywords = 508, - /// Retrieves a UTF-8 string list with values of the supported numeric functions. + /// Retrieves a UTF-8 string list with values of the supported numeric functions. SqlNumericFunctions = 509, - /// Retrieves a UTF-8 string list with values of the supported string functions. + /// Retrieves a UTF-8 string list with values of the supported string functions. SqlStringFunctions = 510, - /// Retrieves a UTF-8 string list with values of the supported system functions. + /// Retrieves a UTF-8 string list with values of the supported system functions. SqlSystemFunctions = 511, - /// Retrieves a UTF-8 string list with values of the supported datetime functions. + /// Retrieves a UTF-8 string list with values of the supported datetime functions. SqlDatetimeFunctions = 512, /// - /// Retrieves the UTF-8 string that can be used to escape wildcard characters. - /// This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern - /// (and therefore use one of the wildcard characters). - /// The '_' character represents any single character; the '%' character represents any sequence of zero or more - /// characters. + /// Retrieves the UTF-8 string that can be used to escape wildcard characters. + /// This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern + /// (and therefore use one of the wildcard characters). + /// The '_' character represents any single character; the '%' character represents any sequence of zero or more + /// characters. SqlSearchStringEscape = 513, /// - /// Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names - /// (those beyond a-z, A-Z, 0-9 and _). + /// Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names + /// (those beyond a-z, A-Z, 0-9 and _). SqlExtraNameCharacters = 514, /// - /// Retrieves a boolean value indicating whether column aliasing is supported. - /// If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns - /// as required. + /// Retrieves a boolean value indicating whether column aliasing is supported. + /// If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns + /// as required. /// - /// Returns: - /// - false: if column aliasing is unsupported; - /// - true: if column aliasing is supported. + /// Returns: + /// - false: if column aliasing is unsupported; + /// - true: if column aliasing is supported. SqlSupportsColumnAliasing = 515, /// - /// Retrieves a boolean value indicating whether concatenations between null and non-null values being - /// null are supported. + /// Retrieves a boolean value indicating whether concatenations between null and non-null values being + /// null are supported. /// - /// - Returns: - /// - false: if concatenations between null and non-null values being null are unsupported; - /// - true: if concatenations between null and non-null values being null are supported. + /// - Returns: + /// - false: if concatenations between null and non-null values being null are unsupported; + /// - true: if concatenations between null and non-null values being null are supported. SqlNullPlusNullIsNull = 516, /// - /// Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, - /// indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on - /// SqlSupportsConvert enum. - /// The returned map will be: map> + /// Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, + /// indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on + /// SqlSupportsConvert enum. + /// The returned map will be: map> SqlSupportsConvert = 517, /// - /// Retrieves a boolean value indicating whether, when table correlation names are supported, - /// they are restricted to being different from the names of the tables. + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. /// - /// Returns: - /// - false: if table correlation names are unsupported; - /// - true: if table correlation names are supported. + /// Returns: + /// - false: if table correlation names are unsupported; + /// - true: if table correlation names are supported. SqlSupportsTableCorrelationNames = 518, /// - /// Retrieves a boolean value indicating whether, when table correlation names are supported, - /// they are restricted to being different from the names of the tables. + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. /// - /// Returns: - /// - false: if different table correlation names are unsupported; - /// - true: if different table correlation names are supported + /// Returns: + /// - false: if different table correlation names are unsupported; + /// - true: if different table correlation names are supported SqlSupportsDifferentTableCorrelationNames = 519, /// - /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. + /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. /// - /// Returns: - /// - false: if expressions in ORDER BY are unsupported; - /// - true: if expressions in ORDER BY are supported; + /// Returns: + /// - false: if expressions in ORDER BY are unsupported; + /// - true: if expressions in ORDER BY are supported; SqlSupportsExpressionsInOrderBy = 520, /// - /// Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY - /// clause is supported. + /// Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY + /// clause is supported. /// - /// Returns: - /// - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; - /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. + /// Returns: + /// - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; + /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. SqlSupportsOrderByUnrelated = 521, /// - /// Retrieves the supported GROUP BY commands; + /// Retrieves the supported GROUP BY commands; /// - /// Returns an int32 bitmask value representing the supported commands. - /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// Returns an int32 bitmask value representing the supported commands. + /// The returned bitmask should be parsed in order to retrieve the supported commands. /// - /// For instance: - /// - return 0 (\b0) => [] (GROUP BY is unsupported); - /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; - /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; - /// - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. - /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. + /// For instance: + /// - return 0 (\b0) => \[\] (GROUP BY is unsupported); + /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; + /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; + /// - return 3 (\b11) => \[SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT\]. + /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. SqlSupportedGroupBy = 522, /// - /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. + /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. /// - /// Returns: - /// - false: if specifying a LIKE escape clause is unsupported; - /// - true: if specifying a LIKE escape clause is supported. + /// Returns: + /// - false: if specifying a LIKE escape clause is unsupported; + /// - true: if specifying a LIKE escape clause is supported. SqlSupportsLikeEscapeClause = 523, /// - /// Retrieves a boolean value indicating whether columns may be defined as non-nullable. + /// Retrieves a boolean value indicating whether columns may be defined as non-nullable. /// - /// Returns: - /// - false: if columns cannot be defined as non-nullable; - /// - true: if columns may be defined as non-nullable. + /// Returns: + /// - false: if columns cannot be defined as non-nullable; + /// - true: if columns may be defined as non-nullable. SqlSupportsNonNullableColumns = 524, /// - /// Retrieves the supported SQL grammar level as per the ODBC specification. - /// - /// Returns an int32 bitmask value representing the supported SQL grammar level. - /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. - /// - /// For instance: - /// - return 0 (\b0) => [] (SQL grammar is unsupported); - /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; - /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; - /// - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; - /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; - /// - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. - /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. + /// Retrieves the supported SQL grammar level as per the ODBC specification. + /// + /// Returns an int32 bitmask value representing the supported SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (SQL grammar is unsupported); + /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; + /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; + /// - return 3 (\b11) => \[SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR\]; + /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; + /// - return 5 (\b101) => \[SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR\]; + /// - return 6 (\b110) => \[SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR\]; + /// - return 7 (\b111) => \[SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR\]. + /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. SqlSupportedGrammar = 525, /// - /// Retrieves the supported ANSI92 SQL grammar level. - /// - /// Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. - /// The returned bitmask should be parsed in order to retrieve the supported commands. - /// - /// For instance: - /// - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); - /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; - /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; - /// - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; - /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; - /// - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; - /// - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; - /// - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. - /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. + /// Retrieves the supported ANSI92 SQL grammar level. + /// + /// Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (ANSI92 SQL grammar is unsupported); + /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; + /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; + /// - return 3 (\b11) => \[ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL\]; + /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; + /// - return 5 (\b101) => \[ANSI92_ENTRY_SQL, ANSI92_FULL_SQL\]; + /// - return 6 (\b110) => \[ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL\]; + /// - return 7 (\b111) => \[ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL\]. + /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. SqlAnsi92SupportedLevel = 526, /// - /// Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. + /// Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. /// - /// Returns: - /// - false: if the SQL Integrity Enhancement Facility is supported; - /// - true: if the SQL Integrity Enhancement Facility is supported. + /// Returns: + /// - false: if the SQL Integrity Enhancement Facility is supported; + /// - true: if the SQL Integrity Enhancement Facility is supported. SqlSupportsIntegrityEnhancementFacility = 527, /// - /// Retrieves the support level for SQL OUTER JOINs. + /// Retrieves the support level for SQL OUTER JOINs. /// - /// Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in - /// `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. + /// Returns a int32 ordinal for the SQL ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. SqlOuterJoinsSupportLevel = 528, - /// Retrieves a UTF-8 string with the preferred term for "schema". + /// Retrieves a UTF-8 string with the preferred term for "schema". SqlSchemaTerm = 529, - /// Retrieves a UTF-8 string with the preferred term for "procedure". + /// Retrieves a UTF-8 string with the preferred term for "procedure". SqlProcedureTerm = 530, - /// Retrieves a UTF-8 string with the preferred term for "catalog". + /// + /// Retrieves a UTF-8 string with the preferred term for "catalog". + /// If a empty string is returned its assumed that the server does NOT supports catalogs. SqlCatalogTerm = 531, /// - /// Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. + /// Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. /// - /// - false: if a catalog does not appear at the start of a fully qualified table name; - /// - true: if a catalog appears at the start of a fully qualified table name. + /// - false: if a catalog does not appear at the start of a fully qualified table name; + /// - true: if a catalog appears at the start of a fully qualified table name. SqlCatalogAtStart = 532, /// - /// Retrieves the supported actions for a SQL schema. - /// - /// Returns an int32 bitmask value representing the supported actions for a SQL schema. - /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL schema); - /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; - /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; - /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. - /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL schema. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported actions for SQL schema); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 6 (\b110) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 7 (\b111) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]. + /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlSchemasSupportedActions = 533, /// - /// Retrieves the supported actions for a SQL schema. - /// - /// Returns an int32 bitmask value representing the supported actions for a SQL catalog. - /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL catalog); - /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; - /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; - /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. - /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL catalog. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported actions for SQL catalog); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 6 (\b110) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 7 (\b111) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]. + /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlCatalogsSupportedActions = 534, /// - /// Retrieves the supported SQL positioned commands. + /// Retrieves the supported SQL positioned commands. /// - /// Returns an int32 bitmask value representing the supported SQL positioned commands. - /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. + /// Returns an int32 bitmask value representing the supported SQL positioned commands. + /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); - /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; - /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; - /// - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; + /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; + /// - return 3 (\b11) => \[SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE\]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. SqlSupportedPositionedCommands = 535, /// - /// Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. + /// Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. /// - /// Returns: - /// - false: if SELECT FOR UPDATE statements are unsupported; - /// - true: if SELECT FOR UPDATE statements are supported. + /// Returns: + /// - false: if SELECT FOR UPDATE statements are unsupported; + /// - true: if SELECT FOR UPDATE statements are supported. SqlSelectForUpdateSupported = 536, /// - /// Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax - /// are supported. + /// Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax + /// are supported. /// - /// Returns: - /// - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; - /// - true: if stored procedure calls that use the stored procedure escape syntax are supported. + /// Returns: + /// - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; + /// - true: if stored procedure calls that use the stored procedure escape syntax are supported. SqlStoredProceduresSupported = 537, /// - /// Retrieves the supported SQL subqueries. - /// - /// Returns an int32 bitmask value representing the supported SQL subqueries. - /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL subqueries); - /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; - /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; - /// - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; - /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; - /// - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; - /// - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; - /// - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; - /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; - /// - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - ... - /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. + /// Retrieves the supported SQL subqueries. + /// + /// Returns an int32 bitmask value representing the supported SQL subqueries. + /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL subqueries); + /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; + /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 3 (\b11) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; + /// - return 5 (\b101) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS\]; + /// - return 6 (\b110) => \[SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 7 (\b111) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS\]; + /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 9 (\b1001) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 10 (\b1010) => \[SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 11 (\b1011) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 12 (\b1100) => \[SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 13 (\b1101) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 14 (\b1110) => \[SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 15 (\b1111) => \[SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - ... + /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. SqlSupportedSubqueries = 538, /// - /// Retrieves a boolean value indicating whether correlated subqueries are supported. + /// Retrieves a boolean value indicating whether correlated subqueries are supported. /// - /// Returns: - /// - false: if correlated subqueries are unsupported; - /// - true: if correlated subqueries are supported. + /// Returns: + /// - false: if correlated subqueries are unsupported; + /// - true: if correlated subqueries are supported. SqlCorrelatedSubqueriesSupported = 539, /// - /// Retrieves the supported SQL UNIONs. + /// Retrieves the supported SQL UNIONs. /// - /// Returns an int32 bitmask value representing the supported SQL UNIONs. - /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. + /// Returns an int32 bitmask value representing the supported SQL UNIONs. + /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); - /// - return 1 (\b1) => \[SQL_UNION\]; - /// - return 2 (\b10) => \[SQL_UNION_ALL\]; - /// - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_UNION\]; + /// - return 2 (\b10) => \[SQL_UNION_ALL\]; + /// - return 3 (\b11) => \[SQL_UNION, SQL_UNION_ALL\]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. SqlSupportedUnions = 540, - /// Retrieves a uint32 value representing the maximum number of hex characters allowed in an inline binary literal. + /// Retrieves a int64 value representing the maximum number of hex characters allowed in an inline binary literal. SqlMaxBinaryLiteralLength = 541, - /// Retrieves a uint32 value representing the maximum number of characters allowed for a character literal. + /// Retrieves a int64 value representing the maximum number of characters allowed for a character literal. SqlMaxCharLiteralLength = 542, - /// Retrieves a uint32 value representing the maximum number of characters allowed for a column name. + /// Retrieves a int64 value representing the maximum number of characters allowed for a column name. SqlMaxColumnNameLength = 543, - /// Retrieves a uint32 value representing the the maximum number of columns allowed in a GROUP BY clause. + /// Retrieves a int64 value representing the the maximum number of columns allowed in a GROUP BY clause. SqlMaxColumnsInGroupBy = 544, - /// Retrieves a uint32 value representing the maximum number of columns allowed in an index. + /// Retrieves a int64 value representing the maximum number of columns allowed in an index. SqlMaxColumnsInIndex = 545, - /// Retrieves a uint32 value representing the maximum number of columns allowed in an ORDER BY clause. + /// Retrieves a int64 value representing the maximum number of columns allowed in an ORDER BY clause. SqlMaxColumnsInOrderBy = 546, - /// Retrieves a uint32 value representing the maximum number of columns allowed in a SELECT list. + /// Retrieves a int64 value representing the maximum number of columns allowed in a SELECT list. SqlMaxColumnsInSelect = 547, - /// Retrieves a uint32 value representing the maximum number of columns allowed in a table. + /// Retrieves a int64 value representing the maximum number of columns allowed in a table. SqlMaxColumnsInTable = 548, - /// Retrieves a uint32 value representing the maximum number of concurrent connections possible. + /// Retrieves a int64 value representing the maximum number of concurrent connections possible. SqlMaxConnections = 549, - /// Retrieves a uint32 value the maximum number of characters allowed in a cursor name. + /// Retrieves a int64 value the maximum number of characters allowed in a cursor name. SqlMaxCursorNameLength = 550, /// - /// Retrieves a uint32 value representing the maximum number of bytes allowed for an index, - /// including all of the parts of the index. + /// Retrieves a int64 value representing the maximum number of bytes allowed for an index, + /// including all of the parts of the index. SqlMaxIndexLength = 551, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a schema name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a schema name. SqlDbSchemaNameLength = 552, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a procedure name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a procedure name. SqlMaxProcedureNameLength = 553, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a catalog name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a catalog name. SqlMaxCatalogNameLength = 554, - /// Retrieves a uint32 value representing the maximum number of bytes allowed in a single row. + /// Retrieves a int64 value representing the maximum number of bytes allowed in a single row. SqlMaxRowSize = 555, /// - /// Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL - /// data types LONGVARCHAR and LONGVARBINARY. + /// Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL + /// data types LONGVARCHAR and LONGVARBINARY. /// - /// Returns: - /// - false: if return value for the JDBC method getMaxRowSize does + /// Returns: + /// - false: if return value for the JDBC method getMaxRowSize does /// not include the SQL data types LONGVARCHAR and LONGVARBINARY; - /// - true: if return value for the JDBC method getMaxRowSize includes + /// - true: if return value for the JDBC method getMaxRowSize includes /// the SQL data types LONGVARCHAR and LONGVARBINARY. SqlMaxRowSizeIncludesBlobs = 556, /// - /// Retrieves a uint32 value representing the maximum number of characters allowed for an SQL statement; - /// a result of 0 (zero) means that there is no limit or the limit is not known. + /// Retrieves a int64 value representing the maximum number of characters allowed for an SQL statement; + /// a result of 0 (zero) means that there is no limit or the limit is not known. SqlMaxStatementLength = 557, - /// Retrieves a uint32 value representing the maximum number of active statements that can be open at the same time. + /// Retrieves a int64 value representing the maximum number of active statements that can be open at the same time. SqlMaxStatements = 558, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a table name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a table name. SqlMaxTableNameLength = 559, - /// Retrieves a uint32 value representing the maximum number of tables allowed in a SELECT statement. + /// Retrieves a int64 value representing the maximum number of tables allowed in a SELECT statement. SqlMaxTablesInSelect = 560, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a user name. + /// Retrieves a int64 value representing the maximum number of characters allowed in a user name. SqlMaxUsernameLength = 561, /// - /// Retrieves this database's default transaction isolation level as described in - /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + /// Retrieves this database's default transaction isolation level as described in + /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. /// - /// Returns a uint32 ordinal for the SQL transaction isolation level. + /// Returns a int32 ordinal for the SQL transaction isolation level. SqlDefaultTransactionIsolation = 562, /// - /// Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a - /// noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + /// Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a + /// noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. /// - /// Returns: - /// - false: if transactions are unsupported; - /// - true: if transactions are supported. + /// Returns: + /// - false: if transactions are unsupported; + /// - true: if transactions are supported. SqlTransactionsSupported = 563, /// - /// Retrieves the supported transactions isolation levels. - /// - /// Returns an int32 bitmask value representing the supported transactions isolation levels. - /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL transactions isolation levels); - /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; - /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; - /// - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; - /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; - /// - ... - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + /// Retrieves the supported transactions isolation levels. + /// + /// Returns an int32 bitmask value representing the supported transactions isolation levels. + /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported SQL transactions isolation levels); + /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; + /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; + /// - return 3 (\b11) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED\]; + /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 5 (\b101) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 6 (\b110) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 7 (\b111) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 9 (\b1001) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 10 (\b1010) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 11 (\b1011) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 12 (\b1100) => \[SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 13 (\b1101) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 14 (\b1110) => \[SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 15 (\b1111) => \[SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; + /// - ... + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. SqlSupportedTransactionsIsolationLevels = 564, /// - /// Retrieves a boolean value indicating whether a data definition statement within a transaction forces - /// the transaction to commit. + /// Retrieves a boolean value indicating whether a data definition statement within a transaction forces + /// the transaction to commit. /// - /// Returns: - /// - false: if a data definition statement within a transaction does not force the transaction to commit; - /// - true: if a data definition statement within a transaction forces the transaction to commit. + /// Returns: + /// - false: if a data definition statement within a transaction does not force the transaction to commit; + /// - true: if a data definition statement within a transaction forces the transaction to commit. SqlDataDefinitionCausesTransactionCommit = 565, /// - /// Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. + /// Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. /// - /// Returns: - /// - false: if a data definition statement within a transaction is taken into account; - /// - true: a data definition statement within a transaction is ignored. + /// Returns: + /// - false: if a data definition statement within a transaction is taken into account; + /// - true: a data definition statement within a transaction is ignored. SqlDataDefinitionsInTransactionsIgnored = 566, /// - /// Retrieves an int32 bitmask value representing the supported result set types. - /// The returned bitmask should be parsed in order to retrieve the supported result set types. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported result set types); - /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; - /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; - /// - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; - /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; - /// - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; - /// - ... - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. + /// Retrieves an int32 bitmask value representing the supported result set types. + /// The returned bitmask should be parsed in order to retrieve the supported result set types. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported result set types); + /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; + /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; + /// - return 3 (\b11) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; + /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 5 (\b101) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 6 (\b110) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 7 (\b111) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; + /// - ... + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. SqlSupportedResultSetTypes = 567, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetUnspecified = 568, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetForwardOnly = 569, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollSensitive = 570, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => \[\] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 6 (\b110) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 7 (\b111) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollInsensitive = 571, /// - /// Retrieves a boolean value indicating whether this database supports batch updates. + /// Retrieves a boolean value indicating whether this database supports batch updates. /// - /// - false: if this database does not support batch updates; - /// - true: if this database supports batch updates. + /// - false: if this database does not support batch updates; + /// - true: if this database supports batch updates. SqlBatchUpdatesSupported = 572, /// - /// Retrieves a boolean value indicating whether this database supports savepoints. + /// Retrieves a boolean value indicating whether this database supports savepoints. /// - /// Returns: - /// - false: if this database does not support savepoints; - /// - true: if this database supports savepoints. + /// Returns: + /// - false: if this database does not support savepoints; + /// - true: if this database supports savepoints. SqlSavepointsSupported = 573, /// - /// Retrieves a boolean value indicating whether named parameters are supported in callable statements. + /// Retrieves a boolean value indicating whether named parameters are supported in callable statements. /// - /// Returns: - /// - false: if named parameters in callable statements are unsupported; - /// - true: if named parameters in callable statements are supported. + /// Returns: + /// - false: if named parameters in callable statements are unsupported; + /// - true: if named parameters in callable statements are supported. SqlNamedParametersSupported = 574, /// - /// Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. + /// Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. /// - /// Returns: - /// - false: if updates made to a LOB are made directly to the LOB; - /// - true: if updates made to a LOB are made on a copy. + /// Returns: + /// - false: if updates made to a LOB are made directly to the LOB; + /// - true: if updates made to a LOB are made on a copy. SqlLocatorsUpdateCopy = 575, /// - /// Retrieves a boolean value indicating whether invoking user-defined or vendor functions - /// using the stored procedure escape syntax is supported. + /// Retrieves a boolean value indicating whether invoking user-defined or vendor functions + /// using the stored procedure escape syntax is supported. /// - /// Returns: - /// - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; - /// - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. + /// Returns: + /// - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; + /// - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. SqlStoredFunctionsUsingCallSyntaxSupported = 576, } impl SqlInfo { @@ -1014,6 +1500,22 @@ impl SqlInfo { SqlInfo::FlightSqlServerVersion => "FLIGHT_SQL_SERVER_VERSION", SqlInfo::FlightSqlServerArrowVersion => "FLIGHT_SQL_SERVER_ARROW_VERSION", SqlInfo::FlightSqlServerReadOnly => "FLIGHT_SQL_SERVER_READ_ONLY", + SqlInfo::FlightSqlServerSql => "FLIGHT_SQL_SERVER_SQL", + SqlInfo::FlightSqlServerSubstrait => "FLIGHT_SQL_SERVER_SUBSTRAIT", + SqlInfo::FlightSqlServerSubstraitMinVersion => { + "FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION" + } + SqlInfo::FlightSqlServerSubstraitMaxVersion => { + "FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION" + } + SqlInfo::FlightSqlServerTransaction => "FLIGHT_SQL_SERVER_TRANSACTION", + SqlInfo::FlightSqlServerCancel => "FLIGHT_SQL_SERVER_CANCEL", + SqlInfo::FlightSqlServerStatementTimeout => { + "FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT" + } + SqlInfo::FlightSqlServerTransactionTimeout => { + "FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT" + } SqlInfo::SqlDdlCatalog => "SQL_DDL_CATALOG", SqlInfo::SqlDdlSchema => "SQL_DDL_SCHEMA", SqlInfo::SqlDdlTable => "SQL_DDL_TABLE", @@ -1032,16 +1534,24 @@ impl SqlInfo { SqlInfo::SqlSupportsColumnAliasing => "SQL_SUPPORTS_COLUMN_ALIASING", SqlInfo::SqlNullPlusNullIsNull => "SQL_NULL_PLUS_NULL_IS_NULL", SqlInfo::SqlSupportsConvert => "SQL_SUPPORTS_CONVERT", - SqlInfo::SqlSupportsTableCorrelationNames => "SQL_SUPPORTS_TABLE_CORRELATION_NAMES", - SqlInfo::SqlSupportsDifferentTableCorrelationNames => "SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES", - SqlInfo::SqlSupportsExpressionsInOrderBy => "SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY", + SqlInfo::SqlSupportsTableCorrelationNames => { + "SQL_SUPPORTS_TABLE_CORRELATION_NAMES" + } + SqlInfo::SqlSupportsDifferentTableCorrelationNames => { + "SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES" + } + SqlInfo::SqlSupportsExpressionsInOrderBy => { + "SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY" + } SqlInfo::SqlSupportsOrderByUnrelated => "SQL_SUPPORTS_ORDER_BY_UNRELATED", SqlInfo::SqlSupportedGroupBy => "SQL_SUPPORTED_GROUP_BY", SqlInfo::SqlSupportsLikeEscapeClause => "SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE", SqlInfo::SqlSupportsNonNullableColumns => "SQL_SUPPORTS_NON_NULLABLE_COLUMNS", SqlInfo::SqlSupportedGrammar => "SQL_SUPPORTED_GRAMMAR", SqlInfo::SqlAnsi92SupportedLevel => "SQL_ANSI92_SUPPORTED_LEVEL", - SqlInfo::SqlSupportsIntegrityEnhancementFacility => "SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY", + SqlInfo::SqlSupportsIntegrityEnhancementFacility => { + "SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY" + } SqlInfo::SqlOuterJoinsSupportLevel => "SQL_OUTER_JOINS_SUPPORT_LEVEL", SqlInfo::SqlSchemaTerm => "SQL_SCHEMA_TERM", SqlInfo::SqlProcedureTerm => "SQL_PROCEDURE_TERM", @@ -1049,11 +1559,15 @@ impl SqlInfo { SqlInfo::SqlCatalogAtStart => "SQL_CATALOG_AT_START", SqlInfo::SqlSchemasSupportedActions => "SQL_SCHEMAS_SUPPORTED_ACTIONS", SqlInfo::SqlCatalogsSupportedActions => "SQL_CATALOGS_SUPPORTED_ACTIONS", - SqlInfo::SqlSupportedPositionedCommands => "SQL_SUPPORTED_POSITIONED_COMMANDS", + SqlInfo::SqlSupportedPositionedCommands => { + "SQL_SUPPORTED_POSITIONED_COMMANDS" + } SqlInfo::SqlSelectForUpdateSupported => "SQL_SELECT_FOR_UPDATE_SUPPORTED", SqlInfo::SqlStoredProceduresSupported => "SQL_STORED_PROCEDURES_SUPPORTED", SqlInfo::SqlSupportedSubqueries => "SQL_SUPPORTED_SUBQUERIES", - SqlInfo::SqlCorrelatedSubqueriesSupported => "SQL_CORRELATED_SUBQUERIES_SUPPORTED", + SqlInfo::SqlCorrelatedSubqueriesSupported => { + "SQL_CORRELATED_SUBQUERIES_SUPPORTED" + } SqlInfo::SqlSupportedUnions => "SQL_SUPPORTED_UNIONS", SqlInfo::SqlMaxBinaryLiteralLength => "SQL_MAX_BINARY_LITERAL_LENGTH", SqlInfo::SqlMaxCharLiteralLength => "SQL_MAX_CHAR_LITERAL_LENGTH", @@ -1076,21 +1590,211 @@ impl SqlInfo { SqlInfo::SqlMaxTableNameLength => "SQL_MAX_TABLE_NAME_LENGTH", SqlInfo::SqlMaxTablesInSelect => "SQL_MAX_TABLES_IN_SELECT", SqlInfo::SqlMaxUsernameLength => "SQL_MAX_USERNAME_LENGTH", - SqlInfo::SqlDefaultTransactionIsolation => "SQL_DEFAULT_TRANSACTION_ISOLATION", + SqlInfo::SqlDefaultTransactionIsolation => { + "SQL_DEFAULT_TRANSACTION_ISOLATION" + } SqlInfo::SqlTransactionsSupported => "SQL_TRANSACTIONS_SUPPORTED", - SqlInfo::SqlSupportedTransactionsIsolationLevels => "SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS", - SqlInfo::SqlDataDefinitionCausesTransactionCommit => "SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT", - SqlInfo::SqlDataDefinitionsInTransactionsIgnored => "SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED", + SqlInfo::SqlSupportedTransactionsIsolationLevels => { + "SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS" + } + SqlInfo::SqlDataDefinitionCausesTransactionCommit => { + "SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT" + } + SqlInfo::SqlDataDefinitionsInTransactionsIgnored => { + "SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED" + } SqlInfo::SqlSupportedResultSetTypes => "SQL_SUPPORTED_RESULT_SET_TYPES", - SqlInfo::SqlSupportedConcurrenciesForResultSetUnspecified => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED", - SqlInfo::SqlSupportedConcurrenciesForResultSetForwardOnly => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY", - SqlInfo::SqlSupportedConcurrenciesForResultSetScrollSensitive => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE", - SqlInfo::SqlSupportedConcurrenciesForResultSetScrollInsensitive => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE", + SqlInfo::SqlSupportedConcurrenciesForResultSetUnspecified => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED" + } + SqlInfo::SqlSupportedConcurrenciesForResultSetForwardOnly => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY" + } + SqlInfo::SqlSupportedConcurrenciesForResultSetScrollSensitive => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE" + } + SqlInfo::SqlSupportedConcurrenciesForResultSetScrollInsensitive => { + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE" + } SqlInfo::SqlBatchUpdatesSupported => "SQL_BATCH_UPDATES_SUPPORTED", SqlInfo::SqlSavepointsSupported => "SQL_SAVEPOINTS_SUPPORTED", SqlInfo::SqlNamedParametersSupported => "SQL_NAMED_PARAMETERS_SUPPORTED", SqlInfo::SqlLocatorsUpdateCopy => "SQL_LOCATORS_UPDATE_COPY", - SqlInfo::SqlStoredFunctionsUsingCallSyntaxSupported => "SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED", + SqlInfo::SqlStoredFunctionsUsingCallSyntaxSupported => { + "SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FLIGHT_SQL_SERVER_NAME" => Some(Self::FlightSqlServerName), + "FLIGHT_SQL_SERVER_VERSION" => Some(Self::FlightSqlServerVersion), + "FLIGHT_SQL_SERVER_ARROW_VERSION" => Some(Self::FlightSqlServerArrowVersion), + "FLIGHT_SQL_SERVER_READ_ONLY" => Some(Self::FlightSqlServerReadOnly), + "FLIGHT_SQL_SERVER_SQL" => Some(Self::FlightSqlServerSql), + "FLIGHT_SQL_SERVER_SUBSTRAIT" => Some(Self::FlightSqlServerSubstrait), + "FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION" => { + Some(Self::FlightSqlServerSubstraitMinVersion) + } + "FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION" => { + Some(Self::FlightSqlServerSubstraitMaxVersion) + } + "FLIGHT_SQL_SERVER_TRANSACTION" => Some(Self::FlightSqlServerTransaction), + "FLIGHT_SQL_SERVER_CANCEL" => Some(Self::FlightSqlServerCancel), + "FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT" => { + Some(Self::FlightSqlServerStatementTimeout) + } + "FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT" => { + Some(Self::FlightSqlServerTransactionTimeout) + } + "SQL_DDL_CATALOG" => Some(Self::SqlDdlCatalog), + "SQL_DDL_SCHEMA" => Some(Self::SqlDdlSchema), + "SQL_DDL_TABLE" => Some(Self::SqlDdlTable), + "SQL_IDENTIFIER_CASE" => Some(Self::SqlIdentifierCase), + "SQL_IDENTIFIER_QUOTE_CHAR" => Some(Self::SqlIdentifierQuoteChar), + "SQL_QUOTED_IDENTIFIER_CASE" => Some(Self::SqlQuotedIdentifierCase), + "SQL_ALL_TABLES_ARE_SELECTABLE" => Some(Self::SqlAllTablesAreSelectable), + "SQL_NULL_ORDERING" => Some(Self::SqlNullOrdering), + "SQL_KEYWORDS" => Some(Self::SqlKeywords), + "SQL_NUMERIC_FUNCTIONS" => Some(Self::SqlNumericFunctions), + "SQL_STRING_FUNCTIONS" => Some(Self::SqlStringFunctions), + "SQL_SYSTEM_FUNCTIONS" => Some(Self::SqlSystemFunctions), + "SQL_DATETIME_FUNCTIONS" => Some(Self::SqlDatetimeFunctions), + "SQL_SEARCH_STRING_ESCAPE" => Some(Self::SqlSearchStringEscape), + "SQL_EXTRA_NAME_CHARACTERS" => Some(Self::SqlExtraNameCharacters), + "SQL_SUPPORTS_COLUMN_ALIASING" => Some(Self::SqlSupportsColumnAliasing), + "SQL_NULL_PLUS_NULL_IS_NULL" => Some(Self::SqlNullPlusNullIsNull), + "SQL_SUPPORTS_CONVERT" => Some(Self::SqlSupportsConvert), + "SQL_SUPPORTS_TABLE_CORRELATION_NAMES" => { + Some(Self::SqlSupportsTableCorrelationNames) + } + "SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES" => { + Some(Self::SqlSupportsDifferentTableCorrelationNames) + } + "SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY" => { + Some(Self::SqlSupportsExpressionsInOrderBy) + } + "SQL_SUPPORTS_ORDER_BY_UNRELATED" => Some(Self::SqlSupportsOrderByUnrelated), + "SQL_SUPPORTED_GROUP_BY" => Some(Self::SqlSupportedGroupBy), + "SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE" => Some(Self::SqlSupportsLikeEscapeClause), + "SQL_SUPPORTS_NON_NULLABLE_COLUMNS" => { + Some(Self::SqlSupportsNonNullableColumns) + } + "SQL_SUPPORTED_GRAMMAR" => Some(Self::SqlSupportedGrammar), + "SQL_ANSI92_SUPPORTED_LEVEL" => Some(Self::SqlAnsi92SupportedLevel), + "SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY" => { + Some(Self::SqlSupportsIntegrityEnhancementFacility) + } + "SQL_OUTER_JOINS_SUPPORT_LEVEL" => Some(Self::SqlOuterJoinsSupportLevel), + "SQL_SCHEMA_TERM" => Some(Self::SqlSchemaTerm), + "SQL_PROCEDURE_TERM" => Some(Self::SqlProcedureTerm), + "SQL_CATALOG_TERM" => Some(Self::SqlCatalogTerm), + "SQL_CATALOG_AT_START" => Some(Self::SqlCatalogAtStart), + "SQL_SCHEMAS_SUPPORTED_ACTIONS" => Some(Self::SqlSchemasSupportedActions), + "SQL_CATALOGS_SUPPORTED_ACTIONS" => Some(Self::SqlCatalogsSupportedActions), + "SQL_SUPPORTED_POSITIONED_COMMANDS" => { + Some(Self::SqlSupportedPositionedCommands) + } + "SQL_SELECT_FOR_UPDATE_SUPPORTED" => Some(Self::SqlSelectForUpdateSupported), + "SQL_STORED_PROCEDURES_SUPPORTED" => Some(Self::SqlStoredProceduresSupported), + "SQL_SUPPORTED_SUBQUERIES" => Some(Self::SqlSupportedSubqueries), + "SQL_CORRELATED_SUBQUERIES_SUPPORTED" => { + Some(Self::SqlCorrelatedSubqueriesSupported) + } + "SQL_SUPPORTED_UNIONS" => Some(Self::SqlSupportedUnions), + "SQL_MAX_BINARY_LITERAL_LENGTH" => Some(Self::SqlMaxBinaryLiteralLength), + "SQL_MAX_CHAR_LITERAL_LENGTH" => Some(Self::SqlMaxCharLiteralLength), + "SQL_MAX_COLUMN_NAME_LENGTH" => Some(Self::SqlMaxColumnNameLength), + "SQL_MAX_COLUMNS_IN_GROUP_BY" => Some(Self::SqlMaxColumnsInGroupBy), + "SQL_MAX_COLUMNS_IN_INDEX" => Some(Self::SqlMaxColumnsInIndex), + "SQL_MAX_COLUMNS_IN_ORDER_BY" => Some(Self::SqlMaxColumnsInOrderBy), + "SQL_MAX_COLUMNS_IN_SELECT" => Some(Self::SqlMaxColumnsInSelect), + "SQL_MAX_COLUMNS_IN_TABLE" => Some(Self::SqlMaxColumnsInTable), + "SQL_MAX_CONNECTIONS" => Some(Self::SqlMaxConnections), + "SQL_MAX_CURSOR_NAME_LENGTH" => Some(Self::SqlMaxCursorNameLength), + "SQL_MAX_INDEX_LENGTH" => Some(Self::SqlMaxIndexLength), + "SQL_DB_SCHEMA_NAME_LENGTH" => Some(Self::SqlDbSchemaNameLength), + "SQL_MAX_PROCEDURE_NAME_LENGTH" => Some(Self::SqlMaxProcedureNameLength), + "SQL_MAX_CATALOG_NAME_LENGTH" => Some(Self::SqlMaxCatalogNameLength), + "SQL_MAX_ROW_SIZE" => Some(Self::SqlMaxRowSize), + "SQL_MAX_ROW_SIZE_INCLUDES_BLOBS" => Some(Self::SqlMaxRowSizeIncludesBlobs), + "SQL_MAX_STATEMENT_LENGTH" => Some(Self::SqlMaxStatementLength), + "SQL_MAX_STATEMENTS" => Some(Self::SqlMaxStatements), + "SQL_MAX_TABLE_NAME_LENGTH" => Some(Self::SqlMaxTableNameLength), + "SQL_MAX_TABLES_IN_SELECT" => Some(Self::SqlMaxTablesInSelect), + "SQL_MAX_USERNAME_LENGTH" => Some(Self::SqlMaxUsernameLength), + "SQL_DEFAULT_TRANSACTION_ISOLATION" => { + Some(Self::SqlDefaultTransactionIsolation) + } + "SQL_TRANSACTIONS_SUPPORTED" => Some(Self::SqlTransactionsSupported), + "SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS" => { + Some(Self::SqlSupportedTransactionsIsolationLevels) + } + "SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT" => { + Some(Self::SqlDataDefinitionCausesTransactionCommit) + } + "SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED" => { + Some(Self::SqlDataDefinitionsInTransactionsIgnored) + } + "SQL_SUPPORTED_RESULT_SET_TYPES" => Some(Self::SqlSupportedResultSetTypes), + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED" => { + Some(Self::SqlSupportedConcurrenciesForResultSetUnspecified) + } + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY" => { + Some(Self::SqlSupportedConcurrenciesForResultSetForwardOnly) + } + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE" => { + Some(Self::SqlSupportedConcurrenciesForResultSetScrollSensitive) + } + "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE" => { + Some(Self::SqlSupportedConcurrenciesForResultSetScrollInsensitive) + } + "SQL_BATCH_UPDATES_SUPPORTED" => Some(Self::SqlBatchUpdatesSupported), + "SQL_SAVEPOINTS_SUPPORTED" => Some(Self::SqlSavepointsSupported), + "SQL_NAMED_PARAMETERS_SUPPORTED" => Some(Self::SqlNamedParametersSupported), + "SQL_LOCATORS_UPDATE_COPY" => Some(Self::SqlLocatorsUpdateCopy), + "SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED" => { + Some(Self::SqlStoredFunctionsUsingCallSyntaxSupported) + } + _ => None, + } + } +} +/// The level of support for Flight SQL transaction RPCs. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedTransaction { + /// Unknown/not indicated/no support + None = 0, + /// Transactions, but not savepoints. + /// A savepoint is a mark within a transaction that can be individually + /// rolled back to. Not all databases support savepoints. + Transaction = 1, + /// Transactions and savepoints + Savepoint = 2, +} +impl SqlSupportedTransaction { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedTransaction::None => "SQL_SUPPORTED_TRANSACTION_NONE", + SqlSupportedTransaction::Transaction => { + "SQL_SUPPORTED_TRANSACTION_TRANSACTION" + } + SqlSupportedTransaction::Savepoint => "SQL_SUPPORTED_TRANSACTION_SAVEPOINT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_SUPPORTED_TRANSACTION_NONE" => Some(Self::None), + "SQL_SUPPORTED_TRANSACTION_TRANSACTION" => Some(Self::Transaction), + "SQL_SUPPORTED_TRANSACTION_SAVEPOINT" => Some(Self::Savepoint), + _ => None, } } } @@ -1109,10 +1813,30 @@ impl SqlSupportedCaseSensitivity { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedCaseSensitivity::SqlCaseSensitivityUnknown => "SQL_CASE_SENSITIVITY_UNKNOWN", - SqlSupportedCaseSensitivity::SqlCaseSensitivityCaseInsensitive => "SQL_CASE_SENSITIVITY_CASE_INSENSITIVE", - SqlSupportedCaseSensitivity::SqlCaseSensitivityUppercase => "SQL_CASE_SENSITIVITY_UPPERCASE", - SqlSupportedCaseSensitivity::SqlCaseSensitivityLowercase => "SQL_CASE_SENSITIVITY_LOWERCASE", + SqlSupportedCaseSensitivity::SqlCaseSensitivityUnknown => { + "SQL_CASE_SENSITIVITY_UNKNOWN" + } + SqlSupportedCaseSensitivity::SqlCaseSensitivityCaseInsensitive => { + "SQL_CASE_SENSITIVITY_CASE_INSENSITIVE" + } + SqlSupportedCaseSensitivity::SqlCaseSensitivityUppercase => { + "SQL_CASE_SENSITIVITY_UPPERCASE" + } + SqlSupportedCaseSensitivity::SqlCaseSensitivityLowercase => { + "SQL_CASE_SENSITIVITY_LOWERCASE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_CASE_SENSITIVITY_UNKNOWN" => Some(Self::SqlCaseSensitivityUnknown), + "SQL_CASE_SENSITIVITY_CASE_INSENSITIVE" => { + Some(Self::SqlCaseSensitivityCaseInsensitive) + } + "SQL_CASE_SENSITIVITY_UPPERCASE" => Some(Self::SqlCaseSensitivityUppercase), + "SQL_CASE_SENSITIVITY_LOWERCASE" => Some(Self::SqlCaseSensitivityLowercase), + _ => None, } } } @@ -1137,6 +1861,16 @@ impl SqlNullOrdering { SqlNullOrdering::SqlNullsSortedAtEnd => "SQL_NULLS_SORTED_AT_END", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_NULLS_SORTED_HIGH" => Some(Self::SqlNullsSortedHigh), + "SQL_NULLS_SORTED_LOW" => Some(Self::SqlNullsSortedLow), + "SQL_NULLS_SORTED_AT_START" => Some(Self::SqlNullsSortedAtStart), + "SQL_NULLS_SORTED_AT_END" => Some(Self::SqlNullsSortedAtEnd), + _ => None, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -1157,6 +1891,15 @@ impl SupportedSqlGrammar { SupportedSqlGrammar::SqlExtendedGrammar => "SQL_EXTENDED_GRAMMAR", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_MINIMUM_GRAMMAR" => Some(Self::SqlMinimumGrammar), + "SQL_CORE_GRAMMAR" => Some(Self::SqlCoreGrammar), + "SQL_EXTENDED_GRAMMAR" => Some(Self::SqlExtendedGrammar), + _ => None, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -1173,10 +1916,21 @@ impl SupportedAnsi92SqlGrammarLevel { pub fn as_str_name(&self) -> &'static str { match self { SupportedAnsi92SqlGrammarLevel::Ansi92EntrySql => "ANSI92_ENTRY_SQL", - SupportedAnsi92SqlGrammarLevel::Ansi92IntermediateSql => "ANSI92_INTERMEDIATE_SQL", + SupportedAnsi92SqlGrammarLevel::Ansi92IntermediateSql => { + "ANSI92_INTERMEDIATE_SQL" + } SupportedAnsi92SqlGrammarLevel::Ansi92FullSql => "ANSI92_FULL_SQL", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "ANSI92_ENTRY_SQL" => Some(Self::Ansi92EntrySql), + "ANSI92_INTERMEDIATE_SQL" => Some(Self::Ansi92IntermediateSql), + "ANSI92_FULL_SQL" => Some(Self::Ansi92FullSql), + _ => None, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -1197,6 +1951,15 @@ impl SqlOuterJoinsSupportLevel { SqlOuterJoinsSupportLevel::SqlFullOuterJoins => "SQL_FULL_OUTER_JOINS", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_JOINS_UNSUPPORTED" => Some(Self::SqlJoinsUnsupported), + "SQL_LIMITED_OUTER_JOINS" => Some(Self::SqlLimitedOuterJoins), + "SQL_FULL_OUTER_JOINS" => Some(Self::SqlFullOuterJoins), + _ => None, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -1215,6 +1978,14 @@ impl SqlSupportedGroupBy { SqlSupportedGroupBy::SqlGroupByBeyondSelect => "SQL_GROUP_BY_BEYOND_SELECT", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_GROUP_BY_UNRELATED" => Some(Self::SqlGroupByUnrelated), + "SQL_GROUP_BY_BEYOND_SELECT" => Some(Self::SqlGroupByBeyondSelect), + _ => None, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -1230,9 +2001,28 @@ impl SqlSupportedElementActions { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedElementActions::SqlElementInProcedureCalls => "SQL_ELEMENT_IN_PROCEDURE_CALLS", - SqlSupportedElementActions::SqlElementInIndexDefinitions => "SQL_ELEMENT_IN_INDEX_DEFINITIONS", - SqlSupportedElementActions::SqlElementInPrivilegeDefinitions => "SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS", + SqlSupportedElementActions::SqlElementInProcedureCalls => { + "SQL_ELEMENT_IN_PROCEDURE_CALLS" + } + SqlSupportedElementActions::SqlElementInIndexDefinitions => { + "SQL_ELEMENT_IN_INDEX_DEFINITIONS" + } + SqlSupportedElementActions::SqlElementInPrivilegeDefinitions => { + "SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_ELEMENT_IN_PROCEDURE_CALLS" => Some(Self::SqlElementInProcedureCalls), + "SQL_ELEMENT_IN_INDEX_DEFINITIONS" => { + Some(Self::SqlElementInIndexDefinitions) + } + "SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS" => { + Some(Self::SqlElementInPrivilegeDefinitions) + } + _ => None, } } } @@ -1249,8 +2039,20 @@ impl SqlSupportedPositionedCommands { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedPositionedCommands::SqlPositionedDelete => "SQL_POSITIONED_DELETE", - SqlSupportedPositionedCommands::SqlPositionedUpdate => "SQL_POSITIONED_UPDATE", + SqlSupportedPositionedCommands::SqlPositionedDelete => { + "SQL_POSITIONED_DELETE" + } + SqlSupportedPositionedCommands::SqlPositionedUpdate => { + "SQL_POSITIONED_UPDATE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_POSITIONED_DELETE" => Some(Self::SqlPositionedDelete), + "SQL_POSITIONED_UPDATE" => Some(Self::SqlPositionedUpdate), + _ => None, } } } @@ -1269,10 +2071,24 @@ impl SqlSupportedSubqueries { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedSubqueries::SqlSubqueriesInComparisons => "SQL_SUBQUERIES_IN_COMPARISONS", + SqlSupportedSubqueries::SqlSubqueriesInComparisons => { + "SQL_SUBQUERIES_IN_COMPARISONS" + } SqlSupportedSubqueries::SqlSubqueriesInExists => "SQL_SUBQUERIES_IN_EXISTS", SqlSupportedSubqueries::SqlSubqueriesInIns => "SQL_SUBQUERIES_IN_INS", - SqlSupportedSubqueries::SqlSubqueriesInQuantifieds => "SQL_SUBQUERIES_IN_QUANTIFIEDS", + SqlSupportedSubqueries::SqlSubqueriesInQuantifieds => { + "SQL_SUBQUERIES_IN_QUANTIFIEDS" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_SUBQUERIES_IN_COMPARISONS" => Some(Self::SqlSubqueriesInComparisons), + "SQL_SUBQUERIES_IN_EXISTS" => Some(Self::SqlSubqueriesInExists), + "SQL_SUBQUERIES_IN_INS" => Some(Self::SqlSubqueriesInIns), + "SQL_SUBQUERIES_IN_QUANTIFIEDS" => Some(Self::SqlSubqueriesInQuantifieds), + _ => None, } } } @@ -1293,6 +2109,14 @@ impl SqlSupportedUnions { SqlSupportedUnions::SqlUnionAll => "SQL_UNION_ALL", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_UNION" => Some(Self::SqlUnion), + "SQL_UNION_ALL" => Some(Self::SqlUnionAll), + _ => None, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -1311,10 +2135,31 @@ impl SqlTransactionIsolationLevel { pub fn as_str_name(&self) -> &'static str { match self { SqlTransactionIsolationLevel::SqlTransactionNone => "SQL_TRANSACTION_NONE", - SqlTransactionIsolationLevel::SqlTransactionReadUncommitted => "SQL_TRANSACTION_READ_UNCOMMITTED", - SqlTransactionIsolationLevel::SqlTransactionReadCommitted => "SQL_TRANSACTION_READ_COMMITTED", - SqlTransactionIsolationLevel::SqlTransactionRepeatableRead => "SQL_TRANSACTION_REPEATABLE_READ", - SqlTransactionIsolationLevel::SqlTransactionSerializable => "SQL_TRANSACTION_SERIALIZABLE", + SqlTransactionIsolationLevel::SqlTransactionReadUncommitted => { + "SQL_TRANSACTION_READ_UNCOMMITTED" + } + SqlTransactionIsolationLevel::SqlTransactionReadCommitted => { + "SQL_TRANSACTION_READ_COMMITTED" + } + SqlTransactionIsolationLevel::SqlTransactionRepeatableRead => { + "SQL_TRANSACTION_REPEATABLE_READ" + } + SqlTransactionIsolationLevel::SqlTransactionSerializable => { + "SQL_TRANSACTION_SERIALIZABLE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_TRANSACTION_NONE" => Some(Self::SqlTransactionNone), + "SQL_TRANSACTION_READ_UNCOMMITTED" => { + Some(Self::SqlTransactionReadUncommitted) + } + "SQL_TRANSACTION_READ_COMMITTED" => Some(Self::SqlTransactionReadCommitted), + "SQL_TRANSACTION_REPEATABLE_READ" => Some(Self::SqlTransactionRepeatableRead), + "SQL_TRANSACTION_SERIALIZABLE" => Some(Self::SqlTransactionSerializable), + _ => None, } } } @@ -1332,9 +2177,28 @@ impl SqlSupportedTransactions { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedTransactions::SqlTransactionUnspecified => "SQL_TRANSACTION_UNSPECIFIED", - SqlSupportedTransactions::SqlDataDefinitionTransactions => "SQL_DATA_DEFINITION_TRANSACTIONS", - SqlSupportedTransactions::SqlDataManipulationTransactions => "SQL_DATA_MANIPULATION_TRANSACTIONS", + SqlSupportedTransactions::SqlTransactionUnspecified => { + "SQL_TRANSACTION_UNSPECIFIED" + } + SqlSupportedTransactions::SqlDataDefinitionTransactions => { + "SQL_DATA_DEFINITION_TRANSACTIONS" + } + SqlSupportedTransactions::SqlDataManipulationTransactions => { + "SQL_DATA_MANIPULATION_TRANSACTIONS" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_TRANSACTION_UNSPECIFIED" => Some(Self::SqlTransactionUnspecified), + "SQL_DATA_DEFINITION_TRANSACTIONS" => { + Some(Self::SqlDataDefinitionTransactions) + } + "SQL_DATA_MANIPULATION_TRANSACTIONS" => { + Some(Self::SqlDataManipulationTransactions) + } + _ => None, } } } @@ -1353,10 +2217,32 @@ impl SqlSupportedResultSetType { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedResultSetType::SqlResultSetTypeUnspecified => "SQL_RESULT_SET_TYPE_UNSPECIFIED", - SqlSupportedResultSetType::SqlResultSetTypeForwardOnly => "SQL_RESULT_SET_TYPE_FORWARD_ONLY", - SqlSupportedResultSetType::SqlResultSetTypeScrollInsensitive => "SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE", - SqlSupportedResultSetType::SqlResultSetTypeScrollSensitive => "SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE", + SqlSupportedResultSetType::SqlResultSetTypeUnspecified => { + "SQL_RESULT_SET_TYPE_UNSPECIFIED" + } + SqlSupportedResultSetType::SqlResultSetTypeForwardOnly => { + "SQL_RESULT_SET_TYPE_FORWARD_ONLY" + } + SqlSupportedResultSetType::SqlResultSetTypeScrollInsensitive => { + "SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE" + } + SqlSupportedResultSetType::SqlResultSetTypeScrollSensitive => { + "SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_RESULT_SET_TYPE_UNSPECIFIED" => Some(Self::SqlResultSetTypeUnspecified), + "SQL_RESULT_SET_TYPE_FORWARD_ONLY" => Some(Self::SqlResultSetTypeForwardOnly), + "SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE" => { + Some(Self::SqlResultSetTypeScrollInsensitive) + } + "SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE" => { + Some(Self::SqlResultSetTypeScrollSensitive) + } + _ => None, } } } @@ -1374,9 +2260,30 @@ impl SqlSupportedResultSetConcurrency { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUnspecified => "SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED", - SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyReadOnly => "SQL_RESULT_SET_CONCURRENCY_READ_ONLY", - SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUpdatable => "SQL_RESULT_SET_CONCURRENCY_UPDATABLE", + SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUnspecified => { + "SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED" + } + SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyReadOnly => { + "SQL_RESULT_SET_CONCURRENCY_READ_ONLY" + } + SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUpdatable => { + "SQL_RESULT_SET_CONCURRENCY_UPDATABLE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED" => { + Some(Self::SqlResultSetConcurrencyUnspecified) + } + "SQL_RESULT_SET_CONCURRENCY_READ_ONLY" => { + Some(Self::SqlResultSetConcurrencyReadOnly) + } + "SQL_RESULT_SET_CONCURRENCY_UPDATABLE" => { + Some(Self::SqlResultSetConcurrencyUpdatable) + } + _ => None, } } } @@ -1419,8 +2326,12 @@ impl SqlSupportsConvert { SqlSupportsConvert::SqlConvertDecimal => "SQL_CONVERT_DECIMAL", SqlSupportsConvert::SqlConvertFloat => "SQL_CONVERT_FLOAT", SqlSupportsConvert::SqlConvertInteger => "SQL_CONVERT_INTEGER", - SqlSupportsConvert::SqlConvertIntervalDayTime => "SQL_CONVERT_INTERVAL_DAY_TIME", - SqlSupportsConvert::SqlConvertIntervalYearMonth => "SQL_CONVERT_INTERVAL_YEAR_MONTH", + SqlSupportsConvert::SqlConvertIntervalDayTime => { + "SQL_CONVERT_INTERVAL_DAY_TIME" + } + SqlSupportsConvert::SqlConvertIntervalYearMonth => { + "SQL_CONVERT_INTERVAL_YEAR_MONTH" + } SqlSupportsConvert::SqlConvertLongvarbinary => "SQL_CONVERT_LONGVARBINARY", SqlSupportsConvert::SqlConvertLongvarchar => "SQL_CONVERT_LONGVARCHAR", SqlSupportsConvert::SqlConvertNumeric => "SQL_CONVERT_NUMERIC", @@ -1433,6 +2344,352 @@ impl SqlSupportsConvert { SqlSupportsConvert::SqlConvertVarchar => "SQL_CONVERT_VARCHAR", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SQL_CONVERT_BIGINT" => Some(Self::SqlConvertBigint), + "SQL_CONVERT_BINARY" => Some(Self::SqlConvertBinary), + "SQL_CONVERT_BIT" => Some(Self::SqlConvertBit), + "SQL_CONVERT_CHAR" => Some(Self::SqlConvertChar), + "SQL_CONVERT_DATE" => Some(Self::SqlConvertDate), + "SQL_CONVERT_DECIMAL" => Some(Self::SqlConvertDecimal), + "SQL_CONVERT_FLOAT" => Some(Self::SqlConvertFloat), + "SQL_CONVERT_INTEGER" => Some(Self::SqlConvertInteger), + "SQL_CONVERT_INTERVAL_DAY_TIME" => Some(Self::SqlConvertIntervalDayTime), + "SQL_CONVERT_INTERVAL_YEAR_MONTH" => Some(Self::SqlConvertIntervalYearMonth), + "SQL_CONVERT_LONGVARBINARY" => Some(Self::SqlConvertLongvarbinary), + "SQL_CONVERT_LONGVARCHAR" => Some(Self::SqlConvertLongvarchar), + "SQL_CONVERT_NUMERIC" => Some(Self::SqlConvertNumeric), + "SQL_CONVERT_REAL" => Some(Self::SqlConvertReal), + "SQL_CONVERT_SMALLINT" => Some(Self::SqlConvertSmallint), + "SQL_CONVERT_TIME" => Some(Self::SqlConvertTime), + "SQL_CONVERT_TIMESTAMP" => Some(Self::SqlConvertTimestamp), + "SQL_CONVERT_TINYINT" => Some(Self::SqlConvertTinyint), + "SQL_CONVERT_VARBINARY" => Some(Self::SqlConvertVarbinary), + "SQL_CONVERT_VARCHAR" => Some(Self::SqlConvertVarchar), + _ => None, + } + } +} +/// * +/// The JDBC/ODBC-defined type of any object. +/// All the values here are the sames as in the JDBC and ODBC specs. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum XdbcDataType { + XdbcUnknownType = 0, + XdbcChar = 1, + XdbcNumeric = 2, + XdbcDecimal = 3, + XdbcInteger = 4, + XdbcSmallint = 5, + XdbcFloat = 6, + XdbcReal = 7, + XdbcDouble = 8, + XdbcDatetime = 9, + XdbcInterval = 10, + XdbcVarchar = 12, + XdbcDate = 91, + XdbcTime = 92, + XdbcTimestamp = 93, + XdbcLongvarchar = -1, + XdbcBinary = -2, + XdbcVarbinary = -3, + XdbcLongvarbinary = -4, + XdbcBigint = -5, + XdbcTinyint = -6, + XdbcBit = -7, + XdbcWchar = -8, + XdbcWvarchar = -9, +} +impl XdbcDataType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + XdbcDataType::XdbcUnknownType => "XDBC_UNKNOWN_TYPE", + XdbcDataType::XdbcChar => "XDBC_CHAR", + XdbcDataType::XdbcNumeric => "XDBC_NUMERIC", + XdbcDataType::XdbcDecimal => "XDBC_DECIMAL", + XdbcDataType::XdbcInteger => "XDBC_INTEGER", + XdbcDataType::XdbcSmallint => "XDBC_SMALLINT", + XdbcDataType::XdbcFloat => "XDBC_FLOAT", + XdbcDataType::XdbcReal => "XDBC_REAL", + XdbcDataType::XdbcDouble => "XDBC_DOUBLE", + XdbcDataType::XdbcDatetime => "XDBC_DATETIME", + XdbcDataType::XdbcInterval => "XDBC_INTERVAL", + XdbcDataType::XdbcVarchar => "XDBC_VARCHAR", + XdbcDataType::XdbcDate => "XDBC_DATE", + XdbcDataType::XdbcTime => "XDBC_TIME", + XdbcDataType::XdbcTimestamp => "XDBC_TIMESTAMP", + XdbcDataType::XdbcLongvarchar => "XDBC_LONGVARCHAR", + XdbcDataType::XdbcBinary => "XDBC_BINARY", + XdbcDataType::XdbcVarbinary => "XDBC_VARBINARY", + XdbcDataType::XdbcLongvarbinary => "XDBC_LONGVARBINARY", + XdbcDataType::XdbcBigint => "XDBC_BIGINT", + XdbcDataType::XdbcTinyint => "XDBC_TINYINT", + XdbcDataType::XdbcBit => "XDBC_BIT", + XdbcDataType::XdbcWchar => "XDBC_WCHAR", + XdbcDataType::XdbcWvarchar => "XDBC_WVARCHAR", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "XDBC_UNKNOWN_TYPE" => Some(Self::XdbcUnknownType), + "XDBC_CHAR" => Some(Self::XdbcChar), + "XDBC_NUMERIC" => Some(Self::XdbcNumeric), + "XDBC_DECIMAL" => Some(Self::XdbcDecimal), + "XDBC_INTEGER" => Some(Self::XdbcInteger), + "XDBC_SMALLINT" => Some(Self::XdbcSmallint), + "XDBC_FLOAT" => Some(Self::XdbcFloat), + "XDBC_REAL" => Some(Self::XdbcReal), + "XDBC_DOUBLE" => Some(Self::XdbcDouble), + "XDBC_DATETIME" => Some(Self::XdbcDatetime), + "XDBC_INTERVAL" => Some(Self::XdbcInterval), + "XDBC_VARCHAR" => Some(Self::XdbcVarchar), + "XDBC_DATE" => Some(Self::XdbcDate), + "XDBC_TIME" => Some(Self::XdbcTime), + "XDBC_TIMESTAMP" => Some(Self::XdbcTimestamp), + "XDBC_LONGVARCHAR" => Some(Self::XdbcLongvarchar), + "XDBC_BINARY" => Some(Self::XdbcBinary), + "XDBC_VARBINARY" => Some(Self::XdbcVarbinary), + "XDBC_LONGVARBINARY" => Some(Self::XdbcLongvarbinary), + "XDBC_BIGINT" => Some(Self::XdbcBigint), + "XDBC_TINYINT" => Some(Self::XdbcTinyint), + "XDBC_BIT" => Some(Self::XdbcBit), + "XDBC_WCHAR" => Some(Self::XdbcWchar), + "XDBC_WVARCHAR" => Some(Self::XdbcWvarchar), + _ => None, + } + } +} +/// * +/// Detailed subtype information for XDBC_TYPE_DATETIME and XDBC_TYPE_INTERVAL. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum XdbcDatetimeSubcode { + XdbcSubcodeUnknown = 0, + XdbcSubcodeYear = 1, + XdbcSubcodeTime = 2, + XdbcSubcodeTimestamp = 3, + XdbcSubcodeTimeWithTimezone = 4, + XdbcSubcodeTimestampWithTimezone = 5, + XdbcSubcodeSecond = 6, + XdbcSubcodeYearToMonth = 7, + XdbcSubcodeDayToHour = 8, + XdbcSubcodeDayToMinute = 9, + XdbcSubcodeDayToSecond = 10, + XdbcSubcodeHourToMinute = 11, + XdbcSubcodeHourToSecond = 12, + XdbcSubcodeMinuteToSecond = 13, + XdbcSubcodeIntervalYear = 101, + XdbcSubcodeIntervalMonth = 102, + XdbcSubcodeIntervalDay = 103, + XdbcSubcodeIntervalHour = 104, + XdbcSubcodeIntervalMinute = 105, + XdbcSubcodeIntervalSecond = 106, + XdbcSubcodeIntervalYearToMonth = 107, + XdbcSubcodeIntervalDayToHour = 108, + XdbcSubcodeIntervalDayToMinute = 109, + XdbcSubcodeIntervalDayToSecond = 110, + XdbcSubcodeIntervalHourToMinute = 111, + XdbcSubcodeIntervalHourToSecond = 112, + XdbcSubcodeIntervalMinuteToSecond = 113, +} +impl XdbcDatetimeSubcode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + XdbcDatetimeSubcode::XdbcSubcodeUnknown => "XDBC_SUBCODE_UNKNOWN", + XdbcDatetimeSubcode::XdbcSubcodeYear => "XDBC_SUBCODE_YEAR", + XdbcDatetimeSubcode::XdbcSubcodeTime => "XDBC_SUBCODE_TIME", + XdbcDatetimeSubcode::XdbcSubcodeTimestamp => "XDBC_SUBCODE_TIMESTAMP", + XdbcDatetimeSubcode::XdbcSubcodeTimeWithTimezone => { + "XDBC_SUBCODE_TIME_WITH_TIMEZONE" + } + XdbcDatetimeSubcode::XdbcSubcodeTimestampWithTimezone => { + "XDBC_SUBCODE_TIMESTAMP_WITH_TIMEZONE" + } + XdbcDatetimeSubcode::XdbcSubcodeSecond => "XDBC_SUBCODE_SECOND", + XdbcDatetimeSubcode::XdbcSubcodeYearToMonth => "XDBC_SUBCODE_YEAR_TO_MONTH", + XdbcDatetimeSubcode::XdbcSubcodeDayToHour => "XDBC_SUBCODE_DAY_TO_HOUR", + XdbcDatetimeSubcode::XdbcSubcodeDayToMinute => "XDBC_SUBCODE_DAY_TO_MINUTE", + XdbcDatetimeSubcode::XdbcSubcodeDayToSecond => "XDBC_SUBCODE_DAY_TO_SECOND", + XdbcDatetimeSubcode::XdbcSubcodeHourToMinute => "XDBC_SUBCODE_HOUR_TO_MINUTE", + XdbcDatetimeSubcode::XdbcSubcodeHourToSecond => "XDBC_SUBCODE_HOUR_TO_SECOND", + XdbcDatetimeSubcode::XdbcSubcodeMinuteToSecond => { + "XDBC_SUBCODE_MINUTE_TO_SECOND" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalYear => "XDBC_SUBCODE_INTERVAL_YEAR", + XdbcDatetimeSubcode::XdbcSubcodeIntervalMonth => { + "XDBC_SUBCODE_INTERVAL_MONTH" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalDay => "XDBC_SUBCODE_INTERVAL_DAY", + XdbcDatetimeSubcode::XdbcSubcodeIntervalHour => "XDBC_SUBCODE_INTERVAL_HOUR", + XdbcDatetimeSubcode::XdbcSubcodeIntervalMinute => { + "XDBC_SUBCODE_INTERVAL_MINUTE" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalSecond => { + "XDBC_SUBCODE_INTERVAL_SECOND" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalYearToMonth => { + "XDBC_SUBCODE_INTERVAL_YEAR_TO_MONTH" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalDayToHour => { + "XDBC_SUBCODE_INTERVAL_DAY_TO_HOUR" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalDayToMinute => { + "XDBC_SUBCODE_INTERVAL_DAY_TO_MINUTE" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalDayToSecond => { + "XDBC_SUBCODE_INTERVAL_DAY_TO_SECOND" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalHourToMinute => { + "XDBC_SUBCODE_INTERVAL_HOUR_TO_MINUTE" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalHourToSecond => { + "XDBC_SUBCODE_INTERVAL_HOUR_TO_SECOND" + } + XdbcDatetimeSubcode::XdbcSubcodeIntervalMinuteToSecond => { + "XDBC_SUBCODE_INTERVAL_MINUTE_TO_SECOND" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "XDBC_SUBCODE_UNKNOWN" => Some(Self::XdbcSubcodeUnknown), + "XDBC_SUBCODE_YEAR" => Some(Self::XdbcSubcodeYear), + "XDBC_SUBCODE_TIME" => Some(Self::XdbcSubcodeTime), + "XDBC_SUBCODE_TIMESTAMP" => Some(Self::XdbcSubcodeTimestamp), + "XDBC_SUBCODE_TIME_WITH_TIMEZONE" => Some(Self::XdbcSubcodeTimeWithTimezone), + "XDBC_SUBCODE_TIMESTAMP_WITH_TIMEZONE" => { + Some(Self::XdbcSubcodeTimestampWithTimezone) + } + "XDBC_SUBCODE_SECOND" => Some(Self::XdbcSubcodeSecond), + "XDBC_SUBCODE_YEAR_TO_MONTH" => Some(Self::XdbcSubcodeYearToMonth), + "XDBC_SUBCODE_DAY_TO_HOUR" => Some(Self::XdbcSubcodeDayToHour), + "XDBC_SUBCODE_DAY_TO_MINUTE" => Some(Self::XdbcSubcodeDayToMinute), + "XDBC_SUBCODE_DAY_TO_SECOND" => Some(Self::XdbcSubcodeDayToSecond), + "XDBC_SUBCODE_HOUR_TO_MINUTE" => Some(Self::XdbcSubcodeHourToMinute), + "XDBC_SUBCODE_HOUR_TO_SECOND" => Some(Self::XdbcSubcodeHourToSecond), + "XDBC_SUBCODE_MINUTE_TO_SECOND" => Some(Self::XdbcSubcodeMinuteToSecond), + "XDBC_SUBCODE_INTERVAL_YEAR" => Some(Self::XdbcSubcodeIntervalYear), + "XDBC_SUBCODE_INTERVAL_MONTH" => Some(Self::XdbcSubcodeIntervalMonth), + "XDBC_SUBCODE_INTERVAL_DAY" => Some(Self::XdbcSubcodeIntervalDay), + "XDBC_SUBCODE_INTERVAL_HOUR" => Some(Self::XdbcSubcodeIntervalHour), + "XDBC_SUBCODE_INTERVAL_MINUTE" => Some(Self::XdbcSubcodeIntervalMinute), + "XDBC_SUBCODE_INTERVAL_SECOND" => Some(Self::XdbcSubcodeIntervalSecond), + "XDBC_SUBCODE_INTERVAL_YEAR_TO_MONTH" => { + Some(Self::XdbcSubcodeIntervalYearToMonth) + } + "XDBC_SUBCODE_INTERVAL_DAY_TO_HOUR" => { + Some(Self::XdbcSubcodeIntervalDayToHour) + } + "XDBC_SUBCODE_INTERVAL_DAY_TO_MINUTE" => { + Some(Self::XdbcSubcodeIntervalDayToMinute) + } + "XDBC_SUBCODE_INTERVAL_DAY_TO_SECOND" => { + Some(Self::XdbcSubcodeIntervalDayToSecond) + } + "XDBC_SUBCODE_INTERVAL_HOUR_TO_MINUTE" => { + Some(Self::XdbcSubcodeIntervalHourToMinute) + } + "XDBC_SUBCODE_INTERVAL_HOUR_TO_SECOND" => { + Some(Self::XdbcSubcodeIntervalHourToSecond) + } + "XDBC_SUBCODE_INTERVAL_MINUTE_TO_SECOND" => { + Some(Self::XdbcSubcodeIntervalMinuteToSecond) + } + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum Nullable { + /// * + /// Indicates that the fields does not allow the use of null values. + NullabilityNoNulls = 0, + /// * + /// Indicates that the fields allow the use of null values. + NullabilityNullable = 1, + /// * + /// Indicates that nullability of the fields can not be determined. + NullabilityUnknown = 2, +} +impl Nullable { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Nullable::NullabilityNoNulls => "NULLABILITY_NO_NULLS", + Nullable::NullabilityNullable => "NULLABILITY_NULLABLE", + Nullable::NullabilityUnknown => "NULLABILITY_UNKNOWN", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "NULLABILITY_NO_NULLS" => Some(Self::NullabilityNoNulls), + "NULLABILITY_NULLABLE" => Some(Self::NullabilityNullable), + "NULLABILITY_UNKNOWN" => Some(Self::NullabilityUnknown), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum Searchable { + /// * + /// Indicates that column can not be used in a WHERE clause. + None = 0, + /// * + /// Indicates that the column can be used in a WHERE clause if it is using a + /// LIKE operator. + Char = 1, + /// * + /// Indicates that the column can be used In a WHERE clause with any + /// operator other than LIKE. + /// + /// - Allowed operators: comparison, quantified comparison, BETWEEN, + /// DISTINCT, IN, MATCH, and UNIQUE. + Basic = 2, + /// * + /// Indicates that the column can be used in a WHERE clause using any operator. + Full = 3, +} +impl Searchable { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Searchable::None => "SEARCHABLE_NONE", + Searchable::Char => "SEARCHABLE_CHAR", + Searchable::Basic => "SEARCHABLE_BASIC", + Searchable::Full => "SEARCHABLE_FULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SEARCHABLE_NONE" => Some(Self::None), + "SEARCHABLE_CHAR" => Some(Self::Char), + "SEARCHABLE_BASIC" => Some(Self::Basic), + "SEARCHABLE_FULL" => Some(Self::Full), + _ => None, + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -1457,4 +2714,15 @@ impl UpdateDeleteRules { UpdateDeleteRules::SetDefault => "SET_DEFAULT", } } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "CASCADE" => Some(Self::Cascade), + "RESTRICT" => Some(Self::Restrict), + "SET_NULL" => Some(Self::SetNull), + "NO_ACTION" => Some(Self::NoAction), + "SET_DEFAULT" => Some(Self::SetDefault), + _ => None, + } + } } diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs new file mode 100644 index 000000000000..133df5b044cf --- /dev/null +++ b/arrow-flight/src/sql/client.rs @@ -0,0 +1,637 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A FlightSQL Client [`FlightSqlServiceClient`] + +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use bytes::Bytes; +use std::collections::HashMap; +use std::str::FromStr; +use tonic::metadata::AsciiMetadataKey; + +use crate::decode::FlightRecordBatchStream; +use crate::encode::FlightDataEncoderBuilder; +use crate::error::FlightError; +use crate::flight_service_client::FlightServiceClient; +use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT}; +use crate::sql::{ + ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, Any, CommandGetCatalogs, CommandGetCrossReference, + CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, + CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, + CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo, +}; +use crate::trailers::extract_lazy_trailers; +use crate::{ + Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, + IpcMessage, PutResult, Ticket, +}; +use arrow_array::RecordBatch; +use arrow_buffer::Buffer; +use arrow_ipc::convert::fb_to_schema; +use arrow_ipc::reader::read_record_batch; +use arrow_ipc::{root_as_message, MessageHeader}; +use arrow_schema::{ArrowError, Schema, SchemaRef}; +use futures::{stream, TryStreamExt}; +use prost::Message; +use tonic::transport::Channel; +use tonic::{IntoRequest, Streaming}; + +/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data +/// by FlightSQL protocol. +#[derive(Debug, Clone)] +pub struct FlightSqlServiceClient { + token: Option, + headers: HashMap, + flight_client: FlightServiceClient, +} + +/// A FlightSql protocol client that can run queries against FlightSql servers +/// This client is in the "experimental" stage. It is not guaranteed to follow the spec in all instances. +/// Github issues are welcomed. +impl FlightSqlServiceClient { + /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel` + pub fn new(channel: Channel) -> Self { + let flight_client = FlightServiceClient::new(channel); + FlightSqlServiceClient { + token: None, + flight_client, + headers: HashMap::default(), + } + } + + /// Return a reference to the underlying [`FlightServiceClient`] + pub fn inner(&self) -> &FlightServiceClient { + &self.flight_client + } + + /// Return a mutable reference to the underlying [`FlightServiceClient`] + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + &mut self.flight_client + } + + /// Consume this client and return the underlying [`FlightServiceClient`] + pub fn into_inner(self) -> FlightServiceClient { + self.flight_client + } + + /// Set auth token to the given value. + pub fn set_token(&mut self, token: String) { + self.token = Some(token); + } + + /// Set header value. + pub fn set_header(&mut self, key: impl Into, value: impl Into) { + let key: String = key.into(); + let value: String = value.into(); + self.headers.insert(key, value); + } + + async fn get_flight_info_for_command( + &mut self, + cmd: M, + ) -> Result { + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let req = self.set_request_headers(descriptor.into_request())?; + let fi = self + .flight_client + .get_flight_info(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + Ok(fi) + } + + /// Execute a query on the server. + pub async fn execute( + &mut self, + query: String, + transaction_id: Option, + ) -> Result { + let cmd = CommandStatementQuery { + query, + transaction_id, + }; + self.get_flight_info_for_command(cmd).await + } + + /// Perform a `handshake` with the server, passing credentials and establishing a session + /// Returns arbitrary auth/handshake info binary blob + pub async fn handshake(&mut self, username: &str, password: &str) -> Result { + let cmd = HandshakeRequest { + protocol_version: 0, + payload: Default::default(), + }; + let mut req = tonic::Request::new(stream::iter(vec![cmd])); + let val = BASE64_STANDARD.encode(format!("{username}:{password}")); + let val = format!("Basic {val}") + .parse() + .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?; + req.metadata_mut().insert("authorization", val); + let req = self.set_request_headers(req)?; + let resp = self + .flight_client + .handshake(req) + .await + .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?; + if let Some(auth) = resp.metadata().get("authorization") { + let auth = auth + .to_str() + .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?; + let bearer = "Bearer "; + if !auth.starts_with(bearer) { + Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; + } + let auth = auth[bearer.len()..].to_string(); + self.token = Some(auth); + } + let responses: Vec = resp + .into_inner() + .try_collect() + .await + .map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?; + let resp = match responses.as_slice() { + [resp] => resp.payload.clone(), + [] => Bytes::new(), + _ => Err(ArrowError::ParseError( + "Multiple handshake responses".to_string(), + ))?, + }; + Ok(resp) + } + + /// Execute a update query on the server, and return the number of records affected + pub async fn execute_update( + &mut self, + query: String, + transaction_id: Option, + ) -> Result { + let cmd = CommandStatementUpdate { + query, + transaction_id, + }; + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let req = self.set_request_headers( + stream::iter(vec![FlightData { + flight_descriptor: Some(descriptor), + ..Default::default() + }]) + .into_request(), + )?; + let mut result = self + .flight_client + .do_put(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result: DoPutUpdateResult = any.unpack()?.unwrap(); + Ok(result.record_count) + } + + /// Request a list of catalogs as tabular FlightInfo results + pub async fn get_catalogs(&mut self) -> Result { + self.get_flight_info_for_command(CommandGetCatalogs {}) + .await + } + + /// Request a list of database schemas as tabular FlightInfo results + pub async fn get_db_schemas( + &mut self, + request: CommandGetDbSchemas, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Given a flight ticket, request to be sent the stream. Returns record batch stream reader + pub async fn do_get( + &mut self, + ticket: impl IntoRequest, + ) -> Result { + let req = self.set_request_headers(ticket.into_request())?; + + let (md, response_stream, _ext) = self + .flight_client + .do_get(req) + .await + .map_err(status_to_arrow_error)? + .into_parts(); + let (response_stream, trailers) = extract_lazy_trailers(response_stream); + + Ok(FlightRecordBatchStream::new_from_flight_data( + response_stream.map_err(FlightError::Tonic), + ) + .with_headers(md) + .with_trailers(trailers)) + } + + /// Push a stream to the flight service associated with a particular flight stream. + pub async fn do_put( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> Result, ArrowError> { + let req = self.set_request_headers(request.into_streaming_request())?; + Ok(self + .flight_client + .do_put(req) + .await + .map_err(status_to_arrow_error)? + .into_inner()) + } + + /// DoAction allows a flight client to do a specific action against a flight service + pub async fn do_action( + &mut self, + request: impl IntoRequest, + ) -> Result, ArrowError> { + let req = self.set_request_headers(request.into_request())?; + Ok(self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner()) + } + + /// Request a list of tables. + pub async fn get_tables( + &mut self, + request: CommandGetTables, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Request the primary keys for a table. + pub async fn get_primary_keys( + &mut self, + request: CommandGetPrimaryKeys, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Retrieves a description about the foreign key columns that reference the + /// primary key columns of the given table. + pub async fn get_exported_keys( + &mut self, + request: CommandGetExportedKeys, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Retrieves the foreign key columns for the given table. + pub async fn get_imported_keys( + &mut self, + request: CommandGetImportedKeys, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Retrieves a description of the foreign key columns in the given foreign key + /// table that reference the primary key or the columns representing a unique + /// constraint of the parent table (could be the same or a different table). + pub async fn get_cross_reference( + &mut self, + request: CommandGetCrossReference, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Request a list of table types. + pub async fn get_table_types(&mut self) -> Result { + self.get_flight_info_for_command(CommandGetTableTypes {}) + .await + } + + /// Request a list of SQL information. + pub async fn get_sql_info( + &mut self, + sql_infos: Vec, + ) -> Result { + let request = CommandGetSqlInfo { + info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(), + }; + self.get_flight_info_for_command(request).await + } + + /// Request XDBC SQL information. + pub async fn get_xdbc_type_info( + &mut self, + request: CommandGetXdbcTypeInfo, + ) -> Result { + self.get_flight_info_for_command(request).await + } + + /// Create a prepared statement object. + pub async fn prepare( + &mut self, + query: String, + transaction_id: Option, + ) -> Result, ArrowError> { + let cmd = ActionCreatePreparedStatementRequest { + query, + transaction_id, + }; + let action = Action { + r#type: CREATE_PREPARED_STATEMENT.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let req = self.set_request_headers(action.into_request())?; + let mut result = self + .flight_client + .do_action(req) + .await + .map_err(status_to_arrow_error)? + .into_inner(); + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap(); + let dataset_schema = match prepared_result.dataset_schema.len() { + 0 => Schema::empty(), + _ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?, + }; + let parameter_schema = match prepared_result.parameter_schema.len() { + 0 => Schema::empty(), + _ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?, + }; + Ok(PreparedStatement::new( + self.clone(), + prepared_result.prepared_statement_handle, + dataset_schema, + parameter_schema, + )) + } + + /// Explicitly shut down and clean up the client. + pub async fn close(&mut self) -> Result<(), ArrowError> { + Ok(()) + } + + fn set_request_headers( + &self, + mut req: tonic::Request, + ) -> Result, ArrowError> { + for (k, v) in &self.headers { + let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| { + ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) + })?; + let v = v.parse().map_err(|e| { + ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}")) + })?; + req.metadata_mut().insert(k, v); + } + if let Some(token) = &self.token { + let val = format!("Bearer {token}").parse().map_err(|e| { + ArrowError::ParseError(format!("Cannot convert token to header value: {e}")) + })?; + req.metadata_mut().insert("authorization", val); + } + Ok(req) + } +} + +/// A PreparedStatement +#[derive(Debug, Clone)] +pub struct PreparedStatement { + flight_sql_client: FlightSqlServiceClient, + parameter_binding: Option, + handle: Bytes, + dataset_schema: Schema, + parameter_schema: Schema, +} + +impl PreparedStatement { + pub(crate) fn new( + flight_client: FlightSqlServiceClient, + handle: impl Into, + dataset_schema: Schema, + parameter_schema: Schema, + ) -> Self { + PreparedStatement { + flight_sql_client: flight_client, + parameter_binding: None, + handle: handle.into(), + dataset_schema, + parameter_schema, + } + } + + /// Executes the prepared statement query on the server. + pub async fn execute(&mut self) -> Result { + self.write_bind_params().await?; + + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle: self.handle.clone(), + }; + + let result = self + .flight_sql_client + .get_flight_info_for_command(cmd) + .await?; + Ok(result) + } + + /// Executes the prepared statement update query on the server. + pub async fn execute_update(&mut self) -> Result { + self.write_bind_params().await?; + + let cmd = CommandPreparedStatementUpdate { + prepared_statement_handle: self.handle.clone(), + }; + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let mut result = self + .flight_sql_client + .do_put(stream::iter(vec![FlightData { + flight_descriptor: Some(descriptor), + ..Default::default() + }])) + .await?; + let result = result + .message() + .await + .map_err(status_to_arrow_error)? + .unwrap(); + let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result: DoPutUpdateResult = any.unpack()?.unwrap(); + Ok(result.record_count) + } + + /// Retrieve the parameter schema from the query. + pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> { + Ok(&self.parameter_schema) + } + + /// Retrieve the ResultSet schema from the query. + pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> { + Ok(&self.dataset_schema) + } + + /// Set a RecordBatch that contains the parameters that will be bind. + pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> { + self.parameter_binding = Some(parameter_binding); + Ok(()) + } + + /// Submit parameters to the server, if any have been set on this prepared statement instance + async fn write_bind_params(&mut self) -> Result<(), ArrowError> { + if let Some(ref params_batch) = self.parameter_binding { + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle: self.handle.clone(), + }; + + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let flight_stream_builder = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .with_schema(params_batch.schema()); + let flight_data = flight_stream_builder + .build(futures::stream::iter( + self.parameter_binding.clone().map(Ok), + )) + .try_collect::>() + .await + .map_err(flight_error_to_arrow_error)?; + + self.flight_sql_client + .do_put(stream::iter(flight_data)) + .await? + .try_collect::>() + .await + .map_err(status_to_arrow_error)?; + } + + Ok(()) + } + + /// Close the prepared statement, so that this PreparedStatement can not used + /// anymore and server can free up any resources. + pub async fn close(mut self) -> Result<(), ArrowError> { + let cmd = ActionClosePreparedStatementRequest { + prepared_statement_handle: self.handle.clone(), + }; + let action = Action { + r#type: CLOSE_PREPARED_STATEMENT.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let _ = self.flight_sql_client.do_action(action).await?; + Ok(()) + } +} + +fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError { + ArrowError::IpcError(err.to_string()) +} + +fn status_to_arrow_error(status: tonic::Status) -> ArrowError { + ArrowError::IpcError(format!("{status:?}")) +} + +fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { + match err { + FlightError::Arrow(e) => e, + e => ArrowError::ExternalError(Box::new(e)), + } +} + +// A polymorphic structure to natively represent different types of data contained in `FlightData` +pub enum ArrowFlightData { + RecordBatch(RecordBatch), + Schema(Schema), +} + +/// Extract `Schema` or `RecordBatch`es from the `FlightData` wire representation +pub fn arrow_data_from_flight_data( + flight_data: FlightData, + arrow_schema_ref: &SchemaRef, +) -> Result { + let ipc_message = root_as_message(&flight_data.data_header[..]) + .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; + + match ipc_message.header_type() { + MessageHeader::RecordBatch => { + let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a record batch".to_string(), + ) + })?; + + let dictionaries_by_field = HashMap::new(); + let record_batch = read_record_batch( + &Buffer::from_bytes(flight_data.data_body.into()), + ipc_record_batch, + arrow_schema_ref.clone(), + &dictionaries_by_field, + None, + &ipc_message.version(), + )?; + Ok(ArrowFlightData::RecordBatch(record_batch)) + } + MessageHeader::Schema => { + let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a schema".to_string(), + ) + })?; + + let arrow_schema = fb_to_schema(ipc_schema); + Ok(ArrowFlightData::Schema(arrow_schema)) + } + MessageHeader::DictionaryBatch => { + let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a dictionary batch".to_string(), + ) + })?; + Err(ArrowError::NotYetImplemented( + "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(), + )) + } + MessageHeader::Tensor => { + let _ = ipc_message.header_as_tensor().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a tensor".to_string(), + ) + })?; + Err(ArrowError::NotYetImplemented( + "no idea on how to convert an ipc tensor to an arrow type".to_string(), + )) + } + MessageHeader::SparseTensor => { + let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| { + ArrowError::ComputeError( + "Unable to convert flight data header to a sparse tensor".to_string(), + ) + })?; + Err(ArrowError::NotYetImplemented( + "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(), + )) + } + _ => Err(ArrowError::ComputeError(format!( + "Unable to convert message with header_type: '{:?}' to arrow data", + ipc_message.header_type() + ))), + } +} diff --git a/arrow-flight/src/sql/metadata/catalogs.rs b/arrow-flight/src/sql/metadata/catalogs.rs new file mode 100644 index 000000000000..327fed81077b --- /dev/null +++ b/arrow-flight/src/sql/metadata/catalogs.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use once_cell::sync::Lazy; + +use crate::error::Result; +use crate::sql::CommandGetCatalogs; + +/// A builder for a [`CommandGetCatalogs`] response. +/// +/// Builds rows like this: +/// +/// * catalog_name: utf8, +pub struct GetCatalogsBuilder { + catalogs: Vec, +} + +impl CommandGetCatalogs { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetCatalogsBuilder { + self.into() + } +} + +impl From for GetCatalogsBuilder { + fn from(_: CommandGetCatalogs) -> Self { + Self::new() + } +} + +impl Default for GetCatalogsBuilder { + fn default() -> Self { + Self::new() + } +} + +impl GetCatalogsBuilder { + /// Create a new instance of [`GetCatalogsBuilder`] + pub fn new() -> Self { + Self { + catalogs: Vec::new(), + } + } + + /// Append a row + pub fn append(&mut self, catalog_name: impl Into) { + self.catalogs.push(catalog_name.into()); + } + + /// builds a `RecordBatch` with the correct schema for a + /// [`CommandGetCatalogs`] response + pub fn build(self) -> Result { + let Self { catalogs } = self; + + let batch = RecordBatch::try_new( + Arc::clone(&GET_CATALOG_SCHEMA), + vec![Arc::new(StringArray::from_iter_values(catalogs)) as _], + )?; + + Ok(batch) + } + + /// Returns the schema that will result from [`CommandGetCatalogs`] + /// + /// [`CommandGetCatalogs`]: crate::sql::CommandGetCatalogs + pub fn schema(&self) -> SchemaRef { + get_catalogs_schema() + } +} + +fn get_catalogs_schema() -> SchemaRef { + Arc::clone(&GET_CATALOG_SCHEMA) +} + +/// The schema for GetCatalogs +static GET_CATALOG_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![Field::new( + "catalog_name", + DataType::Utf8, + false, + )])) +}); diff --git a/arrow-flight/src/sql/metadata/db_schemas.rs b/arrow-flight/src/sql/metadata/db_schemas.rs new file mode 100644 index 000000000000..303d11cd74ca --- /dev/null +++ b/arrow-flight/src/sql/metadata/db_schemas.rs @@ -0,0 +1,286 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetDbSchemasBuilder`] for building responses to [`CommandGetDbSchemas`] queries. +//! +//! [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas + +use std::sync::Arc; + +use arrow_arith::boolean::and; +use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch, StringArray}; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::{filter::filter_record_batch, take::take}; +use arrow_string::like::like; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::CommandGetDbSchemas; + +/// A builder for a [`CommandGetDbSchemas`] response. +/// +/// Builds rows like this: +/// +/// * catalog_name: utf8, +/// * db_schema_name: utf8, +pub struct GetDbSchemasBuilder { + // Specifies the Catalog to search for the tables. + // - An empty string retrieves those without a catalog. + // - If omitted the catalog name is not used to narrow the search. + catalog_filter: Option, + // Optional filters to apply + db_schema_filter_pattern: Option, + // array builder for catalog names + catalog_name: StringBuilder, + // array builder for schema names + db_schema_name: StringBuilder, +} + +impl CommandGetDbSchemas { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetDbSchemasBuilder { + self.into() + } +} + +impl From for GetDbSchemasBuilder { + fn from(value: CommandGetDbSchemas) -> Self { + Self::new(value.catalog, value.db_schema_filter_pattern) + } +} + +impl GetDbSchemasBuilder { + /// Create a new instance of [`GetDbSchemasBuilder`] + /// + /// # Parameters + /// + /// - `catalog`: Specifies the Catalog to search for the tables. + /// - An empty string retrieves those without a catalog. + /// - If omitted the catalog name is not used to narrow the search. + /// - `db_schema_filter_pattern`: Specifies a filter pattern for schemas to search for. + /// When no pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// + /// [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas + pub fn new( + catalog: Option>, + db_schema_filter_pattern: Option>, + ) -> Self { + Self { + catalog_filter: catalog.map(|v| v.into()), + db_schema_filter_pattern: db_schema_filter_pattern.map(|v| v.into()), + catalog_name: StringBuilder::new(), + db_schema_name: StringBuilder::new(), + } + } + + /// Append a row + /// + /// In case the catalog should be considered as empty, pass in an empty string '""'. + pub fn append(&mut self, catalog_name: impl AsRef, schema_name: impl AsRef) { + self.catalog_name.append_value(catalog_name); + self.db_schema_name.append_value(schema_name); + } + + /// builds a `RecordBatch` with the correct schema for a `CommandGetDbSchemas` response + pub fn build(self) -> Result { + let schema = self.schema(); + let Self { + catalog_filter, + db_schema_filter_pattern, + mut catalog_name, + mut db_schema_name, + } = self; + + // Make the arrays + let catalog_name = catalog_name.finish(); + let db_schema_name = db_schema_name.finish(); + + let mut filters = vec![]; + + if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { + // use like kernel to get wildcard matching + let scalar = StringArray::new_scalar(db_schema_filter_pattern); + filters.push(like(&db_schema_name, &scalar)?) + } + + if let Some(catalog_filter_name) = catalog_filter { + let scalar = StringArray::new_scalar(catalog_filter_name); + filters.push(eq(&catalog_name, &scalar)?); + } + + // `AND` any filters together + let mut total_filter = None; + while let Some(filter) = filters.pop() { + let new_filter = match total_filter { + Some(total_filter) => and(&total_filter, &filter)?, + None => filter, + }; + total_filter = Some(new_filter); + } + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + ], + )?; + + // Apply the filters if needed + let filtered_batch = if let Some(filter) = total_filter { + filter_record_batch(&batch, &filter)? + } else { + batch + }; + + // Order filtered results by catalog_name, then db_schema_name + let indices = lexsort_to_indices(filtered_batch.columns()); + let columns = filtered_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(filtered_batch.schema(), columns)?) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetDbSchemas`] + pub fn schema(&self) -> SchemaRef { + get_db_schemas_schema() + } +} + +fn get_db_schemas_schema() -> SchemaRef { + Arc::clone(&GET_DB_SCHEMAS_SCHEMA) +} + +/// The schema for GetDbSchemas +static GET_DB_SCHEMAS_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + ])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{StringArray, UInt32Array}; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_db_schemas_schema(), + vec![ + Arc::new(StringArray::from(vec![ + "a_catalog", + "a_catalog", + "b_catalog", + "b_catalog", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_schema", "b_schema", "a_schema", "b_schema", + ])) as ArrayRef, + ], + ) + .unwrap() + } + + #[test] + fn test_schemas_are_filtered() { + let ref_batch = get_ref_batch(); + + let mut builder = GetDbSchemasBuilder::new(None::, None::); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch); + + let mut builder = GetDbSchemasBuilder::new(None::, Some("a%")); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + let indices = UInt32Array::from(vec![0, 2]); + let ref_filtered = RecordBatch::try_new( + get_db_schemas_schema(), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + + assert_eq!(schema_batch, ref_filtered); + } + + #[test] + fn test_schemas_are_sorted() { + let ref_batch = get_ref_batch(); + + let mut builder = GetDbSchemasBuilder::new(None::, None::); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("a_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch) + } + + #[test] + fn test_builder_from_query() { + let ref_batch = get_ref_batch(); + let query = CommandGetDbSchemas { + catalog: Some("a_catalog".into()), + db_schema_filter_pattern: Some("b%".into()), + }; + + let mut builder = query.into_builder(); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + let indices = UInt32Array::from(vec![1]); + let ref_filtered = RecordBatch::try_new( + get_db_schemas_schema(), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + + assert_eq!(schema_batch, ref_filtered); + } +} diff --git a/arrow-flight/src/sql/metadata/mod.rs b/arrow-flight/src/sql/metadata/mod.rs new file mode 100644 index 000000000000..1e9881ffa70e --- /dev/null +++ b/arrow-flight/src/sql/metadata/mod.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Builders and function for building responses to FlightSQL metadata +//! / information schema requests. +//! +//! - [`GetCatalogsBuilder`] for building responses to [`CommandGetCatalogs`] queries. +//! - [`GetDbSchemasBuilder`] for building responses to [`CommandGetDbSchemas`] queries. +//! - [`GetTablesBuilder`]for building responses to [`CommandGetTables`] queries. +//! - [`SqlInfoDataBuilder`]for building responses to [`CommandGetSqlInfo`] queries. +//! - [`XdbcTypeInfoDataBuilder`]for building responses to [`CommandGetXdbcTypeInfo`] queries. +//! +//! [`CommandGetCatalogs`]: crate::sql::CommandGetCatalogs +//! [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas +//! [`CommandGetTables`]: crate::sql::CommandGetTables +//! [`CommandGetSqlInfo`]: crate::sql::CommandGetSqlInfo +//! [`CommandGetXdbcTypeInfo`]: crate::sql::CommandGetXdbcTypeInfo + +mod catalogs; +mod db_schemas; +mod sql_info; +mod tables; +mod xdbc_info; + +pub use catalogs::GetCatalogsBuilder; +pub use db_schemas::GetDbSchemasBuilder; +pub use sql_info::{SqlInfoData, SqlInfoDataBuilder}; +pub use tables::GetTablesBuilder; +pub use xdbc_info::{XdbcTypeInfo, XdbcTypeInfoData, XdbcTypeInfoDataBuilder}; + +use arrow_array::ArrayRef; +use arrow_array::UInt32Array; +use arrow_row::RowConverter; +use arrow_row::SortField; + +/// Helper function to sort all the columns in an array +fn lexsort_to_indices(arrays: &[ArrayRef]) -> UInt32Array { + let fields = arrays + .iter() + .map(|a| SortField::new(a.data_type().clone())) + .collect(); + let converter = RowConverter::new(fields).unwrap(); + let rows = converter.convert_columns(arrays).unwrap(); + let mut sort: Vec<_> = rows.iter().enumerate().collect(); + sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); + UInt32Array::from_iter_values(sort.iter().map(|(i, _)| *i as u32)) +} + +#[cfg(test)] +mod tests { + use arrow_array::RecordBatch; + use arrow_cast::pretty::pretty_format_batches; + pub fn assert_batches_eq(batches: &[RecordBatch], expected_lines: &[&str]) { + let formatted = pretty_format_batches(batches).unwrap().to_string(); + let actual_lines: Vec<_> = formatted.trim().lines().collect(); + assert_eq!( + &actual_lines, expected_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + } +} diff --git a/arrow-flight/src/sql/metadata/sql_info.rs b/arrow-flight/src/sql/metadata/sql_info.rs new file mode 100644 index 000000000000..d4584f4a6827 --- /dev/null +++ b/arrow-flight/src/sql/metadata/sql_info.rs @@ -0,0 +1,561 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for building responses to [`CommandGetSqlInfo`] metadata requests. +//! +//! - [`SqlInfoDataBuilder`] - a builder for collecting sql infos +//! and building a conformant `RecordBatch` with sql info server metadata. +//! - [`SqlInfoData`] - a helper type wrapping a `RecordBatch` +//! used for storing sql info server metadata. +//! - [`GetSqlInfoBuilder`] - a builder for consructing [`CommandGetSqlInfo`] responses. +//! + +use std::collections::{BTreeMap, HashMap}; +use std::sync::Arc; + +use arrow_arith::boolean::or; +use arrow_array::array::{Array, UInt32Array, UnionArray}; +use arrow_array::builder::{ + ArrayBuilder, BooleanBuilder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, MapBuilder, + StringBuilder, UInt32Builder, +}; +use arrow_array::{RecordBatch, Scalar}; +use arrow_data::ArrayData; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, UnionFields, UnionMode}; +use arrow_select::filter::filter_record_batch; +use once_cell::sync::Lazy; + +use crate::error::Result; +use crate::sql::{CommandGetSqlInfo, SqlInfo}; + +/// Represents a dynamic value +#[derive(Debug, Clone, PartialEq)] +pub enum SqlInfoValue { + String(String), + Bool(bool), + BigInt(i64), + Bitmask(i32), + StringList(Vec), + ListMap(BTreeMap>), +} + +impl From<&str> for SqlInfoValue { + fn from(value: &str) -> Self { + Self::String(value.to_string()) + } +} + +impl From for SqlInfoValue { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} + +impl From for SqlInfoValue { + fn from(value: i32) -> Self { + Self::Bitmask(value) + } +} + +impl From for SqlInfoValue { + fn from(value: i64) -> Self { + Self::BigInt(value) + } +} + +impl From<&[&str]> for SqlInfoValue { + fn from(values: &[&str]) -> Self { + let values = values.iter().map(|s| s.to_string()).collect(); + Self::StringList(values) + } +} + +impl From> for SqlInfoValue { + fn from(values: Vec) -> Self { + Self::StringList(values) + } +} + +impl From>> for SqlInfoValue { + fn from(value: BTreeMap>) -> Self { + Self::ListMap(value) + } +} + +impl From>> for SqlInfoValue { + fn from(value: HashMap>) -> Self { + Self::ListMap(value.into_iter().collect()) + } +} + +impl From<&HashMap>> for SqlInfoValue { + fn from(value: &HashMap>) -> Self { + Self::ListMap( + value + .iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect(), + ) + } +} + +/// Something that can be converted into u32 (the represenation of a [`SqlInfo`] name) +pub trait SqlInfoName { + fn as_u32(&self) -> u32; +} + +impl SqlInfoName for SqlInfo { + fn as_u32(&self) -> u32 { + // SqlInfos are u32 in the flight spec, but for some reason + // SqlInfo repr is an i32, so convert between them + u32::try_from(i32::from(*self)).expect("SqlInfo fit into u32") + } +} + +// Allow passing u32 directly into to with_sql_info +impl SqlInfoName for u32 { + fn as_u32(&self) -> u32 { + *self + } +} + +/// Handles creating the dense [`UnionArray`] described by [flightsql] +/// +/// incrementally build types/offset of the dense union. See [Union Spec] for details. +/// +/// ```text +/// * value: dense_union< +/// * string_value: utf8, +/// * bool_value: bool, +/// * bigint_value: int64, +/// * int32_bitmask: int32, +/// * string_list: list +/// * int32_to_int32_list_map: map> +/// * > +/// ``` +///[flightsql]: (https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/FlightSql.proto#L32-L43 +///[Union Spec]: https://arrow.apache.org/docs/format/Columnar.html#dense-union +struct SqlInfoUnionBuilder { + // Values for each child type + string_values: StringBuilder, + bool_values: BooleanBuilder, + bigint_values: Int64Builder, + int32_bitmask_values: Int32Builder, + string_list_values: ListBuilder, + int32_to_int32_list_map_values: MapBuilder>, + type_ids: Int8Builder, + offsets: Int32Builder, +} + +/// [`DataType`] for the output union array +static UNION_TYPE: Lazy = Lazy::new(|| { + let fields = vec![ + Field::new("string_value", DataType::Utf8, false), + Field::new("bool_value", DataType::Boolean, false), + Field::new("bigint_value", DataType::Int64, false), + Field::new("int32_bitmask", DataType::Int32, false), + // treat list as nullable b/c that is what the builders make + Field::new( + "string_list", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ), + Field::new( + "int32_to_int32_list_map", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new( + "values", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + ])), + false, + )), + false, + ), + true, + ), + ]; + + // create "type ids", one for each type, assume they go from 0 .. num_fields + let type_ids: Vec = (0..fields.len()).map(|v| v as i8).collect(); + + DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) +}); + +impl SqlInfoUnionBuilder { + pub fn new() -> Self { + Self { + string_values: StringBuilder::new(), + bool_values: BooleanBuilder::new(), + bigint_values: Int64Builder::new(), + int32_bitmask_values: Int32Builder::new(), + string_list_values: ListBuilder::new(StringBuilder::new()), + int32_to_int32_list_map_values: MapBuilder::new( + None, + Int32Builder::new(), + ListBuilder::new(Int32Builder::new()), + ), + type_ids: Int8Builder::new(), + offsets: Int32Builder::new(), + } + } + + /// Returns the DataType created by this builder + pub fn schema() -> &'static DataType { + &UNION_TYPE + } + + /// Append the specified value to this builder + pub fn append_value(&mut self, v: &SqlInfoValue) -> Result<()> { + // typeid is which child and len is the child array's length + // *after* adding the value + let (type_id, len) = match v { + SqlInfoValue::String(v) => { + self.string_values.append_value(v); + (0, self.string_values.len()) + } + SqlInfoValue::Bool(v) => { + self.bool_values.append_value(*v); + (1, self.bool_values.len()) + } + SqlInfoValue::BigInt(v) => { + self.bigint_values.append_value(*v); + (2, self.bigint_values.len()) + } + SqlInfoValue::Bitmask(v) => { + self.int32_bitmask_values.append_value(*v); + (3, self.int32_bitmask_values.len()) + } + SqlInfoValue::StringList(values) => { + // build list + for v in values { + self.string_list_values.values().append_value(v); + } + // complete the list + self.string_list_values.append(true); + (4, self.string_list_values.len()) + } + SqlInfoValue::ListMap(values) => { + // build map + for (k, v) in values.clone() { + self.int32_to_int32_list_map_values.keys().append_value(k); + self.int32_to_int32_list_map_values + .values() + .append_value(v.into_iter().map(Some)); + } + // complete the list + self.int32_to_int32_list_map_values.append(true)?; + (5, self.int32_to_int32_list_map_values.len()) + } + }; + + self.type_ids.append_value(type_id); + let len = i32::try_from(len).expect("offset fit in i32"); + self.offsets.append_value(len - 1); + Ok(()) + } + + /// Complete the construction and build the [`UnionArray`] + pub fn finish(self) -> UnionArray { + let Self { + mut string_values, + mut bool_values, + mut bigint_values, + mut int32_bitmask_values, + mut string_list_values, + mut int32_to_int32_list_map_values, + mut type_ids, + mut offsets, + } = self; + let type_ids = type_ids.finish(); + let offsets = offsets.finish(); + + // form the correct ArrayData + + let len = offsets.len(); + let null_bit_buffer = None; + let offset = 0; + + let buffers = vec![ + type_ids.into_data().buffers()[0].clone(), + offsets.into_data().buffers()[0].clone(), + ]; + + let child_data = vec![ + string_values.finish().into_data(), + bool_values.finish().into_data(), + bigint_values.finish().into_data(), + int32_bitmask_values.finish().into_data(), + string_list_values.finish().into_data(), + int32_to_int32_list_map_values.finish().into_data(), + ]; + + let data = ArrayData::try_new( + UNION_TYPE.clone(), + len, + null_bit_buffer, + offset, + buffers, + child_data, + ) + .expect("Correctly created UnionArray"); + + UnionArray::from(data) + } +} + +/// Helper to create [`CommandGetSqlInfo`] responses. +/// +/// [`CommandGetSqlInfo`] are metadata requests used by a Flight SQL +/// server to communicate supported capabilities to Flight SQL clients. +/// +/// Servers constuct - usually static - [`SqlInfoData`] via the [`SqlInfoDataBuilder`], +/// and build responses using [`CommandGetSqlInfo::into_builder`] +#[derive(Debug, Clone, PartialEq)] +pub struct SqlInfoDataBuilder { + /// Use BTreeMap to ensure the values are sorted by value as + /// to make output consistent + /// + /// Use u32 to support "custom" sql info values that are not + /// part of the SqlInfo enum + infos: BTreeMap, +} + +impl Default for SqlInfoDataBuilder { + fn default() -> Self { + Self::new() + } +} + +impl SqlInfoDataBuilder { + pub fn new() -> Self { + Self { + infos: BTreeMap::new(), + } + } + + /// register the specific sql metadata item + pub fn append(&mut self, name: impl SqlInfoName, value: impl Into) { + self.infos.insert(name.as_u32(), value.into()); + } + + /// Encode the contents of this list according to the [FlightSQL spec] + /// + /// [FlightSQL spec]: (https://github.com/apache/arrow/blob/f9324b79bf4fc1ec7e97b32e3cce16e75ef0f5e3/format/FlightSql.proto#L32-L43 + pub fn build(self) -> Result { + let mut name_builder = UInt32Builder::new(); + let mut value_builder = SqlInfoUnionBuilder::new(); + + let mut names: Vec<_> = self.infos.keys().cloned().collect(); + names.sort_unstable(); + + for key in names { + let (name, value) = self.infos.get_key_value(&key).unwrap(); + name_builder.append_value(*name); + value_builder.append_value(value)? + } + + let batch = RecordBatch::try_from_iter(vec![ + ("info_name", Arc::new(name_builder.finish()) as _), + ("value", Arc::new(value_builder.finish()) as _), + ])?; + + Ok(SqlInfoData { batch }) + } + + /// Return the [`Schema`] for a GetSchema RPC call with [`crate::sql::CommandGetSqlInfo`] + pub fn schema() -> &'static Schema { + // It is always the same + &SQL_INFO_SCHEMA + } +} + +/// A builder for [`SqlInfoData`] which is used to create [`CommandGetSqlInfo`] responses. +/// +/// # Example +/// ``` +/// # use arrow_flight::sql::{metadata::SqlInfoDataBuilder, SqlInfo, SqlSupportedTransaction}; +/// // Create the list of metadata describing the server +/// let mut builder = SqlInfoDataBuilder::new(); +/// builder.append(SqlInfo::FlightSqlServerName, "server name"); +/// // ... add other SqlInfo here .. +/// builder.append( +/// SqlInfo::FlightSqlServerTransaction, +/// SqlSupportedTransaction::Transaction as i32, +/// ); +/// +/// // Create the batch to send back to the client +/// let info_data = builder.build().unwrap(); +/// ``` +/// +/// [protos]: https://github.com/apache/arrow/blob/6d3d2fca2c9693231fa1e52c142ceef563fc23f9/format/FlightSql.proto#L71-L820 +pub struct SqlInfoData { + batch: RecordBatch, +} + +impl SqlInfoData { + /// Return a [`RecordBatch`] containing only the requested `u32`, if any + /// from [`CommandGetSqlInfo`] + pub fn record_batch(&self, info: impl IntoIterator) -> Result { + let arr = self.batch.column(0); + let type_filter = info + .into_iter() + .map(|tt| { + let s = UInt32Array::from(vec![tt]); + eq(arr, &Scalar::new(&s)) + }) + .collect::, _>>()? + .into_iter() + // We know the arrays are of same length as they are produced from the same root array + .reduce(|filter, arr| or(&filter, &arr).unwrap()); + if let Some(filter) = type_filter { + Ok(filter_record_batch(&self.batch, &filter)?) + } else { + Ok(self.batch.clone()) + } + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetSqlInfo`] + pub fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +/// A builder for a [`CommandGetSqlInfo`] response. +pub struct GetSqlInfoBuilder<'a> { + /// requested `SqlInfo`s. If empty means return all infos. + info: Vec, + infos: &'a SqlInfoData, +} + +impl CommandGetSqlInfo { + /// Create a builder suitable for constructing a response + pub fn into_builder(self, infos: &SqlInfoData) -> GetSqlInfoBuilder { + GetSqlInfoBuilder { + info: self.info, + infos, + } + } +} + +impl GetSqlInfoBuilder<'_> { + /// Builds a `RecordBatch` with the correct schema for a [`CommandGetSqlInfo`] response + pub fn build(self) -> Result { + self.infos.record_batch(self.info) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetSqlInfo`] + pub fn schema(&self) -> SchemaRef { + self.infos.schema() + } +} + +// The schema produced by [`SqlInfoData`] +static SQL_INFO_SCHEMA: Lazy = Lazy::new(|| { + Schema::new(vec![ + Field::new("info_name", DataType::UInt32, false), + Field::new("value", SqlInfoUnionBuilder::schema().clone(), false), + ]) +}); + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::SqlInfoDataBuilder; + use crate::sql::metadata::tests::assert_batches_eq; + use crate::sql::{SqlInfo, SqlNullOrdering, SqlSupportedTransaction, SqlSupportsConvert}; + + #[test] + fn test_sql_infos() { + let mut convert: HashMap> = HashMap::new(); + convert.insert( + SqlSupportsConvert::SqlConvertInteger as i32, + vec![ + SqlSupportsConvert::SqlConvertFloat as i32, + SqlSupportsConvert::SqlConvertReal as i32, + ], + ); + + let mut builder = SqlInfoDataBuilder::new(); + // str + builder.append(SqlInfo::SqlIdentifierQuoteChar, r#"""#); + // bool + builder.append(SqlInfo::SqlDdlCatalog, false); + // i32 + builder.append( + SqlInfo::SqlNullOrdering, + SqlNullOrdering::SqlNullsSortedHigh as i32, + ); + // i64 + builder.append(SqlInfo::SqlMaxBinaryLiteralLength, i32::MAX as i64); + // [str] + builder.append(SqlInfo::SqlKeywords, &["SELECT", "DELETE"] as &[&str]); + builder.append(SqlInfo::SqlSupportsConvert, &convert); + + let batch = builder.build().unwrap().record_batch(None).unwrap(); + + let expected = vec![ + "+-----------+----------------------------------------+", + "| info_name | value |", + "+-----------+----------------------------------------+", + "| 500 | {bool_value=false} |", + "| 504 | {string_value=\"} |", + "| 507 | {int32_bitmask=0} |", + "| 508 | {string_list=[SELECT, DELETE]} |", + "| 517 | {int32_to_int32_list_map={7: [6, 13]}} |", + "| 541 | {bigint_value=2147483647} |", + "+-----------+----------------------------------------+", + ]; + + assert_batches_eq(&[batch], &expected); + } + + #[test] + fn test_filter_sql_infos() { + let mut builder = SqlInfoDataBuilder::new(); + builder.append(SqlInfo::FlightSqlServerName, "server name"); + builder.append( + SqlInfo::FlightSqlServerTransaction, + SqlSupportedTransaction::Transaction as i32, + ); + let data = builder.build().unwrap(); + + let batch = data.record_batch(None).unwrap(); + assert_eq!(batch.num_rows(), 2); + + let batch = data + .record_batch([SqlInfo::FlightSqlServerTransaction as u32]) + .unwrap(); + let mut ref_builder = SqlInfoDataBuilder::new(); + ref_builder.append( + SqlInfo::FlightSqlServerTransaction, + SqlSupportedTransaction::Transaction as i32, + ); + let ref_batch = ref_builder.build().unwrap().record_batch(None).unwrap(); + + assert_eq!(batch, ref_batch); + } +} diff --git a/arrow-flight/src/sql/metadata/tables.rs b/arrow-flight/src/sql/metadata/tables.rs new file mode 100644 index 000000000000..7ffb76fa1d5f --- /dev/null +++ b/arrow-flight/src/sql/metadata/tables.rs @@ -0,0 +1,476 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetTablesBuilder`] for building responses to [`CommandGetTables`] queries. +//! +//! [`CommandGetTables`]: crate::sql::CommandGetTables + +use std::sync::Arc; + +use arrow_arith::boolean::{and, or}; +use arrow_array::builder::{BinaryBuilder, StringBuilder}; +use arrow_array::{ArrayRef, RecordBatch, StringArray}; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::{filter::filter_record_batch, take::take}; +use arrow_string::like::like; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::CommandGetTables; +use crate::{IpcMessage, IpcWriteOptions, SchemaAsIpc}; + +/// A builder for a [`CommandGetTables`] response. +/// +/// Builds rows like this: +/// +/// * catalog_name: utf8, +/// * db_schema_name: utf8, +/// * table_name: utf8 not null, +/// * table_type: utf8 not null, +/// * (optional) table_schema: bytes not null (schema of the table as described +/// in Schema.fbs::Schema it is serialized as an IPC message.) +pub struct GetTablesBuilder { + catalog_filter: Option, + table_types_filter: Vec, + // Optional filters to apply to schemas + db_schema_filter_pattern: Option, + // Optional filters to apply to tables + table_name_filter_pattern: Option, + // array builder for catalog names + catalog_name: StringBuilder, + // array builder for db schema names + db_schema_name: StringBuilder, + // array builder for tables names + table_name: StringBuilder, + // array builder for table types + table_type: StringBuilder, + // array builder for table schemas + table_schema: Option, +} + +impl CommandGetTables { + /// Create a builder suitable for constructing a response + pub fn into_builder(self) -> GetTablesBuilder { + self.into() + } +} + +impl From for GetTablesBuilder { + fn from(value: CommandGetTables) -> Self { + Self::new( + value.catalog, + value.db_schema_filter_pattern, + value.table_name_filter_pattern, + value.table_types, + value.include_schema, + ) + } +} + +impl GetTablesBuilder { + /// Create a new instance of [`GetTablesBuilder`] + /// + /// # Parameters + /// + /// - `catalog`: Specifies the Catalog to search for the tables. + /// - An empty string retrieves those without a catalog. + /// - If omitted the catalog name is not used to narrow the search. + /// - `db_schema_filter_pattern`: Specifies a filter pattern for schemas to search for. + /// When no pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// - `table_name_filter_pattern`: Specifies a filter pattern for tables to search for. + /// When no pattern is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// - `table_types`: Specifies a filter of table types which must match. + /// An empy Vec matches all table types. + /// - `include_schema`: Specifies if the Arrow schema should be returned for found tables. + /// + /// [`CommandGetTables`]: crate::sql::CommandGetTables + pub fn new( + catalog: Option>, + db_schema_filter_pattern: Option>, + table_name_filter_pattern: Option>, + table_types: impl IntoIterator>, + include_schema: bool, + ) -> Self { + let table_schema = if include_schema { + Some(BinaryBuilder::new()) + } else { + None + }; + Self { + catalog_filter: catalog.map(|s| s.into()), + table_types_filter: table_types.into_iter().map(|tt| tt.into()).collect(), + db_schema_filter_pattern: db_schema_filter_pattern.map(|s| s.into()), + table_name_filter_pattern: table_name_filter_pattern.map(|t| t.into()), + catalog_name: StringBuilder::new(), + db_schema_name: StringBuilder::new(), + table_name: StringBuilder::new(), + table_type: StringBuilder::new(), + table_schema, + } + } + + /// Append a row + pub fn append( + &mut self, + catalog_name: impl AsRef, + schema_name: impl AsRef, + table_name: impl AsRef, + table_type: impl AsRef, + table_schema: &Schema, + ) -> Result<()> { + self.catalog_name.append_value(catalog_name); + self.db_schema_name.append_value(schema_name); + self.table_name.append_value(table_name); + self.table_type.append_value(table_type); + if let Some(self_table_schema) = self.table_schema.as_mut() { + let options = IpcWriteOptions::default(); + // encode the schema into the correct form + let message: std::result::Result = + SchemaAsIpc::new(table_schema, &options).try_into(); + let IpcMessage(schema) = message?; + self_table_schema.append_value(schema); + } + + Ok(()) + } + + /// builds a `RecordBatch` for `CommandGetTables` + pub fn build(self) -> Result { + let schema = self.schema(); + let Self { + catalog_filter, + table_types_filter, + db_schema_filter_pattern, + table_name_filter_pattern, + + mut catalog_name, + mut db_schema_name, + mut table_name, + mut table_type, + table_schema, + } = self; + + // Make the arrays + let catalog_name = catalog_name.finish(); + let db_schema_name = db_schema_name.finish(); + let table_name = table_name.finish(); + let table_type = table_type.finish(); + let table_schema = table_schema.map(|mut table_schema| table_schema.finish()); + + // apply any filters, getting a BooleanArray that represents + // the rows that passed the filter + let mut filters = vec![]; + + if let Some(catalog_filter_name) = catalog_filter { + let scalar = StringArray::new_scalar(catalog_filter_name); + filters.push(eq(&catalog_name, &scalar)?); + } + + let tt_filter = table_types_filter + .into_iter() + .map(|tt| eq(&table_type, &StringArray::new_scalar(tt))) + .collect::, _>>()? + .into_iter() + // We know the arrays are of same length as they are produced fromn the same root array + .reduce(|filter, arr| or(&filter, &arr).unwrap()); + if let Some(filter) = tt_filter { + filters.push(filter); + } + + if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { + // use like kernel to get wildcard matching + let scalar = StringArray::new_scalar(db_schema_filter_pattern); + filters.push(like(&db_schema_name, &scalar)?) + } + + if let Some(table_name_filter_pattern) = table_name_filter_pattern { + // use like kernel to get wildcard matching + let scalar = StringArray::new_scalar(table_name_filter_pattern); + filters.push(like(&table_name, &scalar)?) + } + + let batch = if let Some(table_schema) = table_schema { + RecordBatch::try_new( + schema, + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + Arc::new(table_name) as ArrayRef, + Arc::new(table_type) as ArrayRef, + Arc::new(table_schema) as ArrayRef, + ], + ) + } else { + RecordBatch::try_new( + // schema is different if table_schema is none + schema, + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + Arc::new(table_name) as ArrayRef, + Arc::new(table_type) as ArrayRef, + ], + ) + }?; + + // `AND` any filters together + let mut total_filter = None; + while let Some(filter) = filters.pop() { + let new_filter = match total_filter { + Some(total_filter) => and(&total_filter, &filter)?, + None => filter, + }; + total_filter = Some(new_filter); + } + + // Apply the filters if needed + let filtered_batch = if let Some(total_filter) = total_filter { + filter_record_batch(&batch, &total_filter)? + } else { + batch + }; + + // Order filtered results by catalog_name, then db_schema_name, then table_name, then table_type + // https://github.com/apache/arrow/blob/130f9e981aa98c25de5f5bfe55185db270cec313/format/FlightSql.proto#LL1202C1-L1202C1 + let sort_cols = filtered_batch.project(&[0, 1, 2, 3])?; + let indices = lexsort_to_indices(sort_cols.columns()); + let columns = filtered_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(filtered_batch.schema(), columns)?) + } + + /// Return the schema of the RecordBatch that will be returned from [`CommandGetTables`] + /// + /// Note the schema differs based on the values of `include_schema + /// + /// [`CommandGetTables`]: crate::sql::CommandGetTables + pub fn schema(&self) -> SchemaRef { + get_tables_schema(self.include_schema()) + } + + /// Should the "schema" column be included + pub fn include_schema(&self) -> bool { + self.table_schema.is_some() + } +} + +fn get_tables_schema(include_schema: bool) -> SchemaRef { + if include_schema { + Arc::clone(&GET_TABLES_SCHEMA_WITH_TABLE_SCHEMA) + } else { + Arc::clone(&GET_TABLES_SCHEMA_WITHOUT_TABLE_SCHEMA) + } +} + +/// The schema for GetTables without `table_schema` column +static GET_TABLES_SCHEMA_WITHOUT_TABLE_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ])) +}); + +/// The schema for GetTables with `table_schema` column +static GET_TABLES_SCHEMA_WITH_TABLE_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + Field::new("table_schema", DataType::Binary, false), + ])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{StringArray, UInt32Array}; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_tables_schema(false), + vec![ + Arc::new(StringArray::from(vec![ + "a_catalog", + "a_catalog", + "a_catalog", + "a_catalog", + "b_catalog", + "b_catalog", + "b_catalog", + "b_catalog", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_schema", "a_schema", "b_schema", "b_schema", "a_schema", "a_schema", + "b_schema", "b_schema", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_table", "b_table", "a_table", "b_table", "a_table", "a_table", "b_table", + "b_table", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "TABLE", "TABLE", "TABLE", "TABLE", "TABLE", "VIEW", "TABLE", "VIEW", + ])) as ArrayRef, + ], + ) + .unwrap() + } + + fn get_ref_builder( + catalog: Option<&str>, + db_schema_filter_pattern: Option<&str>, + table_name_filter_pattern: Option<&str>, + table_types: Vec<&str>, + include_schema: bool, + ) -> GetTablesBuilder { + let dummy_schema = Schema::empty(); + let tables = [ + ("a_catalog", "a_schema", "a_table", "TABLE"), + ("a_catalog", "a_schema", "b_table", "TABLE"), + ("a_catalog", "b_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "a_schema", "a_table", "TABLE"), + ("b_catalog", "a_schema", "a_table", "VIEW"), + ("b_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "VIEW"), + ]; + let mut builder = GetTablesBuilder::new( + catalog, + db_schema_filter_pattern, + table_name_filter_pattern, + table_types, + include_schema, + ); + for (catalog_name, schema_name, table_name, table_type) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + table_type, + &dummy_schema, + ) + .unwrap(); + } + builder + } + + #[test] + fn test_tables_are_filtered() { + let ref_batch = get_ref_batch(); + + let builder = get_ref_builder(None, None, None, Vec::new(), false); + let table_batch = builder.build().unwrap(); + assert_eq!(table_batch, ref_batch); + + let builder = get_ref_builder(None, Some("a%"), Some("a%"), Vec::new(), false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![0, 4, 5]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + + let builder = get_ref_builder(Some("a_catalog"), None, None, Vec::new(), false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![0, 1, 2, 3]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + + let builder = get_ref_builder(None, None, None, vec!["VIEW"], false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![5, 7]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + } + + #[test] + fn test_tables_are_sorted() { + let ref_batch = get_ref_batch(); + let dummy_schema = Schema::empty(); + + let tables = [ + ("b_catalog", "a_schema", "a_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "VIEW"), + ("b_catalog", "a_schema", "a_table", "VIEW"), + ("a_catalog", "a_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "b_table", "TABLE"), + ("a_catalog", "a_schema", "b_table", "TABLE"), + ]; + let mut builder = GetTablesBuilder::new( + None::, + None::, + None::, + None::, + false, + ); + for (catalog_name, schema_name, table_name, table_type) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + table_type, + &dummy_schema, + ) + .unwrap(); + } + let table_batch = builder.build().unwrap(); + assert_eq!(table_batch, ref_batch); + } +} diff --git a/arrow-flight/src/sql/metadata/xdbc_info.rs b/arrow-flight/src/sql/metadata/xdbc_info.rs new file mode 100644 index 000000000000..2e635d3037bc --- /dev/null +++ b/arrow-flight/src/sql/metadata/xdbc_info.rs @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for [`CommandGetXdbcTypeInfo`] metadata requests. +//! +//! - [`XdbcTypeInfo`] - a typed struct that holds the xdbc info corresponding to expected schema. +//! - [`XdbcTypeInfoDataBuilder`] - a builder for collecting type infos +//! and building a conformant `RecordBatch`. +//! - [`XdbcTypeInfoData`] - a helper type wrapping a `RecordBatch` +//! used for storing xdbc server metadata. +//! - [`GetXdbcTypeInfoBuilder`] - a builder for consructing [`CommandGetXdbcTypeInfo`] responses. +//! +use std::sync::Arc; + +use arrow_array::builder::{BooleanBuilder, Int32Builder, ListBuilder, StringBuilder}; +use arrow_array::{ArrayRef, Int32Array, ListArray, RecordBatch, Scalar}; +use arrow_ord::cmp::eq; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::filter::filter_record_batch; +use arrow_select::take::take; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::{CommandGetXdbcTypeInfo, Nullable, Searchable, XdbcDataType, XdbcDatetimeSubcode}; + +/// Data structure representing type information for xdbc types. +#[derive(Debug, Clone, Default)] +pub struct XdbcTypeInfo { + pub type_name: String, + pub data_type: XdbcDataType, + pub column_size: Option, + pub literal_prefix: Option, + pub literal_suffix: Option, + pub create_params: Option>, + pub nullable: Nullable, + pub case_sensitive: bool, + pub searchable: Searchable, + pub unsigned_attribute: Option, + pub fixed_prec_scale: bool, + pub auto_increment: Option, + pub local_type_name: Option, + pub minimum_scale: Option, + pub maximum_scale: Option, + pub sql_data_type: XdbcDataType, + pub datetime_subcode: Option, + pub num_prec_radix: Option, + pub interval_precision: Option, +} + +/// Helper to create [`CommandGetXdbcTypeInfo`] responses. +/// +/// [`CommandGetXdbcTypeInfo`] are metadata requests used by a Flight SQL +/// server to communicate supported capabilities to Flight SQL clients. +/// +/// Servers constuct - usually static - [`XdbcTypeInfoData`] via the [`XdbcTypeInfoDataBuilder`], +/// and build responses using [`CommandGetXdbcTypeInfo::into_builder`]. +pub struct XdbcTypeInfoData { + batch: RecordBatch, +} + +impl XdbcTypeInfoData { + /// Return the raw (not encoded) RecordBatch that will be returned + /// from [`CommandGetXdbcTypeInfo`] + pub fn record_batch(&self, data_type: impl Into>) -> Result { + if let Some(dt) = data_type.into() { + let scalar = Int32Array::from(vec![dt]); + let filter = eq(self.batch.column(1), &Scalar::new(&scalar))?; + Ok(filter_record_batch(&self.batch, &filter)?) + } else { + Ok(self.batch.clone()) + } + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetXdbcTypeInfo`] + pub fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +pub struct XdbcTypeInfoDataBuilder { + infos: Vec, +} + +impl Default for XdbcTypeInfoDataBuilder { + fn default() -> Self { + Self::new() + } +} + +/// A builder for [`XdbcTypeInfoData`] which is used to create [`CommandGetXdbcTypeInfo`] responses. +/// +/// # Example +/// ``` +/// use arrow_flight::sql::{Nullable, Searchable, XdbcDataType}; +/// use arrow_flight::sql::metadata::{XdbcTypeInfo, XdbcTypeInfoDataBuilder}; +/// // Create the list of metadata describing the server. Since this would not change at +/// // runtime, using once_cell::Lazy or similar patterns to constuct the list is a common approach. +/// let mut builder = XdbcTypeInfoDataBuilder::new(); +/// builder.append(XdbcTypeInfo { +/// type_name: "INTEGER".into(), +/// data_type: XdbcDataType::XdbcInteger, +/// column_size: Some(32), +/// literal_prefix: None, +/// literal_suffix: None, +/// create_params: None, +/// nullable: Nullable::NullabilityNullable, +/// case_sensitive: false, +/// searchable: Searchable::Full, +/// unsigned_attribute: Some(false), +/// fixed_prec_scale: false, +/// auto_increment: Some(false), +/// local_type_name: Some("INTEGER".into()), +/// minimum_scale: None, +/// maximum_scale: None, +/// sql_data_type: XdbcDataType::XdbcInteger, +/// datetime_subcode: None, +/// num_prec_radix: Some(2), +/// interval_precision: None, +/// }); +/// let info_list = builder.build().unwrap(); +/// +/// // to access the underlying record batch +/// let batch = info_list.record_batch(None); +/// ``` +impl XdbcTypeInfoDataBuilder { + /// Create a new instance of [`XdbcTypeInfoDataBuilder`]. + pub fn new() -> Self { + Self { infos: Vec::new() } + } + + /// Append a new row + pub fn append(&mut self, info: XdbcTypeInfo) { + self.infos.push(info); + } + + /// Create helper structure for handling xdbc metadata requests. + pub fn build(self) -> Result { + let mut type_name_builder = StringBuilder::new(); + let mut data_type_builder = Int32Builder::new(); + let mut column_size_builder = Int32Builder::new(); + let mut literal_prefix_builder = StringBuilder::new(); + let mut literal_suffix_builder = StringBuilder::new(); + let mut create_params_builder = ListBuilder::new(StringBuilder::new()); + let mut nullable_builder = Int32Builder::new(); + let mut case_sensitive_builder = BooleanBuilder::new(); + let mut searchable_builder = Int32Builder::new(); + let mut unsigned_attribute_builder = BooleanBuilder::new(); + let mut fixed_prec_scale_builder = BooleanBuilder::new(); + let mut auto_increment_builder = BooleanBuilder::new(); + let mut local_type_name_builder = StringBuilder::new(); + let mut minimum_scale_builder = Int32Builder::new(); + let mut maximum_scale_builder = Int32Builder::new(); + let mut sql_data_type_builder = Int32Builder::new(); + let mut datetime_subcode_builder = Int32Builder::new(); + let mut num_prec_radix_builder = Int32Builder::new(); + let mut interval_precision_builder = Int32Builder::new(); + + self.infos.into_iter().for_each(|info| { + type_name_builder.append_value(info.type_name); + data_type_builder.append_value(info.data_type as i32); + column_size_builder.append_option(info.column_size); + literal_prefix_builder.append_option(info.literal_prefix); + literal_suffix_builder.append_option(info.literal_suffix); + if let Some(params) = info.create_params { + if !params.is_empty() { + for param in params { + create_params_builder.values().append_value(param); + } + create_params_builder.append(true); + } else { + create_params_builder.append_null(); + } + } else { + create_params_builder.append_null(); + } + nullable_builder.append_value(info.nullable as i32); + case_sensitive_builder.append_value(info.case_sensitive); + searchable_builder.append_value(info.searchable as i32); + unsigned_attribute_builder.append_option(info.unsigned_attribute); + fixed_prec_scale_builder.append_value(info.fixed_prec_scale); + auto_increment_builder.append_option(info.auto_increment); + local_type_name_builder.append_option(info.local_type_name); + minimum_scale_builder.append_option(info.minimum_scale); + maximum_scale_builder.append_option(info.maximum_scale); + sql_data_type_builder.append_value(info.sql_data_type as i32); + datetime_subcode_builder.append_option(info.datetime_subcode.map(|code| code as i32)); + num_prec_radix_builder.append_option(info.num_prec_radix); + interval_precision_builder.append_option(info.interval_precision); + }); + + let type_name = Arc::new(type_name_builder.finish()); + let data_type = Arc::new(data_type_builder.finish()); + let column_size = Arc::new(column_size_builder.finish()); + let literal_prefix = Arc::new(literal_prefix_builder.finish()); + let literal_suffix = Arc::new(literal_suffix_builder.finish()); + let (field, offsets, values, nulls) = create_params_builder.finish().into_parts(); + // Re-defined the field to be non-nullable + let new_field = Arc::new(field.as_ref().clone().with_nullable(false)); + let create_params = Arc::new(ListArray::new(new_field, offsets, values, nulls)) as ArrayRef; + let nullable = Arc::new(nullable_builder.finish()); + let case_sensitive = Arc::new(case_sensitive_builder.finish()); + let searchable = Arc::new(searchable_builder.finish()); + let unsigned_attribute = Arc::new(unsigned_attribute_builder.finish()); + let fixed_prec_scale = Arc::new(fixed_prec_scale_builder.finish()); + let auto_increment = Arc::new(auto_increment_builder.finish()); + let local_type_name = Arc::new(local_type_name_builder.finish()); + let minimum_scale = Arc::new(minimum_scale_builder.finish()); + let maximum_scale = Arc::new(maximum_scale_builder.finish()); + let sql_data_type = Arc::new(sql_data_type_builder.finish()); + let datetime_subcode = Arc::new(datetime_subcode_builder.finish()); + let num_prec_radix = Arc::new(num_prec_radix_builder.finish()); + let interval_precision = Arc::new(interval_precision_builder.finish()); + + let batch = RecordBatch::try_new( + Arc::clone(&GET_XDBC_INFO_SCHEMA), + vec![ + type_name, + data_type, + column_size, + literal_prefix, + literal_suffix, + create_params, + nullable, + case_sensitive, + searchable, + unsigned_attribute, + fixed_prec_scale, + auto_increment, + local_type_name, + minimum_scale, + maximum_scale, + sql_data_type, + datetime_subcode, + num_prec_radix, + interval_precision, + ], + )?; + + // Order batch by data_type and then by type_name + let sort_cols = batch.project(&[1, 0])?; + let indices = lexsort_to_indices(sort_cols.columns()); + let columns = batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(XdbcTypeInfoData { + batch: RecordBatch::try_new(batch.schema(), columns)?, + }) + } + + /// Return the [`Schema`] for a GetSchema RPC call with [`CommandGetXdbcTypeInfo`] + pub fn schema(&self) -> SchemaRef { + Arc::clone(&GET_XDBC_INFO_SCHEMA) + } +} + +/// A builder for a [`CommandGetXdbcTypeInfo`] response. +pub struct GetXdbcTypeInfoBuilder<'a> { + data_type: Option, + infos: &'a XdbcTypeInfoData, +} + +impl CommandGetXdbcTypeInfo { + /// Create a builder suitable for constructing a response + pub fn into_builder(self, infos: &XdbcTypeInfoData) -> GetXdbcTypeInfoBuilder { + GetXdbcTypeInfoBuilder { + data_type: self.data_type, + infos, + } + } +} + +impl GetXdbcTypeInfoBuilder<'_> { + /// Builds a `RecordBatch` with the correct schema for a [`CommandGetXdbcTypeInfo`] response + pub fn build(self) -> Result { + self.infos.record_batch(self.data_type) + } + + /// Return the schema of the RecordBatch that will be returned + /// from [`CommandGetXdbcTypeInfo`] + pub fn schema(&self) -> SchemaRef { + self.infos.schema() + } +} + +/// The schema for GetXdbcTypeInfo +static GET_XDBC_INFO_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("type_name", DataType::Utf8, false), + Field::new("data_type", DataType::Int32, false), + Field::new("column_size", DataType::Int32, true), + Field::new("literal_prefix", DataType::Utf8, true), + Field::new("literal_suffix", DataType::Utf8, true), + Field::new( + "create_params", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, false))), + true, + ), + Field::new("nullable", DataType::Int32, false), + Field::new("case_sensitive", DataType::Boolean, false), + Field::new("searchable", DataType::Int32, false), + Field::new("unsigned_attribute", DataType::Boolean, true), + Field::new("fixed_prec_scale", DataType::Boolean, false), + Field::new("auto_increment", DataType::Boolean, true), + Field::new("local_type_name", DataType::Utf8, true), + Field::new("minimum_scale", DataType::Int32, true), + Field::new("maximum_scale", DataType::Int32, true), + Field::new("sql_data_type", DataType::Int32, false), + Field::new("datetime_subcode", DataType::Int32, true), + Field::new("num_prec_radix", DataType::Int32, true), + Field::new("interval_precision", DataType::Int32, true), + ])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use crate::sql::metadata::tests::assert_batches_eq; + + #[test] + fn test_create_batch() { + let mut builder = XdbcTypeInfoDataBuilder::new(); + builder.append(XdbcTypeInfo { + type_name: "VARCHAR".into(), + data_type: XdbcDataType::XdbcVarchar, + column_size: Some(i32::MAX), + literal_prefix: Some("'".into()), + literal_suffix: Some("'".into()), + create_params: Some(vec!["length".into()]), + nullable: Nullable::NullabilityNullable, + case_sensitive: true, + searchable: Searchable::Full, + unsigned_attribute: None, + fixed_prec_scale: false, + auto_increment: None, + local_type_name: Some("VARCHAR".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcVarchar, + datetime_subcode: None, + num_prec_radix: None, + interval_precision: None, + }); + builder.append(XdbcTypeInfo { + type_name: "INTEGER".into(), + data_type: XdbcDataType::XdbcInteger, + column_size: Some(32), + literal_prefix: None, + literal_suffix: None, + create_params: None, + nullable: Nullable::NullabilityNullable, + case_sensitive: false, + searchable: Searchable::Full, + unsigned_attribute: Some(false), + fixed_prec_scale: false, + auto_increment: Some(false), + local_type_name: Some("INTEGER".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcInteger, + datetime_subcode: None, + num_prec_radix: Some(2), + interval_precision: None, + }); + builder.append(XdbcTypeInfo { + type_name: "INTERVAL".into(), + data_type: XdbcDataType::XdbcInterval, + column_size: Some(i32::MAX), + literal_prefix: Some("'".into()), + literal_suffix: Some("'".into()), + create_params: None, + nullable: Nullable::NullabilityNullable, + case_sensitive: false, + searchable: Searchable::Full, + unsigned_attribute: None, + fixed_prec_scale: false, + auto_increment: None, + local_type_name: Some("INTERVAL".into()), + minimum_scale: None, + maximum_scale: None, + sql_data_type: XdbcDataType::XdbcInterval, + datetime_subcode: Some(XdbcDatetimeSubcode::XdbcSubcodeUnknown), + num_prec_radix: None, + interval_precision: None, + }); + let infos = builder.build().unwrap(); + + let batch = infos.record_batch(None).unwrap(); + let expected = vec![ + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| type_name | data_type | column_size | literal_prefix | literal_suffix | create_params | nullable | case_sensitive | searchable | unsigned_attribute | fixed_prec_scale | auto_increment | local_type_name | minimum_scale | maximum_scale | sql_data_type | datetime_subcode | num_prec_radix | interval_precision |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| INTEGER | 4 | 32 | | | | 1 | false | 3 | false | false | false | INTEGER | | | 4 | | 2 | |", + "| INTERVAL | 10 | 2147483647 | ' | ' | | 1 | false | 3 | | false | | INTERVAL | | | 10 | 0 | | |", + "| VARCHAR | 12 | 2147483647 | ' | ' | [length] | 1 | true | 3 | | false | | VARCHAR | | | 12 | | | |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + ]; + assert_batches_eq(&[batch], &expected); + + let batch = infos.record_batch(Some(10)).unwrap(); + let expected = vec![ + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| type_name | data_type | column_size | literal_prefix | literal_suffix | create_params | nullable | case_sensitive | searchable | unsigned_attribute | fixed_prec_scale | auto_increment | local_type_name | minimum_scale | maximum_scale | sql_data_type | datetime_subcode | num_prec_radix | interval_precision |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + "| INTERVAL | 10 | 2147483647 | ' | ' | | 1 | false | 3 | | false | | INTERVAL | | | 10 | 0 | | |", + "+-----------+-----------+-------------+----------------+----------------+---------------+----------+----------------+------------+--------------------+------------------+----------------+-----------------+---------------+---------------+---------------+------------------+----------------+--------------------+", + ]; + assert_batches_eq(&[batch], &expected); + } +} diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index cd198a1401d1..97645ae7840d 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -15,7 +15,32 @@ // specific language governing permissions and limitations // under the License. -use arrow::error::{ArrowError, Result as ArrowResult}; +//! Support for execute SQL queries using [Apache Arrow] [Flight SQL]. +//! +//! [Flight SQL] is built on top of Arrow Flight RPC framework, by +//! defining specific messages, encoded using the protobuf format, +//! sent in the[`FlightDescriptor::cmd`] field to [`FlightService`] +//! endpoints such as[`get_flight_info`] and [`do_get`]. +//! +//! This module contains: +//! 1. [prost] generated structs for FlightSQL messages such as [`CommandStatementQuery`] +//! 2. Helpers for encoding and decoding FlightSQL messages: [`Any`] and [`Command`] +//! 3. A [`FlightSqlServiceClient`] for interacting with FlightSQL servers. +//! 4. A [`FlightSqlService`] to help building FlightSQL servers from [`FlightService`]. +//! 5. Helpers to build responses for FlightSQL metadata APIs: [`metadata`] +//! +//! [Flight SQL]: https://arrow.apache.org/docs/format/FlightSql.html +//! [Apache Arrow]: https://arrow.apache.org +//! [`FlightDescriptor::cmd`]: crate::FlightDescriptor::cmd +//! [`FlightService`]: crate::flight_service_server::FlightService +//! [`get_flight_info`]: crate::flight_service_server::FlightService::get_flight_info +//! [`do_get`]: crate::flight_service_server::FlightService::do_get +//! [`FlightSqlServiceClient`]: client::FlightSqlServiceClient +//! [`FlightSqlService`]: server::FlightSqlService +//! [`metadata`]: crate::sql::metadata +use arrow_schema::ArrowError; +use bytes::Bytes; +use paste::paste; use prost::Message; mod gen { @@ -23,9 +48,18 @@ mod gen { include!("arrow.flight.protocol.sql.rs"); } +pub use gen::ActionBeginSavepointRequest; +pub use gen::ActionBeginSavepointResult; +pub use gen::ActionBeginTransactionRequest; +pub use gen::ActionBeginTransactionResult; +pub use gen::ActionCancelQueryRequest; +pub use gen::ActionCancelQueryResult; pub use gen::ActionClosePreparedStatementRequest; pub use gen::ActionCreatePreparedStatementRequest; pub use gen::ActionCreatePreparedStatementResult; +pub use gen::ActionCreatePreparedSubstraitPlanRequest; +pub use gen::ActionEndSavepointRequest; +pub use gen::ActionEndTransactionRequest; pub use gen::CommandGetCatalogs; pub use gen::CommandGetCrossReference; pub use gen::CommandGetDbSchemas; @@ -35,11 +69,15 @@ pub use gen::CommandGetPrimaryKeys; pub use gen::CommandGetSqlInfo; pub use gen::CommandGetTableTypes; pub use gen::CommandGetTables; +pub use gen::CommandGetXdbcTypeInfo; pub use gen::CommandPreparedStatementQuery; pub use gen::CommandPreparedStatementUpdate; pub use gen::CommandStatementQuery; +pub use gen::CommandStatementSubstraitPlan; pub use gen::CommandStatementUpdate; pub use gen::DoPutUpdateResult; +pub use gen::Nullable; +pub use gen::Searchable; pub use gen::SqlInfo; pub use gen::SqlNullOrdering; pub use gen::SqlOuterJoinsSupportLevel; @@ -50,14 +88,20 @@ pub use gen::SqlSupportedPositionedCommands; pub use gen::SqlSupportedResultSetConcurrency; pub use gen::SqlSupportedResultSetType; pub use gen::SqlSupportedSubqueries; +pub use gen::SqlSupportedTransaction; pub use gen::SqlSupportedTransactions; pub use gen::SqlSupportedUnions; pub use gen::SqlSupportsConvert; pub use gen::SqlTransactionIsolationLevel; +pub use gen::SubstraitPlan; pub use gen::SupportedSqlGrammar; pub use gen::TicketStatementQuery; pub use gen::UpdateDeleteRules; +pub use gen::XdbcDataType; +pub use gen::XdbcDatetimeSubcode; +pub mod client; +pub mod metadata; pub mod server; /// ProstMessageExt are useful utility methods for prost::Message types @@ -65,34 +109,132 @@ pub trait ProstMessageExt: prost::Message + Default { /// type_url for this Message fn type_url() -> &'static str; - /// Convert this Message to prost_types::Any - fn as_any(&self) -> prost_types::Any; + /// Convert this Message to [`Any`] + fn as_any(&self) -> Any; +} + +/// Macro to coerce a token to an item, specifically +/// to build the `Commands` enum. +/// +/// See: +macro_rules! as_item { + ($i:item) => { + $i + }; } macro_rules! prost_message_ext { - ($($name:ty,)*) => { - $( - impl ProstMessageExt for $name { - fn type_url() -> &'static str { - concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name)) + ($($name:tt,)*) => { + paste! { + $( + const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name)); + )* + + as_item! { + /// Helper to convert to/from protobuf [`Any`] message + /// to a specific FlightSQL command message. + /// + /// # Example + /// ```rust + /// # use arrow_flight::sql::{Any, CommandStatementQuery, Command}; + /// let flightsql_message = CommandStatementQuery { + /// query: "SELECT * FROM foo".to_string(), + /// transaction_id: None, + /// }; + /// + /// // Given a packed FlightSQL Any message + /// let any_message = Any::pack(&flightsql_message).unwrap(); + /// + /// // decode it to Command: + /// match Command::try_from(any_message).unwrap() { + /// Command::CommandStatementQuery(decoded) => { + /// assert_eq!(flightsql_message, decoded); + /// } + /// _ => panic!("Unexpected decoded message"), + /// } + /// ``` + #[derive(Clone, Debug, PartialEq)] + pub enum Command { + $($name($name),)* + + /// Any message that is not any FlightSQL command. + Unknown(Any), + } + } + + impl Command { + /// Convert the command to [`Any`]. + pub fn into_any(self) -> Any { + match self { + $( + Self::$name(cmd) => cmd.as_any(), + )* + Self::Unknown(any) => any, + } + } + + /// Get the URL for the command. + pub fn type_url(&self) -> &str { + match self { + $( + Self::$name(_) => [<$name:snake:upper _TYPE_URL>], + )* + Self::Unknown(any) => any.type_url.as_str(), + } } + } + + impl TryFrom for Command { + type Error = ArrowError; - fn as_any(&self) -> prost_types::Any { - prost_types::Any { - type_url: <$name>::type_url().to_string(), - value: self.encode_to_vec(), + fn try_from(any: Any) -> Result { + match any.type_url.as_str() { + $( + [<$name:snake:upper _TYPE_URL>] + => { + let m: $name = Message::decode(&*any.value).map_err(|err| { + ArrowError::ParseError(format!("Unable to decode Any value: {err}")) + })?; + Ok(Self::$name(m)) + } + )* + _ => Ok(Self::Unknown(any)), } } } - )* + + $( + impl ProstMessageExt for $name { + fn type_url() -> &'static str { + [<$name:snake:upper _TYPE_URL>] + } + + fn as_any(&self) -> Any { + Any { + type_url: <$name>::type_url().to_string(), + value: self.encode_to_vec().into(), + } + } + } + )* + } }; } // Implement ProstMessageExt for all structs defined in FlightSql.proto prost_message_ext!( + ActionBeginSavepointRequest, + ActionBeginSavepointResult, + ActionBeginTransactionRequest, + ActionBeginTransactionResult, + ActionCancelQueryRequest, + ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, + ActionCreatePreparedSubstraitPlanRequest, + ActionEndSavepointRequest, + ActionEndTransactionRequest, CommandGetCatalogs, CommandGetCrossReference, CommandGetDbSchemas, @@ -102,48 +244,63 @@ prost_message_ext!( CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, + CommandStatementSubstraitPlan, CommandStatementUpdate, DoPutUpdateResult, TicketStatementQuery, ); -/// ProstAnyExt are useful utility methods for prost_types::Any -/// The API design is inspired by [rust-protobuf](https://github.com/stepancheg/rust-protobuf/blob/master/protobuf/src/well_known_types_util/any.rs) -pub trait ProstAnyExt { - /// Check if `Any` contains a message of given type. - fn is(&self) -> bool; - - /// Extract a message from this `Any`. - /// - /// # Returns - /// - /// * `Ok(None)` when message type mismatch - /// * `Err` when parse failed - fn unpack(&self) -> ArrowResult>; - - /// Pack any message into `prost_types::Any` value. - fn pack(message: &M) -> ArrowResult; +/// An implementation of the protobuf [`Any`] message type +/// +/// Encoded protobuf messages are not self-describing, nor contain any information +/// on the schema of the encoded payload. Consequently to decode a protobuf a client +/// must know the exact schema of the message. +/// +/// This presents a problem for loosely typed APIs, where the exact message payloads +/// are not enumerable, and therefore cannot be enumerated as variants in a [oneof]. +/// +/// One solution is [`Any`] where the encoded payload is paired with a `type_url` +/// identifying the type of encoded message, and the resulting combination encoded. +/// +/// Clients can then decode the outer [`Any`], inspect the `type_url` and if it is +/// a type they recognise, proceed to decode the embedded message `value` +/// +/// [`Any`]: https://developers.google.com/protocol-buffers/docs/proto3#any +/// [oneof]: https://developers.google.com/protocol-buffers/docs/proto3#oneof +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Any { + /// A URL/resource name that uniquely identifies the type of the serialized + /// protocol buffer message. This string must contain at least + /// one "/" character. The last segment of the URL's path must represent + /// the fully qualified name of the type (as in + /// `path/google.protobuf.Duration`). The name should be in a canonical form + /// (e.g., leading "." is not accepted). + #[prost(string, tag = "1")] + pub type_url: String, + /// Must be a valid serialized protocol buffer of the above specified type. + #[prost(bytes = "bytes", tag = "2")] + pub value: Bytes, } -impl ProstAnyExt for prost_types::Any { - fn is(&self) -> bool { +impl Any { + pub fn is(&self) -> bool { M::type_url() == self.type_url } - fn unpack(&self) -> ArrowResult> { + pub fn unpack(&self) -> Result, ArrowError> { if !self.is::() { return Ok(None); } - let m = prost::Message::decode(&*self.value).map_err(|err| { - ArrowError::ParseError(format!("Unable to decode Any value: {}", err)) - })?; + let m = Message::decode(&*self.value) + .map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?; Ok(Some(m)) } - fn pack(message: &M) -> ArrowResult { + pub fn pack(message: &M) -> Result { Ok(message.as_any()) } } @@ -165,14 +322,38 @@ mod tests { } #[test] - fn test_prost_any_pack_unpack() -> ArrowResult<()> { + fn test_prost_any_pack_unpack() { let query = CommandStatementQuery { query: "select 1".to_string(), + transaction_id: None, }; - let any = prost_types::Any::pack(&query)?; + let any = Any::pack(&query).unwrap(); assert!(any.is::()); - let unpack_query: CommandStatementQuery = any.unpack()?.unwrap(); + let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap(); assert_eq!(query, unpack_query); - Ok(()) + } + + #[test] + fn test_command() { + let query = CommandStatementQuery { + query: "select 1".to_string(), + transaction_id: None, + }; + let any = Any::pack(&query).unwrap(); + let cmd: Command = any.try_into().unwrap(); + + assert!(matches!(cmd, Command::CommandStatementQuery(_))); + assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL); + + // Unknown variant + + let any = Any { + type_url: "fake_url".to_string(), + value: Default::default(), + }; + + let cmd: Command = any.try_into().unwrap(); + assert!(matches!(cmd, Command::Unknown(_))); + assert_eq!(cmd.type_url(), "fake_url"); } } diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index f3208d376497..f1656aca882a 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -15,35 +15,44 @@ // specific language governing permissions and limitations // under the License. +//! Helper trait [`FlightSqlService`] for implementing a [`FlightService`] that implements FlightSQL. + use std::pin::Pin; -use futures::Stream; +use futures::{stream::Peekable, Stream, StreamExt}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; use super::{ - super::{ - flight_service_server::FlightService, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, - PutResult, SchemaResult, Ticket, - }, + ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, + ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference, - CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, + ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, + ActionEndSavepointRequest, ActionEndTransactionRequest, Any, Command, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementUpdate, DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo, - TicketStatementQuery, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, + DoPutUpdateResult, ProstMessageExt, SqlInfo, TicketStatementQuery, +}; +use crate::{ + flight_service_server::FlightService, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, + Ticket, }; -static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement"; -static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement"; +pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement"; +pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement"; +pub(crate) static CREATE_PREPARED_SUBSTRAIT_PLAN: &str = "CreatePreparedSubstraitPlan"; +pub(crate) static BEGIN_TRANSACTION: &str = "BeginTransaction"; +pub(crate) static END_TRANSACTION: &str = "EndTransaction"; +pub(crate) static BEGIN_SAVEPOINT: &str = "BeginSavepoint"; +pub(crate) static END_SAVEPOINT: &str = "EndSavepoint"; +pub(crate) static CANCEL_QUERY: &str = "CancelQuery"; /// Implements FlightSqlService to handle the flight sql protocol #[tonic::async_trait] -pub trait FlightSqlService: - std::marker::Sync + std::marker::Send + std::marker::Sized + 'static -{ +pub trait FlightSqlService: Sync + Send + Sized + 'static { /// When impl FlightSqlService, you can always set FlightService to Self type FlightService: FlightService; @@ -65,7 +74,7 @@ pub trait FlightSqlService: async fn do_get_fallback( &self, _request: Request, - message: prost_types::Any, + message: Any, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented(format!( "do_get: The defined request is invalid: {}", @@ -76,197 +85,465 @@ pub trait FlightSqlService: /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, - query: CommandStatementQuery, - request: Request, - ) -> Result, Status>; + _query: CommandStatementQuery, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_statement has no default implementation", + )) + } + + /// Get a FlightInfo for executing a substrait plan. + async fn get_flight_info_substrait_plan( + &self, + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_substrait_plan has no default implementation", + )) + } /// Get a FlightInfo for executing an already created prepared statement. async fn get_flight_info_prepared_statement( &self, - query: CommandPreparedStatementQuery, - request: Request, - ) -> Result, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_prepared_statement has no default implementation", + )) + } /// Get a FlightInfo for listing catalogs. async fn get_flight_info_catalogs( &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result, Status>; + _query: CommandGetCatalogs, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_catalogs has no default implementation", + )) + } /// Get a FlightInfo for listing schemas. async fn get_flight_info_schemas( &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result, Status>; + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_schemas has no default implementation", + )) + } /// Get a FlightInfo for listing tables. async fn get_flight_info_tables( &self, - query: CommandGetTables, - request: Request, - ) -> Result, Status>; + _query: CommandGetTables, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_tables has no default implementation", + )) + } /// Get a FlightInfo to extract information about the table types. async fn get_flight_info_table_types( &self, - query: CommandGetTableTypes, - request: Request, - ) -> Result, Status>; + _query: CommandGetTableTypes, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_table_types has no default implementation", + )) + } /// Get a FlightInfo for retrieving other information (See SqlInfo). async fn get_flight_info_sql_info( &self, - query: CommandGetSqlInfo, - request: Request, - ) -> Result, Status>; + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_sql_info has no default implementation", + )) + } /// Get a FlightInfo to extract information about primary and foreign keys. async fn get_flight_info_primary_keys( &self, - query: CommandGetPrimaryKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_primary_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about exported keys. async fn get_flight_info_exported_keys( &self, - query: CommandGetExportedKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_exported_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about imported keys. async fn get_flight_info_imported_keys( &self, - query: CommandGetImportedKeys, - request: Request, - ) -> Result, Status>; + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_imported_keys has no default implementation", + )) + } /// Get a FlightInfo to extract information about cross reference. async fn get_flight_info_cross_reference( &self, - query: CommandGetCrossReference, - request: Request, - ) -> Result, Status>; + _query: CommandGetCrossReference, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_cross_reference has no default implementation", + )) + } + + /// Get a FlightInfo to extract information about the supported XDBC types. + async fn get_flight_info_xdbc_type_info( + &self, + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_xdbc_type_info has no default implementation", + )) + } + + /// Implementors may override to handle additional calls to get_flight_info() + async fn get_flight_info_fallback( + &self, + cmd: Command, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented(format!( + "get_flight_info: The defined request is invalid: {}", + cmd.type_url() + ))) + } // do_get /// Get a FlightDataStream containing the query results. async fn do_get_statement( &self, - ticket: TicketStatementQuery, - request: Request, - ) -> Result::DoGetStream>, Status>; + _ticket: TicketStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_statement has no default implementation", + )) + } /// Get a FlightDataStream containing the prepared statement query results. async fn do_get_prepared_statement( &self, - query: CommandPreparedStatementQuery, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_prepared_statement has no default implementation", + )) + } /// Get a FlightDataStream containing the list of catalogs. async fn do_get_catalogs( &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetCatalogs, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_catalogs has no default implementation", + )) + } /// Get a FlightDataStream containing the list of schemas. async fn do_get_schemas( &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_schemas has no default implementation", + )) + } /// Get a FlightDataStream containing the list of tables. async fn do_get_tables( &self, - query: CommandGetTables, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetTables, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_tables has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the table types. async fn do_get_table_types( &self, - query: CommandGetTableTypes, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetTableTypes, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_table_types has no default implementation", + )) + } /// Get a FlightDataStream containing the list of SqlInfo results. async fn do_get_sql_info( &self, - query: CommandGetSqlInfo, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_sql_info has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the primary and foreign keys. async fn do_get_primary_keys( &self, - query: CommandGetPrimaryKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_primary_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the exported keys. async fn do_get_exported_keys( &self, - query: CommandGetExportedKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_exported_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the imported keys. async fn do_get_imported_keys( &self, - query: CommandGetImportedKeys, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_imported_keys has no default implementation", + )) + } /// Get a FlightDataStream containing the data related to the cross reference. async fn do_get_cross_reference( &self, - query: CommandGetCrossReference, - request: Request, - ) -> Result::DoGetStream>, Status>; + _query: CommandGetCrossReference, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_cross_reference has no default implementation", + )) + } + + /// Get a FlightDataStream containing the data related to the supported XDBC types. + async fn do_get_xdbc_type_info( + &self, + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_xdbc_type_info has no default implementation", + )) + } // do_put + /// Implementors may override to handle additional calls to do_put() + async fn do_put_fallback( + &self, + _request: Request, + message: Any, + ) -> Result::DoPutStream>, Status> { + Err(Status::unimplemented(format!( + "do_put: The defined request is invalid: {}", + message.type_url + ))) + } + /// Execute an update SQL statement. async fn do_put_statement_update( &self, - ticket: CommandStatementUpdate, - request: Request>, - ) -> Result; + _ticket: CommandStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_statement_update has no default implementation", + )) + } /// Bind parameters to given prepared statement. async fn do_put_prepared_statement_query( &self, - query: CommandPreparedStatementQuery, - request: Request>, - ) -> Result::DoPutStream>, Status>; + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result::DoPutStream>, Status> { + Err(Status::unimplemented( + "do_put_prepared_statement_query has no default implementation", + )) + } /// Execute an update SQL prepared statement. async fn do_put_prepared_statement_update( &self, - query: CommandPreparedStatementUpdate, - request: Request>, - ) -> Result; + _query: CommandPreparedStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_prepared_statement_update has no default implementation", + )) + } + + /// Execute a substrait plan + async fn do_put_substrait_plan( + &self, + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_substrait_plan has no default implementation", + )) + } // do_action + /// Implementors may override to handle additional calls to do_action() + async fn do_action_fallback( + &self, + request: Request, + ) -> Result::DoActionStream>, Status> { + Err(Status::invalid_argument(format!( + "do_action: The defined request is invalid: {:?}", + request.get_ref().r#type + ))) + } + + /// Add custom actions to list_actions() result + async fn list_custom_actions(&self) -> Option>> { + None + } + /// Create a prepared statement from given SQL statement. async fn do_action_create_prepared_statement( &self, - query: ActionCreatePreparedStatementRequest, - request: Request, - ) -> Result; + _query: ActionCreatePreparedStatementRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_create_prepared_statement has no default implementation", + )) + } /// Close a prepared statement. async fn do_action_close_prepared_statement( &self, - query: ActionClosePreparedStatementRequest, - request: Request, - ); + _query: ActionClosePreparedStatementRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_close_prepared_statement has no default implementation", + )) + } + + /// Create a prepared substrait plan. + async fn do_action_create_prepared_substrait_plan( + &self, + _query: ActionCreatePreparedSubstraitPlanRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_create_prepared_substrait_plan has no default implementation", + )) + } + + /// Begin a transaction + async fn do_action_begin_transaction( + &self, + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_begin_transaction has no default implementation", + )) + } + + /// End a transaction + async fn do_action_end_transaction( + &self, + _query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_end_transaction has no default implementation", + )) + } + + /// Begin a savepoint + async fn do_action_begin_savepoint( + &self, + _query: ActionBeginSavepointRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_begin_savepoint has no default implementation", + )) + } + + /// End a savepoint + async fn do_action_end_savepoint( + &self, + _query: ActionEndSavepointRequest, + _request: Request, + ) -> Result<(), Status> { + Err(Status::unimplemented( + "do_action_end_savepoint has no default implementation", + )) + } + + /// Cancel a query + async fn do_action_cancel_query( + &self, + _query: ActionCancelQueryRequest, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_action_cancel_query has no default implementation", + )) + } + + /// do_exchange + + /// Implementors may override to handle additional calls to do_exchange() + async fn do_exchange_fallback( + &self, + _request: Request>, + ) -> Result::DoExchangeStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } /// Register a new SqlInfo result, making it available when calling GetSqlInfo. async fn register_sql_info(&self, id: i32, result: &SqlInfo); @@ -276,19 +553,16 @@ pub trait FlightSqlService: #[tonic::async_trait] impl FlightService for T where - T: FlightSqlService + std::marker::Send, + T: FlightSqlService + Send, { type HandshakeStream = Pin> + Send + 'static>>; type ListFlightsStream = Pin> + Send + 'static>>; - type DoGetStream = - Pin> + Send + 'static>>; - type DoPutStream = - Pin> + Send + 'static>>; - type DoActionStream = Pin< - Box> + Send + 'static>, - >; + type DoGetStream = Pin> + Send + 'static>>; + type DoPutStream = Pin> + Send + 'static>>; + type DoActionStream = + Pin> + Send + 'static>>; type ListActionsStream = Pin> + Send + 'static>>; type DoExchangeStream = @@ -313,93 +587,49 @@ where &self, request: Request, ) -> Result, Status> { - let message: prost_types::Any = - Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; + let message = Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_statement(token, request).await; - } - if message.is::() { - let handle = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self - .get_flight_info_prepared_statement(handle, request) - .await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_catalogs(token, request).await; + match Command::try_from(message).map_err(arrow_error_to_status)? { + Command::CommandStatementQuery(token) => { + self.get_flight_info_statement(token, request).await + } + Command::CommandPreparedStatementQuery(handle) => { + self.get_flight_info_prepared_statement(handle, request) + .await + } + Command::CommandStatementSubstraitPlan(handle) => { + self.get_flight_info_substrait_plan(handle, request).await + } + Command::CommandGetCatalogs(token) => { + self.get_flight_info_catalogs(token, request).await + } + Command::CommandGetDbSchemas(token) => { + return self.get_flight_info_schemas(token, request).await + } + Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await, + Command::CommandGetTableTypes(token) => { + self.get_flight_info_table_types(token, request).await + } + Command::CommandGetSqlInfo(token) => { + self.get_flight_info_sql_info(token, request).await + } + Command::CommandGetPrimaryKeys(token) => { + self.get_flight_info_primary_keys(token, request).await + } + Command::CommandGetExportedKeys(token) => { + self.get_flight_info_exported_keys(token, request).await + } + Command::CommandGetImportedKeys(token) => { + self.get_flight_info_imported_keys(token, request).await + } + Command::CommandGetCrossReference(token) => { + self.get_flight_info_cross_reference(token, request).await + } + Command::CommandGetXdbcTypeInfo(token) => { + self.get_flight_info_xdbc_type_info(token, request).await + } + cmd => self.get_flight_info_fallback(cmd, request).await, } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_schemas(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_tables(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_table_types(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_sql_info(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_primary_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_exported_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_imported_keys(token, request).await; - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.get_flight_info_cross_reference(token, request).await; - } - - Err(Status::unimplemented(format!( - "get_flight_info: The defined request is invalid: {}", - message.type_url - ))) } async fn get_schema( @@ -413,98 +643,87 @@ where &self, request: Request, ) -> Result, Status> { - let msg: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) - .map_err(decode_error_to_status)?; - - fn unpack(msg: prost_types::Any) -> Result { - msg.unpack() - .map_err(arrow_error_to_status)? - .ok_or_else(|| Status::internal("Expected a command, but found none.")) - } + let msg: Any = + Message::decode(&*request.get_ref().ticket).map_err(decode_error_to_status)?; - if msg.is::() { - return self.do_get_statement(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_prepared_statement(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_catalogs(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_schemas(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_tables(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_table_types(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_sql_info(unpack(msg)?, request).await; + match Command::try_from(msg).map_err(arrow_error_to_status)? { + Command::TicketStatementQuery(command) => self.do_get_statement(command, request).await, + Command::CommandPreparedStatementQuery(command) => { + self.do_get_prepared_statement(command, request).await + } + Command::CommandGetCatalogs(command) => self.do_get_catalogs(command, request).await, + Command::CommandGetDbSchemas(command) => self.do_get_schemas(command, request).await, + Command::CommandGetTables(command) => self.do_get_tables(command, request).await, + Command::CommandGetTableTypes(command) => { + self.do_get_table_types(command, request).await + } + Command::CommandGetSqlInfo(command) => self.do_get_sql_info(command, request).await, + Command::CommandGetPrimaryKeys(command) => { + self.do_get_primary_keys(command, request).await + } + Command::CommandGetExportedKeys(command) => { + self.do_get_exported_keys(command, request).await + } + Command::CommandGetImportedKeys(command) => { + self.do_get_imported_keys(command, request).await + } + Command::CommandGetCrossReference(command) => { + self.do_get_cross_reference(command, request).await + } + Command::CommandGetXdbcTypeInfo(command) => { + self.do_get_xdbc_type_info(command, request).await + } + cmd => self.do_get_fallback(request, cmd.into_any()).await, } - if msg.is::() { - return self.do_get_primary_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_exported_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_imported_keys(unpack(msg)?, request).await; - } - if msg.is::() { - return self.do_get_cross_reference(unpack(msg)?, request).await; - } - - self.do_get_fallback(request, msg).await } async fn do_put( &self, - mut request: Request>, + request: Request>, ) -> Result, Status> { - let cmd = request.get_mut().message().await?.unwrap(); - let message: prost_types::Any = - prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) - .map_err(decode_error_to_status)?; - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - let record_count = self.do_put_statement_update(token, request).await?; - let result = DoPutUpdateResult { record_count }; - let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { - app_metadata: result.encode_to_vec(), - })]); - return Ok(Response::new(Box::pin(output))); - } - if message.is::() { - let token = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_put_prepared_statement_query(token, request).await; - } - if message.is::() { - let handle = message - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - let record_count = self - .do_put_prepared_statement_update(handle, request) - .await?; - let result = DoPutUpdateResult { record_count }; - let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { - app_metadata: result.encode_to_vec(), - })]); - return Ok(Response::new(Box::pin(output))); - } + // See issue #4658: https://github.com/apache/arrow-rs/issues/4658 + // To dispatch to the correct `do_put` method, we cannot discard the first message, + // as it may contain the Arrow schema, which the `do_put` handler may need. + // To allow the first message to be reused by the `do_put` handler, + // we wrap this stream in a `Peekable` one, which allows us to peek at + // the first message without discarding it. + let mut request = request.map(PeekableFlightDataStream::new); + let cmd = Pin::new(request.get_mut()).peek().await.unwrap().clone()?; - Err(Status::invalid_argument(format!( - "do_put: The defined request is invalid: {}", - message.type_url - ))) + let message = + Any::decode(&*cmd.flight_descriptor.unwrap().cmd).map_err(decode_error_to_status)?; + match Command::try_from(message).map_err(arrow_error_to_status)? { + Command::CommandStatementUpdate(command) => { + let record_count = self.do_put_statement_update(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + Command::CommandPreparedStatementQuery(command) => { + self.do_put_prepared_statement_query(command, request).await + } + Command::CommandStatementSubstraitPlan(command) => { + let record_count = self.do_put_substrait_plan(command, request).await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + Command::CommandPreparedStatementUpdate(command) => { + let record_count = self + .do_put_prepared_statement_update(command, request) + .await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(PutResult { + app_metadata: result.as_any().encode_to_vec().into(), + })]); + Ok(Response::new(Box::pin(output))) + } + cmd => self.do_put_fallback(request, cmd.into_any()).await, + } } async fn list_actions( @@ -525,10 +744,63 @@ where Response Message: N/A" .into(), }; - let actions: Vec> = vec![ + let create_prepared_substrait_plan_action_type = ActionType { + r#type: CREATE_PREPARED_SUBSTRAIT_PLAN.to_string(), + description: "Creates a reusable prepared substrait plan resource on the server.\n + Request Message: ActionCreatePreparedSubstraitPlanRequest\n + Response Message: ActionCreatePreparedStatementResult" + .into(), + }; + let begin_transaction_action_type = ActionType { + r#type: BEGIN_TRANSACTION.to_string(), + description: "Begins a transaction.\n + Request Message: ActionBeginTransactionRequest\n + Response Message: ActionBeginTransactionResult" + .into(), + }; + let end_transaction_action_type = ActionType { + r#type: END_TRANSACTION.to_string(), + description: "Ends a transaction\n + Request Message: ActionEndTransactionRequest\n + Response Message: N/A" + .into(), + }; + let begin_savepoint_action_type = ActionType { + r#type: BEGIN_SAVEPOINT.to_string(), + description: "Begins a savepoint.\n + Request Message: ActionBeginSavepointRequest\n + Response Message: ActionBeginSavepointResult" + .into(), + }; + let end_savepoint_action_type = ActionType { + r#type: END_SAVEPOINT.to_string(), + description: "Ends a savepoint\n + Request Message: ActionEndSavepointRequest\n + Response Message: N/A" + .into(), + }; + let cancel_query_action_type = ActionType { + r#type: CANCEL_QUERY.to_string(), + description: "Cancels a query\n + Request Message: ActionCancelQueryRequest\n + Response Message: ActionCancelQueryResult" + .into(), + }; + let mut actions: Vec> = vec![ Ok(create_prepared_statement_action_type), Ok(close_prepared_statement_action_type), + Ok(create_prepared_substrait_plan_action_type), + Ok(begin_transaction_action_type), + Ok(end_transaction_action_type), + Ok(begin_savepoint_action_type), + Ok(end_savepoint_action_type), + Ok(cancel_query_action_type), ]; + + if let Some(mut custom_actions) = self.list_custom_actions().await { + actions.append(&mut custom_actions); + } + let output = futures::stream::iter(actions); Ok(Response::new(Box::pin(output) as Self::ListActionsStream)) } @@ -538,8 +810,7 @@ where request: Request, ) -> Result, Status> { if request.get_ref().r#type == CREATE_PREPARED_STATEMENT { - let any: prost_types::Any = Message::decode(&*request.get_ref().body) - .map_err(decode_error_to_status)?; + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionCreatePreparedStatementRequest = any .unpack() @@ -553,13 +824,11 @@ where .do_action_create_prepared_statement(cmd, request) .await?; let output = futures::stream::iter(vec![Ok(super::super::gen::Result { - body: stmt.as_any().encode_to_vec(), + body: stmt.as_any().encode_to_vec().into(), })]); return Ok(Response::new(Box::pin(output))); - } - if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT { - let any: prost_types::Any = Message::decode(&*request.get_ref().body) - .map_err(decode_error_to_status)?; + } else if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; let cmd: ActionClosePreparedStatementRequest = any .unpack() @@ -569,28 +838,190 @@ where "Unable to unpack ActionClosePreparedStatementRequest.", ) })?; - self.do_action_close_prepared_statement(cmd, request).await; + self.do_action_close_prepared_statement(cmd, request) + .await?; + return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == CREATE_PREPARED_SUBSTRAIT_PLAN { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionCreatePreparedSubstraitPlanRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument( + "Unable to unpack ActionCreatePreparedSubstraitPlanRequest.", + ) + })?; + self.do_action_create_prepared_substrait_plan(cmd, request) + .await?; + return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == BEGIN_TRANSACTION { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionBeginTransactionRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.") + })?; + let stmt = self.do_action_begin_transaction(cmd, request).await?; + let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + body: stmt.as_any().encode_to_vec().into(), + })]); + return Ok(Response::new(Box::pin(output))); + } else if request.get_ref().r#type == END_TRANSACTION { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionEndTransactionRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionEndTransactionRequest.") + })?; + self.do_action_end_transaction(cmd, request).await?; return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == BEGIN_SAVEPOINT { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionBeginSavepointRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.") + })?; + let stmt = self.do_action_begin_savepoint(cmd, request).await?; + let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + body: stmt.as_any().encode_to_vec().into(), + })]); + return Ok(Response::new(Box::pin(output))); + } else if request.get_ref().r#type == END_SAVEPOINT { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionEndSavepointRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionEndSavepointRequest.") + })?; + self.do_action_end_savepoint(cmd, request).await?; + return Ok(Response::new(Box::pin(futures::stream::empty()))); + } else if request.get_ref().r#type == CANCEL_QUERY { + let any = Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?; + + let cmd: ActionCancelQueryRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument("Unable to unpack ActionCancelQueryRequest.") + })?; + let stmt = self.do_action_cancel_query(cmd, request).await?; + let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + body: stmt.as_any().encode_to_vec().into(), + })]); + return Ok(Response::new(Box::pin(output))); } - Err(Status::invalid_argument(format!( - "do_action: The defined request is invalid: {:?}", - request.get_ref().r#type - ))) + self.do_action_fallback(request).await } async fn do_exchange( &self, - _request: Request>, + request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + self.do_exchange_fallback(request).await } } -fn decode_error_to_status(err: prost::DecodeError) -> tonic::Status { - tonic::Status::invalid_argument(format!("{:?}", err)) +fn decode_error_to_status(err: prost::DecodeError) -> Status { + Status::invalid_argument(format!("{err:?}")) +} + +fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status { + Status::internal(format!("{err:?}")) +} + +/// A wrapper around [`Streaming`] that allows "peeking" at the +/// message at the front of the stream without consuming it. +/// This is needed because sometimes the first message in the stream will contain +/// a [`FlightDescriptor`] in addition to potentially any data, and the dispatch logic +/// must inspect this information. +/// +/// # Example +/// +/// [`PeekableFlightDataStream::peek`] can be used to peek at the first message without +/// discarding it; otherwise, `PeekableFlightDataStream` can be used as a regular stream. +/// See the following example: +/// +/// ```no_run +/// use arrow_array::RecordBatch; +/// use arrow_flight::decode::FlightRecordBatchStream; +/// use arrow_flight::FlightDescriptor; +/// use arrow_flight::error::FlightError; +/// use arrow_flight::sql::server::PeekableFlightDataStream; +/// use tonic::{Request, Status}; +/// use futures::TryStreamExt; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Status> { +/// let request: Request = todo!(); +/// let stream: PeekableFlightDataStream = request.into_inner(); +/// +/// // The first message contains the flight descriptor and the schema. +/// // Read the flight descriptor without discarding the schema: +/// let flight_descriptor: FlightDescriptor = stream +/// .peek() +/// .await +/// .cloned() +/// .transpose()? +/// .and_then(|data| data.flight_descriptor) +/// .expect("first message should contain flight descriptor"); +/// +/// // Pass the stream through a decoder +/// let batches: Vec = FlightRecordBatchStream::new_from_flight_data( +/// request.into_inner().map_err(|e| e.into()), +/// ) +/// .try_collect() +/// .await?; +/// } +/// ``` +pub struct PeekableFlightDataStream { + inner: Peekable>, } -fn arrow_error_to_status(err: arrow::error::ArrowError) -> tonic::Status { - tonic::Status::internal(format!("{:?}", err)) +impl PeekableFlightDataStream { + fn new(stream: Streaming) -> Self { + Self { + inner: stream.peekable(), + } + } + + /// Convert this stream into a `Streaming`. + /// Any messages observed through [`Self::peek`] will be lost + /// after the conversion. + pub fn into_inner(self) -> Streaming { + self.inner.into_inner() + } + + /// Convert this stream into a `Peekable>`. + /// Preserves the state of the stream, so that calls to [`Self::peek`] + /// and [`Self::poll_next`] are the same. + pub fn into_peekable(self) -> Peekable> { + self.inner + } + + /// Peek at the head of this stream without advancing it. + pub async fn peek(&mut self) -> Option<&Result> { + Pin::new(&mut self.inner).peek().await + } +} + +impl Stream for PeekableFlightDataStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_next_unpin(cx) + } } diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs new file mode 100644 index 000000000000..73136379d69f --- /dev/null +++ b/arrow-flight/src/trailers.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use futures::{ready, FutureExt, Stream, StreamExt}; +use tonic::{metadata::MetadataMap, Status, Streaming}; + +/// Extract [`LazyTrailers`] from [`Streaming`] [tonic] response. +/// +/// Note that [`LazyTrailers`] has inner mutability and will only hold actual data after [`ExtractTrailersStream`] is +/// fully consumed (dropping it is not required though). +pub fn extract_lazy_trailers(s: Streaming) -> (ExtractTrailersStream, LazyTrailers) { + let trailers: SharedTrailers = Default::default(); + let stream = ExtractTrailersStream { + inner: s, + trailers: Arc::clone(&trailers), + }; + let lazy_trailers = LazyTrailers { trailers }; + (stream, lazy_trailers) +} + +type SharedTrailers = Arc>>; + +/// [Stream] that stores the gRPC trailers into [`LazyTrailers`]. +/// +/// See [`extract_lazy_trailers`] for construction. +#[derive(Debug)] +pub struct ExtractTrailersStream { + inner: Streaming, + trailers: SharedTrailers, +} + +impl Stream for ExtractTrailersStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let res = ready!(self.inner.poll_next_unpin(cx)); + + if res.is_none() { + // stream exhausted => trailers should available + if let Some(trailers) = self + .inner + .trailers() + .now_or_never() + .and_then(|res| res.ok()) + .flatten() + { + *self.trailers.lock().expect("poisoned") = Some(trailers); + } + } + + Poll::Ready(res) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +/// gRPC trailers that are extracted by [`ExtractTrailersStream`]. +/// +/// See [`extract_lazy_trailers`] for construction. +#[derive(Debug)] +pub struct LazyTrailers { + trailers: SharedTrailers, +} + +impl LazyTrailers { + /// gRPC trailers that are known at the end of a stream. + pub fn get(&self) -> Option { + self.trailers.lock().expect("poisoned").clone() + } +} diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 21a5a8572246..b75d61d200cb 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -18,18 +18,22 @@ //! Utilities to assist with reading and writing Arrow data as Flight messages use crate::{FlightData, IpcMessage, SchemaAsIpc, SchemaResult}; +use bytes::Bytes; use std::collections::HashMap; +use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow::buffer::Buffer; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::{ArrowError, Result}; -use arrow::ipc::{reader, writer, writer::IpcWriteOptions}; -use arrow::record_batch::RecordBatch; -use std::convert::TryInto; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::Buffer; +use arrow_ipc::convert::fb_to_schema; +use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions}; +use arrow_schema::{ArrowError, Schema, SchemaRef}; /// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries /// and a `FlightData` representing the bytes of the batch's values +#[deprecated( + since = "30.0.0", + note = "Use IpcDataGenerator directly with DictionaryTracker to avoid re-sending dictionaries" +)] pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, @@ -47,16 +51,38 @@ pub fn flight_data_from_arrow_batch( (flight_dictionaries, flight_batch) } +/// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es +pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result, ArrowError> { + let schema = flight_data.get(0).ok_or_else(|| { + ArrowError::CastError("Need at least one FlightData for schema".to_string()) + })?; + let message = root_as_message(&schema.data_header[..]) + .map_err(|_| ArrowError::CastError("Cannot get root as message".to_string()))?; + + let ipc_schema: arrow_ipc::Schema = message + .header_as_schema() + .ok_or_else(|| ArrowError::CastError("Cannot get header as Schema".to_string()))?; + let schema = fb_to_schema(ipc_schema); + let schema = Arc::new(schema); + + let mut batches = vec![]; + let dictionaries_by_id = HashMap::new(); + for datum in flight_data[1..].iter() { + let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?; + batches.push(batch); + } + Ok(batches) +} + /// Convert `FlightData` (with supplied schema and dictionaries) to an arrow `RecordBatch`. pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, dictionaries_by_id: &HashMap, -) -> Result { +) -> Result { // check that the data_header is a record batch message - let message = arrow::ipc::root_as_message(&data.data_header[..]).map_err(|err| { - ArrowError::ParseError(format!("Unable to get root as message: {:?}", err)) - })?; + let message = arrow_ipc::root_as_message(&data.data_header[..]) + .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; message .header_as_record_batch() @@ -80,13 +106,13 @@ pub fn flight_data_to_arrow_batch( /// Convert a `Schema` to `SchemaResult` by converting to an IPC message #[deprecated( since = "4.4.0", - note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).into()" + note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).try_into()" )] pub fn flight_schema_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, -) -> SchemaResult { - SchemaAsIpc::new(schema, options).into() +) -> Result { + SchemaAsIpc::new(schema, options).try_into() } /// Convert a `Schema` to `FlightData` by converting to an IPC message @@ -94,10 +120,7 @@ pub fn flight_schema_from_arrow_schema( since = "4.4.0", note = "Use From trait, e.g.: SchemaAsIpc::new(schema, options).into()" )] -pub fn flight_data_from_arrow_schema( - schema: &Schema, - options: &IpcWriteOptions, -) -> FlightData { +pub fn flight_data_from_arrow_schema(schema: &Schema, options: &IpcWriteOptions) -> FlightData { SchemaAsIpc::new(schema, options).into() } @@ -109,8 +132,35 @@ pub fn flight_data_from_arrow_schema( pub fn ipc_message_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, -) -> Result> { +) -> Result { let message = SchemaAsIpc::new(schema, options).try_into()?; let IpcMessage(vals) = message; Ok(vals) } + +/// Convert `RecordBatch`es to wire protocol `FlightData`s +pub fn batches_to_flight_data( + schema: &Schema, + batches: Vec, +) -> Result, ArrowError> { + let options = IpcWriteOptions::default(); + let schema_flight_data: FlightData = SchemaAsIpc::new(schema, &options).into(); + let mut dictionaries = vec![]; + let mut flight_data = vec![]; + + let data_gen = writer::IpcDataGenerator::default(); + let mut dictionary_tracker = writer::DictionaryTracker::new(false); + + for batch in batches.iter() { + let (encoded_dictionaries, encoded_batch) = + data_gen.encoded_batch(batch, &mut dictionary_tracker, &options)?; + + dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into)); + flight_data.push(encoded_batch.into()); + } + let mut stream = vec![schema_flight_data]; + stream.extend(dictionaries); + stream.extend(flight_data); + let flight_data: Vec<_> = stream.into_iter().collect(); + Ok(flight_data) +} diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs new file mode 100644 index 000000000000..3ad9ee7a45ca --- /dev/null +++ b/arrow-flight/tests/client.rs @@ -0,0 +1,1011 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration test for "mid level" Client + +mod common { + pub mod server; + pub mod trailers_layer; +} +use arrow_array::{RecordBatch, UInt64Array}; +use arrow_flight::{ + decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError, Action, + ActionType, Criteria, Empty, FlightClient, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, Ticket, +}; +use arrow_schema::{DataType, Field, Schema}; +use bytes::Bytes; +use common::{server::TestFlightServer, trailers_layer::TrailersLayer}; +use futures::{Future, StreamExt, TryStreamExt}; +use tokio::{net::TcpListener, task::JoinHandle}; +use tonic::{ + transport::{Channel, Uri}, + Status, +}; + +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +const DEFAULT_TIMEOUT_SECONDS: u64 = 30; + +#[tokio::test] +async fn test_handshake() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request_payload = Bytes::from("foo-request-payload"); + let response_payload = Bytes::from("bar-response-payload"); + + let request = HandshakeRequest { + payload: request_payload.clone(), + protocol_version: 0, + }; + + let response = HandshakeResponse { + payload: response_payload.clone(), + protocol_version: 0, + }; + + test_server.set_handshake_response(Ok(response)); + let response = client.handshake(request_payload).await.unwrap(); + assert_eq!(response, response_payload); + assert_eq!(test_server.take_handshake_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_handshake_error() { + do_test(|test_server, mut client| async move { + let request_payload = "foo-request-payload".to_string().into_bytes(); + let e = Status::unauthenticated("DENIED"); + test_server.set_handshake_response(Err(e.clone())); + + let response = client.handshake(request_payload).await.unwrap_err(); + expect_status(response, e); + }) + .await; +} + +/// Verifies that all headers sent from the the client are in the request_metadata +fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) { + let client_metadata = client.metadata().clone().into_headers(); + assert!(!client_metadata.is_empty()); + let metadata = test_server + .take_last_request_metadata() + .expect("No headers in server") + .into_headers(); + + for (k, v) in &client_metadata { + assert_eq!( + metadata.get(k).as_ref(), + Some(&v), + "Missing / Mismatched metadata {k:?} sent {client_metadata:?} got {metadata:?}" + ); + } +} + +fn test_flight_info(request: &FlightDescriptor) -> FlightInfo { + FlightInfo { + schema: Bytes::new(), + endpoint: vec![], + flight_descriptor: Some(request.clone()), + total_bytes: 123, + total_records: 456, + ordered: false, + } +} + +#[tokio::test] +async fn test_get_flight_info() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let expected_response = test_flight_info(&request); + test_server.set_get_flight_info_response(Ok(expected_response.clone())); + + let response = client.get_flight_info(request.clone()).await.unwrap(); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_get_flight_info_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_get_flight_info_error() { + do_test(|test_server, mut client| async move { + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let e = Status::unauthenticated("DENIED"); + test_server.set_get_flight_info_response(Err(e.clone())); + + let response = client.get_flight_info(request.clone()).await.unwrap_err(); + expect_status(response, e); + }) + .await; +} + +// TODO more negative tests (like if there are endpoints defined, etc) + +#[tokio::test] +async fn test_do_get() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let response = vec![Ok(batch.clone())]; + test_server.set_do_get_response(response); + let mut response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making request"); + + assert_eq!( + response_stream + .headers() + .get("test-resp-header") + .expect("header exists") + .to_str() + .unwrap(), + "some_val", + ); + + // trailers are not available before stream exhaustion + assert!(response_stream.trailers().is_none()); + + let expected_response = vec![batch]; + let response: Vec<_> = (&mut response_stream) + .try_collect() + .await + .expect("Error streaming data"); + assert_eq!(response, expected_response); + + assert_eq!( + response_stream + .trailers() + .expect("stream exhausted") + .get("test-trailer") + .expect("trailer exists") + .to_str() + .unwrap(), + "trailer_val", + ); + + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_get_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let response = client.do_get(ticket.clone()).await.unwrap_err(); + + let e = Status::internal("No do_get response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_get_error_in_record_batch_stream() { + do_test(|test_server, mut client| async move { + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let expected_response = vec![Ok(batch), Err(e.clone())]; + + test_server.set_do_get_response(expected_response); + + let response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + }) + .await; +} + +#[tokio::test] +async fn test_do_put() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + // encode the batch as a stream of FlightData + let input_flight_data = test_flight_data().await; + + let expected_response = vec![ + PutResult { + app_metadata: Bytes::from("foo-metadata1"), + }, + PutResult { + app_metadata: Bytes::from("bar-metadata2"), + }, + ]; + + test_server.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect()); + + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + + let response = client.do_put(input_stream).await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No do_put response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_stream_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let e = Status::invalid_argument("bad arg"); + + let response = vec![ + Ok(PutResult { + app_metadata: Bytes::from("foo-metadata"), + }), + Err(e.clone()), + ]; + + test_server.set_do_put_response(response); + + let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_client() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::invalid_argument("bad arg: client"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e.clone(), + ))])); + + // server responds with one good message + let response = vec![Ok(PutResult { + app_metadata: Bytes::from("foo-metadata"), + })]; + test_server.set_do_put_response(response); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client + expect_status(response, e); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_put_error_client_and_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e_client = Status::invalid_argument("bad arg: client"); + let e_server = Status::invalid_argument("bad arg: server"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e_client.clone(), + ))])); + + // server responds with an error (e.g. because it got truncated data) + let response = vec![Err(e_server)]; + test_server.set_do_put_response(response); + + let response_stream = client + .do_put(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client (not the server) + expect_status(response, e_client); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + // encode the batch as a stream of FlightData + let input_flight_data = test_flight_data().await; + let output_flight_data = test_flight_data2().await; + + test_server + .set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect()); + + let response_stream = client + .do_exchange(futures::stream::iter(input_flight_data.clone())) + .await + .expect("error making request"); + + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + let expected_stream = futures::stream::iter(output_flight_data).map(Ok); + + let expected_batches: Vec<_> = + FlightRecordBatchStream::new_from_flight_data(expected_stream) + .try_collect() + .await + .unwrap(); + + assert_eq!(response, expected_batches); + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let response = client + .do_exchange(futures::stream::iter(input_flight_data.clone())) + .await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No do_exchange response configured"); + expect_status(response, e); + // server still got the request + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_exchange_error_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let input_flight_data = test_flight_data().await; + + let e = Status::invalid_argument("the error"); + let response = test_flight_data2() + .await + .into_iter() + .enumerate() + .map(|(i, m)| { + if i == 0 { + Ok(m) + } else { + // make all messages after the first an error + Err(e.clone()) + } + }) + .collect(); + + test_server.set_do_exchange_response(response); + + let response_stream = client + .do_exchange(futures::stream::iter(input_flight_data.clone())) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + expect_status(response, e); + // server still got the request + assert_eq!( + test_server.take_do_exchange_request(), + Some(input_flight_data) + ); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_get_schema() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let schema = Schema::new(vec![Field::new("foo", DataType::Int64, true)]); + + let request = FlightDescriptor::new_cmd("my command"); + test_server.set_get_schema_response(Ok(schema.clone())); + + let response = client + .get_schema(request.clone()) + .await + .expect("error making request"); + + let expected_schema = schema; + let expected_request = request; + + assert_eq!(response, expected_schema); + assert_eq!( + test_server.take_get_schema_request(), + Some(expected_request) + ); + + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_get_schema_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request = FlightDescriptor::new_cmd("my command"); + + let e = Status::unauthenticated("DENIED"); + test_server.set_get_schema_response(Err(e.clone())); + + let response = client.get_schema(request).await.unwrap_err(); + expect_status(response, e); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let infos = vec![ + test_flight_info(&FlightDescriptor::new_cmd("foo")), + test_flight_info(&FlightDescriptor::new_cmd("bar")), + ]; + + let response = infos.iter().map(|i| Ok(i.clone())).collect(); + test_server.set_list_flights_response(response); + + let response_stream = client + .list_flights("query") + .await + .expect("error making request"); + + let expected_response = infos; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + let expected_request = Some(Criteria { + expression: "query".into(), + }); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let response = client.list_flights("query").await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No list_flights response configured"); + expect_status(response, e); + // server still got the request + let expected_request = Some(Criteria { + expression: "query".into(), + }); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let response = vec![ + Ok(test_flight_info(&FlightDescriptor::new_cmd("foo"))), + Err(e.clone()), + ]; + test_server.set_list_flights_response(response); + + let response_stream = client + .list_flights("other query") + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + let expected_request = Some(Criteria { + expression: "other query".into(), + }); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let actions = vec![ + ActionType { + r#type: "type 1".into(), + description: "awesomeness".into(), + }, + ActionType { + r#type: "type 2".into(), + description: "more awesomeness".into(), + }, + ]; + + let response = actions.iter().map(|i| Ok(i.clone())).collect(); + test_server.set_list_actions_response(response); + + let response_stream = client.list_actions().await.expect("error making request"); + + let expected_response = actions; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let response = client.list_actions().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No list_actions response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let response = vec![ + Ok(ActionType { + r#type: "type 1".into(), + description: "awesomeness".into(), + }), + Err(e.clone()), + ]; + test_server.set_list_actions_response(response); + + let response_stream = client.list_actions().await.expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let bytes = vec![Bytes::from("foo"), Bytes::from("blarg")]; + + let response = bytes + .iter() + .cloned() + .map(arrow_flight::Result::new) + .map(Ok) + .collect(); + test_server.set_do_action_response(response); + + let request = Action::new("action type", "action body"); + + let response_stream = client + .do_action(request.clone()) + .await + .expect("error making request"); + + let expected_response = bytes; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let request = Action::new("action type", "action body"); + + let response = client.do_action(request.clone()).await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No do_action response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let request = Action::new("action type", "action body"); + + let response = vec![Ok(arrow_flight::Result::new("foo")), Err(e.clone())]; + test_server.set_do_action_response(response); + + let response_stream = client + .do_action(request.clone()) + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +async fn test_flight_data() -> Vec { + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + // encode the batch as a stream of FlightData + FlightDataEncoderBuilder::new() + .build(futures::stream::iter(vec![Ok(batch)])) + .try_collect() + .await + .unwrap() +} + +async fn test_flight_data2() -> Vec { + let batch = RecordBatch::try_from_iter(vec![( + "col2", + Arc::new(UInt64Array::from_iter([10, 23, 33])) as _, + )]) + .unwrap(); + + // encode the batch as a stream of FlightData + FlightDataEncoderBuilder::new() + .build(futures::stream::iter(vec![Ok(batch)])) + .try_collect() + .await + .unwrap() +} + +/// Runs the future returned by the function, passing it a test server and client +async fn do_test(f: F) +where + F: Fn(TestFlightServer, FlightClient) -> Fut, + Fut: Future, +{ + let test_server = TestFlightServer::new(); + let fixture = TestFixture::new(&test_server).await; + let client = FlightClient::new(fixture.channel().await); + + // run the test function + f(test_server, client).await; + + // cleanly shutdown the test fixture + fixture.shutdown_and_wait().await +} + +fn expect_status(error: FlightError, expected: Status) { + let status = if let FlightError::Tonic(status) = error { + status + } else { + panic!("Expected FlightError::Tonic, got: {error:?}"); + }; + + assert_eq!( + status.code(), + expected.code(), + "Got {status:?} want {expected:?}" + ); + assert_eq!( + status.message(), + expected.message(), + "Got {status:?} want {expected:?}" + ); + assert_eq!( + status.details(), + expected.details(), + "Got {status:?} want {expected:?}" + ); +} + +/// Creates and manages a running TestServer with a background task +struct TestFixture { + /// channel to send shutdown command + shutdown: Option>, + + /// Address the server is listening on + addr: SocketAddr, + + // handle for the server task + handle: Option>>, +} + +impl TestFixture { + /// create a new test fixture from the server + pub async fn new(test_server: &TestFlightServer) -> Self { + // let OS choose a a free port + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + println!("Listening on {addr}"); + + // prepare the shutdown channel + let (tx, rx) = tokio::sync::oneshot::channel(); + + let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); + + let shutdown_future = async move { + rx.await.ok(); + }; + + let serve_future = tonic::transport::Server::builder() + .timeout(server_timeout) + .layer(TrailersLayer) + .add_service(test_server.service()) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_future, + ); + + // Run the server in its own background task + let handle = tokio::task::spawn(serve_future); + + Self { + shutdown: Some(tx), + addr, + handle: Some(handle), + } + } + + /// Return a [`Channel`] connected to the TestServer + pub async fn channel(&self) -> Channel { + let url = format!("http://{}", self.addr); + let uri: Uri = url.parse().expect("Valid URI"); + Channel::builder(uri) + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS)) + .connect() + .await + .expect("error connecting to server") + } + + /// Stops the test server and waits for the server to shutdown + pub async fn shutdown_and_wait(mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).expect("server quit early"); + } + if let Some(handle) = self.handle.take() { + println!("Waiting on server to finish"); + handle + .await + .expect("task join error (panic?)") + .expect("Server Error found at shutdown"); + } + } +} + +impl Drop for TestFixture { + fn drop(&mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).ok(); + } + if self.handle.is_some() { + // tests should properly clean up TestFixture + println!("TestFixture::Drop called prior to `shutdown_and_wait`"); + } + } +} diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs new file mode 100644 index 000000000000..8b162d398c4b --- /dev/null +++ b/arrow-flight/tests/common/server.rs @@ -0,0 +1,448 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, Mutex}; + +use arrow_array::RecordBatch; +use arrow_schema::Schema; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; + +use arrow_flight::{ + encode::FlightDataEncoderBuilder, + flight_service_server::{FlightService, FlightServiceServer}, + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaAsIpc, SchemaResult, Ticket, +}; + +#[derive(Debug, Clone)] +/// Flight server for testing, with configurable responses +pub struct TestFlightServer { + /// Shared state to configure responses + state: Arc>, +} + +impl TestFlightServer { + /// Create a `TestFlightServer` + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(State::new())), + } + } + + /// Return an [`FlightServiceServer`] that can be used with a + /// [`Server`](tonic::transport::Server) + pub fn service(&self) -> FlightServiceServer { + // wrap up tonic goop + FlightServiceServer::new(self.clone()) + } + + /// Specify the response returned from the next call to handshake + pub fn set_handshake_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.handshake_response.replace(response); + } + + /// Take and return last handshake request send to the server, + pub fn take_handshake_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .handshake_request + .take() + } + + /// Specify the response returned from the next call to handshake + pub fn set_get_flight_info_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.get_flight_info_response.replace(response); + } + + /// Take and return last get_flight_info request send to the server, + pub fn take_get_flight_info_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .get_flight_info_request + .take() + } + + /// Specify the response returned from the next call to `do_get` + pub fn set_do_get_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_get_response.replace(response); + } + + /// Take and return last do_get request send to the server, + pub fn take_do_get_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .do_get_request + .take() + } + + /// Specify the response returned from the next call to `do_put` + pub fn set_do_put_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_put_response.replace(response); + } + + /// Take and return last do_put request send to the server, + pub fn take_do_put_request(&self) -> Option> { + self.state + .lock() + .expect("mutex not poisoned") + .do_put_request + .take() + } + + /// Specify the response returned from the next call to `do_exchange` + pub fn set_do_exchange_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_exchange_response.replace(response); + } + + /// Take and return last do_exchange request send to the server, + pub fn take_do_exchange_request(&self) -> Option> { + self.state + .lock() + .expect("mutex not poisoned") + .do_exchange_request + .take() + } + + /// Specify the response returned from the next call to `list_flights` + pub fn set_list_flights_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.list_flights_response.replace(response); + } + + /// Take and return last list_flights request send to the server, + pub fn take_list_flights_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .list_flights_request + .take() + } + + /// Specify the response returned from the next call to `get_schema` + pub fn set_get_schema_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_schema_response.replace(response); + } + + /// Take and return last get_schema request send to the server, + pub fn take_get_schema_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .get_schema_request + .take() + } + + /// Specify the response returned from the next call to `list_actions` + pub fn set_list_actions_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.list_actions_response.replace(response); + } + + /// Take and return last list_actions request send to the server, + pub fn take_list_actions_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .list_actions_request + .take() + } + + /// Specify the response returned from the next call to `do_action` + pub fn set_do_action_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_action_response.replace(response); + } + + /// Take and return last do_action request send to the server, + pub fn take_do_action_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .do_action_request + .take() + } + + /// Returns the last metadata from a request received by the server + pub fn take_last_request_metadata(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .last_request_metadata + .take() + } + + /// Save the last request's metadatacom + fn save_metadata(&self, request: &Request) { + let metadata = request.metadata().clone(); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.last_request_metadata = Some(metadata); + } +} + +/// mutable state for the TestFlightServer, captures requests and provides responses +#[derive(Debug, Default)] +struct State { + /// The last handshake request that was received + pub handshake_request: Option, + /// The next response to return from `handshake()` + pub handshake_response: Option>, + /// The last `get_flight_info` request received + pub get_flight_info_request: Option, + /// the next response to return from `get_flight_info` + pub get_flight_info_response: Option>, + /// The last do_get request received + pub do_get_request: Option, + /// The next response returned from `do_get` + pub do_get_response: Option>>, + /// The last do_put request received + pub do_put_request: Option>, + /// The next response returned from `do_put` + pub do_put_response: Option>>, + /// The last do_exchange request received + pub do_exchange_request: Option>, + /// The next response returned from `do_exchange` + pub do_exchange_response: Option>>, + /// The last list_flights request received + pub list_flights_request: Option, + /// The next response returned from `list_flights` + pub list_flights_response: Option>>, + /// The last get_schema request received + pub get_schema_request: Option, + /// The next response returned from `get_schema` + pub get_schema_response: Option>, + /// The last list_actions request received + pub list_actions_request: Option, + /// The next response returned from `list_actions` + pub list_actions_response: Option>>, + /// The last do_action request received + pub do_action_request: Option, + /// The next response returned from `do_action` + pub do_action_response: Option>>, + /// The last request headers received + pub last_request_metadata: Option, +} + +impl State { + fn new() -> Self { + Default::default() + } +} + +/// Implement the FlightService trait +#[tonic::async_trait] +impl FlightService for TestFlightServer { + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; + + async fn handshake( + &self, + request: Request>, + ) -> Result, Status> { + self.save_metadata(&request); + let handshake_request = request.into_inner().message().await?.unwrap(); + + let mut state = self.state.lock().expect("mutex not poisoned"); + state.handshake_request = Some(handshake_request); + + let response = state + .handshake_response + .take() + .unwrap_or_else(|| Err(Status::internal("No handshake response configured")))?; + + // turn into a streaming response + let output = futures::stream::iter(std::iter::once(Ok(response))); + Ok(Response::new(output.boxed())) + } + + async fn list_flights( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.list_flights_request = Some(request.into_inner()); + + let flights: Vec<_> = state + .list_flights_response + .take() + .ok_or_else(|| Status::internal("No list_flights response configured"))?; + + let flights_stream = futures::stream::iter(flights); + + Ok(Response::new(flights_stream.boxed())) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_flight_info_request = Some(request.into_inner()); + let response = state + .get_flight_info_response + .take() + .unwrap_or_else(|| Err(Status::internal("No get_flight_info response configured")))?; + Ok(Response::new(response)) + } + + async fn get_schema( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_schema_request = Some(request.into_inner()); + let schema = state + .get_schema_response + .take() + .unwrap_or_else(|| Err(Status::internal("No get_schema response configured")))?; + + // encode the schema + let options = arrow_ipc::writer::IpcWriteOptions::default(); + let response: SchemaResult = SchemaAsIpc::new(&schema, &options) + .try_into() + .expect("Error encoding schema"); + + Ok(Response::new(response)) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_get_request = Some(request.into_inner()); + + let batches: Vec<_> = state + .do_get_response + .take() + .ok_or_else(|| Status::internal("No do_get response configured"))?; + + let batch_stream = futures::stream::iter(batches).map_err(Into::into); + + let stream = FlightDataEncoderBuilder::new() + .build(batch_stream) + .map_err(Into::into); + + let mut resp = Response::new(stream.boxed()); + resp.metadata_mut() + .insert("test-resp-header", "some_val".parse().unwrap()); + + Ok(resp) + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + self.save_metadata(&request); + let do_put_request: Vec<_> = request.into_inner().try_collect().await?; + + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_put_request = Some(do_put_request); + + let response = state + .do_put_response + .take() + .ok_or_else(|| Status::internal("No do_put response configured"))?; + + let stream = futures::stream::iter(response).map_err(Into::into); + + Ok(Response::new(stream.boxed())) + } + + async fn do_action( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_action_request = Some(request.into_inner()); + + let results: Vec<_> = state + .do_action_response + .take() + .ok_or_else(|| Status::internal("No do_action response configured"))?; + + let results_stream = futures::stream::iter(results); + + Ok(Response::new(results_stream.boxed())) + } + + async fn list_actions( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.list_actions_request = Some(request.into_inner()); + + let actions: Vec<_> = state + .list_actions_response + .take() + .ok_or_else(|| Status::internal("No list_actions response configured"))?; + + let action_stream = futures::stream::iter(actions); + + Ok(Response::new(action_stream.boxed())) + } + + async fn do_exchange( + &self, + request: Request>, + ) -> Result, Status> { + self.save_metadata(&request); + let do_exchange_request: Vec<_> = request.into_inner().try_collect().await?; + + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_exchange_request = Some(do_exchange_request); + + let response = state + .do_exchange_response + .take() + .ok_or_else(|| Status::internal("No do_exchange response configured"))?; + + let stream = futures::stream::iter(response).map_err(Into::into); + + Ok(Response::new(stream.boxed())) + } +} diff --git a/arrow-flight/tests/common/trailers_layer.rs b/arrow-flight/tests/common/trailers_layer.rs new file mode 100644 index 000000000000..b2ab74f7d925 --- /dev/null +++ b/arrow-flight/tests/common/trailers_layer.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::ready; +use http::{HeaderValue, Request, Response}; +use http_body::SizeHint; +use pin_project_lite::pin_project; +use tower::{Layer, Service}; + +#[derive(Debug, Copy, Clone, Default)] +pub struct TrailersLayer; + +impl Layer for TrailersLayer { + type Service = TrailersService; + + fn layer(&self, service: S) -> Self::Service { + TrailersService { service } + } +} + +#[derive(Debug, Clone)] +pub struct TrailersService { + service: S, +} + +impl Service> for TrailersService +where + S: Service, Response = Response>, + ResBody: http_body::Body, +{ + type Response = Response>; + type Error = S::Error; + type Future = WrappedFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + WrappedFuture { + inner: self.service.call(request), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedFuture { + #[pin] + inner: F, + } +} + +impl Future for WrappedFuture +where + F: Future, Error>>, + ResBody: http_body::Body, +{ + type Output = Result>, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result: Result, Error> = + ready!(self.as_mut().project().inner.poll(cx)); + + match result { + Ok(response) => Poll::Ready(Ok(response.map(|body| WrappedBody { inner: body }))), + Err(e) => Poll::Ready(Err(e)), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedBody { + #[pin] + inner: B, + } +} + +impl http_body::Body for WrappedBody { + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.as_mut().project().inner.poll_data(cx) + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let result: Result, Self::Error> = + ready!(self.as_mut().project().inner.poll_trailers(cx)); + + let mut trailers = http::header::HeaderMap::new(); + trailers.insert("test-trailer", HeaderValue::from_static("trailer_val")); + + match result { + Ok(Some(mut existing)) => { + existing.extend(trailers.iter().map(|(k, v)| (k.clone(), v.clone()))); + Poll::Ready(Ok(Some(existing))) + } + Ok(None) => Poll::Ready(Ok(Some(trailers))), + Err(e) => Poll::Ready(Err(e)), + } + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +} diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs new file mode 100644 index 000000000000..f4741d743e57 --- /dev/null +++ b/arrow-flight/tests/encode_decode.rs @@ -0,0 +1,539 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for round trip encoding / decoding + +use std::{collections::HashMap, sync::Arc}; + +use arrow_array::types::Int32Type; +use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array}; +use arrow_cast::pretty::pretty_format_batches; +use arrow_flight::flight_descriptor::DescriptorType; +use arrow_flight::FlightDescriptor; +use arrow_flight::{ + decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream}, + encode::FlightDataEncoderBuilder, + error::FlightError, +}; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use bytes::Bytes; +use futures::{StreamExt, TryStreamExt}; + +#[tokio::test] +async fn test_empty() { + roundtrip(vec![]).await; +} + +#[tokio::test] +async fn test_empty_batch() { + let batch = make_primitive_batch(5); + let empty = RecordBatch::new_empty(batch.schema()); + roundtrip(vec![empty]).await; +} + +#[tokio::test] +async fn test_error() { + let input_batch_stream = + futures::stream::iter(vec![Err(FlightError::NotYetImplemented("foo".into()))]); + + let encoder = FlightDataEncoderBuilder::default(); + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, _> = decode_stream.try_collect().await; + + let result = result.unwrap_err(); + assert_eq!(result.to_string(), r#"NotYetImplemented("foo")"#); +} + +#[tokio::test] +async fn test_primitive_one() { + roundtrip(vec![make_primitive_batch(5)]).await; +} + +#[tokio::test] +async fn test_schema_metadata() { + let batch = make_primitive_batch(5); + let metadata = HashMap::from([("some_key".to_owned(), "some_value".to_owned())]); + + // create a batch that has schema level metadata + let schema = Arc::new(batch.schema().as_ref().clone().with_metadata(metadata)); + let batch = RecordBatch::try_new(schema, batch.columns().to_vec()).unwrap(); + + roundtrip(vec![batch]).await; +} + +#[tokio::test] +async fn test_primitive_many() { + roundtrip(vec![ + make_primitive_batch(1), + make_primitive_batch(7), + make_primitive_batch(32), + ]) + .await; +} + +#[tokio::test] +async fn test_primitive_empty() { + let batch = make_primitive_batch(5); + let empty = RecordBatch::new_empty(batch.schema()); + + roundtrip(vec![batch, empty]).await; +} + +#[tokio::test] +async fn test_dictionary_one() { + roundtrip_dictionary(vec![make_dictionary_batch(5)]).await; +} + +#[tokio::test] +async fn test_dictionary_many() { + roundtrip_dictionary(vec![ + make_dictionary_batch(5), + make_dictionary_batch(9), + make_dictionary_batch(5), + make_dictionary_batch(5), + ]) + .await; +} + +#[tokio::test] +async fn test_zero_batches_no_schema() { + let stream = FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![])); + + let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream); + assert!(decoder.schema().is_none()); + // No batches come out + assert!(decoder.next().await.is_none()); + // schema has not been received + assert!(decoder.schema().is_none()); +} + +#[tokio::test] +async fn test_zero_batches_schema_specified() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let stream = FlightDataEncoderBuilder::default() + .with_schema(schema.clone()) + .build(futures::stream::iter(vec![])); + + let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream); + assert!(decoder.schema().is_none()); + // No batches come out + assert!(decoder.next().await.is_none()); + // But schema has been received correctly + assert_eq!(decoder.schema(), Some(&schema)); +} + +#[tokio::test] +async fn test_with_flight_descriptor() { + let stream = futures::stream::iter(vec![Ok(make_dictionary_batch(5))]); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + + let descriptor = Some(FlightDescriptor { + r#type: DescriptorType::Path.into(), + path: vec!["table_name".to_string()], + cmd: Bytes::default(), + }); + + let encoder = FlightDataEncoderBuilder::default() + .with_schema(schema.clone()) + .with_flight_descriptor(descriptor.clone()); + + let mut encoder = encoder.build(stream); + + // First batch should be the schema + let first_batch = encoder.next().await.unwrap().unwrap(); + + assert_eq!(first_batch.flight_descriptor, descriptor); +} + +#[tokio::test] +async fn test_zero_batches_dictionary_schema_specified() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new_dictionary("b", DataType::Int32, DataType::Utf8, false), + ])); + + // Expect dictionary to be hydrated in output (#3389) + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + ])); + let stream = FlightDataEncoderBuilder::default() + .with_schema(schema.clone()) + .build(futures::stream::iter(vec![])); + + let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream); + assert!(decoder.schema().is_none()); + // No batches come out + assert!(decoder.next().await.is_none()); + // But schema has been received correctly + assert_eq!(decoder.schema(), Some(&expected_schema)); +} + +#[tokio::test] +async fn test_app_metadata() { + let input_batch_stream = futures::stream::iter(vec![Ok(make_primitive_batch(78))]); + + let app_metadata = Bytes::from("My Metadata"); + let encoder = FlightDataEncoderBuilder::default().with_metadata(app_metadata.clone()); + + let encode_stream = encoder.build(input_batch_stream); + + // use lower level stream to get access to app metadata + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + + let mut messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); + + println!("{messages:#?}"); + + // expect that the app metadata made it through on the schema message + assert_eq!(messages.len(), 2); + let message2 = messages.pop().unwrap(); + let message1 = messages.pop().unwrap(); + + assert_eq!(message1.app_metadata(), app_metadata); + assert!(matches!(message1.payload, DecodedPayload::Schema(_))); + + // but not on the data + assert_eq!(message2.app_metadata(), Bytes::new()); + assert!(matches!(message2.payload, DecodedPayload::RecordBatch(_))); +} + +#[tokio::test] +async fn test_max_message_size() { + let input_batch_stream = futures::stream::iter(vec![Ok(make_primitive_batch(5))]); + + // 5 input rows, with a very small limit should result in 5 batch messages + let encoder = FlightDataEncoderBuilder::default().with_max_flight_data_size(1); + + let encode_stream = encoder.build(input_batch_stream); + + // use lower level stream to get access to app metadata + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + + let messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); + + println!("{messages:#?}"); + + assert_eq!(messages.len(), 6); + assert!(matches!(messages[0].payload, DecodedPayload::Schema(_))); + for message in messages.iter().skip(1) { + assert!(matches!(message.payload, DecodedPayload::RecordBatch(_))); + } +} + +#[tokio::test] +async fn test_max_message_size_fuzz() { + // send through batches of varying sizes with various max + // batch sizes and ensure the data gets through ok + let input = vec![ + make_primitive_batch(123), + make_primitive_batch(17), + make_primitive_batch(201), + make_primitive_batch(2), + make_primitive_batch(1), + make_primitive_batch(11), + make_primitive_batch(127), + ]; + + for max_message_size_bytes in [10, 1024, 2048, 6400, 3211212] { + let encoder = + FlightDataEncoderBuilder::default().with_max_flight_data_size(max_message_size_bytes); + + let input_batch_stream = futures::stream::iter(input.clone()).map(Ok); + + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let output: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); + + for b in &output { + assert_eq!(b.schema(), input[0].schema()); + } + + let a = pretty_format_batches(&input).unwrap().to_string(); + let b = pretty_format_batches(&output).unwrap().to_string(); + assert_eq!(a, b); + } +} + +#[tokio::test] +async fn test_mismatched_record_batch_schema() { + // send 2 batches with different schemas + let input_batch_stream = futures::stream::iter(vec![ + Ok(make_primitive_batch(5)), + Ok(make_dictionary_batch(3)), + ]); + + let encoder = FlightDataEncoderBuilder::default(); + let encode_stream = encoder.build(input_batch_stream); + + let result: Result, FlightError> = encode_stream.try_collect().await; + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + "Arrow(InvalidArgumentError(\"number of columns(1) must match number of fields(2) in schema\"))" + ); +} + +#[tokio::test] +async fn test_chained_streams_batch_decoder() { + let batch1 = make_primitive_batch(5); + let batch2 = make_dictionary_batch(3); + + // Model sending two flight streams back to back, with different schemas + let encode_stream1 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())])); + + // append the two streams (so they will have two different schema messages) + let encode_stream = encode_stream1.chain(encode_stream2); + + // FlightRecordBatchStream errors if the schema changes + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, FlightError> = decode_stream.try_collect().await; + + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + "ProtocolError(\"Unexpectedly saw multiple Schema messages in FlightData stream\")" + ); +} + +#[tokio::test] +async fn test_chained_streams_data_decoder() { + let batch1 = make_primitive_batch(5); + let batch2 = make_dictionary_batch(3); + + // Model sending two flight streams back to back, with different schemas + let encode_stream1 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = + FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![Ok(batch2.clone())])); + + // append the two streams (so they will have two different schema messages) + let encode_stream = encode_stream1.chain(encode_stream2); + + // lower level decode stream can handle multiple schema messages + let decode_stream = FlightDataDecoder::new(encode_stream); + + let decoded_data: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); + + println!("decoded data: {decoded_data:#?}"); + + // expect two schema messages with the data + assert_eq!(decoded_data.len(), 4); + assert!(matches!(decoded_data[0].payload, DecodedPayload::Schema(_))); + assert!(matches!( + decoded_data[1].payload, + DecodedPayload::RecordBatch(_) + )); + assert!(matches!(decoded_data[2].payload, DecodedPayload::Schema(_))); + assert!(matches!( + decoded_data[3].payload, + DecodedPayload::RecordBatch(_) + )); +} + +#[tokio::test] +async fn test_mismatched_schema_message() { + // Model sending schema that is mismatched with the data + // and expect an error + async fn do_test(batch1: RecordBatch, batch2: RecordBatch, expected: &str) { + let encode_stream1 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch1.clone())])) + // take only schema message from first stream + .take(1); + let encode_stream2 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch2.clone())])) + // take only data message from second + .skip(1); + + // append the two streams + let encode_stream = encode_stream1.chain(encode_stream2); + + // FlightRecordBatchStream errors if the schema changes + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, FlightError> = decode_stream.try_collect().await; + + let err = result.unwrap_err().to_string(); + assert!( + err.contains(expected), + "could not find '{expected}' in '{err}'" + ); + } + + // primitive batch first (has more columns) + do_test( + make_primitive_batch(5), + make_dictionary_batch(3), + "Error decoding ipc RecordBatch: Schema error: Invalid data for schema", + ) + .await; + + // dictionary batch first + do_test( + make_dictionary_batch(3), + make_primitive_batch(5), + "Error decoding ipc RecordBatch: Invalid argument error", + ) + .await; +} + +/// Make a primitive batch for testing +/// +/// Example: +/// i: 0, 1, None, 3, 4 +/// f: 5.0, 4.0, None, 2.0, 1.0 +fn make_primitive_batch(num_rows: usize) -> RecordBatch { + let i: UInt8Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some(i.try_into().unwrap()) + } + }) + .collect(); + + let f: Float64Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some((num_rows - i) as f64) + } + }) + .collect(); + + RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap() +} + +/// Make a dictionary batch for testing +/// +/// Example: +/// a: value0, value1, value2, None, value1, value2 +fn make_dictionary_batch(num_rows: usize) -> RecordBatch { + let values: Vec<_> = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + // repeat some values for low cardinality + let v = i / 3; + Some(format!("value{v}")) + } + }) + .collect(); + + let a: DictionaryArray = values + .iter() + .map(|s| s.as_ref().map(|s| s.as_str())) + .collect(); + + RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap() +} + +/// Encodes input as a FlightData stream, and then decodes it using +/// FlightRecordBatchStream and valides the decoded record batches +/// match the input. +async fn roundtrip(input: Vec) { + let expected_output = input.clone(); + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await +} + +/// Encodes input as a FlightData stream, and then decodes it using +/// FlightRecordBatchStream and valides the decoded record batches +/// match the expected input. +/// +/// When is resolved, +/// it should be possible to use `roundtrip` +async fn roundtrip_dictionary(input: Vec) { + let schema = Arc::new(prepare_schema_for_flight(&input[0].schema())); + let expected_output: Vec<_> = input + .iter() + .map(|batch| prepare_batch_for_flight(batch, schema.clone()).unwrap()) + .collect(); + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output).await +} + +async fn roundtrip_with_encoder( + encoder: FlightDataEncoderBuilder, + input_batches: Vec, + expected_batches: Vec, +) { + println!("Round tripping with encoder:\n{encoder:#?}"); + + let input_batch_stream = futures::stream::iter(input_batches.clone()).map(Ok); + + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let output_batches: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); + + // remove any empty batches from input as they are not transmitted + let expected_batches: Vec<_> = expected_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect(); + + assert_eq!(expected_batches, output_batches); +} + +/// Workaround for https://github.com/apache/arrow-rs/issues/1206 +fn prepare_schema_for_flight(schema: &Schema) -> Schema { + let fields: Fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + _ => field.as_ref().clone(), + }) + .collect(); + + Schema::new(fields) +} + +/// Workaround for https://github.com/apache/arrow-rs/issues/1206 +fn prepare_batch_for_flight( + batch: &RecordBatch, + schema: SchemaRef, +) -> Result { + let columns = batch + .columns() + .iter() + .map(hydrate_dictionary) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(schema, columns)?) +} + +fn hydrate_dictionary(array: &ArrayRef) -> Result { + let arr = if let DataType::Dictionary(_, value) = array.data_type() { + arrow_cast::cast(array, value)? + } else { + Arc::clone(array) + }; + Ok(arr) +} diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs new file mode 100644 index 000000000000..a28080450bc2 --- /dev/null +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -0,0 +1,722 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; + +use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; +use arrow_flight::{ + decode::FlightRecordBatchStream, + flight_service_server::{FlightService, FlightServiceServer}, + sql::{ + server::{FlightSqlService, PeekableFlightDataStream}, + ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, + ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, + ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, + ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, + CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, + CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, + CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, + CommandStatementUpdate, ProstMessageExt, SqlInfo, TicketStatementQuery, + }, + utils::batches_to_flight_data, + Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket, +}; +use arrow_ipc::writer::IpcWriteOptions; +use arrow_schema::{ArrowError, DataType, Field, Schema}; +use assert_cmd::Command; +use bytes::Bytes; +use futures::{Stream, StreamExt, TryStreamExt}; +use prost::Message; +use tokio::{net::TcpListener, task::JoinHandle}; +use tonic::{Request, Response, Status, Streaming}; + +const QUERY: &str = "SELECT * FROM table;"; + +#[tokio::test] +async fn test_simple() { + let test_server = FlightSqlServiceImpl {}; + let fixture = TestFixture::new(&test_server).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("statement-query") + .arg(QUERY) + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+-----------+\ + \n| field_string | field_int |\ + \n+--------------+-----------+\ + \n| Hello | 42 |\ + \n| lovely | |\ + \n| FlightSQL! | 1337 |\ + \n+--------------+-----------+", + ); +} + +const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1"; +const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle"; + +#[tokio::test] +async fn test_do_put_prepared_statement() { + let test_server = FlightSqlServiceImpl {}; + let fixture = TestFixture::new(&test_server).await; + let addr = fixture.addr; + + let stdout = tokio::task::spawn_blocking(move || { + Command::cargo_bin("flight_sql_client") + .unwrap() + .env_clear() + .env("RUST_BACKTRACE", "1") + .env("RUST_LOG", "warn") + .arg("--host") + .arg(addr.ip().to_string()) + .arg("--port") + .arg(addr.port().to_string()) + .arg("prepared-statement-query") + .arg(PREPARED_QUERY) + .args(["-p", "$1=string"]) + .args(["-p", "$2=64"]) + .assert() + .success() + .get_output() + .stdout + .clone() + }) + .await + .unwrap(); + + fixture.shutdown_and_wait().await; + + assert_eq!( + std::str::from_utf8(&stdout).unwrap().trim(), + "+--------------+-----------+\ + \n| field_string | field_int |\ + \n+--------------+-----------+\ + \n| Hello | 42 |\ + \n| lovely | |\ + \n| FlightSQL! | 1337 |\ + \n+--------------+-----------+", + ); +} + +/// All tests must complete within this many seconds or else the test server is shutdown +const DEFAULT_TIMEOUT_SECONDS: u64 = 30; + +#[derive(Clone, Default)] +pub struct FlightSqlServiceImpl {} + +impl FlightSqlServiceImpl { + /// Return an [`FlightServiceServer`] that can be used with a + /// [`Server`](tonic::transport::Server) + pub fn service(&self) -> FlightServiceServer { + // wrap up tonic goop + FlightServiceServer::new(self.clone()) + } + + fn fake_result() -> Result { + let schema = Schema::new(vec![ + Field::new("field_string", DataType::Utf8, false), + Field::new("field_int", DataType::Int64, true), + ]); + + let string_array = StringArray::from(vec!["Hello", "lovely", "FlightSQL!"]); + let int_array = Int64Array::from(vec![Some(42), None, Some(1337)]); + + let cols = vec![ + Arc::new(string_array) as ArrayRef, + Arc::new(int_array) as ArrayRef, + ]; + RecordBatch::try_new(Arc::new(schema), cols) + } + + fn create_fake_prepared_stmt() -> Result { + let handle = PREPARED_STATEMENT_HANDLE.to_string(); + let schema = Schema::new(vec![ + Field::new("field_string", DataType::Utf8, false), + Field::new("field_int", DataType::Int64, true), + ]); + + let parameter_schema = Schema::new(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]); + + Ok(ActionCreatePreparedStatementResult { + prepared_statement_handle: handle.into(), + dataset_schema: serialize_schema(&schema)?, + parameter_schema: serialize_schema(¶meter_schema)?, + }) + } + + fn fake_flight_info(&self) -> Result { + let batch = Self::fake_result()?; + + Ok(FlightInfo::new() + .try_with_schema(&batch.schema()) + .expect("encoding schema") + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_1"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_endpoint( + FlightEndpoint::new().with_ticket(Ticket::new( + FetchResults { + handle: String::from("part_2"), + } + .as_any() + .encode_to_vec(), + )), + ) + .with_total_records(batch.num_rows() as i64) + .with_total_bytes(batch.get_array_memory_size() as i64) + .with_ordered(false)) + } +} + +fn serialize_schema(schema: &Schema) -> Result { + Ok(IpcMessage::try_from(SchemaAsIpc::new(schema, &IpcWriteOptions::default()))?.0) +} + +#[tonic::async_trait] +impl FlightSqlService for FlightSqlServiceImpl { + type FlightService = FlightSqlServiceImpl; + + async fn do_handshake( + &self, + _request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + Err(Status::unimplemented("do_handshake not implemented")) + } + + async fn do_get_fallback( + &self, + _request: Request, + message: Any, + ) -> Result::DoGetStream>, Status> { + let part = message.unpack::().unwrap().unwrap().handle; + let batch = Self::fake_result().unwrap(); + let batch = match part.as_str() { + "part_1" => batch.slice(0, 2), + "part_2" => batch.slice(2, 1), + ticket => panic!("Invalid ticket: {ticket:?}"), + }; + let schema = batch.schema(); + let batches = vec![batch]; + let flight_data = batches_to_flight_data(schema.as_ref(), batches) + .unwrap() + .into_iter() + .map(Ok); + + let stream: Pin> + Send>> = + Box::pin(futures::stream::iter(flight_data)); + let resp = Response::new(stream); + Ok(resp) + } + + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + _request: Request, + ) -> Result, Status> { + assert_eq!(query.query, QUERY); + + let resp = Response::new(self.fake_flight_info().unwrap()); + Ok(resp) + } + + async fn get_flight_info_prepared_statement( + &self, + cmd: CommandPreparedStatementQuery, + _request: Request, + ) -> Result, Status> { + assert_eq!( + cmd.prepared_statement_handle, + PREPARED_STATEMENT_HANDLE.as_bytes() + ); + let resp = Response::new(self.fake_flight_info().unwrap()); + Ok(resp) + } + + async fn get_flight_info_substrait_plan( + &self, + _query: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_substrait_plan not implemented", + )) + } + + async fn get_flight_info_catalogs( + &self, + _query: CommandGetCatalogs, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_catalogs not implemented", + )) + } + + async fn get_flight_info_schemas( + &self, + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_schemas not implemented", + )) + } + + async fn get_flight_info_tables( + &self, + _query: CommandGetTables, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_tables not implemented", + )) + } + + async fn get_flight_info_table_types( + &self, + _query: CommandGetTableTypes, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_table_types not implemented", + )) + } + + async fn get_flight_info_sql_info( + &self, + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_sql_info not implemented", + )) + } + + async fn get_flight_info_primary_keys( + &self, + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_primary_keys not implemented", + )) + } + + async fn get_flight_info_exported_keys( + &self, + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_exported_keys not implemented", + )) + } + + async fn get_flight_info_imported_keys( + &self, + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_imported_keys not implemented", + )) + } + + async fn get_flight_info_cross_reference( + &self, + _query: CommandGetCrossReference, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_imported_keys not implemented", + )) + } + + async fn get_flight_info_xdbc_type_info( + &self, + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "get_flight_info_xdbc_type_info not implemented", + )) + } + + // do_get + async fn do_get_statement( + &self, + _ticket: TicketStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("do_get_statement not implemented")) + } + + async fn do_get_prepared_statement( + &self, + _query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_prepared_statement not implemented", + )) + } + + async fn do_get_catalogs( + &self, + _query: CommandGetCatalogs, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("do_get_catalogs not implemented")) + } + + async fn do_get_schemas( + &self, + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("do_get_schemas not implemented")) + } + + async fn do_get_tables( + &self, + _query: CommandGetTables, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("do_get_tables not implemented")) + } + + async fn do_get_table_types( + &self, + _query: CommandGetTableTypes, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("do_get_table_types not implemented")) + } + + async fn do_get_sql_info( + &self, + _query: CommandGetSqlInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("do_get_sql_info not implemented")) + } + + async fn do_get_primary_keys( + &self, + _query: CommandGetPrimaryKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("do_get_primary_keys not implemented")) + } + + async fn do_get_exported_keys( + &self, + _query: CommandGetExportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_exported_keys not implemented", + )) + } + + async fn do_get_imported_keys( + &self, + _query: CommandGetImportedKeys, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_imported_keys not implemented", + )) + } + + async fn do_get_cross_reference( + &self, + _query: CommandGetCrossReference, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_cross_reference not implemented", + )) + } + + async fn do_get_xdbc_type_info( + &self, + _query: CommandGetXdbcTypeInfo, + _request: Request, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented( + "do_get_xdbc_type_info not implemented", + )) + } + + // do_put + async fn do_put_statement_update( + &self, + _ticket: CommandStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_statement_update not implemented", + )) + } + + async fn do_put_substrait_plan( + &self, + _ticket: CommandStatementSubstraitPlan, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_substrait_plan not implemented", + )) + } + + async fn do_put_prepared_statement_query( + &self, + _query: CommandPreparedStatementQuery, + request: Request, + ) -> Result::DoPutStream>, Status> { + // just make sure decoding the parameters works + let parameters = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect::>() + .await?; + + for (left, right) in parameters[0].schema().all_fields().iter().zip(vec![ + Field::new("$1", DataType::Utf8, false), + Field::new("$2", DataType::Int64, true), + ]) { + if left.name() != right.name() || left.data_type() != right.data_type() { + return Err(Status::invalid_argument(format!( + "Parameters did not match parameter schema\ngot {}", + parameters[0].schema(), + ))); + } + } + + Ok(Response::new( + futures::stream::once(async { Ok(PutResult::default()) }).boxed(), + )) + } + + async fn do_put_prepared_statement_update( + &self, + _query: CommandPreparedStatementUpdate, + _request: Request, + ) -> Result { + Err(Status::unimplemented( + "do_put_prepared_statement_update not implemented", + )) + } + + async fn do_action_create_prepared_statement( + &self, + _query: ActionCreatePreparedStatementRequest, + _request: Request, + ) -> Result { + Self::create_fake_prepared_stmt() + .map_err(|e| Status::internal(format!("Unable to serialize schema: {e}"))) + } + + async fn do_action_close_prepared_statement( + &self, + _query: ActionClosePreparedStatementRequest, + _request: Request, + ) -> Result<(), Status> { + unimplemented!("Implement do_action_close_prepared_statement") + } + + async fn do_action_create_prepared_substrait_plan( + &self, + _query: ActionCreatePreparedSubstraitPlanRequest, + _request: Request, + ) -> Result { + unimplemented!("Implement do_action_create_prepared_substrait_plan") + } + + async fn do_action_begin_transaction( + &self, + _query: ActionBeginTransactionRequest, + _request: Request, + ) -> Result { + unimplemented!("Implement do_action_begin_transaction") + } + + async fn do_action_end_transaction( + &self, + _query: ActionEndTransactionRequest, + _request: Request, + ) -> Result<(), Status> { + unimplemented!("Implement do_action_end_transaction") + } + + async fn do_action_begin_savepoint( + &self, + _query: ActionBeginSavepointRequest, + _request: Request, + ) -> Result { + unimplemented!("Implement do_action_begin_savepoint") + } + + async fn do_action_end_savepoint( + &self, + _query: ActionEndSavepointRequest, + _request: Request, + ) -> Result<(), Status> { + unimplemented!("Implement do_action_end_savepoint") + } + + async fn do_action_cancel_query( + &self, + _query: ActionCancelQueryRequest, + _request: Request, + ) -> Result { + unimplemented!("Implement do_action_cancel_query") + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} +} + +/// Creates and manages a running TestServer with a background task +struct TestFixture { + /// channel to send shutdown command + shutdown: Option>, + + /// Address the server is listening on + addr: SocketAddr, + + // handle for the server task + handle: Option>>, +} + +impl TestFixture { + /// create a new test fixture from the server + pub async fn new(test_server: &FlightSqlServiceImpl) -> Self { + // let OS choose a a free port + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + println!("Listening on {addr}"); + + // prepare the shutdown channel + let (tx, rx) = tokio::sync::oneshot::channel(); + + let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); + + let shutdown_future = async move { + rx.await.ok(); + }; + + let serve_future = tonic::transport::Server::builder() + .timeout(server_timeout) + .add_service(test_server.service()) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_future, + ); + + // Run the server in its own background task + let handle = tokio::task::spawn(serve_future); + + Self { + shutdown: Some(tx), + addr, + handle: Some(handle), + } + } + + /// Stops the test server and waits for the server to shutdown + pub async fn shutdown_and_wait(mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).expect("server quit early"); + } + if let Some(handle) = self.handle.take() { + println!("Waiting on server to finish"); + handle + .await + .expect("task join error (panic?)") + .expect("Server Error found at shutdown"); + } + } +} + +impl Drop for TestFixture { + fn drop(&mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).ok(); + } + if self.handle.is_some() { + // tests should properly clean up TestFixture + println!("TestFixture::Drop called prior to `shutdown_and_wait`"); + } + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchResults { + #[prost(string, tag = "1")] + pub handle: ::prost::alloc::string::String, +} + +impl ProstMessageExt for FetchResults { + fn type_url() -> &'static str { + "type.googleapis.com/arrow.flight.protocol.sql.FetchResults" + } + + fn as_any(&self) -> Any { + Any { + type_url: FetchResults::type_url().to_string(), + value: ::prost::Message::encode_to_vec(self).into(), + } + } +} diff --git a/arrow-integration-test/Cargo.toml b/arrow-integration-test/Cargo.toml new file mode 100644 index 000000000000..8afbfacff7c3 --- /dev/null +++ b/arrow-integration-test/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-integration-test" +version = { workspace = true } +description = "Support for the Apache Arrow JSON test data format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_integration_test" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow = { workspace = true } +arrow-buffer = { workspace = true } +hex = { version = "0.4", default-features = false, features = ["std"] } +serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } +serde_json = { version = "1.0", default-features = false, features = ["std"] } +num = { version = "0.4", default-features = false, features = ["std"] } + +[build-dependencies] diff --git a/integration-testing/data/integration.json b/arrow-integration-test/data/integration.json similarity index 100% rename from integration-testing/data/integration.json rename to arrow-integration-test/data/integration.json diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs new file mode 100644 index 000000000000..42ac71fbbd7e --- /dev/null +++ b/arrow-integration-test/src/datatype.rs @@ -0,0 +1,367 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit, UnionMode}; +use arrow::error::{ArrowError, Result}; +use std::sync::Arc; + +/// Parse a data type from a JSON representation. +pub fn data_type_from_json(json: &serde_json::Value) -> Result { + use serde_json::Value; + let default_field = Arc::new(Field::new("", DataType::Boolean, true)); + match *json { + Value::Object(ref map) => match map.get("name") { + Some(s) if s == "null" => Ok(DataType::Null), + Some(s) if s == "bool" => Ok(DataType::Boolean), + Some(s) if s == "binary" => Ok(DataType::Binary), + Some(s) if s == "largebinary" => Ok(DataType::LargeBinary), + Some(s) if s == "utf8" => Ok(DataType::Utf8), + Some(s) if s == "largeutf8" => Ok(DataType::LargeUtf8), + Some(s) if s == "fixedsizebinary" => { + // return a list with any type as its child isn't defined in the map + if let Some(Value::Number(size)) = map.get("byteWidth") { + Ok(DataType::FixedSizeBinary(size.as_i64().unwrap() as i32)) + } else { + Err(ArrowError::ParseError( + "Expecting a byteWidth for fixedsizebinary".to_string(), + )) + } + } + Some(s) if s == "decimal" => { + // return a list with any type as its child isn't defined in the map + let precision = match map.get("precision") { + Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()), + None => Err(ArrowError::ParseError( + "Expecting a precision for decimal".to_string(), + )), + }?; + let scale = match map.get("scale") { + Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()), + _ => Err(ArrowError::ParseError( + "Expecting a scale for decimal".to_string(), + )), + }?; + let bit_width: usize = match map.get("bitWidth") { + Some(b) => b.as_u64().unwrap() as usize, + _ => 128, // Default bit width + }; + + if bit_width == 128 { + Ok(DataType::Decimal128(precision, scale)) + } else if bit_width == 256 { + Ok(DataType::Decimal256(precision, scale)) + } else { + Err(ArrowError::ParseError( + "Decimal bit_width invalid".to_string(), + )) + } + } + Some(s) if s == "floatingpoint" => match map.get("precision") { + Some(p) if p == "HALF" => Ok(DataType::Float16), + Some(p) if p == "SINGLE" => Ok(DataType::Float32), + Some(p) if p == "DOUBLE" => Ok(DataType::Float64), + _ => Err(ArrowError::ParseError( + "floatingpoint precision missing or invalid".to_string(), + )), + }, + Some(s) if s == "timestamp" => { + let unit = match map.get("unit") { + Some(p) if p == "SECOND" => Ok(TimeUnit::Second), + Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond), + Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond), + Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond), + _ => Err(ArrowError::ParseError( + "timestamp unit missing or invalid".to_string(), + )), + }; + let tz = match map.get("timezone") { + None => Ok(None), + Some(Value::String(tz)) => Ok(Some(tz.as_str().into())), + _ => Err(ArrowError::ParseError( + "timezone must be a string".to_string(), + )), + }; + Ok(DataType::Timestamp(unit?, tz?)) + } + Some(s) if s == "date" => match map.get("unit") { + Some(p) if p == "DAY" => Ok(DataType::Date32), + Some(p) if p == "MILLISECOND" => Ok(DataType::Date64), + _ => Err(ArrowError::ParseError( + "date unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "time" => { + let unit = match map.get("unit") { + Some(p) if p == "SECOND" => Ok(TimeUnit::Second), + Some(p) if p == "MILLISECOND" => Ok(TimeUnit::Millisecond), + Some(p) if p == "MICROSECOND" => Ok(TimeUnit::Microsecond), + Some(p) if p == "NANOSECOND" => Ok(TimeUnit::Nanosecond), + _ => Err(ArrowError::ParseError( + "time unit missing or invalid".to_string(), + )), + }; + match map.get("bitWidth") { + Some(p) if p == 32 => Ok(DataType::Time32(unit?)), + Some(p) if p == 64 => Ok(DataType::Time64(unit?)), + _ => Err(ArrowError::ParseError( + "time bitWidth missing or invalid".to_string(), + )), + } + } + Some(s) if s == "duration" => match map.get("unit") { + Some(p) if p == "SECOND" => Ok(DataType::Duration(TimeUnit::Second)), + Some(p) if p == "MILLISECOND" => Ok(DataType::Duration(TimeUnit::Millisecond)), + Some(p) if p == "MICROSECOND" => Ok(DataType::Duration(TimeUnit::Microsecond)), + Some(p) if p == "NANOSECOND" => Ok(DataType::Duration(TimeUnit::Nanosecond)), + _ => Err(ArrowError::ParseError( + "time unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "interval" => match map.get("unit") { + Some(p) if p == "DAY_TIME" => Ok(DataType::Interval(IntervalUnit::DayTime)), + Some(p) if p == "YEAR_MONTH" => Ok(DataType::Interval(IntervalUnit::YearMonth)), + Some(p) if p == "MONTH_DAY_NANO" => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => Err(ArrowError::ParseError( + "interval unit missing or invalid".to_string(), + )), + }, + Some(s) if s == "int" => match map.get("isSigned") { + Some(&Value::Bool(true)) => match map.get("bitWidth") { + Some(Value::Number(n)) => match n.as_u64() { + Some(8) => Ok(DataType::Int8), + Some(16) => Ok(DataType::Int16), + Some(32) => Ok(DataType::Int32), + Some(64) => Ok(DataType::Int64), + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + Some(&Value::Bool(false)) => match map.get("bitWidth") { + Some(Value::Number(n)) => match n.as_u64() { + Some(8) => Ok(DataType::UInt8), + Some(16) => Ok(DataType::UInt16), + Some(32) => Ok(DataType::UInt32), + Some(64) => Ok(DataType::UInt64), + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int bitWidth missing or invalid".to_string(), + )), + }, + _ => Err(ArrowError::ParseError( + "int signed missing or invalid".to_string(), + )), + }, + Some(s) if s == "list" => { + // return a list with any type as its child isn't defined in the map + Ok(DataType::List(default_field)) + } + Some(s) if s == "largelist" => { + // return a largelist with any type as its child isn't defined in the map + Ok(DataType::LargeList(default_field)) + } + Some(s) if s == "fixedsizelist" => { + // return a list with any type as its child isn't defined in the map + if let Some(Value::Number(size)) = map.get("listSize") { + Ok(DataType::FixedSizeList( + default_field, + size.as_i64().unwrap() as i32, + )) + } else { + Err(ArrowError::ParseError( + "Expecting a listSize for fixedsizelist".to_string(), + )) + } + } + Some(s) if s == "struct" => { + // return an empty `struct` type as its children aren't defined in the map + Ok(DataType::Struct(Fields::empty())) + } + Some(s) if s == "map" => { + if let Some(Value::Bool(keys_sorted)) = map.get("keysSorted") { + // Return a map with an empty type as its children aren't defined in the map + Ok(DataType::Map(default_field, *keys_sorted)) + } else { + Err(ArrowError::ParseError( + "Expecting a keysSorted for map".to_string(), + )) + } + } + Some(s) if s == "union" => { + if let Some(Value::String(mode)) = map.get("mode") { + let union_mode = if mode == "SPARSE" { + UnionMode::Sparse + } else if mode == "DENSE" { + UnionMode::Dense + } else { + return Err(ArrowError::ParseError(format!( + "Unknown union mode {mode:?} for union" + ))); + }; + if let Some(values) = map.get("typeIds") { + let values = values.as_array().unwrap(); + let fields = values + .iter() + .map(|t| (t.as_i64().unwrap() as i8, default_field.clone())) + .collect(); + + Ok(DataType::Union(fields, union_mode)) + } else { + Err(ArrowError::ParseError( + "Expecting a typeIds for union ".to_string(), + )) + } + } else { + Err(ArrowError::ParseError( + "Expecting a mode for union".to_string(), + )) + } + } + Some(other) => Err(ArrowError::ParseError(format!( + "invalid or unsupported type name: {other} in {json:?}" + ))), + None => Err(ArrowError::ParseError("type name missing".to_string())), + }, + _ => Err(ArrowError::ParseError( + "invalid json value type".to_string(), + )), + } +} + +/// Generate a JSON representation of the data type. +pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value { + use serde_json::json; + match data_type { + DataType::Null => json!({"name": "null"}), + DataType::Boolean => json!({"name": "bool"}), + DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}), + DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}), + DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}), + DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}), + DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}), + DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}), + DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}), + DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}), + DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}), + DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}), + DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}), + DataType::Utf8 => json!({"name": "utf8"}), + DataType::LargeUtf8 => json!({"name": "largeutf8"}), + DataType::Binary => json!({"name": "binary"}), + DataType::LargeBinary => json!({"name": "largebinary"}), + DataType::FixedSizeBinary(byte_width) => { + json!({"name": "fixedsizebinary", "byteWidth": byte_width}) + } + DataType::Struct(_) => json!({"name": "struct"}), + DataType::Union(_, _) => json!({"name": "union"}), + DataType::List(_) => json!({ "name": "list"}), + DataType::LargeList(_) => json!({ "name": "largelist"}), + DataType::FixedSizeList(_, length) => { + json!({"name":"fixedsizelist", "listSize": length}) + } + DataType::Time32(unit) => { + json!({"name": "time", "bitWidth": 32, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Time64(unit) => { + json!({"name": "time", "bitWidth": 64, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Date32 => { + json!({"name": "date", "unit": "DAY"}) + } + DataType::Date64 => { + json!({"name": "date", "unit": "MILLISECOND"}) + } + DataType::Timestamp(unit, None) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Timestamp(unit, Some(tz)) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }, "timezone": tz}) + } + DataType::Interval(unit) => json!({"name": "interval", "unit": match unit { + IntervalUnit::YearMonth => "YEAR_MONTH", + IntervalUnit::DayTime => "DAY_TIME", + IntervalUnit::MonthDayNano => "MONTH_DAY_NANO", + }}), + DataType::Duration(unit) => json!({"name": "duration", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}), + DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), + DataType::Decimal128(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128}) + } + DataType::Decimal256(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256}) + } + DataType::Map(_, keys_sorted) => { + json!({"name": "map", "keysSorted": keys_sorted}) + } + DataType::RunEndEncoded(_, _) => todo!(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::Value; + + #[test] + fn parse_utf8_from_json() { + let json = "{\"name\":\"utf8\"}"; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = data_type_from_json(&value).unwrap(); + assert_eq!(DataType::Utf8, dt); + } + + #[test] + fn parse_int32_from_json() { + let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}"; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = data_type_from_json(&value).unwrap(); + assert_eq!(DataType::Int32, dt); + } +} diff --git a/arrow-integration-test/src/field.rs b/arrow-integration-test/src/field.rs new file mode 100644 index 000000000000..32edc4165938 --- /dev/null +++ b/arrow-integration-test/src/field.rs @@ -0,0 +1,568 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{data_type_from_json, data_type_to_json}; +use arrow::datatypes::{DataType, Field}; +use arrow::error::{ArrowError, Result}; +use std::collections::HashMap; +use std::sync::Arc; + +/// Parse a `Field` definition from a JSON representation. +pub fn field_from_json(json: &serde_json::Value) -> Result { + use serde_json::Value; + match *json { + Value::Object(ref map) => { + let name = match map.get("name") { + Some(Value::String(name)) => name.to_string(), + _ => { + return Err(ArrowError::ParseError( + "Field missing 'name' attribute".to_string(), + )); + } + }; + let nullable = match map.get("nullable") { + Some(&Value::Bool(b)) => b, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'nullable' attribute".to_string(), + )); + } + }; + let data_type = match map.get("type") { + Some(t) => data_type_from_json(t)?, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'type' attribute".to_string(), + )); + } + }; + + // Referenced example file: testing/data/arrow-ipc-stream/integration/1.0.0-littleendian/generated_custom_metadata.json.gz + let metadata = match map.get("metadata") { + Some(Value::Array(values)) => { + let mut res: HashMap = HashMap::default(); + for value in values { + match value.as_object() { + Some(map) => { + if map.len() != 2 { + return Err(ArrowError::ParseError( + "Field 'metadata' must have exact two entries for each key-value map".to_string(), + )); + } + if let (Some(k), Some(v)) = (map.get("key"), map.get("value")) { + if let (Some(k_str), Some(v_str)) = (k.as_str(), v.as_str()) { + res.insert( + k_str.to_string().clone(), + v_str.to_string().clone(), + ); + } else { + return Err(ArrowError::ParseError( + "Field 'metadata' must have map value of string type" + .to_string(), + )); + } + } else { + return Err(ArrowError::ParseError("Field 'metadata' lacks map keys named \"key\" or \"value\"".to_string())); + } + } + _ => { + return Err(ArrowError::ParseError( + "Field 'metadata' contains non-object key-value pair" + .to_string(), + )); + } + } + } + res + } + // We also support map format, because Schema's metadata supports this. + // See https://github.com/apache/arrow/pull/5907 + Some(Value::Object(values)) => { + let mut res: HashMap = HashMap::default(); + for (k, v) in values { + if let Some(str_value) = v.as_str() { + res.insert(k.clone(), str_value.to_string().clone()); + } else { + return Err(ArrowError::ParseError(format!( + "Field 'metadata' contains non-string value for key {k}" + ))); + } + } + res + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field `metadata` is not json array".to_string(), + )); + } + _ => HashMap::default(), + }; + + // if data_type is a struct or list, get its children + let data_type = match data_type { + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { + match map.get("children") { + Some(Value::Array(values)) => { + if values.len() != 1 { + return Err(ArrowError::ParseError( + "Field 'children' must have one element for a list data type" + .to_string(), + )); + } + match data_type { + DataType::List(_) => { + DataType::List(Arc::new(field_from_json(&values[0])?)) + } + DataType::LargeList(_) => { + DataType::LargeList(Arc::new(field_from_json(&values[0])?)) + } + DataType::FixedSizeList(_, int) => DataType::FixedSizeList( + Arc::new(field_from_json(&values[0])?), + int, + ), + _ => unreachable!( + "Data type should be a list, largelist or fixedsizelist" + ), + } + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + } + } + DataType::Struct(_) => match map.get("children") { + Some(Value::Array(values)) => { + DataType::Struct(values.iter().map(field_from_json).collect::>()?) + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + }, + DataType::Map(_, keys_sorted) => { + match map.get("children") { + Some(Value::Array(values)) if values.len() == 1 => { + let child = field_from_json(&values[0])?; + // child must be a struct + match child.data_type() { + DataType::Struct(map_fields) if map_fields.len() == 2 => { + DataType::Map(Arc::new(child), keys_sorted) + } + t => { + return Err(ArrowError::ParseError(format!( + "Map children should be a struct with 2 fields, found {t:?}" + ))) + } + } + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array with 1 element".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + } + } + DataType::Union(fields, mode) => match map.get("children") { + Some(Value::Array(values)) => { + let fields = fields + .iter() + .zip(values) + .map(|((id, _), value)| Ok((id, Arc::new(field_from_json(value)?)))) + .collect::>()?; + + DataType::Union(fields, mode) + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + }, + _ => data_type, + }; + + let mut dict_id = 0; + let mut dict_is_ordered = false; + + let data_type = match map.get("dictionary") { + Some(dictionary) => { + let index_type = match dictionary.get("indexType") { + Some(t) => data_type_from_json(t)?, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'indexType' attribute".to_string(), + )); + } + }; + dict_id = match dictionary.get("id") { + Some(Value::Number(n)) => n.as_i64().unwrap(), + _ => { + return Err(ArrowError::ParseError( + "Field missing 'id' attribute".to_string(), + )); + } + }; + dict_is_ordered = match dictionary.get("isOrdered") { + Some(&Value::Bool(n)) => n, + _ => { + return Err(ArrowError::ParseError( + "Field missing 'isOrdered' attribute".to_string(), + )); + } + }; + DataType::Dictionary(Box::new(index_type), Box::new(data_type)) + } + _ => data_type, + }; + + let mut field = Field::new_dict(name, data_type, nullable, dict_id, dict_is_ordered); + field.set_metadata(metadata); + Ok(field) + } + _ => Err(ArrowError::ParseError( + "Invalid json value type for field".to_string(), + )), + } +} + +/// Generate a JSON representation of the `Field`. +pub fn field_to_json(field: &Field) -> serde_json::Value { + let children: Vec = match field.data_type() { + DataType::Struct(fields) => fields.iter().map(|x| field_to_json(x.as_ref())).collect(), + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) + | DataType::Map(field, _) => vec![field_to_json(field)], + _ => vec![], + }; + + match field.data_type() { + DataType::Dictionary(ref index_type, ref value_type) => serde_json::json!({ + "name": field.name(), + "nullable": field.is_nullable(), + "type": data_type_to_json(value_type), + "children": children, + "dictionary": { + "id": field.dict_id().unwrap(), + "indexType": data_type_to_json(index_type), + "isOrdered": field.dict_is_ordered().unwrap(), + } + }), + _ => serde_json::json!({ + "name": field.name(), + "nullable": field.is_nullable(), + "type": data_type_to_json(field.data_type()), + "children": children + }), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::UnionMode; + use serde_json::Value; + + #[test] + fn struct_field_to_json() { + let f = Field::new_struct( + "address", + vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ], + false, + ); + let value: Value = serde_json::from_str( + r#"{ + "name": "address", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "street", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "zip", + "nullable": false, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + }"#, + ) + .unwrap(); + assert_eq!(value, field_to_json(&f)); + } + + #[test] + fn map_field_to_json() { + let f = Field::new_map( + "my_map", + "my_entries", + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + true, + false, + ); + let value: Value = serde_json::from_str( + r#"{ + "name": "my_map", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + }"#, + ) + .unwrap(); + assert_eq!(value, field_to_json(&f)); + } + + #[test] + fn primitive_field_to_json() { + let f = Field::new("first_name", DataType::Utf8, false); + let value: Value = serde_json::from_str( + r#"{ + "name": "first_name", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }"#, + ) + .unwrap(); + assert_eq!(value, field_to_json(&f)); + } + #[test] + fn parse_struct_from_json() { + let json = r#" + { + "name": "address", + "type": { + "name": "struct" + }, + "nullable": false, + "children": [ + { + "name": "street", + "type": { + "name": "utf8" + }, + "nullable": false, + "children": [] + }, + { + "name": "zip", + "type": { + "name": "int", + "isSigned": false, + "bitWidth": 16 + }, + "nullable": false, + "children": [] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = field_from_json(&value).unwrap(); + + let expected = Field::new_struct( + "address", + vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ], + false, + ); + + assert_eq!(expected, dt); + } + + #[test] + fn parse_map_from_json() { + let json = r#" + { + "name": "my_map", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = field_from_json(&value).unwrap(); + + let expected = Field::new_map( + "my_map", + "my_entries", + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + true, + false, + ); + + assert_eq!(expected, dt); + } + + #[test] + fn parse_union_from_json() { + let json = r#" + { + "name": "my_union", + "nullable": false, + "type": { + "name": "union", + "mode": "SPARSE", + "typeIds": [ + 5, + 7 + ] + }, + "children": [ + { + "name": "f1", + "type": { + "name": "int", + "isSigned": true, + "bitWidth": 32 + }, + "nullable": true, + "children": [] + }, + { + "name": "f2", + "type": { + "name": "utf8" + }, + "nullable": true, + "children": [] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = field_from_json(&value).unwrap(); + + let expected = Field::new_union( + "my_union", + vec![5, 7], + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + UnionMode::Sparse, + ); + + assert_eq!(expected, dt); + } +} diff --git a/integration-testing/src/util.rs b/arrow-integration-test/src/lib.rs similarity index 77% rename from integration-testing/src/util.rs rename to arrow-integration-test/src/lib.rs index e098c4e1491a..7b797aa07061 100644 --- a/integration-testing/src/util.rs +++ b/arrow-integration-test/src/lib.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -//! Utils for JSON integration testing +//! Support for the [Apache Arrow JSON test data format](https://github.com/apache/arrow/blob/master/docs/source/format/Integration.rst#json-test-data-format) //! //! These utilities define structs that read the integration JSON format for integration testing purposes. +//! +//! This is not a canonical format, but provides a human-readable way of verifying language implementations use hex::decode; use num::BigInt; @@ -34,9 +36,19 @@ use arrow::datatypes::*; use arrow::error::{ArrowError, Result}; use arrow::record_batch::{RecordBatch, RecordBatchReader}; use arrow::util::bit_util; -use arrow::util::decimal::Decimal256; +use arrow_buffer::i256; + +mod datatype; +mod field; +mod schema; + +pub use datatype::*; +pub use field::*; +pub use schema::*; /// A struct that represents an Arrow file with a schema and record batches +/// +/// See #[derive(Deserialize, Serialize, Debug)] pub struct ArrowJson { pub schema: ArrowJsonSchema, @@ -69,12 +81,18 @@ pub struct ArrowJsonField { pub metadata: Option, } +impl From<&FieldRef> for ArrowJsonField { + fn from(value: &FieldRef) -> Self { + Self::from(value.as_ref()) + } +} + impl From<&Field> for ArrowJsonField { fn from(field: &Field) -> Self { - let metadata_value = match field.metadata() { - Some(kv_list) => { + let metadata_value = match field.metadata().is_empty() { + false => { let mut array = Vec::new(); - for (k, v) in kv_list { + for (k, v) in field.metadata() { let mut kv_map = SJMap::new(); kv_map.insert(k.clone(), Value::String(v.clone())); array.push(Value::Object(kv_map)); @@ -90,7 +108,7 @@ impl From<&Field> for ArrowJsonField { Self { name: field.name().to_string(), - field_type: field.data_type().to_json(), + field_type: data_type_to_json(field.data_type()), nullable: field.is_nullable(), children: vec![], dictionary: None, // TODO: not enough info @@ -160,12 +178,13 @@ impl ArrowJson { match batch { Some(Ok(batch)) => { if json_batch != batch { - println!("json: {:?}", json_batch); - println!("batch: {:?}", batch); + println!("json: {json_batch:?}"); + println!("batch: {batch:?}"); return Ok(false); } } - _ => return Ok(false), + Some(Err(e)) => return Err(e), + None => return Ok(false), } } @@ -242,10 +261,7 @@ impl ArrowJsonField { true } Err(e) => { - eprintln!( - "Encountered error while converting JSON field to Arrow field: {:?}", - e - ); + eprintln!("Encountered error while converting JSON field to Arrow field: {e:?}"); false } } @@ -255,8 +271,9 @@ impl ArrowJsonField { /// TODO: convert to use an Into fn to_arrow_field(&self) -> Result { // a bit regressive, but we have to convert the field to JSON in order to convert it - let field = serde_json::to_value(self)?; - Field::from(&field) + let field = + serde_json::to_value(self).map_err(|error| ArrowError::JsonError(error.to_string()))?; + field_from_json(&field) } } @@ -310,10 +327,7 @@ pub fn array_from_json( { match is_valid { 1 => b.append_value(value.as_i64().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to get {:?} as int64", - value - )) + ArrowError::JsonError(format!("Unable to get {value:?} as int64")) })? as i8), _ => b.append_null(), }; @@ -373,12 +387,9 @@ pub fn array_from_json( match is_valid { 1 => b.append_value(match value { Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } + Value::String(s) => s.parse().expect("Unable to parse string as i64"), Value::Object(ref map) - if map.contains_key("days") - && map.contains_key("milliseconds") => + if map.contains_key("days") && map.contains_key("milliseconds") => { match field.data_type() { DataType::Interval(IntervalUnit::DayTime) => { @@ -388,28 +399,22 @@ pub fn array_from_json( match (days, milliseconds) { (Value::Number(d), Value::Number(m)) => { let mut bytes = [0_u8; 8]; - let m = (m.as_i64().unwrap() as i32) - .to_le_bytes(); - let d = (d.as_i64().unwrap() as i32) - .to_le_bytes(); + let m = (m.as_i64().unwrap() as i32).to_le_bytes(); + let d = (d.as_i64().unwrap() as i32).to_le_bytes(); let c = [d, m].concat(); bytes.copy_from_slice(c.as_slice()); i64::from_le_bytes(bytes) } - _ => panic!( - "Unable to parse {:?} as interval daytime", - value - ), + _ => { + panic!("Unable to parse {value:?} as interval daytime") + } } } - _ => panic!( - "Unable to parse {:?} as interval daytime", - value - ), + _ => panic!("Unable to parse {value:?} as interval daytime"), } } - _ => panic!("Unable to parse {:?} as number", value), + _ => panic!("Unable to parse {value:?} as number"), }), _ => b.append_null(), }; @@ -485,11 +490,9 @@ pub fn array_from_json( .expect("Unable to parse string as u64"), ) } else if value.is_number() { - b.append_value( - value.as_u64().expect("Unable to read number as u64"), - ) + b.append_value(value.as_u64().expect("Unable to read number as u64")) } else { - panic!("Unable to parse value {:?} as u64", value) + panic!("Unable to parse value {value:?} as u64") } } _ => b.append_null(), @@ -521,19 +524,18 @@ pub fn array_from_json( let months = months.as_i64().unwrap() as i32; let days = days.as_i64().unwrap() as i32; let nanoseconds = nanoseconds.as_i64().unwrap(); - let months_days_ns: i128 = ((nanoseconds as i128) - & 0xFFFFFFFFFFFFFFFF) - << 64 - | ((days as i128) & 0xFFFFFFFF) << 32 - | ((months as i128) & 0xFFFFFFFF); + let months_days_ns: i128 = + ((nanoseconds as i128) & 0xFFFFFFFFFFFFFFFF) << 64 + | ((days as i128) & 0xFFFFFFFF) << 32 + | ((months as i128) & 0xFFFFFFFF); months_days_ns } (_, _, _) => { - panic!("Unable to parse {:?} as MonthDayNano", v) + panic!("Unable to parse {v:?} as MonthDayNano") } } } - _ => panic!("Unable to parse {:?} as MonthDayNano", value), + _ => panic!("Unable to parse {value:?} as MonthDayNano"), }), _ => b.append_null(), }; @@ -664,11 +666,8 @@ pub fn array_from_json( DataType::List(child_field) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -688,11 +687,8 @@ pub fn array_from_json( DataType::LargeList(child_field) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -715,11 +711,8 @@ pub fn array_from_json( } DataType::FixedSizeList(child_field, _) => { let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let null_buf = create_null_buf(&json_col); let list_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) @@ -746,17 +739,13 @@ pub fn array_from_json( } DataType::Dictionary(key_type, value_type) => { let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {:?}", - field - )) + ArrowError::JsonError(format!("Unable to find dict_id for field {field:?}")) })?; // find dictionary let dictionary = dictionaries .ok_or_else(|| { ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {:?}", - field + "Unable to find any dictionaries for field {field:?}" )) })? .get(&dict_id); @@ -770,18 +759,12 @@ pub fn array_from_json( dictionaries, ), None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {:?}", - field + "Unable to find dictionary for field {field:?}" ))), } } DataType::Decimal128(precision, scale) => { - let mut b = - Decimal128Builder::with_capacity(json_col.count, *precision, *scale); - // C++ interop tests involve incompatible decimal values - unsafe { - b.disable_value_validation(); - } + let mut b = Decimal128Builder::with_capacity(json_col.count); for (is_valid, value) in json_col .validity .as_ref() @@ -790,21 +773,16 @@ pub fn array_from_json( .zip(json_col.data.unwrap()) { match is_valid { - 1 => { - b.append_value(value.as_str().unwrap().parse::().unwrap())? - } + 1 => b.append_value(value.as_str().unwrap().parse::().unwrap()), _ => b.append_null(), }; } - Ok(Arc::new(b.finish())) + Ok(Arc::new( + b.finish().with_precision_and_scale(*precision, *scale)?, + )) } DataType::Decimal256(precision, scale) => { - let mut b = - Decimal256Builder::with_capacity(json_col.count, *precision, *scale); - // C++ interop tests involve incompatible decimal values - unsafe { - b.disable_value_validation(); - } + let mut b = Decimal256Builder::with_capacity(json_col.count); for (is_valid, value) in json_col .validity .as_ref() @@ -822,26 +800,21 @@ pub fn array_from_json( } else { [255_u8; 32] }; - bytes[0..integer_bytes.len()] - .copy_from_slice(integer_bytes.as_slice()); - let decimal = - Decimal256::try_new_from_bytes(*precision, *scale, &bytes) - .unwrap(); - b.append_value(&decimal)?; + bytes[0..integer_bytes.len()].copy_from_slice(integer_bytes.as_slice()); + b.append_value(i256::from_le_bytes(bytes)); } _ => b.append_null(), } } - Ok(Arc::new(b.finish())) + Ok(Arc::new( + b.finish().with_precision_and_scale(*precision, *scale)?, + )) } DataType::Map(child_field, _) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; + let child_array = + array_from_json(child_field, children.get(0).unwrap().clone(), dictionaries)?; let offsets: Vec = json_col .offset .unwrap() @@ -859,7 +832,7 @@ pub fn array_from_json( let array = MapArray::from(array_data); Ok(Arc::new(array)) } - DataType::Union(fields, field_type_ids, _) => { + DataType::Union(fields, _) => { let type_ids = if let Some(type_id) = json_col.type_id { type_id } else { @@ -875,13 +848,14 @@ pub fn array_from_json( }); let mut children: Vec<(Field, Arc)> = vec![]; - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + for ((_, field), col) in fields.iter().zip(json_col.children.unwrap()) { let array = array_from_json(field, col, dictionaries)?; - children.push((field.clone(), array)); + children.push((field.as_ref().clone(), array)); } + let field_type_ids = fields.iter().map(|(id, _)| id).collect::>(); let array = UnionArray::try_new( - field_type_ids, + &field_type_ids, Buffer::from(&type_ids.to_byte_slice()), offset, children, @@ -890,8 +864,7 @@ pub fn array_from_json( Ok(Arc::new(array)) } t => Err(ArrowError::JsonError(format!( - "data type {:?} not supported", - t + "data type {t:?} not supported" ))), } } @@ -939,16 +912,14 @@ pub fn dictionary_array_from_json( // convert key and value to dictionary data let dict_data = ArrayData::builder(field.data_type().clone()) .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) + .add_buffer(keys.to_data().buffers()[0].clone()) .null_bit_buffer(Some(null_buf)) .add_child_data(values.into_data()) .build() .unwrap(); let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } + DataType::Int8 => Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef, DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), @@ -961,13 +932,12 @@ pub fn dictionary_array_from_json( Ok(array) } _ => Err(ArrowError::JsonError(format!( - "Dictionary key type {:?} not supported", - dict_key + "Dictionary key type {dict_key:?} not supported" ))), } } -/// A helper to create a null buffer from a Vec +/// A helper to create a null buffer from a `Vec` fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { let num_bytes = bit_util::ceil(json_col.count, 8); let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); @@ -1100,11 +1070,7 @@ mod tests { Field::new("c3", DataType::Utf8, true), Field::new( "c4", - DataType::List(Box::new(Field::new( - "custom_item", - DataType::Int32, - false, - ))), + DataType::List(Arc::new(Field::new("custom_item", DataType::Int32, false))), true, ), ]); @@ -1113,100 +1079,95 @@ mod tests { #[test] fn test_arrow_data_equality() { - let secs_tz = Some("Europe/Budapest".to_string()); - let millis_tz = Some("America/New_York".to_string()); - let micros_tz = Some("UTC".to_string()); - let nanos_tz = Some("Africa/Johannesburg".to_string()); - - let schema = - Schema::new(vec![ - Field::new("bools-with-metadata-map", DataType::Boolean, true) - .with_metadata(Some( - [("k".to_string(), "v".to_string())] - .iter() - .cloned() - .collect(), - )), - Field::new("bools-with-metadata-vec", DataType::Boolean, true) - .with_metadata(Some( - [("k2".to_string(), "v2".to_string())] - .iter() - .cloned() - .collect(), - )), - Field::new("bools", DataType::Boolean, true), - Field::new("int8s", DataType::Int8, true), - Field::new("int16s", DataType::Int16, true), - Field::new("int32s", DataType::Int32, true), - Field::new("int64s", DataType::Int64, true), - Field::new("uint8s", DataType::UInt8, true), - Field::new("uint16s", DataType::UInt16, true), - Field::new("uint32s", DataType::UInt32, true), - Field::new("uint64s", DataType::UInt64, true), - Field::new("float32s", DataType::Float32, true), - Field::new("float64s", DataType::Float64, true), - Field::new("date_days", DataType::Date32, true), - Field::new("date_millis", DataType::Date64, true), - Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), - Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), - Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), - Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), - Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), - Field::new( - "ts_millis", - DataType::Timestamp(TimeUnit::Millisecond, None), - true, - ), - Field::new( - "ts_micros", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - Field::new( - "ts_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - Field::new( - "ts_secs_tz", - DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), - true, - ), - Field::new( - "ts_millis_tz", - DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), - true, - ), - Field::new( - "ts_micros_tz", - DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), - true, - ), - Field::new( - "ts_nanos_tz", - DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), - true, - ), - Field::new("utf8s", DataType::Utf8, true), - Field::new( - "lists", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - ), - Field::new( - "structs", - DataType::Struct(vec![ - Field::new("int32s", DataType::Int32, true), - Field::new("utf8s", DataType::Utf8, true), - ]), - true, - ), - ]); + let secs_tz = Some("Europe/Budapest".into()); + let millis_tz = Some("America/New_York".into()); + let micros_tz = Some("UTC".into()); + let nanos_tz = Some("Africa/Johannesburg".into()); - let bools_with_metadata_map = - BooleanArray::from(vec![Some(true), None, Some(false)]); - let bools_with_metadata_vec = - BooleanArray::from(vec![Some(true), None, Some(false)]); + let schema = Schema::new(vec![ + Field::new("bools-with-metadata-map", DataType::Boolean, true).with_metadata( + [("k".to_string(), "v".to_string())] + .iter() + .cloned() + .collect(), + ), + Field::new("bools-with-metadata-vec", DataType::Boolean, true).with_metadata( + [("k2".to_string(), "v2".to_string())] + .iter() + .cloned() + .collect(), + ), + Field::new("bools", DataType::Boolean, true), + Field::new("int8s", DataType::Int8, true), + Field::new("int16s", DataType::Int16, true), + Field::new("int32s", DataType::Int32, true), + Field::new("int64s", DataType::Int64, true), + Field::new("uint8s", DataType::UInt8, true), + Field::new("uint16s", DataType::UInt16, true), + Field::new("uint32s", DataType::UInt32, true), + Field::new("uint64s", DataType::UInt64, true), + Field::new("float32s", DataType::Float32, true), + Field::new("float64s", DataType::Float64, true), + Field::new("date_days", DataType::Date32, true), + Field::new("date_millis", DataType::Date64, true), + Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), + Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), + Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), + Field::new( + "ts_millis", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micros", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "ts_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "ts_secs_tz", + DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), + true, + ), + Field::new( + "ts_millis_tz", + DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), + true, + ), + Field::new( + "ts_micros_tz", + DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), + true, + ), + Field::new( + "ts_nanos_tz", + DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), + true, + ), + Field::new("utf8s", DataType::Utf8, true), + Field::new( + "lists", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "structs", + DataType::Struct(Fields::from(vec![ + Field::new("int32s", DataType::Int32, true), + Field::new("utf8s", DataType::Utf8, true), + ])), + true, + ), + ]); + + let bools_with_metadata_map = BooleanArray::from(vec![Some(true), None, Some(false)]); + let bools_with_metadata_vec = BooleanArray::from(vec![Some(true), None, Some(false)]); let bools = BooleanArray::from(vec![Some(true), None, Some(false)]); let int8s = Int8Array::from(vec![Some(1), None, Some(3)]); let int16s = Int16Array::from(vec![Some(1), None, Some(3)]); @@ -1224,54 +1185,32 @@ mod tests { Some(29923997007884), Some(30612271819236), ]); - let time_secs = - Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); - let time_millis = Time32MillisecondArray::from(vec![ - Some(6613125), - Some(74667230), - Some(52260079), - ]); - let time_micros = - Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); - let time_nanos = Time64NanosecondArray::from(vec![ - Some(73380123595985), - None, - Some(16584393546415), - ]); - let ts_secs = TimestampSecondArray::from_opt_vec( - vec![None, Some(193438817552), None], - None, - ); - let ts_millis = TimestampMillisecondArray::from_opt_vec( - vec![None, Some(38606916383008), Some(58113709376587)], - None, - ); - let ts_micros = - TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], None); - let ts_nanos = TimestampNanosecondArray::from_opt_vec( - vec![None, None, Some(-6473623571954960143)], - None, - ); - let ts_secs_tz = TimestampSecondArray::from_opt_vec( - vec![None, Some(193438817552), None], - secs_tz, - ); - let ts_millis_tz = TimestampMillisecondArray::from_opt_vec( - vec![None, Some(38606916383008), Some(58113709376587)], - millis_tz, - ); + let time_secs = Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); + let time_millis = + Time32MillisecondArray::from(vec![Some(6613125), Some(74667230), Some(52260079)]); + let time_micros = Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); + let time_nanos = + Time64NanosecondArray::from(vec![Some(73380123595985), None, Some(16584393546415)]); + let ts_secs = TimestampSecondArray::from(vec![None, Some(193438817552), None]); + let ts_millis = + TimestampMillisecondArray::from(vec![None, Some(38606916383008), Some(58113709376587)]); + let ts_micros = TimestampMicrosecondArray::from(vec![None, None, None]); + let ts_nanos = TimestampNanosecondArray::from(vec![None, None, Some(-6473623571954960143)]); + let ts_secs_tz = TimestampSecondArray::from(vec![None, Some(193438817552), None]) + .with_timezone_opt(secs_tz); + let ts_millis_tz = + TimestampMillisecondArray::from(vec![None, Some(38606916383008), Some(58113709376587)]) + .with_timezone_opt(millis_tz); let ts_micros_tz = - TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], micros_tz); - let ts_nanos_tz = TimestampNanosecondArray::from_opt_vec( - vec![None, None, Some(-6473623571954960143)], - nanos_tz, - ); + TimestampMicrosecondArray::from(vec![None, None, None]).with_timezone_opt(micros_tz); + let ts_nanos_tz = + TimestampNanosecondArray::from(vec![None, None, Some(-6473623571954960143)]) + .with_timezone_opt(nanos_tz); let utf8s = StringArray::from(vec![Some("aa"), None, Some("bbb")]); let value_data = Int32Array::from(vec![None, Some(2), None, None]); - let value_offsets = Buffer::from_slice_ref(&[0, 3, 4, 4]); - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let value_offsets = Buffer::from_slice_ref([0, 3, 4, 4]); + let list_data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); let list_data = ArrayData::builder(list_data_type) .len(3) .add_buffer(value_offsets) @@ -1283,14 +1222,14 @@ mod tests { let structs_int32s = Int32Array::from(vec![None, Some(-2), None]); let structs_utf8s = StringArray::from(vec![None, None, Some("aaaaaa")]); - let struct_data_type = DataType::Struct(vec![ + let struct_data_type = DataType::Struct(Fields::from(vec![ Field::new("int32s", DataType::Int32, true), Field::new("utf8s", DataType::Utf8, true), - ]); + ])); let struct_data = ArrayData::builder(struct_data_type) .len(3) - .add_child_data(structs_int32s.data().clone()) - .add_child_data(structs_utf8s.data().clone()) + .add_child_data(structs_int32s.into_data()) + .add_child_data(structs_utf8s.into_data()) .null_bit_buffer(Some(Buffer::from([0b00000011]))) .build() .unwrap(); diff --git a/arrow-integration-test/src/schema.rs b/arrow-integration-test/src/schema.rs new file mode 100644 index 000000000000..b5f6c5e86b38 --- /dev/null +++ b/arrow-integration-test/src/schema.rs @@ -0,0 +1,728 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{field_from_json, field_to_json}; +use arrow::datatypes::{Fields, Schema}; +use arrow::error::{ArrowError, Result}; +use std::collections::HashMap; + +/// Generate a JSON representation of the `Schema`. +pub fn schema_to_json(schema: &Schema) -> serde_json::Value { + serde_json::json!({ + "fields": schema.fields().iter().map(|f| field_to_json(f.as_ref())).collect::>(), + "metadata": serde_json::to_value(schema.metadata()).unwrap() + }) +} + +/// Parse a `Schema` definition from a JSON representation. +pub fn schema_from_json(json: &serde_json::Value) -> Result { + use serde_json::Value; + match *json { + Value::Object(ref schema) => { + let fields: Fields = match schema.get("fields") { + Some(Value::Array(fields)) => { + fields.iter().map(field_from_json).collect::>()? + } + _ => { + return Err(ArrowError::ParseError( + "Schema fields should be an array".to_string(), + )) + } + }; + + let metadata = if let Some(value) = schema.get("metadata") { + from_metadata(value)? + } else { + HashMap::default() + }; + + Ok(Schema::new_with_metadata(fields, metadata)) + } + _ => Err(ArrowError::ParseError( + "Invalid json value type for schema".to_string(), + )), + } +} + +/// Parse a `metadata` definition from a JSON representation. +/// The JSON can either be an Object or an Array of Objects. +fn from_metadata(json: &serde_json::Value) -> Result> { + use serde_json::Value; + match json { + Value::Array(_) => { + let mut hashmap = HashMap::new(); + let values: Vec = + serde_json::from_value(json.clone()).map_err(|_| { + ArrowError::JsonError("Unable to parse object into key-value pair".to_string()) + })?; + for meta in values { + hashmap.insert(meta.key.clone(), meta.value); + } + Ok(hashmap) + } + Value::Object(md) => md + .iter() + .map(|(k, v)| { + if let Value::String(v) = v { + Ok((k.to_string(), v.to_string())) + } else { + Err(ArrowError::ParseError( + "metadata `value` field must be a string".to_string(), + )) + } + }) + .collect::>(), + _ => Err(ArrowError::ParseError( + "`metadata` field must be an object".to_string(), + )), + } +} + +#[derive(serde::Deserialize)] +struct MetadataKeyValue { + key: String, + value: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; + use serde_json::Value; + use std::sync::Arc; + + #[test] + fn schema_json() { + // Add some custom metadata + let metadata: HashMap = [("Key".to_string(), "Value".to_string())] + .iter() + .cloned() + .collect(); + + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Binary, false), + Field::new("c3", DataType::FixedSizeBinary(3), false), + Field::new("c4", DataType::Boolean, false), + Field::new("c5", DataType::Date32, false), + Field::new("c6", DataType::Date64, false), + Field::new("c7", DataType::Time32(TimeUnit::Second), false), + Field::new("c8", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("c9", DataType::Time32(TimeUnit::Microsecond), false), + Field::new("c10", DataType::Time32(TimeUnit::Nanosecond), false), + Field::new("c11", DataType::Time64(TimeUnit::Second), false), + Field::new("c12", DataType::Time64(TimeUnit::Millisecond), false), + Field::new("c13", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("c14", DataType::Time64(TimeUnit::Nanosecond), false), + Field::new("c15", DataType::Timestamp(TimeUnit::Second, None), false), + Field::new( + "c16", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + false, + ), + Field::new( + "c17", + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), + false, + ), + Field::new( + "c18", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("c19", DataType::Interval(IntervalUnit::DayTime), false), + Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), + Field::new("c21", DataType::Interval(IntervalUnit::MonthDayNano), false), + Field::new( + "c22", + DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))), + false, + ), + Field::new( + "c23", + DataType::FixedSizeList( + Arc::new(Field::new("bools", DataType::Boolean, false)), + 5, + ), + false, + ), + Field::new( + "c24", + DataType::List(Arc::new(Field::new( + "inner_list", + DataType::List(Arc::new(Field::new( + "struct", + DataType::Struct(Fields::empty()), + true, + ))), + false, + ))), + true, + ), + Field::new( + "c25", + DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::UInt16, false), + ])), + false, + ), + Field::new("c26", DataType::Interval(IntervalUnit::YearMonth), true), + Field::new("c27", DataType::Interval(IntervalUnit::DayTime), true), + Field::new("c28", DataType::Interval(IntervalUnit::MonthDayNano), true), + Field::new("c29", DataType::Duration(TimeUnit::Second), false), + Field::new("c30", DataType::Duration(TimeUnit::Millisecond), false), + Field::new("c31", DataType::Duration(TimeUnit::Microsecond), false), + Field::new("c32", DataType::Duration(TimeUnit::Nanosecond), false), + Field::new_dict( + "c33", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 123, + true, + ), + Field::new("c34", DataType::LargeBinary, true), + Field::new("c35", DataType::LargeUtf8, true), + Field::new( + "c36", + DataType::LargeList(Arc::new(Field::new( + "inner_large_list", + DataType::LargeList(Arc::new(Field::new( + "struct", + DataType::Struct(Fields::empty()), + false, + ))), + true, + ))), + true, + ), + Field::new( + "c37", + DataType::Map( + Arc::new(Field::new( + "my_entries", + DataType::Struct(Fields::from(vec![ + Field::new("my_keys", DataType::Utf8, false), + Field::new("my_values", DataType::UInt16, true), + ])), + false, + )), + true, + ), + false, + ), + ], + metadata, + ); + + let expected = schema_to_json(&schema); + let json = r#"{ + "fields": [ + { + "name": "c1", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "c2", + "nullable": false, + "type": { + "name": "binary" + }, + "children": [] + }, + { + "name": "c3", + "nullable": false, + "type": { + "name": "fixedsizebinary", + "byteWidth": 3 + }, + "children": [] + }, + { + "name": "c4", + "nullable": false, + "type": { + "name": "bool" + }, + "children": [] + }, + { + "name": "c5", + "nullable": false, + "type": { + "name": "date", + "unit": "DAY" + }, + "children": [] + }, + { + "name": "c6", + "nullable": false, + "type": { + "name": "date", + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c7", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c8", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c9", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "MICROSECOND" + }, + "children": [] + }, + { + "name": "c10", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 32, + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c11", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c12", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c13", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "MICROSECOND" + }, + "children": [] + }, + { + "name": "c14", + "nullable": false, + "type": { + "name": "time", + "bitWidth": 64, + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c15", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c16", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "MILLISECOND", + "timezone": "UTC" + }, + "children": [] + }, + { + "name": "c17", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "MICROSECOND", + "timezone": "Africa/Johannesburg" + }, + "children": [] + }, + { + "name": "c18", + "nullable": false, + "type": { + "name": "timestamp", + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c19", + "nullable": false, + "type": { + "name": "interval", + "unit": "DAY_TIME" + }, + "children": [] + }, + { + "name": "c20", + "nullable": false, + "type": { + "name": "interval", + "unit": "YEAR_MONTH" + }, + "children": [] + }, + { + "name": "c21", + "nullable": false, + "type": { + "name": "interval", + "unit": "MONTH_DAY_NANO" + }, + "children": [] + }, + { + "name": "c22", + "nullable": false, + "type": { + "name": "list" + }, + "children": [ + { + "name": "item", + "nullable": true, + "type": { + "name": "bool" + }, + "children": [] + } + ] + }, + { + "name": "c23", + "nullable": false, + "type": { + "name": "fixedsizelist", + "listSize": 5 + }, + "children": [ + { + "name": "bools", + "nullable": false, + "type": { + "name": "bool" + }, + "children": [] + } + ] + }, + { + "name": "c24", + "nullable": true, + "type": { + "name": "list" + }, + "children": [ + { + "name": "inner_list", + "nullable": false, + "type": { + "name": "list" + }, + "children": [ + { + "name": "struct", + "nullable": true, + "type": { + "name": "struct" + }, + "children": [] + } + ] + } + ] + }, + { + "name": "c25", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "a", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "b", + "nullable": false, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + }, + { + "name": "c26", + "nullable": true, + "type": { + "name": "interval", + "unit": "YEAR_MONTH" + }, + "children": [] + }, + { + "name": "c27", + "nullable": true, + "type": { + "name": "interval", + "unit": "DAY_TIME" + }, + "children": [] + }, + { + "name": "c28", + "nullable": true, + "type": { + "name": "interval", + "unit": "MONTH_DAY_NANO" + }, + "children": [] + }, + { + "name": "c29", + "nullable": false, + "type": { + "name": "duration", + "unit": "SECOND" + }, + "children": [] + }, + { + "name": "c30", + "nullable": false, + "type": { + "name": "duration", + "unit": "MILLISECOND" + }, + "children": [] + }, + { + "name": "c31", + "nullable": false, + "type": { + "name": "duration", + "unit": "MICROSECOND" + }, + "children": [] + }, + { + "name": "c32", + "nullable": false, + "type": { + "name": "duration", + "unit": "NANOSECOND" + }, + "children": [] + }, + { + "name": "c33", + "nullable": true, + "children": [], + "type": { + "name": "utf8" + }, + "dictionary": { + "id": 123, + "indexType": { + "name": "int", + "bitWidth": 32, + "isSigned": true + }, + "isOrdered": true + } + }, + { + "name": "c34", + "nullable": true, + "type": { + "name": "largebinary" + }, + "children": [] + }, + { + "name": "c35", + "nullable": true, + "type": { + "name": "largeutf8" + }, + "children": [] + }, + { + "name": "c36", + "nullable": true, + "type": { + "name": "largelist" + }, + "children": [ + { + "name": "inner_large_list", + "nullable": true, + "type": { + "name": "largelist" + }, + "children": [ + { + "name": "struct", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [] + } + ] + } + ] + }, + { + "name": "c37", + "nullable": false, + "type": { + "name": "map", + "keysSorted": true + }, + "children": [ + { + "name": "my_entries", + "nullable": false, + "type": { + "name": "struct" + }, + "children": [ + { + "name": "my_keys", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + }, + { + "name": "my_values", + "nullable": true, + "type": { + "name": "int", + "bitWidth": 16, + "isSigned": false + }, + "children": [] + } + ] + } + ] + } + ], + "metadata" : { + "Key": "Value" + } + }"#; + let value: Value = serde_json::from_str(json).unwrap(); + assert_eq!(expected, value); + + // convert back to a schema + let value: Value = serde_json::from_str(json).unwrap(); + let schema2 = schema_from_json(&value).unwrap(); + + assert_eq!(schema, schema2); + + // Check that empty metadata produces empty value in JSON and can be parsed + let json = r#"{ + "fields": [ + { + "name": "c1", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + } + ], + "metadata": {} + }"#; + let value: Value = serde_json::from_str(json).unwrap(); + let schema = schema_from_json(&value).unwrap(); + assert!(schema.metadata.is_empty()); + + // Check that metadata field is not required in the JSON. + let json = r#"{ + "fields": [ + { + "name": "c1", + "nullable": false, + "type": { + "name": "utf8" + }, + "children": [] + } + ] + }"#; + let value: Value = serde_json::from_str(json).unwrap(); + let schema = schema_from_json(&value).unwrap(); + assert!(schema.metadata.is_empty()); + } +} diff --git a/integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml similarity index 69% rename from integration-testing/Cargo.toml rename to arrow-integration-testing/Cargo.toml index b9f6cf81855e..c29860f09d64 100644 --- a/integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -17,31 +17,36 @@ [package] name = "arrow-integration-testing" -description = "Binaries used in the Arrow integration tests" -version = "22.0.0" -homepage = "https://github.com/apache/arrow-rs" -repository = "https://github.com/apache/arrow-rs" -authors = ["Apache Arrow "] -license = "Apache-2.0" -edition = "2021" +description = "Binaries used in the Arrow integration tests (NOT PUBLISHED TO crates.io)" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +edition = { workspace = true } publish = false -rust-version = "1.62" +rust-version = { workspace = true } + +[lib] +crate-type = ["lib", "cdylib"] [features] logging = ["tracing-subscriber"] [dependencies] -arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json"] } +arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json", "ffi"] } arrow-flight = { path = "../arrow-flight", default-features = false } +arrow-buffer = { path = "../arrow-buffer", default-features = false } +arrow-integration-test = { path = "../arrow-integration-test", default-features = false } async-trait = { version = "0.1.41", default-features = false } -clap = { version = "3", default-features = false, features = ["std", "derive"] } +clap = { version = "4", default-features = false, features = ["std", "derive", "help", "error-context", "usage"] } futures = { version = "0.3", default-features = false } hex = { version = "0.4", default-features = false, features = ["std"] } -prost = { version = "0.11", default-features = false } +prost = { version = "0.12", default-features = false } serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.0", default-features = false } -tonic = { version = "0.8", default-features = false } +tonic = { version = "0.10", default-features = false } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true } num = { version = "0.4", default-features = false, features = ["std"] } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } diff --git a/integration-testing/README.md b/arrow-integration-testing/README.md similarity index 99% rename from integration-testing/README.md rename to arrow-integration-testing/README.md index e82591e6b139..dcf39c27fbc5 100644 --- a/integration-testing/README.md +++ b/arrow-integration-testing/README.md @@ -48,7 +48,7 @@ ln -s arrow/rust ```shell cd arrow -pip install -e dev/archery[docker] +pip install -e dev/archery[integration] ``` ### Build the C++ binaries: diff --git a/integration-testing/src/bin/arrow-file-to-stream.rs b/arrow-integration-testing/src/bin/arrow-file-to-stream.rs similarity index 97% rename from integration-testing/src/bin/arrow-file-to-stream.rs rename to arrow-integration-testing/src/bin/arrow-file-to-stream.rs index e939fe4f0bf7..3e027faef91f 100644 --- a/integration-testing/src/bin/arrow-file-to-stream.rs +++ b/arrow-integration-testing/src/bin/arrow-file-to-stream.rs @@ -30,7 +30,7 @@ struct Args { fn main() -> Result<()> { let args = Args::parse(); - let f = File::open(&args.file_name)?; + let f = File::open(args.file_name)?; let reader = BufReader::new(f); let mut reader = FileReader::try_new(reader, None)?; let schema = reader.schema(); diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs similarity index 66% rename from integration-testing/src/bin/arrow-json-integration-test.rs rename to arrow-integration-testing/src/bin/arrow-json-integration-test.rs index a7d7cf6ee7cb..9f1abb16a668 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::Schema; -use arrow::datatypes::{DataType, Field}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow_integration_testing::{read_json_file, util::*}; +use arrow_integration_test::*; +use arrow_integration_testing::{canonicalize_schema, open_json_file}; use clap::Parser; use std::fs::File; -#[derive(clap::ArgEnum, Debug, Clone)] +#[derive(clap::ValueEnum, Debug, Clone)] #[clap(rename_all = "SCREAMING_SNAKE_CASE")] enum Mode { ArrowToJson, @@ -41,7 +40,7 @@ struct Args { arrow: String, #[clap(short, long, help("Path to JSON file"))] json: String, - #[clap(arg_enum, short, long, default_value_t = Mode::Validate, help="Mode of integration testing tool")] + #[clap(value_enum, short, long, default_value_t = Mode::Validate, help="Mode of integration testing tool")] mode: Mode, #[clap(short, long)] verbose: bool, @@ -61,15 +60,15 @@ fn main() -> Result<()> { fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> { if verbose { - eprintln!("Converting {} to {}", json_name, arrow_name); + eprintln!("Converting {json_name} to {arrow_name}"); } - let json_file = read_json_file(json_name)?; + let json_file = open_json_file(json_name)?; let arrow_file = File::create(arrow_name)?; let mut writer = FileWriter::try_new(arrow_file, &json_file.schema)?; - for b in json_file.batches { + for b in json_file.read_batches()? { writer.write(&b)?; } @@ -80,7 +79,7 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { - eprintln!("Converting {} to {}", arrow_name, json_name); + eprintln!("Converting {arrow_name} to {json_name}"); } let arrow_file = File::open(arrow_name)?; @@ -111,54 +110,13 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> Ok(()) } -fn canonicalize_schema(schema: &Schema) -> Schema { - let fields = schema - .fields() - .iter() - .map(|field| match field.data_type() { - DataType::Map(child_field, sorted) => match child_field.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { - let first_field = fields.get(0).unwrap(); - let key_field = Field::new( - "key", - first_field.data_type().clone(), - first_field.is_nullable(), - ); - let second_field = fields.get(1).unwrap(); - let value_field = Field::new( - "value", - second_field.data_type().clone(), - second_field.is_nullable(), - ); - - let struct_type = DataType::Struct(vec![key_field, value_field]); - let child_field = - Field::new("entries", struct_type, child_field.is_nullable()); - - Field::new( - field.name().as_str(), - DataType::Map(Box::new(child_field), *sorted), - field.is_nullable(), - ) - } - _ => panic!( - "The child field of Map type should be Struct type with 2 fields." - ), - }, - _ => field.clone(), - }) - .collect::>(); - - Schema::new(fields).with_metadata(schema.metadata().clone()) -} - fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { - eprintln!("Validating {} and {}", arrow_name, json_name); + eprintln!("Validating {arrow_name} and {json_name}"); } // open JSON file - let json_file = read_json_file(json_name)?; + let json_file = open_json_file(json_name)?; // open Arrow file let arrow_file = File::open(arrow_name)?; @@ -173,7 +131,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { ))); } - let json_batches = &json_file.batches; + let json_batches = json_file.read_batches()?; // compare number of batches assert!( @@ -197,8 +155,8 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { for i in 0..num_columns { assert_eq!( - arrow_batch.column(i).data(), - json_batch.column(i).data(), + arrow_batch.column(i).as_ref(), + json_batch.column(i).as_ref(), "Arrow and JSON batch columns not the same" ); } diff --git a/integration-testing/src/bin/arrow-stream-to-file.rs b/arrow-integration-testing/src/bin/arrow-stream-to-file.rs similarity index 100% rename from integration-testing/src/bin/arrow-stream-to-file.rs rename to arrow-integration-testing/src/bin/arrow-stream-to-file.rs diff --git a/integration-testing/src/bin/flight-test-integration-client.rs b/arrow-integration-testing/src/bin/flight-test-integration-client.rs similarity index 95% rename from integration-testing/src/bin/flight-test-integration-client.rs rename to arrow-integration-testing/src/bin/flight-test-integration-client.rs index fa99b424e378..b8bbb952837b 100644 --- a/integration-testing/src/bin/flight-test-integration-client.rs +++ b/arrow-integration-testing/src/bin/flight-test-integration-client.rs @@ -20,7 +20,7 @@ use clap::Parser; type Error = Box; type Result = std::result::Result; -#[derive(clap::ArgEnum, Debug, Clone)] +#[derive(clap::ValueEnum, Debug, Clone)] enum Scenario { Middleware, #[clap(name = "auth:basic_proto")] @@ -40,7 +40,7 @@ struct Args { help = "path to the descriptor file, only used when scenario is not provided. See https://arrow.apache.org/docs/format/Integration.html#json-test-data-format" )] path: Option, - #[clap(long, arg_enum)] + #[clap(long, value_enum)] scenario: Option, } @@ -62,8 +62,7 @@ async fn main() -> Result { } None => { let path = args.path.expect("No path is given"); - flight_client_scenarios::integration_test::run_scenario(&host, port, &path) - .await?; + flight_client_scenarios::integration_test::run_scenario(&host, port, &path).await?; } } diff --git a/integration-testing/src/bin/flight-test-integration-server.rs b/arrow-integration-testing/src/bin/flight-test-integration-server.rs similarity index 96% rename from integration-testing/src/bin/flight-test-integration-server.rs rename to arrow-integration-testing/src/bin/flight-test-integration-server.rs index 6ed22ad81d90..5310d07d4f8e 100644 --- a/integration-testing/src/bin/flight-test-integration-server.rs +++ b/arrow-integration-testing/src/bin/flight-test-integration-server.rs @@ -21,7 +21,7 @@ use clap::Parser; type Error = Box; type Result = std::result::Result; -#[derive(clap::ArgEnum, Debug, Clone)] +#[derive(clap::ValueEnum, Debug, Clone)] enum Scenario { Middleware, #[clap(name = "auth:basic_proto")] @@ -33,7 +33,7 @@ enum Scenario { struct Args { #[clap(long)] port: u16, - #[clap(long, arg_enum)] + #[clap(long, value_enum)] scenario: Option, } diff --git a/integration-testing/src/flight_client_scenarios.rs b/arrow-integration-testing/src/flight_client_scenarios.rs similarity index 100% rename from integration-testing/src/flight_client_scenarios.rs rename to arrow-integration-testing/src/flight_client_scenarios.rs diff --git a/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs similarity index 84% rename from integration-testing/src/flight_client_scenarios/auth_basic_proto.rs rename to arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs index ab398d3d2e7b..376e31e15553 100644 --- a/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs @@ -17,9 +17,7 @@ use crate::{AUTH_PASSWORD, AUTH_USERNAME}; -use arrow_flight::{ - flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest, -}; +use arrow_flight::{flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest}; use futures::{stream, StreamExt}; use prost::Message; use tonic::{metadata::MetadataValue, Request, Status}; @@ -30,7 +28,7 @@ type Result = std::result::Result; type Client = FlightServiceClient; pub async fn run_scenario(host: &str, port: u16) -> Result { - let url = format!("http://{}:{}", host, port); + let url = format!("http://{host}:{port}"); let mut client = FlightServiceClient::connect(url).await?; let action = arrow_flight::Action::default(); @@ -41,15 +39,13 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { Err(e) => { if e.code() != tonic::Code::Unauthenticated { return Err(Box::new(Status::internal(format!( - "Expected UNAUTHENTICATED but got {:?}", - e + "Expected UNAUTHENTICATED but got {e:?}" )))); } } Ok(other) => { return Err(Box::new(Status::internal(format!( - "Expected UNAUTHENTICATED but got {:?}", - other + "Expected UNAUTHENTICATED but got {other:?}" )))); } } @@ -74,17 +70,13 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { .expect("No response received") .expect("Invalid response received"); - let body = String::from_utf8(r.body).unwrap(); + let body = std::str::from_utf8(&r.body).unwrap(); assert_eq!(body, AUTH_USERNAME); Ok(()) } -async fn authenticate( - client: &mut Client, - username: &str, - password: &str, -) -> Result { +async fn authenticate(client: &mut Client, username: &str, password: &str) -> Result { let auth = BasicAuth { username: username.into(), password: password.into(), @@ -94,7 +86,7 @@ async fn authenticate( let req = stream::once(async { HandshakeRequest { - payload, + payload: payload.into(), ..HandshakeRequest::default() } }); @@ -105,5 +97,5 @@ async fn authenticate( let r = rx.next().await.expect("must respond from handshake")?; assert!(rx.next().await.is_none(), "must not respond a second time"); - Ok(String::from_utf8(r.payload).unwrap()) + Ok(std::str::from_utf8(&r.payload).unwrap().into()) } diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs similarity index 81% rename from integration-testing/src/flight_client_scenarios/integration_test.rs rename to arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index c01baa09a1f7..c6b5a72ca6e2 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{read_json_file, ArrowFile}; +use crate::open_json_file; use std::collections::HashMap; use arrow::{ @@ -27,8 +27,7 @@ use arrow::{ }; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, - SchemaAsIpc, Ticket, + utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, SchemaAsIpc, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; @@ -42,27 +41,20 @@ type Result = std::result::Result; type Client = FlightServiceClient; pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { - let url = format!("http://{}:{}", host, port); + let url = format!("http://{host}:{port}"); let client = FlightServiceClient::connect(url).await?; - let ArrowFile { - schema, batches, .. - } = read_json_file(path)?; + let json_file = open_json_file(path)?; - let schema = Arc::new(schema); + let batches = json_file.read_batches()?; + let schema = Arc::new(json_file.schema); let mut descriptor = FlightDescriptor::default(); descriptor.set_type(DescriptorType::Path); descriptor.path = vec![path.to_string()]; - upload_data( - client.clone(), - schema.clone(), - descriptor.clone(), - batches.clone(), - ) - .await?; + upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?; verify_data(client, descriptor, &batches).await?; Ok(()) @@ -130,15 +122,23 @@ async fn send_batch( batch: &RecordBatch, options: &writer::IpcWriteOptions, ) -> Result { - let (dictionary_flight_data, mut batch_flight_data) = - arrow_flight::utils::flight_data_from_arrow_batch(batch, options); + let data_gen = writer::IpcDataGenerator::default(); + let mut dictionary_tracker = writer::DictionaryTracker::new(false); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let dictionary_flight_data: Vec = + encoded_dictionaries.into_iter().map(Into::into).collect(); + let mut batch_flight_data: FlightData = encoded_batch.into(); upload_tx .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) .await?; // Only the record batch's FlightData gets app_metadata - batch_flight_data.app_metadata = metadata.to_vec(); + batch_flight_data.app_metadata = metadata.to_vec().into(); upload_tx.send(batch_flight_data).await?; Ok(()) } @@ -195,19 +195,16 @@ async fn consume_flight_location( let mut dictionaries_by_id = HashMap::new(); for (counter, expected_batch) in expected_data.iter().enumerate() { - let data = receive_batch_flight_data( - &mut resp, - actual_schema.clone(), - &mut dictionaries_by_id, - ) - .await - .unwrap_or_else(|| { - panic!( - "Got fewer batches than expected, received so far: {} expected: {}", - counter, - expected_data.len(), - ) - }); + let data = + receive_batch_flight_data(&mut resp, actual_schema.clone(), &mut dictionaries_by_id) + .await + .unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + }); let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); @@ -224,10 +221,10 @@ async fn consume_flight_location( let field = schema.field(i); let field_name = field.name(); - let expected_data = expected_batch.column(i).data(); - let actual_data = actual_batch.column(i).data(); + let expected_data = expected_batch.column(i).as_ref(); + let actual_data = actual_batch.column(i).as_ref(); - assert_eq!(expected_data, actual_data, "Data for field {}", field_name); + assert_eq!(expected_data, actual_data, "Data for field {field_name}"); } } @@ -242,8 +239,8 @@ async fn consume_flight_location( async fn receive_schema_flight_data(resp: &mut Streaming) -> Option { let data = resp.next().await?.ok()?; - let message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing message"); + let message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message"); // message header is a Schema, so read it let ipc_schema: ipc::Schema = message @@ -260,8 +257,8 @@ async fn receive_batch_flight_data( dictionaries_by_id: &mut HashMap, ) -> Option { let mut data = resp.next().await?.ok()?; - let mut message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing first message"); + let mut message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing first message"); while message.header_type() == ipc::MessageHeader::DictionaryBatch { reader::read_dictionary( @@ -276,8 +273,8 @@ async fn receive_batch_flight_data( .expect("Error reading dictionary"); data = resp.next().await?.ok()?; - message = arrow::ipc::root_as_message(&data.data_header[..]) - .expect("Error parsing message"); + message = + arrow::ipc::root_as_message(&data.data_header[..]).expect("Error parsing message"); } Some(data) diff --git a/integration-testing/src/flight_client_scenarios/middleware.rs b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs similarity index 90% rename from integration-testing/src/flight_client_scenarios/middleware.rs rename to arrow-integration-testing/src/flight_client_scenarios/middleware.rs index db8c42cc081c..3b71edf446a3 100644 --- a/integration-testing/src/flight_client_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs @@ -16,22 +16,22 @@ // under the License. use arrow_flight::{ - flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - FlightDescriptor, + flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, FlightDescriptor, }; +use prost::bytes::Bytes; use tonic::{Request, Status}; type Error = Box; type Result = std::result::Result; pub async fn run_scenario(host: &str, port: u16) -> Result { - let url = format!("http://{}:{}", host, port); + let url = format!("http://{host}:{port}"); let conn = tonic::transport::Endpoint::new(url)?.connect().await?; let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor); let mut descriptor = FlightDescriptor::default(); descriptor.set_type(DescriptorType::Cmd); - descriptor.cmd = b"".to_vec(); + descriptor.cmd = Bytes::from_static(b""); // This call is expected to fail. match client @@ -47,8 +47,7 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { if value != "expected value" { let msg = format!( "On failing call: Expected to receive header 'x-middleware: expected value', \ - but instead got: '{}'", - value + but instead got: '{value}'" ); return Err(Box::new(Status::internal(msg))); } @@ -56,7 +55,7 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { } // This call should succeed - descriptor.cmd = b"success".to_vec(); + descriptor.cmd = Bytes::from_static(b"success"); let resp = client.get_flight_info(Request::new(descriptor)).await?; let headers = resp.metadata(); @@ -66,8 +65,7 @@ pub async fn run_scenario(host: &str, port: u16) -> Result { if value != "expected value" { let msg = format!( "On success call: Expected to receive header 'x-middleware: expected value', \ - but instead got: '{}'", - value + but instead got: '{value}'" ); return Err(Box::new(Status::internal(msg))); } diff --git a/integration-testing/src/flight_server_scenarios.rs b/arrow-integration-testing/src/flight_server_scenarios.rs similarity index 92% rename from integration-testing/src/flight_server_scenarios.rs rename to arrow-integration-testing/src/flight_server_scenarios.rs index e56252f1dfbf..9034776c68d4 100644 --- a/integration-testing/src/flight_server_scenarios.rs +++ b/arrow-integration-testing/src/flight_server_scenarios.rs @@ -28,7 +28,7 @@ type Error = Box; type Result = std::result::Result; pub async fn listen_on(port: u16) -> Result { - let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?; + let addr: SocketAddr = format!("0.0.0.0:{port}").parse()?; let listener = TcpListener::bind(addr).await?; let addr = listener.local_addr()?; @@ -39,7 +39,7 @@ pub async fn listen_on(port: u16) -> Result { pub fn endpoint(ticket: &str, location_uri: impl Into) -> FlightEndpoint { FlightEndpoint { ticket: Some(Ticket { - ticket: ticket.as_bytes().to_vec(), + ticket: ticket.as_bytes().to_vec().into(), }), location: vec![Location { uri: location_uri.into(), diff --git a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs similarity index 90% rename from integration-testing/src/flight_server_scenarios/auth_basic_proto.rs rename to arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs index 68a4a0d3b4ad..ff4fc12f2523 100644 --- a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -19,15 +19,13 @@ use std::pin::Pin; use std::sync::Arc; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, - FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, + ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; use tokio::sync::Mutex; -use tonic::{ - metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming, -}; +use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming}; type TonicStream = Pin + Send + Sync + 'static>>; type Error = Box; @@ -63,10 +61,7 @@ pub struct AuthBasicProtoScenarioImpl { } impl AuthBasicProtoScenarioImpl { - async fn check_auth( - &self, - metadata: &MetadataMap, - ) -> Result { + async fn check_auth(&self, metadata: &MetadataMap) -> Result { let token = metadata .get_bin("auth-token-bin") .and_then(|v| v.to_bytes().ok()) @@ -74,10 +69,7 @@ impl AuthBasicProtoScenarioImpl { self.is_valid(token).await } - async fn is_valid( - &self, - token: Option, - ) -> Result { + async fn is_valid(&self, token: Option) -> Result { match token { Some(t) if t == *self.username => Ok(GrpcServerCallContext { peer_identity: self.username.to_string(), @@ -142,14 +134,12 @@ impl FlightService for AuthBasicProtoScenarioImpl { let req = req.expect("Error reading handshake request"); let HandshakeRequest { payload, .. } = req; - let auth = BasicAuth::decode(&*payload) - .expect("Error parsing handshake request"); + let auth = + BasicAuth::decode(&*payload).expect("Error parsing handshake request"); - let resp = if *auth.username == *username - && *auth.password == *password - { + let resp = if *auth.username == *username && *auth.password == *password { Ok(HandshakeResponse { - payload: username.as_bytes().to_vec(), + payload: username.as_bytes().to_vec().into(), ..HandshakeResponse::default() }) } else { @@ -203,7 +193,7 @@ impl FlightService for AuthBasicProtoScenarioImpl { ) -> Result, Status> { let flight_context = self.check_auth(request.metadata()).await?; // Respond with the authenticated username. - let buf = flight_context.peer_identity().as_bytes().to_vec(); + let buf = flight_context.peer_identity().as_bytes().to_vec().into(); let result = arrow_flight::Result { body: buf }; let output = futures::stream::once(async { Ok(result) }); Ok(Response::new(Box::pin(output) as Self::DoActionStream)) diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs similarity index 84% rename from integration-testing/src/flight_server_scenarios/integration_test.rs rename to arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index dee2fda3be3d..2011031e921a 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -25,14 +25,14 @@ use arrow::{ buffer::Buffer, datatypes::Schema, datatypes::SchemaRef, - ipc::{self, reader}, + ipc::{self, reader, writer}, record_batch::RecordBatch, }; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, SchemaResult, Ticket, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, + PutResult, SchemaAsIpc, SchemaResult, Ticket, }; use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; use std::convert::TryInto; @@ -48,7 +48,7 @@ pub async fn scenario_setup(port: u16) -> Result { let addr = super::listen_on(port).await?; let service = FlightServiceImpl { - server_location: format!("grpc+tcp://{}", addr), + server_location: format!("grpc+tcp://{addr}"), ..Default::default() }; let svc = FlightServiceServer::new(service); @@ -103,33 +103,38 @@ impl FlightService for FlightServiceImpl { let ticket = request.into_inner(); let key = String::from_utf8(ticket.ticket.to_vec()) - .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {:?}", e)))?; + .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {e:?}")))?; let uploaded_chunks = self.uploaded_chunks.lock().await; - let flight = uploaded_chunks.get(&key).ok_or_else(|| { - Status::not_found(format!("Could not find flight. {}", key)) - })?; + let flight = uploaded_chunks + .get(&key) + .ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?; let options = arrow::ipc::writer::IpcWriteOptions::default(); - let schema = - std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into())); + let schema = std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into())); let batches = flight .chunks .iter() .enumerate() .flat_map(|(counter, batch)| { - let (dictionary_flight_data, mut batch_flight_data) = - arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + let data_gen = writer::IpcDataGenerator::default(); + let mut dictionary_tracker = writer::DictionaryTracker::new(false); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, &options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into); + let mut batch_flight_data: FlightData = encoded_batch.into(); // Only the record batch's FlightData gets app_metadata - let metadata = counter.to_string().into_bytes(); + let metadata = counter.to_string().into(); batch_flight_data.app_metadata = metadata; dictionary_flight_data - .into_iter() .chain(std::iter::once(batch_flight_data)) .map(Ok) }); @@ -173,8 +178,7 @@ impl FlightService for FlightServiceImpl { let endpoint = self.endpoint_from_path(&path[0]); - let total_records: usize = - flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); let options = arrow::ipc::writer::IpcWriteOptions::default(); let message = SchemaAsIpc::new(&flight.schema, &options) @@ -191,11 +195,12 @@ impl FlightService for FlightServiceImpl { endpoint: vec![endpoint], total_records: total_records as i64, total_bytes: -1, + ordered: false, }; Ok(Response::new(info)) } - other => Err(Status::unimplemented(format!("Request type: {}", other))), + other => Err(Status::unimplemented(format!("Request type: {other}"))), } } @@ -214,15 +219,14 @@ impl FlightService for FlightServiceImpl { .clone() .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?; - if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() - { + if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() { return Err(Status::invalid_argument("Must specify a path")); } let key = descriptor.path[0].clone(); let schema = Schema::try_from(&flight_data) - .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; + .map_err(|e| Status::invalid_argument(format!("Invalid schema: {e:?}")))?; let schema_ref = Arc::new(schema.clone()); let (response_tx, response_rx) = mpsc::channel(10); @@ -275,10 +279,10 @@ async fn send_app_metadata( app_metadata: &[u8], ) -> Result<(), Status> { tx.send(Ok(PutResult { - app_metadata: app_metadata.to_vec(), + app_metadata: app_metadata.to_vec().into(), })) .await - .map_err(|e| Status::internal(format!("Could not send PutResult: {:?}", e))) + .map_err(|e| Status::internal(format!("Could not send PutResult: {e:?}"))) } async fn record_batch_from_message( @@ -287,9 +291,9 @@ async fn record_batch_from_message( schema_ref: SchemaRef, dictionaries_by_id: &HashMap, ) -> Result { - let ipc_batch = message.header_as_record_batch().ok_or_else(|| { - Status::internal("Could not parse message header as record batch") - })?; + let ipc_batch = message + .header_as_record_batch() + .ok_or_else(|| Status::internal("Could not parse message header as record batch"))?; let arrow_batch_result = reader::read_record_batch( data_body, @@ -300,9 +304,8 @@ async fn record_batch_from_message( &message.version(), ); - arrow_batch_result.map_err(|e| { - Status::internal(format!("Could not convert to RecordBatch: {:?}", e)) - }) + arrow_batch_result + .map_err(|e| Status::internal(format!("Could not convert to RecordBatch: {e:?}"))) } async fn dictionary_from_message( @@ -311,9 +314,9 @@ async fn dictionary_from_message( schema_ref: SchemaRef, dictionaries_by_id: &mut HashMap, ) -> Result<(), Status> { - let ipc_batch = message.header_as_dictionary_batch().ok_or_else(|| { - Status::internal("Could not parse message header as dictionary batch") - })?; + let ipc_batch = message + .header_as_dictionary_batch() + .ok_or_else(|| Status::internal("Could not parse message header as dictionary batch"))?; let dictionary_batch_result = reader::read_dictionary( data_body, @@ -322,9 +325,8 @@ async fn dictionary_from_message( dictionaries_by_id, &message.version(), ); - dictionary_batch_result.map_err(|e| { - Status::internal(format!("Could not convert to Dictionary: {:?}", e)) - }) + dictionary_batch_result + .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {e:?}"))) } async fn save_uploaded_chunks( @@ -342,7 +344,7 @@ async fn save_uploaded_chunks( while let Some(Ok(data)) = input_stream.next().await { let message = arrow::ipc::root_as_message(&data.data_header[..]) - .map_err(|e| Status::internal(format!("Could not parse message: {:?}", e)))?; + .map_err(|e| Status::internal(format!("Could not parse message: {e:?}")))?; match message.header_type() { ipc::MessageHeader::Schema => { @@ -375,8 +377,7 @@ async fn save_uploaded_chunks( t => { return Err(Status::internal(format!( "Reading types other than record batches not yet supported, \ - unable to read {:?}", - t + unable to read {t:?}" ))); } } diff --git a/integration-testing/src/flight_server_scenarios/middleware.rs b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs similarity index 96% rename from integration-testing/src/flight_server_scenarios/middleware.rs rename to arrow-integration-testing/src/flight_server_scenarios/middleware.rs index 5876ac9bfe6d..68d871b528a6 100644 --- a/integration-testing/src/flight_server_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs @@ -19,9 +19,9 @@ use std::pin::Pin; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, - PutResult, SchemaResult, Ticket, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, + Ticket, }; use futures::Stream; use tonic::{transport::Server, Request, Response, Status, Streaming}; @@ -93,7 +93,7 @@ impl FlightService for MiddlewareScenarioImpl { let descriptor = request.into_inner(); - if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd == b"success" + if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd.as_ref() == b"success" { // Return a fake location - the test doesn't read it let endpoint = super::endpoint("foo", "grpc+tcp://localhost:10010"); diff --git a/arrow-integration-testing/src/lib.rs b/arrow-integration-testing/src/lib.rs new file mode 100644 index 000000000000..553e69b0a1a0 --- /dev/null +++ b/arrow-integration-testing/src/lib.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common code used in the integration test binaries + +use serde_json::Value; + +use arrow::array::{Array, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema}; +use arrow::error::{ArrowError, Result}; +use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::record_batch::RecordBatch; +use arrow::util::test_util::arrow_test_data; +use arrow_integration_test::*; +use std::collections::HashMap; +use std::ffi::{c_int, CStr, CString}; +use std::fs::File; +use std::io::BufReader; +use std::iter::zip; +use std::ptr; +use std::sync::Arc; + +/// The expected username for the basic auth integration test. +pub const AUTH_USERNAME: &str = "arrow"; +/// The expected password for the basic auth integration test. +pub const AUTH_PASSWORD: &str = "flight"; + +pub mod flight_client_scenarios; +pub mod flight_server_scenarios; + +pub struct ArrowFile { + pub schema: Schema, + // we can evolve this into a concrete Arrow type + // this is temporarily not being read from + dictionaries: HashMap, + arrow_json: Value, +} + +impl ArrowFile { + pub fn read_batch(&self, batch_num: usize) -> Result { + let b = self.arrow_json["batches"].get(batch_num).unwrap(); + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + record_batch_from_json(&self.schema, json_batch, Some(&self.dictionaries)) + } + + pub fn read_batches(&self) -> Result> { + self.arrow_json["batches"] + .as_array() + .unwrap() + .iter() + .map(|b| { + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + record_batch_from_json(&self.schema, json_batch, Some(&self.dictionaries)) + }) + .collect() + } +} + +// Canonicalize the names of map fields in a schema +pub fn canonicalize_schema(schema: &Schema) -> Schema { + let fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Map(child_field, sorted) => match child_field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + let first_field = fields.get(0).unwrap(); + let key_field = + Arc::new(Field::new("key", first_field.data_type().clone(), false)); + let second_field = fields.get(1).unwrap(); + let value_field = Arc::new(Field::new( + "value", + second_field.data_type().clone(), + second_field.is_nullable(), + )); + + let fields = Fields::from([key_field, value_field]); + let struct_type = DataType::Struct(fields); + let child_field = Field::new("entries", struct_type, false); + + Arc::new(Field::new( + field.name().as_str(), + DataType::Map(Arc::new(child_field), *sorted), + field.is_nullable(), + )) + } + _ => panic!("The child field of Map type should be Struct type with 2 fields."), + }, + _ => field.clone(), + }) + .collect::(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + +pub fn open_json_file(json_name: &str) -> Result { + let json_file = File::open(json_name)?; + let reader = BufReader::new(json_file); + let arrow_json: Value = serde_json::from_reader(reader).unwrap(); + let schema = schema_from_json(&arrow_json["schema"])?; + // read dictionaries + let mut dictionaries = HashMap::new(); + if let Some(dicts) = arrow_json.get("dictionaries") { + for d in dicts + .as_array() + .expect("Unable to get dictionaries as array") + { + let json_dict: ArrowJsonDictionaryBatch = + serde_json::from_value(d.clone()).expect("Unable to get dictionary from JSON"); + // TODO: convert to a concrete Arrow type + dictionaries.insert(json_dict.id, json_dict); + } + } + Ok(ArrowFile { + schema, + dictionaries, + arrow_json, + }) +} + +/// Read gzipped JSON test file +/// +/// For example given the input: +/// version = `0.17.1` +/// path = `generated_union` +/// +/// Returns the contents of +/// `arrow-ipc-stream/integration/0.17.1/generated_union.json.gz` +pub fn read_gzip_json(version: &str, path: &str) -> ArrowJson { + use flate2::read::GzDecoder; + use std::io::Read; + + let testdata = arrow_test_data(); + let file = File::open(format!( + "{testdata}/arrow-ipc-stream/integration/{version}/{path}.json.gz" + )) + .unwrap(); + let mut gz = GzDecoder::new(&file); + let mut s = String::new(); + gz.read_to_string(&mut s).unwrap(); + // convert to Arrow JSON + let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); + arrow_json +} + +// +// C Data Integration entrypoints +// + +fn cdata_integration_export_schema_from_json( + c_json_name: *const i8, + out: *mut FFI_ArrowSchema, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let f = open_json_file(json_name.to_str()?)?; + let c_schema = FFI_ArrowSchema::try_from(&f.schema)?; + // Move exported schema into output struct + unsafe { ptr::write(out, c_schema) }; + Ok(()) +} + +fn cdata_integration_export_batch_from_json( + c_json_name: *const i8, + batch_num: c_int, + out: *mut FFI_ArrowArray, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let b = open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?; + let a = StructArray::from(b).into_data(); + let c_array = FFI_ArrowArray::new(&a); + // Move exported array into output struct + unsafe { ptr::write(out, c_array) }; + Ok(()) +} + +fn cdata_integration_import_schema_and_compare_to_json( + c_json_name: *const i8, + c_schema: *mut FFI_ArrowSchema, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let json_schema = open_json_file(json_name.to_str()?)?.schema; + + // The source ArrowSchema will be released when this is dropped + let imported_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema) }; + let imported_schema = Schema::try_from(&imported_schema)?; + + // compare schemas + if canonicalize_schema(&json_schema) != canonicalize_schema(&imported_schema) { + return Err(ArrowError::ComputeError(format!( + "Schemas do not match.\n- JSON: {:?}\n- Imported: {:?}", + json_schema, imported_schema + ))); + } + Ok(()) +} + +fn compare_batches(a: &RecordBatch, b: &RecordBatch) -> Result<()> { + if a.num_columns() != b.num_columns() { + return Err(ArrowError::InvalidArgumentError( + "batches do not have the same number of columns".to_string(), + )); + } + for (a_column, b_column) in zip(a.columns(), b.columns()) { + if a_column != b_column { + return Err(ArrowError::InvalidArgumentError( + "batch columns are not the same".to_string(), + )); + } + } + Ok(()) +} + +fn cdata_integration_import_batch_and_compare_to_json( + c_json_name: *const i8, + batch_num: c_int, + c_array: *mut FFI_ArrowArray, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let json_batch = + open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?; + let schema = json_batch.schema(); + + let data_type_for_import = DataType::Struct(schema.fields.clone()); + let imported_array = unsafe { FFI_ArrowArray::from_raw(c_array) }; + let imported_array = unsafe { from_ffi_and_data_type(imported_array, data_type_for_import) }?; + imported_array.validate_full()?; + let imported_batch = RecordBatch::from(StructArray::from(imported_array)); + + compare_batches(&json_batch, &imported_batch) +} + +// If Result is an error, then export a const char* to its string display, otherwise NULL +fn result_to_c_error(result: &std::result::Result) -> *mut i8 { + match result { + Ok(_) => ptr::null_mut(), + Err(e) => CString::new(format!("{}", e)).unwrap().into_raw(), + } +} + +/// Release a const char* exported by result_to_c_error() +/// +/// # Safety +/// +/// The pointer is assumed to have been obtained using CString::into_raw. +#[no_mangle] +pub unsafe extern "C" fn arrow_rs_free_error(c_error: *mut i8) { + if !c_error.is_null() { + drop(unsafe { CString::from_raw(c_error) }); + } +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_export_schema_from_json( + c_json_name: *const i8, + out: *mut FFI_ArrowSchema, +) -> *mut i8 { + let r = cdata_integration_export_schema_from_json(c_json_name, out); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_import_schema_and_compare_to_json( + c_json_name: *const i8, + c_schema: *mut FFI_ArrowSchema, +) -> *mut i8 { + let r = cdata_integration_import_schema_and_compare_to_json(c_json_name, c_schema); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_export_batch_from_json( + c_json_name: *const i8, + batch_num: c_int, + out: *mut FFI_ArrowArray, +) -> *mut i8 { + let r = cdata_integration_export_batch_from_json(c_json_name, batch_num, out); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_import_batch_and_compare_to_json( + c_json_name: *const i8, + batch_num: c_int, + c_array: *mut FFI_ArrowArray, +) -> *mut i8 { + let r = cdata_integration_import_batch_and_compare_to_json(c_json_name, batch_num, c_array); + result_to_c_error(&r) +} diff --git a/arrow-integration-testing/tests/ipc_reader.rs b/arrow-integration-testing/tests/ipc_reader.rs new file mode 100644 index 000000000000..11b8fa84534e --- /dev/null +++ b/arrow-integration-testing/tests/ipc_reader.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for reading the content of [`FileReader`] and [`StreamReader`] +//! in `testing/arrow-ipc-stream/integration/...` + +use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::util::test_util::arrow_test_data; +use arrow_integration_testing::read_gzip_json; +use std::fs::File; + +#[test] +fn read_0_1_4() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + let paths = [ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn read_0_1_7() { + let testdata = arrow_test_data(); + let version = "0.17.1"; + let paths = ["generated_union"]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +#[test] +#[should_panic(expected = "Big Endian is not supported for Decimal!")] +fn read_1_0_0_bigendian_decimal_should_panic() { + let testdata = arrow_test_data(); + verify_arrow_file(&testdata, "1.0.0-bigendian", "generated_decimal"); +} + +#[test] +#[should_panic(expected = "Last offset 687865856 of Utf8 is larger than values length 41")] +fn read_1_0_0_bigendian_dictionary_should_panic() { + // The offsets are not translated for big-endian files + // https://github.com/apache/arrow-rs/issues/859 + let testdata = arrow_test_data(); + verify_arrow_file(&testdata, "1.0.0-bigendian", "generated_dictionary"); +} + +#[test] +fn read_1_0_0_bigendian() { + let testdata = arrow_test_data(); + let paths = [ + "generated_interval", + "generated_datetime", + "generated_map", + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{testdata}/arrow-ipc-stream/integration/1.0.0-bigendian/{path}.arrow_file" + )) + .unwrap(); + + FileReader::try_new(file, None).unwrap(); + + // While the the reader doesn't error but the values are not + // read correctly on little endian platforms so verifying the + // contents fails + // + // https://github.com/apache/arrow-rs/issues/3459 + //verify_arrow_file(&testdata, "1.0.0-bigendian", path); + }); +} + +#[test] +fn read_1_0_0_littleendian() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + let paths = vec![ + "generated_datetime", + "generated_custom_metadata", + "generated_decimal", + "generated_decimal256", + "generated_dictionary", + "generated_dictionary_unsigned", + "generated_duplicate_fieldnames", + "generated_extension", + "generated_interval", + "generated_map", + // https://github.com/apache/arrow-rs/issues/3460 + //"generated_map_non_canonical", + "generated_nested", + "generated_nested_dictionary", + "generated_nested_large_offsets", + "generated_null", + "generated_null_trivial", + "generated_primitive", + "generated_primitive_large_offsets", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_recursive_nested", + "generated_union", + ]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn read_2_0_0_compression() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + + // the test is repetitive, thus we can read all supported files at once + let paths = ["generated_lz4", "generated_zstd"]; + paths.iter().for_each(|path| { + verify_arrow_file(&testdata, version, path); + verify_arrow_stream(&testdata, version, path); + }); +} + +/// Verifies the arrow file format integration test +/// +/// Input file: +/// `arrow-ipc-stream/integration//.arrow_file +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn verify_arrow_file(testdata: &str, version: &str, path: &str) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); + println!("Verifying {filename}"); + + // Compare contents to the expected output format in JSON + { + println!(" verifying content"); + let file = File::open(&filename).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + } + + // Verify that projection works by selecting the first column + { + println!(" verifying projection"); + let file = File::open(&filename).unwrap(); + let reader = FileReader::try_new(file, Some(vec![0])).unwrap(); + let datatype_0 = reader.schema().fields()[0].data_type().clone(); + reader.for_each(|batch| { + let batch = batch.unwrap(); + assert_eq!(batch.columns().len(), 1); + assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone()); + }); + } +} + +/// Verifies the arrow stream integration test +/// +/// Input file: +/// `arrow-ipc-stream/integration//.stream +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn verify_arrow_stream(testdata: &str, version: &str, path: &str) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); + println!("Verifying {filename}"); + + // Compare contents to the expected output format in JSON + { + println!(" verifying content"); + let file = File::open(&filename).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + // the next batch must be empty + assert!(reader.next().is_none()); + // the stream must indicate that it's finished + assert!(reader.is_finished()); + } +} diff --git a/arrow-integration-testing/tests/ipc_writer.rs b/arrow-integration-testing/tests/ipc_writer.rs new file mode 100644 index 000000000000..d780eb2ee0b5 --- /dev/null +++ b/arrow-integration-testing/tests/ipc_writer.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::ipc; +use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::ipc::writer::{FileWriter, IpcWriteOptions, StreamWriter}; +use arrow::util::test_util::arrow_test_data; +use arrow_integration_testing::read_gzip_json; +use std::fs::File; +use std::io::Seek; + +#[test] +fn write_0_1_4() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + let paths = [ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + roundtrip_arrow_file(&testdata, version, path); + roundtrip_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn write_0_1_7() { + let testdata = arrow_test_data(); + let version = "0.17.1"; + let paths = ["generated_union"]; + paths.iter().for_each(|path| { + roundtrip_arrow_file(&testdata, version, path); + roundtrip_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn write_1_0_0_littleendian() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + let paths = [ + "generated_datetime", + "generated_custom_metadata", + "generated_decimal", + "generated_decimal256", + "generated_dictionary", + "generated_dictionary_unsigned", + "generated_duplicate_fieldnames", + "generated_extension", + "generated_interval", + "generated_map", + // https://github.com/apache/arrow-rs/issues/3460 + // "generated_map_non_canonical", + "generated_nested", + "generated_nested_dictionary", + "generated_nested_large_offsets", + "generated_null", + "generated_null_trivial", + "generated_primitive", + "generated_primitive_large_offsets", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_recursive_nested", + "generated_union", + ]; + paths.iter().for_each(|path| { + roundtrip_arrow_file(&testdata, version, path); + roundtrip_arrow_stream(&testdata, version, path); + }); +} + +#[test] +fn write_2_0_0_compression() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + let paths = ["generated_lz4", "generated_zstd"]; + + // writer options for each compression type + let all_options = [ + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) + .unwrap(), + // write IPC version 5 with zstd + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::ZSTD)) + .unwrap(), + ]; + + paths.iter().for_each(|path| { + for options in &all_options { + println!("Using options {options:?}"); + roundtrip_arrow_file_with_options(&testdata, version, path, options.clone()); + roundtrip_arrow_stream_with_options(&testdata, version, path, options.clone()); + } + }); +} + +/// Verifies the arrow file writer by reading the contents of an +/// arrow_file, writing it to a file, and then ensuring the contents +/// match the expected json contents. It also verifies that +/// RecordBatches read from the new file matches the original. +/// +/// Input file: +/// `arrow-ipc-stream/integration//.arrow_file +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn roundtrip_arrow_file(testdata: &str, version: &str, path: &str) { + roundtrip_arrow_file_with_options(testdata, version, path, IpcWriteOptions::default()) +} + +fn roundtrip_arrow_file_with_options( + testdata: &str, + version: &str, + path: &str, + options: IpcWriteOptions, +) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.arrow_file"); + println!("Verifying {filename}"); + + let mut tempfile = tempfile::tempfile().unwrap(); + + { + println!(" writing to tempfile {tempfile:?}"); + let file = File::open(&filename).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read and rewrite the file to a temp location + { + let mut writer = + FileWriter::try_new_with_options(&mut tempfile, &reader.schema(), options).unwrap(); + while let Some(Ok(batch)) = reader.next() { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + } + + { + println!(" checking rewrite to with json"); + tempfile.rewind().unwrap(); + let mut reader = FileReader::try_new(&tempfile, None).unwrap(); + + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + } + + { + println!(" checking rewrite with original"); + let file = File::open(&filename).unwrap(); + let reader = FileReader::try_new(file, None).unwrap(); + + tempfile.rewind().unwrap(); + let rewrite_reader = FileReader::try_new(&tempfile, None).unwrap(); + + // Compare to original reader + reader + .into_iter() + .zip(rewrite_reader) + .for_each(|(batch1, batch2)| { + assert_eq!(batch1.unwrap(), batch2.unwrap()); + }); + } +} + +/// Verifies the arrow file writer by reading the contents of an +/// arrow_file, writing it to a file, and then ensuring the contents +/// match the expected json contents. It also verifies that +/// RecordBatches read from the new file matches the original. +/// +/// Input file: +/// `arrow-ipc-stream/integration//.stream +/// +/// Verification json file +/// `arrow-ipc-stream/integration//.json.gz +fn roundtrip_arrow_stream(testdata: &str, version: &str, path: &str) { + roundtrip_arrow_stream_with_options(testdata, version, path, IpcWriteOptions::default()) +} + +fn roundtrip_arrow_stream_with_options( + testdata: &str, + version: &str, + path: &str, + options: IpcWriteOptions, +) { + let filename = format!("{testdata}/arrow-ipc-stream/integration/{version}/{path}.stream"); + println!("Verifying {filename}"); + + let mut tempfile = tempfile::tempfile().unwrap(); + + { + println!(" writing to tempfile {tempfile:?}"); + let file = File::open(&filename).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read and rewrite the file to a temp location + { + let mut writer = + StreamWriter::try_new_with_options(&mut tempfile, &reader.schema(), options) + .unwrap(); + while let Some(Ok(batch)) = reader.next() { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + } + + { + println!(" checking rewrite to with json"); + tempfile.rewind().unwrap(); + let mut reader = StreamReader::try_new(&tempfile, None).unwrap(); + + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + } + + { + println!(" checking rewrite with original"); + let file = File::open(&filename).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); + + tempfile.rewind().unwrap(); + let rewrite_reader = StreamReader::try_new(&tempfile, None).unwrap(); + + // Compare to original reader + reader + .into_iter() + .zip(rewrite_reader) + .for_each(|(batch1, batch2)| { + assert_eq!(batch1.unwrap(), batch2.unwrap()); + }); + } +} diff --git a/arrow-ipc/CONTRIBUTING.md b/arrow-ipc/CONTRIBUTING.md new file mode 100644 index 000000000000..5e14760f19df --- /dev/null +++ b/arrow-ipc/CONTRIBUTING.md @@ -0,0 +1,37 @@ + + +## Developer's guide + +# IPC + +The expected flatc version is 1.12.0+, built from [flatbuffers](https://github.com/google/flatbuffers) +master at fixed commit ID, by regen.sh. + +The IPC flatbuffer code was generated by running this command from the root of the project: + +```bash +./regen.sh +``` + +The above script will run the `flatc` compiler and perform some adjustments to the source code: + +- Replace `type__` with `type_` +- Remove `org::apache::arrow::flatbuffers` namespace +- Add includes to each generated file diff --git a/arrow-ipc/Cargo.toml b/arrow-ipc/Cargo.toml new file mode 100644 index 000000000000..83ad044d25e7 --- /dev/null +++ b/arrow-ipc/Cargo.toml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-ipc" +version = { workspace = true } +description = "Support for the Arrow IPC format" +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "arrow_ipc" +path = "src/lib.rs" +bench = false + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +flatbuffers = { version = "23.1.21", default-features = false } +lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } +zstd = { version = "0.13.0", default-features = false, optional = true } + +[features] +default = [] +lz4 = ["lz4_flex"] + +[dev-dependencies] +tempfile = "3.3" diff --git a/arrow/regen.sh b/arrow-ipc/regen.sh similarity index 83% rename from arrow/regen.sh rename to arrow-ipc/regen.sh index 9d384b6b63b6..8d8862ccc7f4 100755 --- a/arrow/regen.sh +++ b/arrow-ipc/regen.sh @@ -18,15 +18,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# Change to the toplevel Rust directory -pushd $DIR/../../ +# Change to the toplevel `arrow-rs` directory +pushd $DIR/../ echo "Build flatc from source ..." FB_URL="https://github.com/google/flatbuffers" -# https://github.com/google/flatbuffers/pull/6393 -FB_COMMIT="408cf5802415e1dea65fef7489a6c2f3740fb381" -FB_DIR="rust/arrow/.flatbuffers" +FB_DIR="arrow/.flatbuffers" FLATC="$FB_DIR/bazel-bin/flatc" if [ -z $(which bazel) ]; then @@ -44,28 +42,21 @@ else git -C $FB_DIR pull fi -echo "hard reset to $FB_COMMIT" -git -C $FB_DIR reset --hard $FB_COMMIT - pushd $FB_DIR echo "run: bazel build :flatc ..." bazel build :flatc popd -FB_PATCH="rust/arrow/format-0ed34c83.patch" -echo "Patch flatbuffer files with ${FB_PATCH} for cargo doc" -echo "NOTE: the patch MAY need update in case of changes in format/*.fbs" -git apply --check ${FB_PATCH} && git apply ${FB_PATCH} # Execute the code generation: -$FLATC --filename-suffix "" --rust -o rust/arrow/src/ipc/gen/ format/*.fbs +$FLATC --filename-suffix "" --rust -o arrow-ipc/src/gen/ format/*.fbs # Reset changes to format/ git checkout -- format # Now the files are wrongly named so we have to change that. popd -pushd $DIR/src/ipc/gen +pushd $DIR/src/gen PREFIX=$(cat <<'HEREDOC' // Licensed to the Apache Software Foundation (ASF) under one @@ -94,9 +85,9 @@ use flatbuffers::EndianScalar; HEREDOC ) -SCHEMA_IMPORT="\nuse crate::ipc::gen::Schema::*;" -SPARSE_TENSOR_IMPORT="\nuse crate::ipc::gen::SparseTensor::*;" -TENSOR_IMPORT="\nuse crate::ipc::gen::Tensor::*;" +SCHEMA_IMPORT="\nuse crate::gen::Schema::*;" +SPARSE_TENSOR_IMPORT="\nuse crate::gen::SparseTensor::*;" +TENSOR_IMPORT="\nuse crate::gen::Tensor::*;" # For flatbuffer(1.12.0+), remove: use crate::${name}::\*; names=("File" "Message" "Schema" "SparseTensor" "Tensor") @@ -119,8 +110,9 @@ for f in `ls *.rs`; do sed -i '' '/} \/\/ pub mod arrow/d' $f sed -i '' '/} \/\/ pub mod apache/d' $f sed -i '' '/} \/\/ pub mod org/d' $f - sed -i '' '/use std::mem;/d' $f - sed -i '' '/use std::cmp::Ordering;/d' $f + sed -i '' '/use core::mem;/d' $f + sed -i '' '/use core::cmp::Ordering;/d' $f + sed -i '' '/use self::flatbuffers::{EndianScalar, Follow};/d' $f # required by flatc 1.12.0+ sed -i '' "/\#\!\[allow(unused_imports, dead_code)\]/d" $f @@ -150,7 +142,7 @@ done # Return back to base directory popd -cargo +stable fmt -- src/ipc/gen/* +cargo +stable fmt -- src/gen/* echo "DONE!" echo "Please run 'cargo doc' and 'cargo test' with nightly and stable, " diff --git a/arrow/src/ipc/compression/codec.rs b/arrow-ipc/src/compression.rs similarity index 54% rename from arrow/src/ipc/compression/codec.rs rename to arrow-ipc/src/compression.rs index 58ba8cb86585..0d8b7b4c1bd4 100644 --- a/arrow/src/ipc/compression/codec.rs +++ b/arrow-ipc/src/compression.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::Buffer; -use crate::error::{ArrowError, Result}; -use crate::ipc::CompressionType; -use std::io::{Read, Write}; +use crate::CompressionType; +use arrow_buffer::Buffer; +use arrow_schema::ArrowError; const LENGTH_NO_COMPRESSED_DATA: i64 = -1; const LENGTH_OF_PREFIX_DATA: i64 = 8; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] /// Represents compressing a ipc stream using a particular compression algorithm +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CompressionCodec { Lz4Frame, Zstd, @@ -33,13 +32,12 @@ pub enum CompressionCodec { impl TryFrom for CompressionCodec { type Error = ArrowError; - fn try_from(compression_type: CompressionType) -> Result { + fn try_from(compression_type: CompressionType) -> Result { match compression_type { CompressionType::ZSTD => Ok(CompressionCodec::Zstd), CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame), other_type => Err(ArrowError::NotYetImplemented(format!( - "compression type {:?} not supported ", - other_type + "compression type {other_type:?} not supported " ))), } } @@ -60,7 +58,7 @@ impl CompressionCodec { &self, input: &[u8], output: &mut Vec, - ) -> Result { + ) -> Result { let uncompressed_data_len = input.len(); let original_output_len = output.len(); @@ -71,7 +69,7 @@ impl CompressionCodec { output.extend_from_slice(&uncompressed_data_len.to_le_bytes()); self.compress(input, output)?; - let compression_len = output.len(); + let compression_len = output.len() - original_output_len; if compression_len > uncompressed_data_len { // length of compressed data was larger than // uncompressed data, use the uncompressed data with @@ -92,73 +90,123 @@ impl CompressionCodec { /// [8 bytes]: uncompressed length /// [remaining bytes]: compressed data stream /// ``` - pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result { + pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result { // read the first 8 bytes to determine if the data is // compressed let decompressed_length = read_uncompressed_size(input); let buffer = if decompressed_length == 0 { - // emtpy + // empty Buffer::from([]) } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA { // no compression input.slice(LENGTH_OF_PREFIX_DATA as usize) - } else { + } else if let Ok(decompressed_length) = usize::try_from(decompressed_length) { // decompress data using the codec - let mut uncompressed_buffer = - Vec::with_capacity(decompressed_length as usize); let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..]; - self.decompress(input_data, &mut uncompressed_buffer)?; - Buffer::from(uncompressed_buffer) + self.decompress(input_data, decompressed_length as _)? + .into() + } else { + return Err(ArrowError::IpcError(format!( + "Invalid uncompressed length: {decompressed_length}" + ))); }; Ok(buffer) } /// Compress the data in input buffer and write to output buffer /// using the specified compression - fn compress(&self, input: &[u8], output: &mut Vec) -> Result<()> { + fn compress(&self, input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { match self { - CompressionCodec::Lz4Frame => { - let mut encoder = lz4::EncoderBuilder::new().build(output)?; - encoder.write_all(input)?; - match encoder.finish().1 { - Ok(_) => Ok(()), - Err(e) => Err(e.into()), - } - } - CompressionCodec::Zstd => { - let mut encoder = zstd::Encoder::new(output, 0)?; - encoder.write_all(input)?; - match encoder.finish() { - Ok(_) => Ok(()), - Err(e) => Err(e.into()), - } - } + CompressionCodec::Lz4Frame => compress_lz4(input, output), + CompressionCodec::Zstd => compress_zstd(input, output), } } /// Decompress the data in input buffer and write to output buffer /// using the specified compression - fn decompress(&self, input: &[u8], output: &mut Vec) -> Result { - let result: Result = match self { - CompressionCodec::Lz4Frame => { - let mut decoder = lz4::Decoder::new(input)?; - match decoder.read_to_end(output) { - Ok(size) => Ok(size), - Err(e) => Err(e.into()), - } - } - CompressionCodec::Zstd => { - let mut decoder = zstd::Decoder::new(input)?; - match decoder.read_to_end(output) { - Ok(size) => Ok(size), - Err(e) => Err(e.into()), - } - } + fn decompress(&self, input: &[u8], decompressed_size: usize) -> Result, ArrowError> { + let ret = match self { + CompressionCodec::Lz4Frame => decompress_lz4(input, decompressed_size)?, + CompressionCodec::Zstd => decompress_zstd(input, decompressed_size)?, }; - result + if ret.len() != decompressed_size { + return Err(ArrowError::IpcError(format!( + "Expected compressed length of {decompressed_size} got {}", + ret.len() + ))); + } + Ok(ret) } } +#[cfg(feature = "lz4")] +fn compress_lz4(input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { + use std::io::Write; + let mut encoder = lz4_flex::frame::FrameEncoder::new(output); + encoder.write_all(input)?; + encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(()) +} + +#[cfg(not(feature = "lz4"))] +#[allow(clippy::ptr_arg)] +fn compress_lz4(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> { + Err(ArrowError::InvalidArgumentError( + "lz4 IPC compression requires the lz4 feature".to_string(), + )) +} + +#[cfg(feature = "lz4")] +fn decompress_lz4(input: &[u8], decompressed_size: usize) -> Result, ArrowError> { + use std::io::Read; + let mut output = Vec::with_capacity(decompressed_size); + lz4_flex::frame::FrameDecoder::new(input).read_to_end(&mut output)?; + Ok(output) +} + +#[cfg(not(feature = "lz4"))] +#[allow(clippy::ptr_arg)] +fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result, ArrowError> { + Err(ArrowError::InvalidArgumentError( + "lz4 IPC decompression requires the lz4 feature".to_string(), + )) +} + +#[cfg(feature = "zstd")] +fn compress_zstd(input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { + use std::io::Write; + let mut encoder = zstd::Encoder::new(output, 0)?; + encoder.write_all(input)?; + encoder.finish()?; + Ok(()) +} + +#[cfg(not(feature = "zstd"))] +#[allow(clippy::ptr_arg)] +fn compress_zstd(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> { + Err(ArrowError::InvalidArgumentError( + "zstd IPC compression requires the zstd feature".to_string(), + )) +} + +#[cfg(feature = "zstd")] +fn decompress_zstd(input: &[u8], decompressed_size: usize) -> Result, ArrowError> { + use std::io::Read; + let mut output = Vec::with_capacity(decompressed_size); + zstd::Decoder::with_buffer(input)?.read_to_end(&mut output)?; + Ok(output) +} + +#[cfg(not(feature = "zstd"))] +#[allow(clippy::ptr_arg)] +fn decompress_zstd(_input: &[u8], _decompressed_size: usize) -> Result, ArrowError> { + Err(ArrowError::InvalidArgumentError( + "zstd IPC decompression requires the zstd feature".to_string(), + )) +} + /// Get the uncompressed length /// Notes: /// LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed @@ -173,31 +221,29 @@ fn read_uncompressed_size(buffer: &[u8]) -> i64 { #[cfg(test)] mod tests { - use super::*; - #[test] + #[cfg(feature = "lz4")] fn test_lz4_compression() { - let input_bytes = "hello lz4".as_bytes(); - let codec: CompressionCodec = CompressionCodec::Lz4Frame; + let input_bytes = b"hello lz4"; + let codec = super::CompressionCodec::Lz4Frame; let mut output_bytes: Vec = Vec::new(); codec.compress(input_bytes, &mut output_bytes).unwrap(); - let mut result_output_bytes: Vec = Vec::new(); - codec - .decompress(output_bytes.as_slice(), &mut result_output_bytes) + let result = codec + .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); - assert_eq!(input_bytes, result_output_bytes.as_slice()); + assert_eq!(input_bytes, result.as_slice()); } #[test] + #[cfg(feature = "zstd")] fn test_zstd_compression() { - let input_bytes = "hello zstd".as_bytes(); - let codec: CompressionCodec = CompressionCodec::Zstd; + let input_bytes = b"hello zstd"; + let codec = super::CompressionCodec::Zstd; let mut output_bytes: Vec = Vec::new(); codec.compress(input_bytes, &mut output_bytes).unwrap(); - let mut result_output_bytes: Vec = Vec::new(); - codec - .decompress(output_bytes.as_slice(), &mut result_output_bytes) + let result = codec + .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); - assert_eq!(input_bytes, result_output_bytes.as_slice()); + assert_eq!(input_bytes, result.as_slice()); } } diff --git a/arrow/src/ipc/convert.rs b/arrow-ipc/src/convert.rs similarity index 55% rename from arrow/src/ipc/convert.rs rename to arrow-ipc/src/convert.rs index 00503d50e338..b290a09acf5d 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -17,15 +17,12 @@ //! Utilities for converting between IPC types and native Arrow types -use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}; -use crate::error::{ArrowError, Result}; -use crate::ipc; - -use flatbuffers::{ - FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset, -}; -use std::collections::{BTreeMap, HashMap}; +use arrow_schema::*; +use flatbuffers::{FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset}; +use std::collections::HashMap; +use std::sync::Arc; +use crate::{size_prefixed_root_as_message, KeyValue, CONTINUATION_MARKER}; use DataType::*; /// Serialize a schema in IPC format @@ -39,39 +36,50 @@ pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder { fbb } -pub fn schema_to_fb_offset<'a>( +pub fn metadata_to_fb<'a>( fbb: &mut FlatBufferBuilder<'a>, - schema: &Schema, -) -> WIPOffset> { - let mut fields = vec![]; - for field in schema.fields() { - let fb_field = build_field(fbb, field); - fields.push(fb_field); - } - - let mut custom_metadata = vec![]; - for (k, v) in schema.metadata() { - let fb_key_name = fbb.create_string(k.as_str()); - let fb_val_name = fbb.create_string(v.as_str()); + metadata: &HashMap, +) -> WIPOffset>>> { + let custom_metadata = metadata + .iter() + .map(|(k, v)| { + let fb_key_name = fbb.create_string(k); + let fb_val_name = fbb.create_string(v); - let mut kv_builder = ipc::KeyValueBuilder::new(fbb); - kv_builder.add_key(fb_key_name); - kv_builder.add_value(fb_val_name); - custom_metadata.push(kv_builder.finish()); - } + let mut kv_builder = crate::KeyValueBuilder::new(fbb); + kv_builder.add_key(fb_key_name); + kv_builder.add_value(fb_val_name); + kv_builder.finish() + }) + .collect::>(); + fbb.create_vector(&custom_metadata) +} +pub fn schema_to_fb_offset<'a>( + fbb: &mut FlatBufferBuilder<'a>, + schema: &Schema, +) -> WIPOffset> { + let fields = schema + .fields() + .iter() + .map(|field| build_field(fbb, field)) + .collect::>(); let fb_field_list = fbb.create_vector(&fields); - let fb_metadata_list = fbb.create_vector(&custom_metadata); - let mut builder = ipc::SchemaBuilder::new(fbb); + let fb_metadata_list = + (!schema.metadata().is_empty()).then(|| metadata_to_fb(fbb, schema.metadata())); + + let mut builder = crate::SchemaBuilder::new(fbb); builder.add_fields(fb_field_list); - builder.add_custom_metadata(fb_metadata_list); + if let Some(fb_metadata_list) = fb_metadata_list { + builder.add_custom_metadata(fb_metadata_list); + } builder.finish() } /// Convert an IPC Field to Arrow Field -impl<'a> From> for Field { - fn from(field: ipc::Field) -> Field { +impl<'a> From> for Field { + fn from(field: crate::Field) -> Field { let arrow_field = if let Some(dictionary) = field.dictionary() { Field::new_dict( field.name().unwrap(), @@ -88,30 +96,28 @@ impl<'a> From> for Field { ) }; - let mut metadata = None; + let mut metadata_map = HashMap::default(); if let Some(list) = field.custom_metadata() { - let mut metadata_map = BTreeMap::default(); for kv in list { if let (Some(k), Some(v)) = (kv.key(), kv.value()) { metadata_map.insert(k.to_string(), v.to_string()); } } - metadata = Some(metadata_map); } - arrow_field.with_metadata(metadata) + arrow_field.with_metadata(metadata_map) } } -/// Deserialize a Schema table from IPC format to Schema data type -pub fn fb_to_schema(fb: ipc::Schema) -> Schema { +/// Deserialize a Schema table from flat buffer format to Schema data type +pub fn fb_to_schema(fb: crate::Schema) -> Schema { let mut fields: Vec = vec![]; let c_fields = fb.fields().unwrap(); let len = c_fields.len(); for i in 0..len { - let c_field: ipc::Field = c_fields.get(i); + let c_field: crate::Field = c_fields.get(i); match c_field.type_type() { - ipc::Type::Decimal if fb.endianness() == ipc::Endianness::Big => { + crate::Type::Decimal if fb.endianness() == crate::Endianness::Big => { unimplemented!("Big Endian is not supported for Decimal!") } _ => (), @@ -136,25 +142,64 @@ pub fn fb_to_schema(fb: ipc::Schema) -> Schema { Schema::new_with_metadata(fields, metadata) } -/// Deserialize an IPC message into a schema -pub fn schema_from_bytes(bytes: &[u8]) -> Result { - if let Ok(ipc) = ipc::root_as_message(bytes) { +/// Try deserialize flat buffer format bytes into a schema +pub fn try_schema_from_flatbuffer_bytes(bytes: &[u8]) -> Result { + if let Ok(ipc) = crate::root_as_message(bytes) { if let Some(schema) = ipc.header_as_schema().map(fb_to_schema) { Ok(schema) } else { - Err(ArrowError::IoError( + Err(ArrowError::ParseError( "Unable to get head as schema".to_string(), )) } } else { - Err(ArrowError::IoError( + Err(ArrowError::ParseError( "Unable to get root as message".to_string(), )) } } +/// Try deserialize the IPC format bytes into a schema +pub fn try_schema_from_ipc_buffer(buffer: &[u8]) -> Result { + // There are two protocol types: https://issues.apache.org/jira/browse/ARROW-6313 + // The original protocol is: + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + // The latest version of protocol is: + // The schema of the dataset in its IPC form: + // 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix + // 4 bytes - the byte length of the payload + // a flatbuffer Message whose header is the Schema + if buffer.len() >= 4 { + // check continuation maker + let continuation_maker = &buffer[0..4]; + let begin_offset: usize = if continuation_maker.eq(&CONTINUATION_MARKER) { + // 4 bytes: CONTINUATION_MARKER + // 4 bytes: length + // buffer + 4 + } else { + // backward compatibility for buffer without the continuation maker + // 4 bytes: length + // buffer + 0 + }; + let msg = size_prefixed_root_as_message(&buffer[begin_offset..]).map_err(|err| { + ArrowError::ParseError(format!("Unable to convert flight info to a message: {err}")) + })?; + let ipc_schema = msg.header_as_schema().ok_or_else(|| { + ArrowError::ParseError("Unable to convert flight info to a schema".to_string()) + })?; + Ok(fb_to_schema(ipc_schema)) + } else { + Err(ArrowError::ParseError( + "The buffer length is less than 4 and missing the continuation maker or length of buffer".to_string() + )) + } +} + /// Get the Arrow data type from the flatbuffer Field table -pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataType { +pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> DataType { if let Some(dictionary) = field.dictionary() { if may_be_dictionary { let int = dictionary.indexType().unwrap(); @@ -177,9 +222,9 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT } match field.type_type() { - ipc::Type::Null => DataType::Null, - ipc::Type::Bool => DataType::Boolean, - ipc::Type::Int => { + crate::Type::Null => DataType::Null, + crate::Type::Bool => DataType::Boolean, + crate::Type::Int => { let int = field.type_as_int().unwrap(); match (int.bitWidth(), int.is_signed()) { (8, true) => DataType::Int8, @@ -196,129 +241,130 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT ), } } - ipc::Type::Binary => DataType::Binary, - ipc::Type::LargeBinary => DataType::LargeBinary, - ipc::Type::Utf8 => DataType::Utf8, - ipc::Type::LargeUtf8 => DataType::LargeUtf8, - ipc::Type::FixedSizeBinary => { + crate::Type::Binary => DataType::Binary, + crate::Type::LargeBinary => DataType::LargeBinary, + crate::Type::Utf8 => DataType::Utf8, + crate::Type::LargeUtf8 => DataType::LargeUtf8, + crate::Type::FixedSizeBinary => { let fsb = field.type_as_fixed_size_binary().unwrap(); DataType::FixedSizeBinary(fsb.byteWidth()) } - ipc::Type::FloatingPoint => { + crate::Type::FloatingPoint => { let float = field.type_as_floating_point().unwrap(); match float.precision() { - ipc::Precision::HALF => DataType::Float16, - ipc::Precision::SINGLE => DataType::Float32, - ipc::Precision::DOUBLE => DataType::Float64, - z => panic!("FloatingPoint type with precision of {:?} not supported", z), + crate::Precision::HALF => DataType::Float16, + crate::Precision::SINGLE => DataType::Float32, + crate::Precision::DOUBLE => DataType::Float64, + z => panic!("FloatingPoint type with precision of {z:?} not supported"), } } - ipc::Type::Date => { + crate::Type::Date => { let date = field.type_as_date().unwrap(); match date.unit() { - ipc::DateUnit::DAY => DataType::Date32, - ipc::DateUnit::MILLISECOND => DataType::Date64, - z => panic!("Date type with unit of {:?} not supported", z), + crate::DateUnit::DAY => DataType::Date32, + crate::DateUnit::MILLISECOND => DataType::Date64, + z => panic!("Date type with unit of {z:?} not supported"), } } - ipc::Type::Time => { + crate::Type::Time => { let time = field.type_as_time().unwrap(); match (time.bitWidth(), time.unit()) { - (32, ipc::TimeUnit::SECOND) => DataType::Time32(TimeUnit::Second), - (32, ipc::TimeUnit::MILLISECOND) => { - DataType::Time32(TimeUnit::Millisecond) - } - (64, ipc::TimeUnit::MICROSECOND) => { - DataType::Time64(TimeUnit::Microsecond) - } - (64, ipc::TimeUnit::NANOSECOND) => DataType::Time64(TimeUnit::Nanosecond), + (32, crate::TimeUnit::SECOND) => DataType::Time32(TimeUnit::Second), + (32, crate::TimeUnit::MILLISECOND) => DataType::Time32(TimeUnit::Millisecond), + (64, crate::TimeUnit::MICROSECOND) => DataType::Time64(TimeUnit::Microsecond), + (64, crate::TimeUnit::NANOSECOND) => DataType::Time64(TimeUnit::Nanosecond), z => panic!( "Time type with bit width of {} and unit of {:?} not supported", z.0, z.1 ), } } - ipc::Type::Timestamp => { + crate::Type::Timestamp => { let timestamp = field.type_as_timestamp().unwrap(); - let timezone: Option = timestamp.timezone().map(|tz| tz.to_string()); + let timezone: Option<_> = timestamp.timezone().map(|tz| tz.into()); match timestamp.unit() { - ipc::TimeUnit::SECOND => DataType::Timestamp(TimeUnit::Second, timezone), - ipc::TimeUnit::MILLISECOND => { + crate::TimeUnit::SECOND => DataType::Timestamp(TimeUnit::Second, timezone), + crate::TimeUnit::MILLISECOND => { DataType::Timestamp(TimeUnit::Millisecond, timezone) } - ipc::TimeUnit::MICROSECOND => { + crate::TimeUnit::MICROSECOND => { DataType::Timestamp(TimeUnit::Microsecond, timezone) } - ipc::TimeUnit::NANOSECOND => { - DataType::Timestamp(TimeUnit::Nanosecond, timezone) - } - z => panic!("Timestamp type with unit of {:?} not supported", z), + crate::TimeUnit::NANOSECOND => DataType::Timestamp(TimeUnit::Nanosecond, timezone), + z => panic!("Timestamp type with unit of {z:?} not supported"), } } - ipc::Type::Interval => { + crate::Type::Interval => { let interval = field.type_as_interval().unwrap(); match interval.unit() { - ipc::IntervalUnit::YEAR_MONTH => { - DataType::Interval(IntervalUnit::YearMonth) - } - ipc::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime), - ipc::IntervalUnit::MONTH_DAY_NANO => { + crate::IntervalUnit::YEAR_MONTH => DataType::Interval(IntervalUnit::YearMonth), + crate::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime), + crate::IntervalUnit::MONTH_DAY_NANO => { DataType::Interval(IntervalUnit::MonthDayNano) } - z => panic!("Interval type with unit of {:?} unsupported", z), + z => panic!("Interval type with unit of {z:?} unsupported"), } } - ipc::Type::Duration => { + crate::Type::Duration => { let duration = field.type_as_duration().unwrap(); match duration.unit() { - ipc::TimeUnit::SECOND => DataType::Duration(TimeUnit::Second), - ipc::TimeUnit::MILLISECOND => DataType::Duration(TimeUnit::Millisecond), - ipc::TimeUnit::MICROSECOND => DataType::Duration(TimeUnit::Microsecond), - ipc::TimeUnit::NANOSECOND => DataType::Duration(TimeUnit::Nanosecond), - z => panic!("Duration type with unit of {:?} unsupported", z), + crate::TimeUnit::SECOND => DataType::Duration(TimeUnit::Second), + crate::TimeUnit::MILLISECOND => DataType::Duration(TimeUnit::Millisecond), + crate::TimeUnit::MICROSECOND => DataType::Duration(TimeUnit::Microsecond), + crate::TimeUnit::NANOSECOND => DataType::Duration(TimeUnit::Nanosecond), + z => panic!("Duration type with unit of {z:?} unsupported"), } } - ipc::Type::List => { + crate::Type::List => { let children = field.children().unwrap(); if children.len() != 1 { panic!("expect a list to have one child") } - DataType::List(Box::new(children.get(0).into())) + DataType::List(Arc::new(children.get(0).into())) } - ipc::Type::LargeList => { + crate::Type::LargeList => { let children = field.children().unwrap(); if children.len() != 1 { panic!("expect a large list to have one child") } - DataType::LargeList(Box::new(children.get(0).into())) + DataType::LargeList(Arc::new(children.get(0).into())) } - ipc::Type::FixedSizeList => { + crate::Type::FixedSizeList => { let children = field.children().unwrap(); if children.len() != 1 { panic!("expect a list to have one child") } let fsl = field.type_as_fixed_size_list().unwrap(); - DataType::FixedSizeList(Box::new(children.get(0).into()), fsl.listSize()) + DataType::FixedSizeList(Arc::new(children.get(0).into()), fsl.listSize()) } - ipc::Type::Struct_ => { - let mut fields = vec![]; - if let Some(children) = field.children() { - for i in 0..children.len() { - fields.push(children.get(i).into()); - } + crate::Type::Struct_ => { + let fields = match field.children() { + Some(children) => children.iter().map(Field::from).collect(), + None => Fields::empty(), }; - DataType::Struct(fields) } - ipc::Type::Map => { + crate::Type::RunEndEncoded => { + let children = field.children().unwrap(); + if children.len() != 2 { + panic!( + "RunEndEncoded type should have exactly two children. Found {}", + children.len() + ) + } + let run_ends_field = children.get(0).into(); + let values_field = children.get(1).into(); + DataType::RunEndEncoded(Arc::new(run_ends_field), Arc::new(values_field)) + } + crate::Type::Map => { let map = field.type_as_map().unwrap(); let children = field.children().unwrap(); if children.len() != 1 { panic!("expect a map to have one child") } - DataType::Map(Box::new(children.get(0).into()), map.keysSorted()) + DataType::Map(Arc::new(children.get(0).into()), map.keysSorted()) } - ipc::Type::Decimal => { + crate::Type::Decimal => { let fsb = field.type_as_decimal().unwrap(); let bit_width = fsb.bitWidth(); if bit_width == 128 { @@ -332,66 +378,55 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT fsb.scale().try_into().unwrap(), ) } else { - panic!("Unexpected decimal bit width {}", bit_width) + panic!("Unexpected decimal bit width {bit_width}") } } - ipc::Type::Union => { + crate::Type::Union => { let union = field.type_as_union().unwrap(); let union_mode = match union.mode() { - ipc::UnionMode::Dense => UnionMode::Dense, - ipc::UnionMode::Sparse => UnionMode::Sparse, - mode => panic!("Unexpected union mode: {:?}", mode), + crate::UnionMode::Dense => UnionMode::Dense, + crate::UnionMode::Sparse => UnionMode::Sparse, + mode => panic!("Unexpected union mode: {mode:?}"), }; let mut fields = vec![]; if let Some(children) = field.children() { for i in 0..children.len() { - fields.push(children.get(i).into()); + fields.push(Field::from(children.get(i))); } }; - let type_ids: Vec = match union.typeIds() { - None => (0_i8..fields.len() as i8).collect(), - Some(ids) => ids.iter().map(|i| i as i8).collect(), + let fields = match union.typeIds() { + None => UnionFields::new(0_i8..fields.len() as i8, fields), + Some(ids) => UnionFields::new(ids.iter().map(|i| i as i8), fields), }; - DataType::Union(fields, type_ids, union_mode) + DataType::Union(fields, union_mode) } t => unimplemented!("Type {:?} not supported", t), } } pub(crate) struct FBFieldType<'b> { - pub(crate) type_type: ipc::Type, + pub(crate) type_type: crate::Type, pub(crate) type_: WIPOffset, - pub(crate) children: Option>>>>, + pub(crate) children: Option>>>>, } /// Create an IPC Field from an Arrow Field pub(crate) fn build_field<'a>( fbb: &mut FlatBufferBuilder<'a>, field: &Field, -) -> WIPOffset> { +) -> WIPOffset> { // Optional custom metadata. let mut fb_metadata = None; - if let Some(metadata) = field.metadata() { - if !metadata.is_empty() { - let mut kv_vec = vec![]; - for (k, v) in metadata { - let kv_args = ipc::KeyValueArgs { - key: Some(fbb.create_string(k.as_str())), - value: Some(fbb.create_string(v.as_str())), - }; - let kv_offset = ipc::KeyValue::create(fbb, &kv_args); - kv_vec.push(kv_offset); - } - fb_metadata = Some(fbb.create_vector(&kv_vec)); - } + if !field.metadata().is_empty() { + fb_metadata = Some(metadata_to_fb(fbb, field.metadata())); }; let fb_field_name = fbb.create_string(field.name().as_str()); - let field_type = get_fb_field_type(field.data_type(), field.is_nullable(), fbb); + let field_type = get_fb_field_type(field.data_type(), fbb); let fb_dictionary = if let Dictionary(index_type, _) = field.data_type() { Some(get_fb_dictionary( @@ -408,7 +443,7 @@ pub(crate) fn build_field<'a>( None }; - let mut field_builder = ipc::FieldBuilder::new(fbb); + let mut field_builder = crate::FieldBuilder::new(fbb); field_builder.add_name(fb_field_name); if let Some(dictionary) = fb_dictionary { field_builder.add_dictionary(dictionary) @@ -431,26 +466,25 @@ pub(crate) fn build_field<'a>( /// Get the IPC type of a data type pub(crate) fn get_fb_field_type<'a>( data_type: &DataType, - is_nullable: bool, fbb: &mut FlatBufferBuilder<'a>, ) -> FBFieldType<'a> { // some IPC implementations expect an empty list for child data, instead of a null value. // An empty field list is thus returned for primitive types - let empty_fields: Vec> = vec![]; + let empty_fields: Vec> = vec![]; match data_type { Null => FBFieldType { - type_type: ipc::Type::Null, - type_: ipc::NullBuilder::new(fbb).finish().as_union_value(), + type_type: crate::Type::Null, + type_: crate::NullBuilder::new(fbb).finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), }, Boolean => FBFieldType { - type_type: ipc::Type::Bool, - type_: ipc::BoolBuilder::new(fbb).finish().as_union_value(), + type_type: crate::Type::Bool, + type_: crate::BoolBuilder::new(fbb).finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), }, UInt8 | UInt16 | UInt32 | UInt64 => { let children = fbb.create_vector(&empty_fields[..]); - let mut builder = ipc::IntBuilder::new(fbb); + let mut builder = crate::IntBuilder::new(fbb); builder.add_is_signed(false); match data_type { UInt8 => builder.add_bitWidth(8), @@ -460,14 +494,14 @@ pub(crate) fn get_fb_field_type<'a>( _ => {} }; FBFieldType { - type_type: ipc::Type::Int, + type_type: crate::Type::Int, type_: builder.finish().as_union_value(), children: Some(children), } } Int8 | Int16 | Int32 | Int64 => { let children = fbb.create_vector(&empty_fields[..]); - let mut builder = ipc::IntBuilder::new(fbb); + let mut builder = crate::IntBuilder::new(fbb); builder.add_is_signed(true); match data_type { Int8 => builder.add_bitWidth(8), @@ -477,144 +511,146 @@ pub(crate) fn get_fb_field_type<'a>( _ => {} }; FBFieldType { - type_type: ipc::Type::Int, + type_type: crate::Type::Int, type_: builder.finish().as_union_value(), children: Some(children), } } Float16 | Float32 | Float64 => { let children = fbb.create_vector(&empty_fields[..]); - let mut builder = ipc::FloatingPointBuilder::new(fbb); + let mut builder = crate::FloatingPointBuilder::new(fbb); match data_type { - Float16 => builder.add_precision(ipc::Precision::HALF), - Float32 => builder.add_precision(ipc::Precision::SINGLE), - Float64 => builder.add_precision(ipc::Precision::DOUBLE), + Float16 => builder.add_precision(crate::Precision::HALF), + Float32 => builder.add_precision(crate::Precision::SINGLE), + Float64 => builder.add_precision(crate::Precision::DOUBLE), _ => {} }; FBFieldType { - type_type: ipc::Type::FloatingPoint, + type_type: crate::Type::FloatingPoint, type_: builder.finish().as_union_value(), children: Some(children), } } Binary => FBFieldType { - type_type: ipc::Type::Binary, - type_: ipc::BinaryBuilder::new(fbb).finish().as_union_value(), + type_type: crate::Type::Binary, + type_: crate::BinaryBuilder::new(fbb).finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), }, LargeBinary => FBFieldType { - type_type: ipc::Type::LargeBinary, - type_: ipc::LargeBinaryBuilder::new(fbb).finish().as_union_value(), + type_type: crate::Type::LargeBinary, + type_: crate::LargeBinaryBuilder::new(fbb) + .finish() + .as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), }, Utf8 => FBFieldType { - type_type: ipc::Type::Utf8, - type_: ipc::Utf8Builder::new(fbb).finish().as_union_value(), + type_type: crate::Type::Utf8, + type_: crate::Utf8Builder::new(fbb).finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), }, LargeUtf8 => FBFieldType { - type_type: ipc::Type::LargeUtf8, - type_: ipc::LargeUtf8Builder::new(fbb).finish().as_union_value(), + type_type: crate::Type::LargeUtf8, + type_: crate::LargeUtf8Builder::new(fbb).finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), }, FixedSizeBinary(len) => { - let mut builder = ipc::FixedSizeBinaryBuilder::new(fbb); - builder.add_byteWidth(*len as i32); + let mut builder = crate::FixedSizeBinaryBuilder::new(fbb); + builder.add_byteWidth(*len); FBFieldType { - type_type: ipc::Type::FixedSizeBinary, + type_type: crate::Type::FixedSizeBinary, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } Date32 => { - let mut builder = ipc::DateBuilder::new(fbb); - builder.add_unit(ipc::DateUnit::DAY); + let mut builder = crate::DateBuilder::new(fbb); + builder.add_unit(crate::DateUnit::DAY); FBFieldType { - type_type: ipc::Type::Date, + type_type: crate::Type::Date, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } Date64 => { - let mut builder = ipc::DateBuilder::new(fbb); - builder.add_unit(ipc::DateUnit::MILLISECOND); + let mut builder = crate::DateBuilder::new(fbb); + builder.add_unit(crate::DateUnit::MILLISECOND); FBFieldType { - type_type: ipc::Type::Date, + type_type: crate::Type::Date, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } Time32(unit) | Time64(unit) => { - let mut builder = ipc::TimeBuilder::new(fbb); + let mut builder = crate::TimeBuilder::new(fbb); match unit { TimeUnit::Second => { builder.add_bitWidth(32); - builder.add_unit(ipc::TimeUnit::SECOND); + builder.add_unit(crate::TimeUnit::SECOND); } TimeUnit::Millisecond => { builder.add_bitWidth(32); - builder.add_unit(ipc::TimeUnit::MILLISECOND); + builder.add_unit(crate::TimeUnit::MILLISECOND); } TimeUnit::Microsecond => { builder.add_bitWidth(64); - builder.add_unit(ipc::TimeUnit::MICROSECOND); + builder.add_unit(crate::TimeUnit::MICROSECOND); } TimeUnit::Nanosecond => { builder.add_bitWidth(64); - builder.add_unit(ipc::TimeUnit::NANOSECOND); + builder.add_unit(crate::TimeUnit::NANOSECOND); } } FBFieldType { - type_type: ipc::Type::Time, + type_type: crate::Type::Time, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } Timestamp(unit, tz) => { - let tz = tz.clone().unwrap_or_default(); - let tz_str = fbb.create_string(tz.as_str()); - let mut builder = ipc::TimestampBuilder::new(fbb); + let tz = tz.as_deref().unwrap_or_default(); + let tz_str = fbb.create_string(tz); + let mut builder = crate::TimestampBuilder::new(fbb); let time_unit = match unit { - TimeUnit::Second => ipc::TimeUnit::SECOND, - TimeUnit::Millisecond => ipc::TimeUnit::MILLISECOND, - TimeUnit::Microsecond => ipc::TimeUnit::MICROSECOND, - TimeUnit::Nanosecond => ipc::TimeUnit::NANOSECOND, + TimeUnit::Second => crate::TimeUnit::SECOND, + TimeUnit::Millisecond => crate::TimeUnit::MILLISECOND, + TimeUnit::Microsecond => crate::TimeUnit::MICROSECOND, + TimeUnit::Nanosecond => crate::TimeUnit::NANOSECOND, }; builder.add_unit(time_unit); if !tz.is_empty() { builder.add_timezone(tz_str); } FBFieldType { - type_type: ipc::Type::Timestamp, + type_type: crate::Type::Timestamp, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } Interval(unit) => { - let mut builder = ipc::IntervalBuilder::new(fbb); + let mut builder = crate::IntervalBuilder::new(fbb); let interval_unit = match unit { - IntervalUnit::YearMonth => ipc::IntervalUnit::YEAR_MONTH, - IntervalUnit::DayTime => ipc::IntervalUnit::DAY_TIME, - IntervalUnit::MonthDayNano => ipc::IntervalUnit::MONTH_DAY_NANO, + IntervalUnit::YearMonth => crate::IntervalUnit::YEAR_MONTH, + IntervalUnit::DayTime => crate::IntervalUnit::DAY_TIME, + IntervalUnit::MonthDayNano => crate::IntervalUnit::MONTH_DAY_NANO, }; builder.add_unit(interval_unit); FBFieldType { - type_type: ipc::Type::Interval, + type_type: crate::Type::Interval, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } Duration(unit) => { - let mut builder = ipc::DurationBuilder::new(fbb); + let mut builder = crate::DurationBuilder::new(fbb); let time_unit = match unit { - TimeUnit::Second => ipc::TimeUnit::SECOND, - TimeUnit::Millisecond => ipc::TimeUnit::MILLISECOND, - TimeUnit::Microsecond => ipc::TimeUnit::MICROSECOND, - TimeUnit::Nanosecond => ipc::TimeUnit::NANOSECOND, + TimeUnit::Second => crate::TimeUnit::SECOND, + TimeUnit::Millisecond => crate::TimeUnit::MILLISECOND, + TimeUnit::Microsecond => crate::TimeUnit::MICROSECOND, + TimeUnit::Nanosecond => crate::TimeUnit::NANOSECOND, }; builder.add_unit(time_unit); FBFieldType { - type_type: ipc::Type::Duration, + type_type: crate::Type::Duration, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } @@ -622,25 +658,25 @@ pub(crate) fn get_fb_field_type<'a>( List(ref list_type) => { let child = build_field(fbb, list_type); FBFieldType { - type_type: ipc::Type::List, - type_: ipc::ListBuilder::new(fbb).finish().as_union_value(), + type_type: crate::Type::List, + type_: crate::ListBuilder::new(fbb).finish().as_union_value(), children: Some(fbb.create_vector(&[child])), } } LargeList(ref list_type) => { let child = build_field(fbb, list_type); FBFieldType { - type_type: ipc::Type::LargeList, - type_: ipc::LargeListBuilder::new(fbb).finish().as_union_value(), + type_type: crate::Type::LargeList, + type_: crate::LargeListBuilder::new(fbb).finish().as_union_value(), children: Some(fbb.create_vector(&[child])), } } FixedSizeList(ref list_type, len) => { let child = build_field(fbb, list_type); - let mut builder = ipc::FixedSizeListBuilder::new(fbb); - builder.add_listSize(*len as i32); + let mut builder = crate::FixedSizeListBuilder::new(fbb); + builder.add_listSize(*len); FBFieldType { - type_type: ipc::Type::FixedSizeList, + type_type: crate::Type::FixedSizeList, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&[child])), } @@ -652,17 +688,29 @@ pub(crate) fn get_fb_field_type<'a>( children.push(build_field(fbb, field)); } FBFieldType { - type_type: ipc::Type::Struct_, - type_: ipc::Struct_Builder::new(fbb).finish().as_union_value(), + type_type: crate::Type::Struct_, + type_: crate::Struct_Builder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&children[..])), + } + } + RunEndEncoded(run_ends, values) => { + let run_ends_field = build_field(fbb, run_ends); + let values_field = build_field(fbb, values); + let children = [run_ends_field, values_field]; + FBFieldType { + type_type: crate::Type::RunEndEncoded, + type_: crate::RunEndEncodedBuilder::new(fbb) + .finish() + .as_union_value(), children: Some(fbb.create_vector(&children[..])), } } Map(map_field, keys_sorted) => { let child = build_field(fbb, map_field); - let mut field_type = ipc::MapBuilder::new(fbb); + let mut field_type = crate::MapBuilder::new(fbb); field_type.add_keysSorted(*keys_sorted); FBFieldType { - type_type: ipc::Type::Map, + type_type: crate::Type::Map, type_: field_type.finish().as_union_value(), children: Some(fbb.create_vector(&[child])), } @@ -671,49 +719,49 @@ pub(crate) fn get_fb_field_type<'a>( // In this library, the dictionary "type" is a logical construct. Here we // pass through to the value type, as we've already captured the index // type in the DictionaryEncoding metadata in the parent field - get_fb_field_type(value_type, is_nullable, fbb) + get_fb_field_type(value_type, fbb) } Decimal128(precision, scale) => { - let mut builder = ipc::DecimalBuilder::new(fbb); + let mut builder = crate::DecimalBuilder::new(fbb); builder.add_precision(*precision as i32); builder.add_scale(*scale as i32); builder.add_bitWidth(128); FBFieldType { - type_type: ipc::Type::Decimal, + type_type: crate::Type::Decimal, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } Decimal256(precision, scale) => { - let mut builder = ipc::DecimalBuilder::new(fbb); + let mut builder = crate::DecimalBuilder::new(fbb); builder.add_precision(*precision as i32); builder.add_scale(*scale as i32); builder.add_bitWidth(256); FBFieldType { - type_type: ipc::Type::Decimal, + type_type: crate::Type::Decimal, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&empty_fields[..])), } } - Union(fields, type_ids, mode) => { + Union(fields, mode) => { let mut children = vec![]; - for field in fields { + for (_, field) in fields.iter() { children.push(build_field(fbb, field)); } let union_mode = match mode { - UnionMode::Sparse => ipc::UnionMode::Sparse, - UnionMode::Dense => ipc::UnionMode::Dense, + UnionMode::Sparse => crate::UnionMode::Sparse, + UnionMode::Dense => crate::UnionMode::Dense, }; - let fbb_type_ids = fbb - .create_vector(&type_ids.iter().map(|t| *t as i32).collect::>()); - let mut builder = ipc::UnionBuilder::new(fbb); + let fbb_type_ids = + fbb.create_vector(&fields.iter().map(|(t, _)| t as i32).collect::>()); + let mut builder = crate::UnionBuilder::new(fbb); builder.add_mode(union_mode); builder.add_typeIds(fbb_type_ids); FBFieldType { - type_type: ipc::Type::Union, + type_type: crate::Type::Union, type_: builder.finish().as_union_value(), children: Some(fbb.create_vector(&children[..])), } @@ -727,10 +775,10 @@ pub(crate) fn get_fb_dictionary<'a>( dict_id: i64, dict_is_ordered: bool, fbb: &mut FlatBufferBuilder<'a>, -) -> WIPOffset> { +) -> WIPOffset> { // We assume that the dictionary index type (as an integer) has already been // validated elsewhere, and can safely assume we are dealing with integers - let mut index_builder = ipc::IntBuilder::new(fbb); + let mut index_builder = crate::IntBuilder::new(fbb); match *index_type { Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), @@ -748,7 +796,7 @@ pub(crate) fn get_fb_dictionary<'a>( let index_builder = index_builder.finish(); - let mut builder = ipc::DictionaryEncodingBuilder::new(fbb); + let mut builder = crate::DictionaryEncodingBuilder::new(fbb); builder.add_id(dict_id); builder.add_indexType(index_builder); builder.add_isOrdered(dict_is_ordered); @@ -759,7 +807,6 @@ pub(crate) fn get_fb_dictionary<'a>( #[cfg(test)] mod tests { use super::*; - use crate::datatypes::{DataType, Field, Schema, UnionMode}; #[test] fn convert_schema_round_trip() { @@ -767,13 +814,13 @@ mod tests { .iter() .cloned() .collect(); - let field_md: BTreeMap = [("k".to_string(), "v".to_string())] + let field_md: HashMap = [("k".to_string(), "v".to_string())] .iter() .cloned() .collect(); let schema = Schema::new_with_metadata( vec![ - Field::new("uint8", DataType::UInt8, false).with_metadata(Some(field_md)), + Field::new("uint8", DataType::UInt8, false).with_metadata(field_md), Field::new("uint16", DataType::UInt16, true), Field::new("uint32", DataType::UInt32, false), Field::new("uint64", DataType::UInt64, true), @@ -804,10 +851,7 @@ mod tests { ), Field::new( "timestamp[us]", - DataType::Timestamp( - TimeUnit::Microsecond, - Some("Africa/Johannesburg".to_string()), - ), + DataType::Timestamp(TimeUnit::Microsecond, Some("Africa/Johannesburg".into())), false, ), Field::new( @@ -832,141 +876,123 @@ mod tests { ), Field::new("utf8", DataType::Utf8, false), Field::new("binary", DataType::Binary, false), - Field::new( - "list[u8]", - DataType::List(Box::new(Field::new("item", DataType::UInt8, false))), - true, - ), - Field::new( + Field::new_list("list[u8]", Field::new("item", DataType::UInt8, false), true), + Field::new_list( "list[struct]", - DataType::List(Box::new(Field::new( + Field::new_struct( "struct", - DataType::Struct(vec![ - Field::new("float32", DataType::UInt8, false), - Field::new("int32", DataType::Int32, true), - Field::new("bool", DataType::Boolean, true), - ]), + vec![ + Field::new("float32", UInt8, false), + Field::new("int32", Int32, true), + Field::new("bool", Boolean, true), + ], true, - ))), + ), false, ), - Field::new( + Field::new_struct( "struct>", - DataType::Struct(vec![Field::new( + vec![Field::new( "dictionary", - DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ), + Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), false, - )]), + )], false, ), - Field::new( + Field::new_struct( "struct]>]>", - DataType::Struct(vec![ + vec![ Field::new("int64", DataType::Int64, true), - Field::new( + Field::new_list( "list[struct]>]", - DataType::List(Box::new(Field::new( + Field::new_struct( "struct", - DataType::Struct(vec![ + vec![ Field::new("date32", DataType::Date32, true), - Field::new( + Field::new_list( "list[struct<>]", - DataType::List(Box::new(Field::new( + Field::new( "struct", - DataType::Struct(vec![]), + DataType::Struct(Fields::empty()), false, - ))), + ), false, ), - ]), + ], false, - ))), + ), false, ), - ]), + ], false, ), - Field::new( + Field::new_union( "union]>]>", - DataType::Union( - vec![ - Field::new("int64", DataType::Int64, true), - Field::new( - "list[union]>]", - DataType::List(Box::new(Field::new( - "union]>", - DataType::Union( - vec![ - Field::new("date32", DataType::Date32, true), - Field::new( - "list[union<>]", - DataType::List(Box::new(Field::new( - "union", - DataType::Union( - vec![], - vec![], - UnionMode::Sparse, - ), - false, - ))), - false, + vec![0, 1], + vec![ + Field::new("int64", DataType::Int64, true), + Field::new_list( + "list[union]>]", + Field::new_union( + "union]>", + vec![0, 1], + vec![ + Field::new("date32", DataType::Date32, true), + Field::new_list( + "list[union<>]", + Field::new( + "union", + DataType::Union( + UnionFields::empty(), + UnionMode::Sparse, ), - ], - vec![0, 1], - UnionMode::Dense, + false, + ), + false, ), - false, - ))), - false, + ], + UnionMode::Dense, ), - ], - vec![0, 1], - UnionMode::Sparse, - ), - false, + false, + ), + ], + UnionMode::Sparse, ), - Field::new("struct<>", DataType::Struct(vec![]), true), + Field::new("struct<>", DataType::Struct(Fields::empty()), true), Field::new( "union<>", - DataType::Union(vec![], vec![], UnionMode::Dense), + DataType::Union(UnionFields::empty(), UnionMode::Dense), true, ), Field::new( "union<>", - DataType::Union(vec![], vec![], UnionMode::Sparse), + DataType::Union(UnionFields::empty(), UnionMode::Sparse), true, ), Field::new( "union", DataType::Union( - vec![ - Field::new("int32", DataType::Int32, true), - Field::new("utf8", DataType::Utf8, true), - ], - vec![2, 3], // non-default type ids + UnionFields::new( + vec![2, 3], // non-default type ids + vec![ + Field::new("int32", DataType::Int32, true), + Field::new("utf8", DataType::Utf8, true), + ], + ), UnionMode::Dense, ), true, ), Field::new_dict( "dictionary", - DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), true, 123, true, ), Field::new_dict( "dictionary", - DataType::Dictionary( - Box::new(DataType::UInt8), - Box::new(DataType::UInt32), - ), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), true, 123, true, @@ -979,39 +1005,50 @@ mod tests { let fb = schema_to_fb(&schema); // read back fields - let ipc = ipc::root_as_schema(fb.finished_data()).unwrap(); + let ipc = crate::root_as_schema(fb.finished_data()).unwrap(); let schema2 = fb_to_schema(ipc); assert_eq!(schema, schema2); } #[test] fn schema_from_bytes() { - // bytes of a schema generated from python (0.14.0), saved as an `ipc::Message`. - // the schema is: Field("field1", DataType::UInt32, false) + // Bytes of a schema generated via following python code, using pyarrow 10.0.1: + // + // import pyarrow as pa + // schema = pa.schema([pa.field('field1', pa.uint32(), nullable=False)]) + // sink = pa.BufferOutputStream() + // with pa.ipc.new_stream(sink, schema) as writer: + // pass + // # stripping continuation & length prefix & suffix bytes to get only schema bytes + // [x for x in sink.getvalue().to_pybytes()][8:-8] let bytes: Vec = vec![ - 16, 0, 0, 0, 0, 0, 10, 0, 12, 0, 6, 0, 5, 0, 8, 0, 10, 0, 0, 0, 0, 1, 3, 0, - 12, 0, 0, 0, 8, 0, 8, 0, 0, 0, 4, 0, 8, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 20, - 0, 0, 0, 16, 0, 20, 0, 8, 0, 0, 0, 7, 0, 12, 0, 0, 0, 16, 0, 16, 0, 0, 0, 0, - 0, 0, 2, 32, 0, 0, 0, 20, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 8, 0, - 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49, 0, 0, - 0, 0, 0, 0, + 16, 0, 0, 0, 0, 0, 10, 0, 12, 0, 6, 0, 5, 0, 8, 0, 10, 0, 0, 0, 0, 1, 4, 0, 12, 0, 0, + 0, 8, 0, 8, 0, 0, 0, 4, 0, 8, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 20, 0, 0, 0, 16, 0, 20, + 0, 8, 0, 0, 0, 7, 0, 12, 0, 0, 0, 16, 0, 16, 0, 0, 0, 0, 0, 0, 2, 16, 0, 0, 0, 32, 0, + 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49, 0, 0, 0, 0, 6, + 0, 8, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, ]; - let ipc = ipc::root_as_message(&bytes[..]).unwrap(); + let ipc = crate::root_as_message(&bytes).unwrap(); let schema = ipc.header_as_schema().unwrap(); - // a message generated from Rust, same as the Python one - let bytes: Vec = vec![ - 16, 0, 0, 0, 0, 0, 10, 0, 14, 0, 12, 0, 11, 0, 4, 0, 10, 0, 0, 0, 20, 0, 0, - 0, 0, 0, 0, 1, 3, 0, 10, 0, 12, 0, 0, 0, 8, 0, 4, 0, 10, 0, 0, 0, 8, 0, 0, 0, - 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 16, 0, 0, 0, 12, 0, 18, 0, 12, 0, 0, 0, - 11, 0, 4, 0, 12, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 2, 20, 0, 0, 0, 0, 0, 6, 0, - 8, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49, - 0, 0, - ]; - let ipc2 = ipc::root_as_message(&bytes[..]).unwrap(); - let schema2 = ipc.header_as_schema().unwrap(); + // generate same message with Rust + let data_gen = crate::writer::IpcDataGenerator::default(); + let arrow_schema = Schema::new(vec![Field::new("field1", DataType::UInt32, false)]); + let bytes = data_gen + .schema_to_bytes(&arrow_schema, &crate::writer::IpcWriteOptions::default()) + .ipc_message; + + let ipc2 = crate::root_as_message(&bytes).unwrap(); + let schema2 = ipc2.header_as_schema().unwrap(); + + // can't compare schema directly as it compares the underlying bytes, which can differ + assert!(schema.custom_metadata().is_none()); + assert!(schema2.custom_metadata().is_none()); + assert_eq!(schema.endianness(), schema2.endianness()); + assert!(schema.features().is_none()); + assert!(schema2.features().is_none()); + assert_eq!(fb_to_schema(schema), fb_to_schema(schema2)); - assert_eq!(schema, schema2); assert_eq!(ipc.version(), ipc2.version()); assert_eq!(ipc.header_type(), ipc2.header_type()); assert_eq!(ipc.bodyLength(), ipc2.bodyLength()); diff --git a/arrow/src/ipc/gen/File.rs b/arrow-ipc/src/gen/File.rs similarity index 69% rename from arrow/src/ipc/gen/File.rs rename to arrow-ipc/src/gen/File.rs index 04cbc6441377..c0c2fb183237 100644 --- a/arrow/src/ipc/gen/File.rs +++ b/arrow-ipc/src/gen/File.rs @@ -18,7 +18,7 @@ #![allow(dead_code)] #![allow(unused_imports)] -use crate::ipc::gen::Schema::*; +use crate::gen::Schema::*; use flatbuffers::EndianScalar; use std::{cmp::Ordering, mem}; // automatically generated by the FlatBuffers compiler, do not modify @@ -27,8 +27,13 @@ use std::{cmp::Ordering, mem}; #[repr(transparent)] #[derive(Clone, Copy, PartialEq)] pub struct Block(pub [u8; 24]); -impl std::fmt::Debug for Block { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Default for Block { + fn default() -> Self { + Self([0; 24]) + } +} +impl core::fmt::Debug for Block { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { f.debug_struct("Block") .field("offset", &self.offset()) .field("metaDataLength", &self.metaDataLength()) @@ -38,39 +43,25 @@ impl std::fmt::Debug for Block { } impl flatbuffers::SimpleToVerifyInSlice for Block {} -impl flatbuffers::SafeSliceAccess for Block {} impl<'a> flatbuffers::Follow<'a> for Block { type Inner = &'a Block; #[inline] - fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { <&'a Block>::follow(buf, loc) } } impl<'a> flatbuffers::Follow<'a> for &'a Block { type Inner = &'a Block; #[inline] - fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { flatbuffers::follow_cast_ref::(buf, loc) } } impl<'b> flatbuffers::Push for Block { type Output = Block; #[inline] - fn push(&self, dst: &mut [u8], _rest: &[u8]) { - let src = unsafe { - ::std::slice::from_raw_parts(self as *const Block as *const u8, Self::size()) - }; - dst.copy_from_slice(src); - } -} -impl<'b> flatbuffers::Push for &'b Block { - type Output = Block; - - #[inline] - fn push(&self, dst: &mut [u8], _rest: &[u8]) { - let src = unsafe { - ::std::slice::from_raw_parts(*self as *const Block as *const u8, Self::size()) - }; + unsafe fn push(&self, dst: &mut [u8], _written_len: usize) { + let src = ::core::slice::from_raw_parts(self as *const Block as *const u8, Self::size()); dst.copy_from_slice(src); } } @@ -85,7 +76,8 @@ impl<'a> flatbuffers::Verifiable for Block { v.in_buffer::(pos) } } -impl Block { + +impl<'a> Block { #[allow(clippy::too_many_arguments)] pub fn new(offset: i64, metaDataLength: i32, bodyLength: i64) -> Self { let mut s = Self([0; 24]); @@ -97,50 +89,60 @@ impl Block { /// Index to the start of the RecordBlock (note this is past the Message header) pub fn offset(&self) -> i64 { - let mut mem = core::mem::MaybeUninit::::uninit(); - unsafe { + let mut mem = core::mem::MaybeUninit::<::Scalar>::uninit(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot + EndianScalar::from_little_endian(unsafe { core::ptr::copy_nonoverlapping( self.0[0..].as_ptr(), mem.as_mut_ptr() as *mut u8, - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); mem.assume_init() - } - .from_little_endian() + }) } pub fn set_offset(&mut self, x: i64) { let x_le = x.to_little_endian(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot unsafe { core::ptr::copy_nonoverlapping( - &x_le as *const i64 as *const u8, + &x_le as *const _ as *const u8, self.0[0..].as_mut_ptr(), - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); } } /// Length of the metadata pub fn metaDataLength(&self) -> i32 { - let mut mem = core::mem::MaybeUninit::::uninit(); - unsafe { + let mut mem = core::mem::MaybeUninit::<::Scalar>::uninit(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot + EndianScalar::from_little_endian(unsafe { core::ptr::copy_nonoverlapping( self.0[8..].as_ptr(), mem.as_mut_ptr() as *mut u8, - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); mem.assume_init() - } - .from_little_endian() + }) } pub fn set_metaDataLength(&mut self, x: i32) { let x_le = x.to_little_endian(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot unsafe { core::ptr::copy_nonoverlapping( - &x_le as *const i32 as *const u8, + &x_le as *const _ as *const u8, self.0[8..].as_mut_ptr(), - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); } } @@ -148,25 +150,30 @@ impl Block { /// Length of the data (this is aligned so there can be a gap between this and /// the metadata). pub fn bodyLength(&self) -> i64 { - let mut mem = core::mem::MaybeUninit::::uninit(); - unsafe { + let mut mem = core::mem::MaybeUninit::<::Scalar>::uninit(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot + EndianScalar::from_little_endian(unsafe { core::ptr::copy_nonoverlapping( self.0[16..].as_ptr(), mem.as_mut_ptr() as *mut u8, - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); mem.assume_init() - } - .from_little_endian() + }) } pub fn set_bodyLength(&mut self, x: i64) { let x_le = x.to_little_endian(); + // Safety: + // Created from a valid Table for this object + // Which contains a valid value in this slot unsafe { core::ptr::copy_nonoverlapping( - &x_le as *const i64 as *const u8, + &x_le as *const _ as *const u8, self.0[16..].as_mut_ptr(), - core::mem::size_of::(), + core::mem::size_of::<::Scalar>(), ); } } @@ -185,16 +192,22 @@ pub struct Footer<'a> { impl<'a> flatbuffers::Follow<'a> for Footer<'a> { type Inner = Footer<'a>; #[inline] - fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { Self { - _tab: flatbuffers::Table { buf, loc }, + _tab: flatbuffers::Table::new(buf, loc), } } } impl<'a> Footer<'a> { + pub const VT_VERSION: flatbuffers::VOffsetT = 4; + pub const VT_SCHEMA: flatbuffers::VOffsetT = 6; + pub const VT_DICTIONARIES: flatbuffers::VOffsetT = 8; + pub const VT_RECORDBATCHES: flatbuffers::VOffsetT = 10; + pub const VT_CUSTOM_METADATA: flatbuffers::VOffsetT = 12; + #[inline] - pub fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { Footer { _tab: table } } #[allow(unused_mut)] @@ -219,49 +232,66 @@ impl<'a> Footer<'a> { builder.finish() } - pub const VT_VERSION: flatbuffers::VOffsetT = 4; - pub const VT_SCHEMA: flatbuffers::VOffsetT = 6; - pub const VT_DICTIONARIES: flatbuffers::VOffsetT = 8; - pub const VT_RECORDBATCHES: flatbuffers::VOffsetT = 10; - pub const VT_CUSTOM_METADATA: flatbuffers::VOffsetT = 12; - #[inline] pub fn version(&self) -> MetadataVersion { - self._tab - .get::(Footer::VT_VERSION, Some(MetadataVersion::V1)) - .unwrap() + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(Footer::VT_VERSION, Some(MetadataVersion::V1)) + .unwrap() + } } #[inline] pub fn schema(&self) -> Option> { - self._tab - .get::>(Footer::VT_SCHEMA, None) + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Footer::VT_SCHEMA, None) + } } #[inline] - pub fn dictionaries(&self) -> Option<&'a [Block]> { - self._tab - .get::>>( - Footer::VT_DICTIONARIES, - None, - ) - .map(|v| v.safe_slice()) + pub fn dictionaries(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + Footer::VT_DICTIONARIES, + None, + ) + } } #[inline] - pub fn recordBatches(&self) -> Option<&'a [Block]> { - self._tab - .get::>>( - Footer::VT_RECORDBATCHES, - None, - ) - .map(|v| v.safe_slice()) + pub fn recordBatches(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + Footer::VT_RECORDBATCHES, + None, + ) + } } /// User-defined metadata #[inline] pub fn custom_metadata( &self, ) -> Option>>> { - self._tab.get::>, - >>(Footer::VT_CUSTOM_METADATA, None) + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab.get::>, + >>(Footer::VT_CUSTOM_METADATA, None) + } } } @@ -273,25 +303,21 @@ impl flatbuffers::Verifiable for Footer<'_> { ) -> Result<(), flatbuffers::InvalidFlatbuffer> { use flatbuffers::Verifiable; v.visit_table(pos)? - .visit_field::(&"version", Self::VT_VERSION, false)? - .visit_field::>( - &"schema", - Self::VT_SCHEMA, - false, - )? + .visit_field::("version", Self::VT_VERSION, false)? + .visit_field::>("schema", Self::VT_SCHEMA, false)? .visit_field::>>( - &"dictionaries", + "dictionaries", Self::VT_DICTIONARIES, false, )? .visit_field::>>( - &"recordBatches", + "recordBatches", Self::VT_RECORDBATCHES, false, )? .visit_field::>, - >>(&"custom_metadata", Self::VT_CUSTOM_METADATA, false)? + >>("custom_metadata", Self::VT_CUSTOM_METADATA, false)? .finish(); Ok(()) } @@ -302,9 +328,7 @@ pub struct FooterArgs<'a> { pub dictionaries: Option>>, pub recordBatches: Option>>, pub custom_metadata: Option< - flatbuffers::WIPOffset< - flatbuffers::Vector<'a, flatbuffers::ForwardsUOffset>>, - >, + flatbuffers::WIPOffset>>>, >, } impl<'a> Default for FooterArgs<'a> { @@ -319,6 +343,7 @@ impl<'a> Default for FooterArgs<'a> { } } } + pub struct FooterBuilder<'a: 'b, 'b> { fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a>, start_: flatbuffers::WIPOffset, @@ -326,39 +351,29 @@ pub struct FooterBuilder<'a: 'b, 'b> { impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { #[inline] pub fn add_version(&mut self, version: MetadataVersion) { - self.fbb_.push_slot::( - Footer::VT_VERSION, - version, - MetadataVersion::V1, - ); + self.fbb_ + .push_slot::(Footer::VT_VERSION, version, MetadataVersion::V1); } #[inline] pub fn add_schema(&mut self, schema: flatbuffers::WIPOffset>) { self.fbb_ - .push_slot_always::>( - Footer::VT_SCHEMA, - schema, - ); + .push_slot_always::>(Footer::VT_SCHEMA, schema); } #[inline] pub fn add_dictionaries( &mut self, dictionaries: flatbuffers::WIPOffset>, ) { - self.fbb_.push_slot_always::>( - Footer::VT_DICTIONARIES, - dictionaries, - ); + self.fbb_ + .push_slot_always::>(Footer::VT_DICTIONARIES, dictionaries); } #[inline] pub fn add_recordBatches( &mut self, recordBatches: flatbuffers::WIPOffset>, ) { - self.fbb_.push_slot_always::>( - Footer::VT_RECORDBATCHES, - recordBatches, - ); + self.fbb_ + .push_slot_always::>(Footer::VT_RECORDBATCHES, recordBatches); } #[inline] pub fn add_custom_metadata( @@ -373,9 +388,7 @@ impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { ); } #[inline] - pub fn new( - _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>, - ) -> FooterBuilder<'a, 'b> { + pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>) -> FooterBuilder<'a, 'b> { let start = _fbb.start_table(); FooterBuilder { fbb_: _fbb, @@ -389,8 +402,8 @@ impl<'a: 'b, 'b> FooterBuilder<'a, 'b> { } } -impl std::fmt::Debug for Footer<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for Footer<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("Footer"); ds.field("version", &self.version()); ds.field("schema", &self.schema()); @@ -400,18 +413,6 @@ impl std::fmt::Debug for Footer<'_> { ds.finish() } } -#[inline] -#[deprecated(since = "2.0.0", note = "Deprecated in favor of `root_as...` methods.")] -pub fn get_root_as_footer<'a>(buf: &'a [u8]) -> Footer<'a> { - unsafe { flatbuffers::root_unchecked::>(buf) } -} - -#[inline] -#[deprecated(since = "2.0.0", note = "Deprecated in favor of `root_as...` methods.")] -pub fn get_size_prefixed_root_as_footer<'a>(buf: &'a [u8]) -> Footer<'a> { - unsafe { flatbuffers::size_prefixed_root_unchecked::>(buf) } -} - #[inline] /// Verifies that a buffer of bytes contains a `Footer` /// and returns it. @@ -429,9 +430,7 @@ pub fn root_as_footer(buf: &[u8]) -> Result Result { +pub fn size_prefixed_root_as_footer(buf: &[u8]) -> Result { flatbuffers::size_prefixed_root::